diff options
Diffstat (limited to 'synapse/storage/database.py')
-rw-r--r-- | synapse/storage/database.py | 137 |
1 files changed, 110 insertions, 27 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 0693d39006..2cacc7dd6c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -13,8 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging import time +import types from collections import defaultdict from sys import intern from time import monotonic as monotonic_time @@ -53,6 +55,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -175,7 +178,7 @@ class LoggingDatabaseConnection: def rollback(self) -> None: self.conn.rollback() - def __enter__(self) -> "Connection": + def __enter__(self) -> "LoggingDatabaseConnection": self.conn.__enter__() return self @@ -526,6 +529,12 @@ class DatabasePool: the function will correctly handle being aborted and retried half way through its execution. + Similarly, the arguments to `func` (`args`, `kwargs`) should not be generators, + since they could be evaluated multiple times (which would produce an empty + result on the second or subsequent evaluation). Likewise, the closure of `func` + must not reference any generators. This method attempts to detect such usage + and will log an error. + Args: conn desc @@ -536,6 +545,39 @@ class DatabasePool: **kwargs """ + # Robustness check: ensure that none of the arguments are generators, since that + # will fail if we have to repeat the transaction. + # For now, we just log an error, and hope that it works on the first attempt. + # TODO: raise an exception. + for i, arg in enumerate(args): + if inspect.isgenerator(arg): + logger.error( + "Programming error: generator passed to new_transaction as " + "argument %i to function %s", + i, + func, + ) + for name, val in kwargs.items(): + if inspect.isgenerator(val): + logger.error( + "Programming error: generator passed to new_transaction as " + "argument %s to function %s", + name, + func, + ) + # also check variables referenced in func's closure + if inspect.isfunction(func): + f = cast(types.FunctionType, func) + if f.__closure__: + for i, cell in enumerate(f.__closure__): + if inspect.isgenerator(cell.cell_contents): + logger.error( + "Programming error: function %s references generator %s " + "via its closure", + f, + f.__code__.co_freevars[i], + ) + start = monotonic_time() txn_id = self._TXN_ID @@ -896,6 +938,9 @@ class DatabasePool: ) -> None: """Executes an INSERT query on the named table. + The input is given as a list of dicts, with one dict per row. + Generally simple_insert_many_values should be preferred for new code. + Args: table: string giving the table name values: dict of new column names and values for them @@ -909,6 +954,9 @@ class DatabasePool: ) -> None: """Executes an INSERT query on the named table. + The input is given as a list of dicts, with one dict per row. + Generally simple_insert_many_values_txn should be preferred for new code. + Args: txn: The transaction to use. table: string giving the table name @@ -933,23 +981,66 @@ class DatabasePool: if k != keys[0]: raise RuntimeError("All items must have the same keys") + return DatabasePool.simple_insert_many_values_txn(txn, table, keys[0], vals) + + async def simple_insert_many_values( + self, + table: str, + keys: Collection[str], + values: Collection[Collection[Any]], + desc: str, + ) -> None: + """Executes an INSERT query on the named table. + + The input is given as a list of rows, where each row is a list of values. + (Actually any iterable is fine.) + + Args: + table: string giving the table name + keys: list of column names + values: for each row, a list of values in the same order as `keys` + desc: description of the transaction, for logging and metrics + """ + await self.runInteraction( + desc, self.simple_insert_many_values_txn, table, keys, values + ) + + @staticmethod + def simple_insert_many_values_txn( + txn: LoggingTransaction, + table: str, + keys: Collection[str], + values: Iterable[Iterable[Any]], + ) -> None: + """Executes an INSERT query on the named table. + + The input is given as a list of rows, where each row is a list of values. + (Actually any iterable is fine.) + + Args: + txn: The transaction to use. + table: string giving the table name + keys: list of column names + values: for each row, a list of values in the same order as `keys` + """ + if isinstance(txn.database_engine, PostgresEngine): # We use `execute_values` as it can be a lot faster than `execute_batch`, # but it's only available on postgres. sql = "INSERT INTO %s (%s) VALUES ?" % ( table, - ", ".join(k for k in keys[0]), + ", ".join(k for k in keys), ) - txn.execute_values(sql, vals, fetch=False) + txn.execute_values(sql, values, fetch=False) else: sql = "INSERT INTO %s (%s) VALUES(%s)" % ( table, - ", ".join(k for k in keys[0]), - ", ".join("?" for _ in keys[0]), + ", ".join(k for k in keys), + ", ".join("?" for _ in keys), ) - txn.execute_batch(sql, vals) + txn.execute_batch(sql, values) async def simple_upsert( self, @@ -1177,9 +1268,9 @@ class DatabasePool: self, table: str, key_names: Collection[str], - key_values: Collection[Iterable[Any]], + key_values: Collection[Collection[Any]], value_names: Collection[str], - value_values: Iterable[Iterable[Any]], + value_values: Collection[Collection[Any]], desc: str, ) -> None: """ @@ -1337,7 +1428,7 @@ class DatabasePool: self, table: str, keyvalues: Dict[str, Any], - retcols: Iterable[str], + retcols: Collection[str], allow_none: Literal[False] = False, desc: str = "simple_select_one", ) -> Dict[str, Any]: @@ -1348,7 +1439,7 @@ class DatabasePool: self, table: str, keyvalues: Dict[str, Any], - retcols: Iterable[str], + retcols: Collection[str], allow_none: Literal[True] = True, desc: str = "simple_select_one", ) -> Optional[Dict[str, Any]]: @@ -1358,7 +1449,7 @@ class DatabasePool: self, table: str, keyvalues: Dict[str, Any], - retcols: Iterable[str], + retcols: Collection[str], allow_none: bool = False, desc: str = "simple_select_one", ) -> Optional[Dict[str, Any]]: @@ -1528,7 +1619,7 @@ class DatabasePool: self, table: str, keyvalues: Optional[Dict[str, Any]], - retcols: Iterable[str], + retcols: Collection[str], desc: str = "simple_select_list", ) -> List[Dict[str, Any]]: """Executes a SELECT query on the named table, which may return zero or @@ -1591,7 +1682,7 @@ class DatabasePool: table: str, column: str, iterable: Iterable[Any], - retcols: Iterable[str], + retcols: Collection[str], keyvalues: Optional[Dict[str, Any]] = None, desc: str = "simple_select_many_batch", batch_size: int = 100, @@ -1614,16 +1705,7 @@ class DatabasePool: results: List[Dict[str, Any]] = [] - if not iterable: - return results - - # iterables can not be sliced, so convert it to a list first - it_list = list(iterable) - - chunks = [ - it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) - ] - for chunk in chunks: + for chunk in batch_iter(iterable, batch_size): rows = await self.runInteraction( desc, self.simple_select_many_txn, @@ -1763,7 +1845,7 @@ class DatabasePool: txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], - retcols: Iterable[str], + retcols: Collection[str], allow_none: bool = False, ) -> Optional[Dict[str, Any]]: select_sql = "SELECT %s FROM %s WHERE %s" % ( @@ -1871,7 +1953,7 @@ class DatabasePool: self, table: str, column: str, - iterable: Iterable[Any], + iterable: Collection[Any], keyvalues: Dict[str, Any], desc: str, ) -> int: @@ -1882,7 +1964,8 @@ class DatabasePool: Args: table: string giving the table name column: column name to test for inclusion against `iterable` - iterable: list + iterable: list of values to match against `column`. NB cannot be a generator + as it may be evaluated multiple times. keyvalues: dict of column names and values to select the rows with desc: description of the transaction, for logging and metrics @@ -2055,7 +2138,7 @@ class DatabasePool: table: str, term: Optional[str], col: str, - retcols: Iterable[str], + retcols: Collection[str], desc="simple_search_list", ) -> Optional[List[Dict[str, Any]]]: """Executes a SELECT query on the named table, which may return zero or |