diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a219999f15..2cacc7dd6c 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -55,6 +55,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
+from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -986,7 +987,7 @@ class DatabasePool:
self,
table: str,
keys: Collection[str],
- values: Iterable[Iterable[Any]],
+ values: Collection[Collection[Any]],
desc: str,
) -> None:
"""Executes an INSERT query on the named table.
@@ -1427,7 +1428,7 @@ class DatabasePool:
self,
table: str,
keyvalues: Dict[str, Any],
- retcols: Iterable[str],
+ retcols: Collection[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one",
) -> Dict[str, Any]:
@@ -1438,7 +1439,7 @@ class DatabasePool:
self,
table: str,
keyvalues: Dict[str, Any],
- retcols: Iterable[str],
+ retcols: Collection[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
@@ -1448,7 +1449,7 @@ class DatabasePool:
self,
table: str,
keyvalues: Dict[str, Any],
- retcols: Iterable[str],
+ retcols: Collection[str],
allow_none: bool = False,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
@@ -1618,7 +1619,7 @@ class DatabasePool:
self,
table: str,
keyvalues: Optional[Dict[str, Any]],
- retcols: Iterable[str],
+ retcols: Collection[str],
desc: str = "simple_select_list",
) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
@@ -1681,7 +1682,7 @@ class DatabasePool:
table: str,
column: str,
iterable: Iterable[Any],
- retcols: Iterable[str],
+ retcols: Collection[str],
keyvalues: Optional[Dict[str, Any]] = None,
desc: str = "simple_select_many_batch",
batch_size: int = 100,
@@ -1704,16 +1705,7 @@ class DatabasePool:
results: List[Dict[str, Any]] = []
- if not iterable:
- return results
-
- # iterables can not be sliced, so convert it to a list first
- it_list = list(iterable)
-
- chunks = [
- it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
- ]
- for chunk in chunks:
+ for chunk in batch_iter(iterable, batch_size):
rows = await self.runInteraction(
desc,
self.simple_select_many_txn,
@@ -1853,7 +1845,7 @@ class DatabasePool:
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
- retcols: Iterable[str],
+ retcols: Collection[str],
allow_none: bool = False,
) -> Optional[Dict[str, Any]]:
select_sql = "SELECT %s FROM %s WHERE %s" % (
@@ -2146,7 +2138,7 @@ class DatabasePool:
table: str,
term: Optional[str],
col: str,
- retcols: Iterable[str],
+ retcols: Collection[str],
desc="simple_search_list",
) -> Optional[List[Dict[str, Any]]]:
"""Executes a SELECT query on the named table, which may return zero or
|