diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 1b73fa62..99b1670f 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -3,6 +3,7 @@ import logging import typing import uuid +from functools import partial import aiopg from aiopg.sa.engine import APGCompiler_psycopg2 @@ -174,7 +175,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: cursor.close() async def iterate( - self, query: ClauseElement + self, query: ClauseElement, *, n: int = 1 ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" query_str, args, context = self._compile(query) @@ -182,14 +183,18 @@ async def iterate( try: await cursor.execute(query_str, args) metadata = CursorResultMetaData(context, cursor.description) - async for row in cursor: - yield Row( - metadata, - metadata._processors, - metadata._keymap, - Row._default_key_style, - row, - ) + row_func = partial( + Row, + metadata, + metadata._processors, + metadata._keymap, + Row._default_key_style + ) + while True: + rows = await cursor.fetchmany(n) + if not len(rows): break + records = list(map(row_func, rows)) + yield records[0] if n == 1 else records finally: cursor.close() diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index 4c490d71..e3ffc7a4 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -2,6 +2,7 @@ import logging import typing import uuid +from functools import partial import aiomysql from sqlalchemy.dialects.mysql import pymysql @@ -164,7 +165,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: await cursor.close() async def iterate( - self, query: ClauseElement + self, query: ClauseElement, *, n: int = 1 ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" query_str, args, context = self._compile(query) @@ -172,14 +173,18 @@ async def iterate( try: await cursor.execute(query_str, args) metadata = CursorResultMetaData(context, cursor.description) - async for row in cursor: - yield Row( - metadata, - metadata._processors, - metadata._keymap, - Row._default_key_style, - row, - ) + row_func = partial( + Row, + metadata, + metadata._processors, + metadata._keymap, + Row._default_key_style, + ) + while True: + rows = await cursor.fetchmany(n) + if not len(rows): break + records = list(map(row_func, rows)) + yield records[0] if n == 1 else records finally: await cursor.close() diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index ed12c2b0..cdf423ca 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -1,6 +1,7 @@ import logging import typing from collections.abc import Sequence +from functools import partial import asyncpg from sqlalchemy.dialects.postgresql import pypostgresql @@ -219,13 +220,23 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: await self._connection.execute(single_query, *args) async def iterate( - self, query: ClauseElement + self, query: ClauseElement, *, n: int = 1 ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns = self._compile(query) column_maps = self._create_column_maps(result_columns) - async for row in self._connection.cursor(query_str, *args): - yield Record(row, result_columns, self._dialect, column_maps) + record_func = partial( + Record, + result_columns=result_columns, + dialect=self._dialect, + column_maps=column_maps + ) + cursor = await self._connection.cursor(query_str, *args) + while True: + rows = await cursor.fetch(n) + if not rows: break + records = list(map(record_func, rows)) + yield records[0] if n == 1 else records def transaction(self) -> TransactionBackend: return PostgresTransaction(connection=self) diff --git a/databases/backends/sqlite.py b/databases/backends/sqlite.py index e7e1bad6..09a07522 100644 --- a/databases/backends/sqlite.py +++ b/databases/backends/sqlite.py @@ -1,6 +1,7 @@ import logging import typing import uuid +from functools import partial import aiosqlite from sqlalchemy.dialects.sqlite import pysqlite @@ -136,20 +137,24 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: await self.execute(single_query) async def iterate( - self, query: ClauseElement + self, query: ClauseElement, n: int = 1 ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" query_str, args, context = self._compile(query) async with self._connection.execute(query_str, args) as cursor: metadata = CursorResultMetaData(context, cursor.description) - async for row in cursor: - yield Row( - metadata, - metadata._processors, - metadata._keymap, - Row._default_key_style, - row, - ) + row_func = partial( + Row, + metadata, + metadata._processors, + metadata._keymap, + Row._default_key_style + ) + while True: + rows = await cursor.fetchmany(n) + if not len(rows): break + records = list(map(row_func, rows)) + yield records[0] if n == 1 else records def transaction(self) -> TransactionBackend: return SQLiteTransaction(self) diff --git a/databases/core.py b/databases/core.py index 9c43ae76..46cb906e 100644 --- a/databases/core.py +++ b/databases/core.py @@ -175,10 +175,10 @@ async def execute_many( return await connection.execute_many(query, values) async def iterate( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, query: typing.Union[ClauseElement, str], values: dict = None, **kwargs ) -> typing.AsyncGenerator[typing.Mapping, None]: async with self.connection() as connection: - async for record in connection.iterate(query, values): + async for record in connection.iterate(query, values, **kwargs): yield record def _new_connection(self) -> "Connection": @@ -301,12 +301,12 @@ async def execute_many( await self._connection.execute_many(queries) async def iterate( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, query: typing.Union[ClauseElement, str], values: dict = None, **kwargs ) -> typing.AsyncGenerator[typing.Any, None]: built_query = self._build_query(query, values) async with self.transaction(): async with self._query_lock: - async for record in self._connection.iterate(built_query): + async for record in self._connection.iterate(built_query, **kwargs): yield record def transaction( diff --git a/docs/database_queries.md b/docs/database_queries.md index 7ca4d173..3f31cb0e 100644 --- a/docs/database_queries.md +++ b/docs/database_queries.md @@ -65,11 +65,17 @@ row = await database.fetch_one(query=query) query = notes.select() value = await database.fetch_val(query=query) -# Fetch multiple rows without loading them all into memory at once +# Fetch multiple rows without loading them all into memory at once (1 at a time) query = notes.select() async for row in database.iterate(query=query): ... +# Fetch multiple rows loading n of them into memory at a time +query = notes.select() +async for rows in database.iterate(query=query, n=10): + assert 1 <= len(rows) <= 10 + ... + # Close all connection in the connection pool await database.disconnect() ``` diff --git a/tests/test_databases.py b/tests/test_databases.py index 8fde4387..09d508f4 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -1139,3 +1139,27 @@ async def test_postcompile_queries(database_url): results = await database.fetch_all(query=query) assert len(results) == 0 + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_iterate_n(database_url): + """ + Test fetching multiple records per iteration. + """ + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): + query = "INSERT INTO notes(text, completed) VALUES (:text, :completed)" + values = [ + {"text": "example1", "completed": True}, + {"text": "example2", "completed": False}, + {"text": "example3", "completed": True}, + ] + await database.execute_many(query, values) + + async for records in database.iterate(notes.select(), n=2): + assert len(records) == 2 + assert records[0] == values[0] + assert records[1] == values[1] + assert len(records) == 1 + assert records[0] == values[2]