summary refs log tree commit diff
path: root/synapse/storage/database.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-12-15 12:00:50 -0500
committerGitHub <noreply@github.com>2021-12-15 17:00:50 +0000
commitf901f8b70eef7c1ac62b68587c0d6cd2e1e0febe (patch)
treef89b42ab0dd7f57e51c81c9d36ddebcee4bc396b /synapse/storage/database.py
parentConvert EventStreamResult to attrs. (#11574) (diff)
downloadsynapse-f901f8b70eef7c1ac62b68587c0d6cd2e1e0febe.tar.xz
Require Collections as the parameters for simple_* methods. (#11580)
Instead of Iterable since the generators are not allowed due
to the potential for their re-use.
Diffstat (limited to 'synapse/storage/database.py')
-rw-r--r--synapse/storage/database.py28
1 files changed, 10 insertions, 18 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a219999f15..2cacc7dd6c 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -55,6 +55,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
+from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -986,7 +987,7 @@ class DatabasePool:
         self,
         table: str,
         keys: Collection[str],
-        values: Iterable[Iterable[Any]],
+        values: Collection[Collection[Any]],
         desc: str,
     ) -> None:
         """Executes an INSERT query on the named table.
@@ -1427,7 +1428,7 @@ class DatabasePool:
         self,
         table: str,
         keyvalues: Dict[str, Any],
-        retcols: Iterable[str],
+        retcols: Collection[str],
         allow_none: Literal[False] = False,
         desc: str = "simple_select_one",
     ) -> Dict[str, Any]:
@@ -1438,7 +1439,7 @@ class DatabasePool:
         self,
         table: str,
         keyvalues: Dict[str, Any],
-        retcols: Iterable[str],
+        retcols: Collection[str],
         allow_none: Literal[True] = True,
         desc: str = "simple_select_one",
     ) -> Optional[Dict[str, Any]]:
@@ -1448,7 +1449,7 @@ class DatabasePool:
         self,
         table: str,
         keyvalues: Dict[str, Any],
-        retcols: Iterable[str],
+        retcols: Collection[str],
         allow_none: bool = False,
         desc: str = "simple_select_one",
     ) -> Optional[Dict[str, Any]]:
@@ -1618,7 +1619,7 @@ class DatabasePool:
         self,
         table: str,
         keyvalues: Optional[Dict[str, Any]],
-        retcols: Iterable[str],
+        retcols: Collection[str],
         desc: str = "simple_select_list",
     ) -> List[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which may return zero or
@@ -1681,7 +1682,7 @@ class DatabasePool:
         table: str,
         column: str,
         iterable: Iterable[Any],
-        retcols: Iterable[str],
+        retcols: Collection[str],
         keyvalues: Optional[Dict[str, Any]] = None,
         desc: str = "simple_select_many_batch",
         batch_size: int = 100,
@@ -1704,16 +1705,7 @@ class DatabasePool:
 
         results: List[Dict[str, Any]] = []
 
-        if not iterable:
-            return results
-
-        # iterables can not be sliced, so convert it to a list first
-        it_list = list(iterable)
-
-        chunks = [
-            it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
-        ]
-        for chunk in chunks:
+        for chunk in batch_iter(iterable, batch_size):
             rows = await self.runInteraction(
                 desc,
                 self.simple_select_many_txn,
@@ -1853,7 +1845,7 @@ class DatabasePool:
         txn: LoggingTransaction,
         table: str,
         keyvalues: Dict[str, Any],
-        retcols: Iterable[str],
+        retcols: Collection[str],
         allow_none: bool = False,
     ) -> Optional[Dict[str, Any]]:
         select_sql = "SELECT %s FROM %s WHERE %s" % (
@@ -2146,7 +2138,7 @@ class DatabasePool:
         table: str,
         term: Optional[str],
         col: str,
-        retcols: Iterable[str],
+        retcols: Collection[str],
         desc="simple_search_list",
     ) -> Optional[List[Dict[str, Any]]]:
         """Executes a SELECT query on the named table, which may return zero or