From 2871fdb8b1fecf3926c075174483456975ece221 Mon Sep 17 00:00:00 2001 From: Micheal Gendy Date: Tue, 22 Feb 2022 18:04:07 +0200 Subject: [PATCH 1/2] use execute_many from backends directly --- databases/backends/aiopg.py | 4 +++- databases/backends/asyncmy.py | 4 +++- databases/backends/mysql.py | 21 +++++++++++++++++---- databases/backends/postgres.py | 25 +++++++++++++++++-------- databases/backends/sqlite.py | 21 ++++++++++++++++++--- databases/core.py | 2 +- databases/interfaces.py | 4 +++- 7 files changed, 62 insertions(+), 19 deletions(-) diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 9ad12f63..d2ad98a6 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -168,7 +168,9 @@ async def execute(self, query: ClauseElement) -> typing.Any: finally: cursor.close() - async def execute_many(self, queries: typing.List[ClauseElement]) -> None: + async def execute_many( + self, queries: typing.List[ClauseElement], values: typing.List[dict] + ) -> None: assert self._connection is not None, "Connection is not acquired" cursor = await self._connection.cursor() try: diff --git a/databases/backends/asyncmy.py b/databases/backends/asyncmy.py index e15dfa45..8c327c71 100644 --- a/databases/backends/asyncmy.py +++ b/databases/backends/asyncmy.py @@ -158,7 +158,9 @@ async def execute(self, query: ClauseElement) -> typing.Any: finally: await cursor.close() - async def execute_many(self, queries: typing.List[ClauseElement]) -> None: + async def execute_many( + self, queries: typing.List[ClauseElement], values: typing.List[dict] + ) -> None: assert self._connection is not None, "Connection is not acquired" async with self._connection.cursor() as cursor: try: diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index 2a0a8425..bf24e3b5 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -158,13 +158,14 @@ async def execute(self, query: ClauseElement) -> typing.Any: finally: await cursor.close() - async def execute_many(self, queries: typing.List[ClauseElement]) -> None: + async def execute_many( + self, queries: typing.List[ClauseElement], values: typing.List[dict] + ) -> None: assert self._connection is not None, "Connection is not acquired" cursor = await self._connection.cursor() + query_str, values = self._compile_many(queries, values) try: - for single_query in queries: - single_query, args, context = self._compile(single_query) - await cursor.execute(single_query, args) + await cursor.executemany(query_str, values) finally: await cursor.close() @@ -220,6 +221,18 @@ def _compile( logger.debug("Query: %s Args: %s", query_message, repr(args), extra=LOG_EXTRA) return compiled.string, args, CompilationContext(execution_context) + def _compile_many( + self, queries: typing.List[ClauseElement], values: typing.List[dict] + ) -> typing.Tuple[str, list]: + compiled = queries[0].compile( + dialect=self._dialect, compile_kwargs={"render_postcompile": True} + ) + for args in values: + for key, val in args.items(): + if key in compiled._bind_processors: + args[key] = compiled._bind_processors[key](val) + return compiled.string, values + @property def raw_connection(self) -> aiomysql.connection.Connection: assert self._connection is not None, "Connection is not acquired" diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index 3d0a36f2..69e65fb2 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -1,6 +1,5 @@ import logging import typing -from collections.abc import Sequence import asyncpg from sqlalchemy.dialects.postgresql import pypostgresql @@ -217,14 +216,12 @@ async def execute(self, query: ClauseElement) -> typing.Any: query_str, args, result_columns = self._compile(query) return await self._connection.fetchval(query_str, *args) - async def execute_many(self, queries: typing.List[ClauseElement]) -> None: + async def execute_many( + self, queries: typing.List[ClauseElement], values: typing.List[dict] + ) -> None: assert self._connection is not None, "Connection is not acquired" - # asyncpg uses prepared statements under the hood, so we just - # loop through multiple executes here, which should all end up - # using the same prepared statement. - for single_query in queries: - single_query, args, result_columns = self._compile(single_query) - await self._connection.execute(single_query, *args) + query_str, values = self._compile_many(queries, values) + await self._connection.executemany(query_str, values) async def iterate( self, query: ClauseElement @@ -269,6 +266,18 @@ def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: ) return compiled_query, args, result_map + def _compile_many( + self, queries: typing.List[ClauseElement], values: typing.List[dict] + ) -> typing.Tuple[str, list]: + compiled = queries[0].compile( + dialect=self._dialect, compile_kwargs={"render_postcompile": True} + ) + for args in values: + for key, val in args.items(): + if key in compiled._bind_processors: + args[key] = compiled._bind_processors[key](val) + return compiled.string, values + @staticmethod def _create_column_maps( result_columns: tuple, diff --git a/databases/backends/sqlite.py b/databases/backends/sqlite.py index 9626dcf8..8f6c9748 100644 --- a/databases/backends/sqlite.py +++ b/databases/backends/sqlite.py @@ -135,10 +135,13 @@ async def execute(self, query: ClauseElement) -> typing.Any: return cursor.rowcount return cursor.lastrowid - async def execute_many(self, queries: typing.List[ClauseElement]) -> None: + async def execute_many( + self, queries: typing.List[ClauseElement], values: typing.List[dict] + ) -> None: assert self._connection is not None, "Connection is not acquired" - for single_query in queries: - await self.execute(single_query) + query_str, values = self._compile_many(queries, values) + async with self._connection.cursor() as cursor: + await cursor.executemany(query_str, values) async def iterate( self, query: ClauseElement @@ -194,6 +197,18 @@ def _compile( ) return compiled.string, args, CompilationContext(execution_context) + def _compile_many( + self, queries: typing.List[ClauseElement], values: typing.List[dict] + ) -> typing.Tuple[str, list]: + compiled = queries[0].compile( + dialect=self._dialect, compile_kwargs={"render_postcompile": True} + ) + for args in values: + for key, val in args.items(): + if key in compiled._bind_processors: + args[key] = compiled._bind_processors[key](val) + return compiled.string, values + @property def raw_connection(self) -> aiosqlite.core.Connection: assert self._connection is not None, "Connection is not acquired" diff --git a/databases/core.py b/databases/core.py index 7005281c..939ba87b 100644 --- a/databases/core.py +++ b/databases/core.py @@ -304,7 +304,7 @@ async def execute_many( ) -> None: queries = [self._build_query(query, values_set) for values_set in values] async with self._query_lock: - await self._connection.execute_many(queries) + await self._connection.execute_many(queries, values) async def iterate( self, query: typing.Union[ClauseElement, str], values: dict = None diff --git a/databases/interfaces.py b/databases/interfaces.py index c2109a23..6c7d649c 100644 --- a/databases/interfaces.py +++ b/databases/interfaces.py @@ -37,7 +37,9 @@ async def fetch_val( async def execute(self, query: ClauseElement) -> typing.Any: raise NotImplementedError() # pragma: no cover - async def execute_many(self, queries: typing.List[ClauseElement]) -> None: + async def execute_many( + self, queries: typing.List[ClauseElement], values: typing.List[dict] + ) -> None: raise NotImplementedError() # pragma: no cover async def iterate( From 2107de31fac5f96ee388cb228f760d1eba40ce51 Mon Sep 17 00:00:00 2001 From: Micheal Gendy Date: Tue, 22 Feb 2022 19:57:41 +0200 Subject: [PATCH 2/2] fix issues --- databases/backends/mysql.py | 11 +++++++---- databases/backends/postgres.py | 22 +++++++++++++++++----- databases/backends/sqlite.py | 18 +++++++++++++----- 3 files changed, 37 insertions(+), 14 deletions(-) diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index bf24e3b5..2922f3f2 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -227,10 +227,13 @@ def _compile_many( compiled = queries[0].compile( dialect=self._dialect, compile_kwargs={"render_postcompile": True} ) - for args in values: - for key, val in args.items(): - if key in compiled._bind_processors: - args[key] = compiled._bind_processors[key](val) + if not isinstance(queries[0], DDLElement): + for args in values: + for key, val in args.items(): + if key in compiled._bind_processors: + args[key] = compiled._bind_processors[key](val) + else: + values = [] return compiled.string, values @property diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index 69e65fb2..77a40171 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -272,11 +272,23 @@ def _compile_many( compiled = queries[0].compile( dialect=self._dialect, compile_kwargs={"render_postcompile": True} ) - for args in values: - for key, val in args.items(): - if key in compiled._bind_processors: - args[key] = compiled._bind_processors[key](val) - return compiled.string, values + new_values = [] + if not isinstance(queries[0], DDLElement): + for args in values: + sorted_args = sorted(args.items()) + mapping = { + key: "$" + str(i) for i, (key, _) in enumerate(sorted_args, start=1) + } + compiled_query = compiled.string % mapping + processors = compiled._bind_processors + values = [ + processors[key](val) if key in processors else val + for key, val in sorted_args + ] + new_values.append(values) + else: + compiled_query = compiled.string + return compiled_query, new_values @staticmethod def _create_column_maps( diff --git a/databases/backends/sqlite.py b/databases/backends/sqlite.py index 8f6c9748..5e88c061 100644 --- a/databases/backends/sqlite.py +++ b/databases/backends/sqlite.py @@ -203,11 +203,19 @@ def _compile_many( compiled = queries[0].compile( dialect=self._dialect, compile_kwargs={"render_postcompile": True} ) - for args in values: - for key, val in args.items(): - if key in compiled._bind_processors: - args[key] = compiled._bind_processors[key](val) - return compiled.string, values + new_values = [] + if not isinstance(queries[0], DDLElement): + for args in values: + temp_arr = [] + for key in compiled.positiontup: + raw_val = args[key] + if key in compiled._bind_processors: + val = compiled._bind_processors[key](raw_val) + else: + val = raw_val + temp_arr.append(val) + new_values.append(temp_arr) + return compiled.string, new_values @property def raw_connection(self) -> aiosqlite.core.Connection: