summary refs log tree commit diff
path: root/synapse/storage/database.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/database.py')
-rw-r--r--synapse/storage/database.py137
1 files changed, 110 insertions, 27 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 0693d39006..2cacc7dd6c 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -13,8 +13,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import inspect
 import logging
 import time
+import types
 from collections import defaultdict
 from sys import intern
 from time import monotonic as monotonic_time
@@ -53,6 +55,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
+from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -175,7 +178,7 @@ class LoggingDatabaseConnection:
     def rollback(self) -> None:
         self.conn.rollback()
 
-    def __enter__(self) -> "Connection":
+    def __enter__(self) -> "LoggingDatabaseConnection":
         self.conn.__enter__()
         return self
 
@@ -526,6 +529,12 @@ class DatabasePool:
         the function will correctly handle being aborted and retried half way
         through its execution.
 
+        Similarly, the arguments to `func` (`args`, `kwargs`) should not be generators,
+        since they could be evaluated multiple times (which would produce an empty
+        result on the second or subsequent evaluation). Likewise, the closure of `func`
+        must not reference any generators.  This method attempts to detect such usage
+        and will log an error.
+
         Args:
             conn
             desc
@@ -536,6 +545,39 @@ class DatabasePool:
             **kwargs
         """
 
+        # Robustness check: ensure that none of the arguments are generators, since that
+        # will fail if we have to repeat the transaction.
+        # For now, we just log an error, and hope that it works on the first attempt.
+        # TODO: raise an exception.
+        for i, arg in enumerate(args):
+            if inspect.isgenerator(arg):
+                logger.error(
+                    "Programming error: generator passed to new_transaction as "
+                    "argument %i to function %s",
+                    i,
+                    func,
+                )
+        for name, val in kwargs.items():
+            if inspect.isgenerator(val):
+                logger.error(
+                    "Programming error: generator passed to new_transaction as "
+                    "argument %s to function %s",
+                    name,
+                    func,
+                )
+        # also check variables referenced in func's closure
+        if inspect.isfunction(func):
+            f = cast(types.FunctionType, func)
+            if f.__closure__:
+                for i, cell in enumerate(f.__closure__):
+                    if inspect.isgenerator(cell.cell_contents):
+                        logger.error(
+                            "Programming error: function %s references generator %s "
+                            "via its closure",
+                            f,
+                            f.__code__.co_freevars[i],
+                        )
+
         start = monotonic_time()
         txn_id = self._TXN_ID
 
@@ -896,6 +938,9 @@ class DatabasePool:
     ) -> None:
         """Executes an INSERT query on the named table.
 
+        The input is given as a list of dicts, with one dict per row.
+        Generally simple_insert_many_values should be preferred for new code.
+
         Args:
             table: string giving the table name
             values: dict of new column names and values for them
@@ -909,6 +954,9 @@ class DatabasePool:
     ) -> None:
         """Executes an INSERT query on the named table.
 
+        The input is given as a list of dicts, with one dict per row.
+        Generally simple_insert_many_values_txn should be preferred for new code.
+
         Args:
             txn: The transaction to use.
             table: string giving the table name
@@ -933,23 +981,66 @@ class DatabasePool:
             if k != keys[0]:
                 raise RuntimeError("All items must have the same keys")
 
+        return DatabasePool.simple_insert_many_values_txn(txn, table, keys[0], vals)
+
+    async def simple_insert_many_values(
+        self,
+        table: str,
+        keys: Collection[str],
+        values: Collection[Collection[Any]],
+        desc: str,
+    ) -> None:
+        """Executes an INSERT query on the named table.
+
+        The input is given as a list of rows, where each row is a list of values.
+        (Actually any iterable is fine.)
+
+        Args:
+            table: string giving the table name
+            keys: list of column names
+            values: for each row, a list of values in the same order as `keys`
+            desc: description of the transaction, for logging and metrics
+        """
+        await self.runInteraction(
+            desc, self.simple_insert_many_values_txn, table, keys, values
+        )
+
+    @staticmethod
+    def simple_insert_many_values_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keys: Collection[str],
+        values: Iterable[Iterable[Any]],
+    ) -> None:
+        """Executes an INSERT query on the named table.
+
+        The input is given as a list of rows, where each row is a list of values.
+        (Actually any iterable is fine.)
+
+        Args:
+            txn: The transaction to use.
+            table: string giving the table name
+            keys: list of column names
+            values: for each row, a list of values in the same order as `keys`
+        """
+
         if isinstance(txn.database_engine, PostgresEngine):
             # We use `execute_values` as it can be a lot faster than `execute_batch`,
             # but it's only available on postgres.
             sql = "INSERT INTO %s (%s) VALUES ?" % (
                 table,
-                ", ".join(k for k in keys[0]),
+                ", ".join(k for k in keys),
             )
 
-            txn.execute_values(sql, vals, fetch=False)
+            txn.execute_values(sql, values, fetch=False)
         else:
             sql = "INSERT INTO %s (%s) VALUES(%s)" % (
                 table,
-                ", ".join(k for k in keys[0]),
-                ", ".join("?" for _ in keys[0]),
+                ", ".join(k for k in keys),
+                ", ".join("?" for _ in keys),
             )
 
-            txn.execute_batch(sql, vals)
+            txn.execute_batch(sql, values)
 
     async def simple_upsert(
         self,
@@ -1177,9 +1268,9 @@ class DatabasePool:
         self,
         table: str,
         key_names: Collection[str],
-        key_values: Collection[Iterable[Any]],
+        key_values: Collection[Collection[Any]],
         value_names: Collection[str],
-        value_values: Iterable[Iterable[Any]],
+        value_values: Collection[Collection[Any]],
         desc: str,
     ) -> None:
         """
@@ -1337,7 +1428,7 @@ class DatabasePool:
         self,
         table: str,
         keyvalues: Dict[str, Any],
-        retcols: Iterable[str],
+        retcols: Collection[str],
         allow_none: Literal[False] = False,
         desc: str = "simple_select_one",
     ) -> Dict[str, Any]:
@@ -1348,7 +1439,7 @@ class DatabasePool:
         self,
         table: str,
         keyvalues: Dict[str, Any],
-        retcols: Iterable[str],
+        retcols: Collection[str],
         allow_none: Literal[True] = True,
         desc: str = "simple_select_one",
     ) -> Optional[Dict[str, Any]]:
@@ -1358,7 +1449,7 @@ class DatabasePool:
         self,
         table: str,
         keyvalues: Dict[str, Any],
-        retcols: Iterable[str],
+        retcols: Collection[str],
         allow_none: bool = False,
         desc: str = "simple_select_one",
     ) -> Optional[Dict[str, Any]]:
@@ -1528,7 +1619,7 @@ class DatabasePool:
         self,
         table: str,
         keyvalues: Optional[Dict[str, Any]],
-        retcols: Iterable[str],
+        retcols: Collection[str],
         desc: str = "simple_select_list",
     ) -> List[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which may return zero or
@@ -1591,7 +1682,7 @@ class DatabasePool:
         table: str,
         column: str,
         iterable: Iterable[Any],
-        retcols: Iterable[str],
+        retcols: Collection[str],
         keyvalues: Optional[Dict[str, Any]] = None,
         desc: str = "simple_select_many_batch",
         batch_size: int = 100,
@@ -1614,16 +1705,7 @@ class DatabasePool:
 
         results: List[Dict[str, Any]] = []
 
-        if not iterable:
-            return results
-
-        # iterables can not be sliced, so convert it to a list first
-        it_list = list(iterable)
-
-        chunks = [
-            it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
-        ]
-        for chunk in chunks:
+        for chunk in batch_iter(iterable, batch_size):
             rows = await self.runInteraction(
                 desc,
                 self.simple_select_many_txn,
@@ -1763,7 +1845,7 @@ class DatabasePool:
         txn: LoggingTransaction,
         table: str,
         keyvalues: Dict[str, Any],
-        retcols: Iterable[str],
+        retcols: Collection[str],
         allow_none: bool = False,
     ) -> Optional[Dict[str, Any]]:
         select_sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1871,7 +1953,7 @@ class DatabasePool:
         self,
         table: str,
         column: str,
-        iterable: Iterable[Any],
+        iterable: Collection[Any],
         keyvalues: Dict[str, Any],
         desc: str,
     ) -> int:
@@ -1882,7 +1964,8 @@ class DatabasePool:
         Args:
             table: string giving the table name
             column: column name to test for inclusion against `iterable`
-            iterable: list
+            iterable: list of values to match against `column`. NB cannot be a generator
+                as it may be evaluated multiple times.
             keyvalues: dict of column names and values to select the rows with
             desc: description of the transaction, for logging and metrics
 
@@ -2055,7 +2138,7 @@ class DatabasePool:
         table: str,
         term: Optional[str],
         col: str,
-        retcols: Iterable[str],
+        retcols: Collection[str],
         desc="simple_search_list",
     ) -> Optional[List[Dict[str, Any]]]:
         """Executes a SELECT query on the named table, which may return zero or