summary refs log tree commit diff
path: root/synapse/storage/database.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/database.py99
1 files changed, 80 insertions, 19 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py

index 569f618193..a4941e58f6 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -35,6 +35,8 @@ from typing import ( Iterable, Iterator, List, + Literal, + Mapping, Optional, Sequence, Tuple, @@ -46,7 +48,7 @@ from typing import ( import attr from prometheus_client import Counter, Histogram -from typing_extensions import Concatenate, Literal, ParamSpec +from typing_extensions import Concatenate, ParamSpec from twisted.enterprise import adbapi from twisted.internet.interfaces import IReactorCore @@ -64,6 +66,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor, SQLQueryParameters +from synapse.types import StrCollection from synapse.util.async_helpers import delay_cancellation from synapse.util.iterutils import batch_iter @@ -1095,6 +1098,48 @@ class DatabasePool: txn.execute(sql, vals) + @staticmethod + def simple_insert_returning_txn( + txn: LoggingTransaction, + table: str, + values: Dict[str, Any], + returning: StrCollection, + ) -> Tuple[Any, ...]: + """Executes a `INSERT INTO... RETURNING...` statement (or equivalent for + SQLite versions that don't support it). + """ + + if txn.database_engine.supports_returning: + sql = "INSERT INTO %s (%s) VALUES(%s) RETURNING %s" % ( + table, + ", ".join(k for k in values.keys()), + ", ".join("?" for _ in values.keys()), + ", ".join(k for k in returning), + ) + + txn.execute(sql, list(values.values())) + row = txn.fetchone() + assert row is not None + return row + else: + # For old versions of SQLite we do a standard insert and then can + # use `last_insert_rowid` to get at the row we just inserted + DatabasePool.simple_insert_txn( + txn, + table=table, + values=values, + ) + txn.execute("SELECT last_insert_rowid()") + row = txn.fetchone() + assert row is not None + (rowid,) = row + + row = DatabasePool.simple_select_one_txn( + txn, table=table, keyvalues={"rowid": rowid}, retcols=returning + ) + assert row is not None + return row + async def simple_insert_many( self, table: str, @@ -1254,9 +1299,9 @@ class DatabasePool: self, txn: LoggingTransaction, table: str, - keyvalues: Dict[str, Any], - values: Dict[str, Any], - insertion_values: Optional[Dict[str, Any]] = None, + keyvalues: Mapping[str, Any], + values: Mapping[str, Any], + insertion_values: Optional[Mapping[str, Any]] = None, where_clause: Optional[str] = None, ) -> bool: """ @@ -1299,9 +1344,9 @@ class DatabasePool: self, txn: LoggingTransaction, table: str, - keyvalues: Dict[str, Any], - values: Dict[str, Any], - insertion_values: Optional[Dict[str, Any]] = None, + keyvalues: Mapping[str, Any], + values: Mapping[str, Any], + insertion_values: Optional[Mapping[str, Any]] = None, where_clause: Optional[str] = None, lock: bool = True, ) -> bool: @@ -1322,7 +1367,7 @@ class DatabasePool: if lock: # We need to lock the table :( - self.engine.lock_table(txn, table) + txn.database_engine.lock_table(txn, table) def _getwhere(key: str) -> str: # If the value we're passing in is None (aka NULL), we need to use @@ -1376,13 +1421,13 @@ class DatabasePool: # successfully inserted return True + @staticmethod def simple_upsert_txn_native_upsert( - self, txn: LoggingTransaction, table: str, - keyvalues: Dict[str, Any], - values: Dict[str, Any], - insertion_values: Optional[Dict[str, Any]] = None, + keyvalues: Mapping[str, Any], + values: Mapping[str, Any], + insertion_values: Optional[Mapping[str, Any]] = None, where_clause: Optional[str] = None, ) -> bool: """ @@ -1535,8 +1580,8 @@ class DatabasePool: self.simple_upsert_txn_emulated(txn, table, _keys, _vals, lock=False) + @staticmethod def simple_upsert_many_txn_native_upsert( - self, txn: LoggingTransaction, table: str, key_names: Collection[str], @@ -1966,8 +2011,8 @@ class DatabasePool: def simple_update_txn( txn: LoggingTransaction, table: str, - keyvalues: Dict[str, Any], - updatevalues: Dict[str, Any], + keyvalues: Mapping[str, Any], + updatevalues: Mapping[str, Any], ) -> int: """ Update rows in the given database table. @@ -2115,10 +2160,26 @@ class DatabasePool: if rowcount > 1: raise StoreError(500, "More than one row matched (%s)" % (table,)) - # Ideally we could use the overload decorator here to specify that the - # return type is only optional if allow_none is True, but this does not work - # when you call a static method from an instance. - # See https://github.com/python/mypy/issues/7781 + @overload + @staticmethod + def simple_select_one_txn( + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + retcols: Collection[str], + allow_none: Literal[False] = False, + ) -> Tuple[Any, ...]: ... + + @overload + @staticmethod + def simple_select_one_txn( + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + retcols: Collection[str], + allow_none: Literal[True] = True, + ) -> Optional[Tuple[Any, ...]]: ... + @staticmethod def simple_select_one_txn( txn: LoggingTransaction,