diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 569f618193..a4941e58f6 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -35,6 +35,8 @@ from typing import (
Iterable,
Iterator,
List,
+ Literal,
+ Mapping,
Optional,
Sequence,
Tuple,
@@ -46,7 +48,7 @@ from typing import (
import attr
from prometheus_client import Counter, Histogram
-from typing_extensions import Concatenate, Literal, ParamSpec
+from typing_extensions import Concatenate, ParamSpec
from twisted.enterprise import adbapi
from twisted.internet.interfaces import IReactorCore
@@ -64,6 +66,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor, SQLQueryParameters
+from synapse.types import StrCollection
from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter
@@ -1095,6 +1098,48 @@ class DatabasePool:
txn.execute(sql, vals)
+ @staticmethod
+ def simple_insert_returning_txn(
+ txn: LoggingTransaction,
+ table: str,
+ values: Dict[str, Any],
+ returning: StrCollection,
+ ) -> Tuple[Any, ...]:
+ """Executes a `INSERT INTO... RETURNING...` statement (or equivalent for
+ SQLite versions that don't support it).
+ """
+
+ if txn.database_engine.supports_returning:
+ sql = "INSERT INTO %s (%s) VALUES(%s) RETURNING %s" % (
+ table,
+ ", ".join(k for k in values.keys()),
+ ", ".join("?" for _ in values.keys()),
+ ", ".join(k for k in returning),
+ )
+
+ txn.execute(sql, list(values.values()))
+ row = txn.fetchone()
+ assert row is not None
+ return row
+ else:
+ # For old versions of SQLite we do a standard insert and then can
+ # use `last_insert_rowid` to get at the row we just inserted
+ DatabasePool.simple_insert_txn(
+ txn,
+ table=table,
+ values=values,
+ )
+ txn.execute("SELECT last_insert_rowid()")
+ row = txn.fetchone()
+ assert row is not None
+ (rowid,) = row
+
+ row = DatabasePool.simple_select_one_txn(
+ txn, table=table, keyvalues={"rowid": rowid}, retcols=returning
+ )
+ assert row is not None
+ return row
+
async def simple_insert_many(
self,
table: str,
@@ -1254,9 +1299,9 @@ class DatabasePool:
self,
txn: LoggingTransaction,
table: str,
- keyvalues: Dict[str, Any],
- values: Dict[str, Any],
- insertion_values: Optional[Dict[str, Any]] = None,
+ keyvalues: Mapping[str, Any],
+ values: Mapping[str, Any],
+ insertion_values: Optional[Mapping[str, Any]] = None,
where_clause: Optional[str] = None,
) -> bool:
"""
@@ -1299,9 +1344,9 @@ class DatabasePool:
self,
txn: LoggingTransaction,
table: str,
- keyvalues: Dict[str, Any],
- values: Dict[str, Any],
- insertion_values: Optional[Dict[str, Any]] = None,
+ keyvalues: Mapping[str, Any],
+ values: Mapping[str, Any],
+ insertion_values: Optional[Mapping[str, Any]] = None,
where_clause: Optional[str] = None,
lock: bool = True,
) -> bool:
@@ -1322,7 +1367,7 @@ class DatabasePool:
if lock:
# We need to lock the table :(
- self.engine.lock_table(txn, table)
+ txn.database_engine.lock_table(txn, table)
def _getwhere(key: str) -> str:
# If the value we're passing in is None (aka NULL), we need to use
@@ -1376,13 +1421,13 @@ class DatabasePool:
# successfully inserted
return True
+ @staticmethod
def simple_upsert_txn_native_upsert(
- self,
txn: LoggingTransaction,
table: str,
- keyvalues: Dict[str, Any],
- values: Dict[str, Any],
- insertion_values: Optional[Dict[str, Any]] = None,
+ keyvalues: Mapping[str, Any],
+ values: Mapping[str, Any],
+ insertion_values: Optional[Mapping[str, Any]] = None,
where_clause: Optional[str] = None,
) -> bool:
"""
@@ -1535,8 +1580,8 @@ class DatabasePool:
self.simple_upsert_txn_emulated(txn, table, _keys, _vals, lock=False)
+ @staticmethod
def simple_upsert_many_txn_native_upsert(
- self,
txn: LoggingTransaction,
table: str,
key_names: Collection[str],
@@ -1966,8 +2011,8 @@ class DatabasePool:
def simple_update_txn(
txn: LoggingTransaction,
table: str,
- keyvalues: Dict[str, Any],
- updatevalues: Dict[str, Any],
+ keyvalues: Mapping[str, Any],
+ updatevalues: Mapping[str, Any],
) -> int:
"""
Update rows in the given database table.
@@ -2115,10 +2160,26 @@ class DatabasePool:
if rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
- # Ideally we could use the overload decorator here to specify that the
- # return type is only optional if allow_none is True, but this does not work
- # when you call a static method from an instance.
- # See https://github.com/python/mypy/issues/7781
+ @overload
+ @staticmethod
+ def simple_select_one_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Collection[str],
+ allow_none: Literal[False] = False,
+ ) -> Tuple[Any, ...]: ...
+
+ @overload
+ @staticmethod
+ def simple_select_one_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Collection[str],
+ allow_none: Literal[True] = True,
+ ) -> Optional[Tuple[Any, ...]]: ...
+
@staticmethod
def simple_select_one_txn(
txn: LoggingTransaction,
|