Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow iterating over custom number of records with asyncpg #414

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions databases/backends/aiopg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import typing
import uuid
from functools import partial

import aiopg
from aiopg.sa.engine import APGCompiler_psycopg2
Expand Down Expand Up @@ -174,22 +175,26 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
cursor.close()

async def iterate(
self, query: ClauseElement
self, query: ClauseElement, *, n: int = 1
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
cursor = await self._connection.cursor()
try:
await cursor.execute(query_str, args)
metadata = CursorResultMetaData(context, cursor.description)
async for row in cursor:
yield Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
row_func = partial(
Row,
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style
)
while True:
rows = await cursor.fetchmany(n)
if not len(rows): break
records = list(map(row_func, rows))
yield records[0] if n == 1 else records
finally:
cursor.close()

Expand Down
23 changes: 14 additions & 9 deletions databases/backends/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import typing
import uuid
from functools import partial

import aiomysql
from sqlalchemy.dialects.mysql import pymysql
Expand Down Expand Up @@ -164,22 +165,26 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
await cursor.close()

async def iterate(
self, query: ClauseElement
self, query: ClauseElement, *, n: int = 1
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
cursor = await self._connection.cursor()
try:
await cursor.execute(query_str, args)
metadata = CursorResultMetaData(context, cursor.description)
async for row in cursor:
yield Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
row_func = partial(
Row,
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
)
while True:
rows = await cursor.fetchmany(n)
if not len(rows): break
records = list(map(row_func, rows))
yield records[0] if n == 1 else records
finally:
await cursor.close()

Expand Down
17 changes: 14 additions & 3 deletions databases/backends/postgres.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import typing
from collections.abc import Sequence
from functools import partial

import asyncpg
from sqlalchemy.dialects.postgresql import pypostgresql
Expand Down Expand Up @@ -219,13 +220,23 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
await self._connection.execute(single_query, *args)

async def iterate(
self, query: ClauseElement
self, query: ClauseElement, *, n: int = 1
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns = self._compile(query)
column_maps = self._create_column_maps(result_columns)
async for row in self._connection.cursor(query_str, *args):
yield Record(row, result_columns, self._dialect, column_maps)
record_func = partial(
Record,
result_columns=result_columns,
dialect=self._dialect,
column_maps=column_maps
)
cursor = await self._connection.cursor(query_str, *args)
while True:
rows = await cursor.fetch(n)
if not rows: break
records = list(map(record_func, rows))
yield records[0] if n == 1 else records

def transaction(self) -> TransactionBackend:
return PostgresTransaction(connection=self)
Expand Down
23 changes: 14 additions & 9 deletions databases/backends/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import typing
import uuid
from functools import partial

import aiosqlite
from sqlalchemy.dialects.sqlite import pysqlite
Expand Down Expand Up @@ -136,20 +137,24 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
await self.execute(single_query)

async def iterate(
self, query: ClauseElement
self, query: ClauseElement, n: int = 1
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
async with self._connection.execute(query_str, args) as cursor:
metadata = CursorResultMetaData(context, cursor.description)
async for row in cursor:
yield Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
row_func = partial(
Row,
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style
)
while True:
rows = await cursor.fetchmany(n)
if not len(rows): break
records = list(map(row_func, rows))
yield records[0] if n == 1 else records

def transaction(self) -> TransactionBackend:
return SQLiteTransaction(self)
Expand Down
8 changes: 4 additions & 4 deletions databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ async def execute_many(
return await connection.execute_many(query, values)

async def iterate(
self, query: typing.Union[ClauseElement, str], values: dict = None
self, query: typing.Union[ClauseElement, str], values: dict = None, **kwargs
) -> typing.AsyncGenerator[typing.Mapping, None]:
async with self.connection() as connection:
async for record in connection.iterate(query, values):
async for record in connection.iterate(query, values, **kwargs):
yield record

def _new_connection(self) -> "Connection":
Expand Down Expand Up @@ -301,12 +301,12 @@ async def execute_many(
await self._connection.execute_many(queries)

async def iterate(
self, query: typing.Union[ClauseElement, str], values: dict = None
self, query: typing.Union[ClauseElement, str], values: dict = None, **kwargs
) -> typing.AsyncGenerator[typing.Any, None]:
built_query = self._build_query(query, values)
async with self.transaction():
async with self._query_lock:
async for record in self._connection.iterate(built_query):
async for record in self._connection.iterate(built_query, **kwargs):
yield record

def transaction(
Expand Down
8 changes: 7 additions & 1 deletion docs/database_queries.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,17 @@ row = await database.fetch_one(query=query)
query = notes.select()
value = await database.fetch_val(query=query)

# Fetch multiple rows without loading them all into memory at once
# Fetch multiple rows without loading them all into memory at once (1 at a time)
query = notes.select()
async for row in database.iterate(query=query):
...

# Fetch multiple rows loading n of them into memory at a time
query = notes.select()
async for rows in database.iterate(query=query, n=10):
assert 1 <= len(rows) <= 10
...

# Close all connection in the connection pool
await database.disconnect()
```
Expand Down
24 changes: 24 additions & 0 deletions tests/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,3 +1139,27 @@ async def test_postcompile_queries(database_url):
results = await database.fetch_all(query=query)

assert len(results) == 0


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_iterate_n(database_url):
"""
Test fetching multiple records per iteration.
"""
async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
query = "INSERT INTO notes(text, completed) VALUES (:text, :completed)"
values = [
{"text": "example1", "completed": True},
{"text": "example2", "completed": False},
{"text": "example3", "completed": True},
]
await database.execute_many(query, values)

async for records in database.iterate(notes.select(), n=2):
assert len(records) == 2
assert records[0] == values[0]
assert records[1] == values[1]
assert len(records) == 1
assert records[0] == values[2]