diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 6a6d0dcd73..ea672ff89e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -23,6 +23,7 @@ from time import monotonic as monotonic_time
from typing import (
TYPE_CHECKING,
Any,
+ Awaitable,
Callable,
Collection,
Dict,
@@ -57,7 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
-from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
+from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -168,6 +169,7 @@ class LoggingDatabaseConnection:
*,
txn_name: Optional[str] = None,
after_callbacks: Optional[List["_CallbackListEntry"]] = None,
+ async_after_callbacks: Optional[List["_AsyncCallbackListEntry"]] = None,
exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
) -> "LoggingTransaction":
if not txn_name:
@@ -178,6 +180,7 @@ class LoggingDatabaseConnection:
name=txn_name,
database_engine=self.engine,
after_callbacks=after_callbacks,
+ async_after_callbacks=async_after_callbacks,
exception_callbacks=exception_callbacks,
)
@@ -209,6 +212,9 @@ class LoggingDatabaseConnection:
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
_CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
+_AsyncCallbackListEntry = Tuple[
+ Callable[..., Awaitable], Tuple[object, ...], Dict[str, object]
+]
P = ParamSpec("P")
R = TypeVar("R")
@@ -227,6 +233,10 @@ class LoggingTransaction:
that have been added by `call_after` which should be run on
successful completion of the transaction. None indicates that no
callbacks should be allowed to be scheduled to run.
+ async_after_callbacks: A list that asynchronous callbacks will be appended
+ to by `async_call_after` which should run, before after_callbacks, on
+ successful completion of the transaction. None indicates that no
+ callbacks should be allowed to be scheduled to run.
exception_callbacks: A list that callbacks will be appended
to that have been added by `call_on_exception` which should be run
if transaction ends with an error. None indicates that no callbacks
@@ -238,6 +248,7 @@ class LoggingTransaction:
"name",
"database_engine",
"after_callbacks",
+ "async_after_callbacks",
"exception_callbacks",
]
@@ -247,12 +258,14 @@ class LoggingTransaction:
name: str,
database_engine: BaseDatabaseEngine,
after_callbacks: Optional[List[_CallbackListEntry]] = None,
+ async_after_callbacks: Optional[List[_AsyncCallbackListEntry]] = None,
exception_callbacks: Optional[List[_CallbackListEntry]] = None,
):
self.txn = txn
self.name = name
self.database_engine = database_engine
self.after_callbacks = after_callbacks
+ self.async_after_callbacks = async_after_callbacks
self.exception_callbacks = exception_callbacks
def call_after(
@@ -277,6 +290,28 @@ class LoggingTransaction:
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
+ def async_call_after(
+ self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
+ ) -> None:
+ """Call the given asynchronous callback on the main twisted thread after
+ the transaction has finished (but before those added in `call_after`).
+
+ Mostly used to invalidate remote caches after transactions.
+
+ Note that transactions may be retried a few times if they encounter database
+ errors such as serialization failures. Callbacks given to `async_call_after`
+ will accumulate across transaction attempts and will _all_ be called once a
+ transaction attempt succeeds, regardless of whether previous transaction
+ attempts failed. Otherwise, if all transaction attempts fail, all
+ `call_on_exception` callbacks will be run instead.
+ """
+ # if self.async_after_callbacks is None, that means that whatever constructed the
+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
+ # is not the case.
+ assert self.async_after_callbacks is not None
+ # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
+ self.async_after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
+
def call_on_exception(
self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
) -> None:
@@ -574,6 +609,7 @@ class DatabasePool:
conn: LoggingDatabaseConnection,
desc: str,
after_callbacks: List[_CallbackListEntry],
+ async_after_callbacks: List[_AsyncCallbackListEntry],
exception_callbacks: List[_CallbackListEntry],
func: Callable[Concatenate[LoggingTransaction, P], R],
*args: P.args,
@@ -597,6 +633,7 @@ class DatabasePool:
conn
desc
after_callbacks
+ async_after_callbacks
exception_callbacks
func
*args
@@ -659,6 +696,7 @@ class DatabasePool:
cursor = conn.cursor(
txn_name=name,
after_callbacks=after_callbacks,
+ async_after_callbacks=async_after_callbacks,
exception_callbacks=exception_callbacks,
)
try:
@@ -798,6 +836,7 @@ class DatabasePool:
async def _runInteraction() -> R:
after_callbacks: List[_CallbackListEntry] = []
+ async_after_callbacks: List[_AsyncCallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = []
if not current_context():
@@ -809,6 +848,7 @@ class DatabasePool:
self.new_transaction,
desc,
after_callbacks,
+ async_after_callbacks,
exception_callbacks,
func,
*args,
@@ -817,15 +857,17 @@ class DatabasePool:
**kwargs,
)
+ # We order these assuming that async functions call out to external
+ # systems (e.g. to invalidate a cache) and the sync functions make these
+ # changes on any local in-memory caches/similar, and thus must be second.
+ for async_callback, async_args, async_kwargs in async_after_callbacks:
+ await async_callback(*async_args, **async_kwargs)
for after_callback, after_args, after_kwargs in after_callbacks:
- await maybe_awaitable(after_callback(*after_args, **after_kwargs))
-
+ after_callback(*after_args, **after_kwargs)
return cast(R, result)
except Exception:
for exception_callback, after_args, after_kwargs in exception_callbacks:
- await maybe_awaitable(
- exception_callback(*after_args, **after_kwargs)
- )
+ exception_callback(*after_args, **after_kwargs)
raise
# To handle cancellation, we ensure that `after_callback`s and
|