diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 0693d39006..2cacc7dd6c 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -13,8 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
import time
+import types
from collections import defaultdict
from sys import intern
from time import monotonic as monotonic_time
@@ -53,6 +55,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
+from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -175,7 +178,7 @@ class LoggingDatabaseConnection:
def rollback(self) -> None:
self.conn.rollback()
- def __enter__(self) -> "Connection":
+ def __enter__(self) -> "LoggingDatabaseConnection":
self.conn.__enter__()
return self
@@ -526,6 +529,12 @@ class DatabasePool:
the function will correctly handle being aborted and retried half way
through its execution.
+ Similarly, the arguments to `func` (`args`, `kwargs`) should not be generators,
+ since they could be evaluated multiple times (which would produce an empty
+ result on the second or subsequent evaluation). Likewise, the closure of `func`
+ must not reference any generators. This method attempts to detect such usage
+ and will log an error.
+
Args:
conn
desc
@@ -536,6 +545,39 @@ class DatabasePool:
**kwargs
"""
+ # Robustness check: ensure that none of the arguments are generators, since that
+ # will fail if we have to repeat the transaction.
+ # For now, we just log an error, and hope that it works on the first attempt.
+ # TODO: raise an exception.
+ for i, arg in enumerate(args):
+ if inspect.isgenerator(arg):
+ logger.error(
+ "Programming error: generator passed to new_transaction as "
+ "argument %i to function %s",
+ i,
+ func,
+ )
+ for name, val in kwargs.items():
+ if inspect.isgenerator(val):
+ logger.error(
+ "Programming error: generator passed to new_transaction as "
+ "argument %s to function %s",
+ name,
+ func,
+ )
+ # also check variables referenced in func's closure
+ if inspect.isfunction(func):
+ f = cast(types.FunctionType, func)
+ if f.__closure__:
+ for i, cell in enumerate(f.__closure__):
+ if inspect.isgenerator(cell.cell_contents):
+ logger.error(
+ "Programming error: function %s references generator %s "
+ "via its closure",
+ f,
+ f.__code__.co_freevars[i],
+ )
+
start = monotonic_time()
txn_id = self._TXN_ID
@@ -896,6 +938,9 @@ class DatabasePool:
) -> None:
"""Executes an INSERT query on the named table.
+ The input is given as a list of dicts, with one dict per row.
+ Generally simple_insert_many_values should be preferred for new code.
+
Args:
table: string giving the table name
values: dict of new column names and values for them
@@ -909,6 +954,9 @@ class DatabasePool:
) -> None:
"""Executes an INSERT query on the named table.
+ The input is given as a list of dicts, with one dict per row.
+ Generally simple_insert_many_values_txn should be preferred for new code.
+
Args:
txn: The transaction to use.
table: string giving the table name
@@ -933,23 +981,66 @@ class DatabasePool:
if k != keys[0]:
raise RuntimeError("All items must have the same keys")
+ return DatabasePool.simple_insert_many_values_txn(txn, table, keys[0], vals)
+
+ async def simple_insert_many_values(
+ self,
+ table: str,
+ keys: Collection[str],
+ values: Collection[Collection[Any]],
+ desc: str,
+ ) -> None:
+ """Executes an INSERT query on the named table.
+
+ The input is given as a list of rows, where each row is a list of values.
+ (Actually any iterable is fine.)
+
+ Args:
+ table: string giving the table name
+ keys: list of column names
+ values: for each row, a list of values in the same order as `keys`
+ desc: description of the transaction, for logging and metrics
+ """
+ await self.runInteraction(
+ desc, self.simple_insert_many_values_txn, table, keys, values
+ )
+
+ @staticmethod
+ def simple_insert_many_values_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keys: Collection[str],
+ values: Iterable[Iterable[Any]],
+ ) -> None:
+ """Executes an INSERT query on the named table.
+
+ The input is given as a list of rows, where each row is a list of values.
+ (Actually any iterable is fine.)
+
+ Args:
+ txn: The transaction to use.
+ table: string giving the table name
+ keys: list of column names
+ values: for each row, a list of values in the same order as `keys`
+ """
+
if isinstance(txn.database_engine, PostgresEngine):
# We use `execute_values` as it can be a lot faster than `execute_batch`,
# but it's only available on postgres.
sql = "INSERT INTO %s (%s) VALUES ?" % (
table,
- ", ".join(k for k in keys[0]),
+ ", ".join(k for k in keys),
)
- txn.execute_values(sql, vals, fetch=False)
+ txn.execute_values(sql, values, fetch=False)
else:
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table,
- ", ".join(k for k in keys[0]),
- ", ".join("?" for _ in keys[0]),
+ ", ".join(k for k in keys),
+ ", ".join("?" for _ in keys),
)
- txn.execute_batch(sql, vals)
+ txn.execute_batch(sql, values)
async def simple_upsert(
self,
@@ -1177,9 +1268,9 @@ class DatabasePool:
self,
table: str,
key_names: Collection[str],
- key_values: Collection[Iterable[Any]],
+ key_values: Collection[Collection[Any]],
value_names: Collection[str],
- value_values: Iterable[Iterable[Any]],
+ value_values: Collection[Collection[Any]],
desc: str,
) -> None:
"""
@@ -1337,7 +1428,7 @@ class DatabasePool:
self,
table: str,
keyvalues: Dict[str, Any],
- retcols: Iterable[str],
+ retcols: Collection[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one",
) -> Dict[str, Any]:
@@ -1348,7 +1439,7 @@ class DatabasePool:
self,
table: str,
keyvalues: Dict[str, Any],
- retcols: Iterable[str],
+ retcols: Collection[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
@@ -1358,7 +1449,7 @@ class DatabasePool:
self,
table: str,
keyvalues: Dict[str, Any],
- retcols: Iterable[str],
+ retcols: Collection[str],
allow_none: bool = False,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
@@ -1528,7 +1619,7 @@ class DatabasePool:
self,
table: str,
keyvalues: Optional[Dict[str, Any]],
- retcols: Iterable[str],
+ retcols: Collection[str],
desc: str = "simple_select_list",
) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
@@ -1591,7 +1682,7 @@ class DatabasePool:
table: str,
column: str,
iterable: Iterable[Any],
- retcols: Iterable[str],
+ retcols: Collection[str],
keyvalues: Optional[Dict[str, Any]] = None,
desc: str = "simple_select_many_batch",
batch_size: int = 100,
@@ -1614,16 +1705,7 @@ class DatabasePool:
results: List[Dict[str, Any]] = []
- if not iterable:
- return results
-
- # iterables can not be sliced, so convert it to a list first
- it_list = list(iterable)
-
- chunks = [
- it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
- ]
- for chunk in chunks:
+ for chunk in batch_iter(iterable, batch_size):
rows = await self.runInteraction(
desc,
self.simple_select_many_txn,
@@ -1763,7 +1845,7 @@ class DatabasePool:
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
- retcols: Iterable[str],
+ retcols: Collection[str],
allow_none: bool = False,
) -> Optional[Dict[str, Any]]:
select_sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1871,7 +1953,7 @@ class DatabasePool:
self,
table: str,
column: str,
- iterable: Iterable[Any],
+ iterable: Collection[Any],
keyvalues: Dict[str, Any],
desc: str,
) -> int:
@@ -1882,7 +1964,8 @@ class DatabasePool:
Args:
table: string giving the table name
column: column name to test for inclusion against `iterable`
- iterable: list
+ iterable: list of values to match against `column`. NB cannot be a generator
+ as it may be evaluated multiple times.
keyvalues: dict of column names and values to select the rows with
desc: description of the transaction, for logging and metrics
@@ -2055,7 +2138,7 @@ class DatabasePool:
table: str,
term: Optional[str],
col: str,
- retcols: Iterable[str],
+ retcols: Collection[str],
desc="simple_search_list",
) -> Optional[List[Dict[str, Any]]]:
"""Executes a SELECT query on the named table, which may return zero or
|