summary refs log tree commit diff
path: root/synapse/storage/database.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/database.py')
-rw-r--r--synapse/storage/database.py54
1 files changed, 48 insertions, 6 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 6a6d0dcd73..ea672ff89e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -23,6 +23,7 @@ from time import monotonic as monotonic_time
 from typing import (
     TYPE_CHECKING,
     Any,
+    Awaitable,
     Callable,
     Collection,
     Dict,
@@ -57,7 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
-from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
+from synapse.util.async_helpers import delay_cancellation
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -168,6 +169,7 @@ class LoggingDatabaseConnection:
         *,
         txn_name: Optional[str] = None,
         after_callbacks: Optional[List["_CallbackListEntry"]] = None,
+        async_after_callbacks: Optional[List["_AsyncCallbackListEntry"]] = None,
         exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
     ) -> "LoggingTransaction":
         if not txn_name:
@@ -178,6 +180,7 @@ class LoggingDatabaseConnection:
             name=txn_name,
             database_engine=self.engine,
             after_callbacks=after_callbacks,
+            async_after_callbacks=async_after_callbacks,
             exception_callbacks=exception_callbacks,
         )
 
@@ -209,6 +212,9 @@ class LoggingDatabaseConnection:
 
 # The type of entry which goes on our after_callbacks and exception_callbacks lists.
 _CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
+_AsyncCallbackListEntry = Tuple[
+    Callable[..., Awaitable], Tuple[object, ...], Dict[str, object]
+]
 
 P = ParamSpec("P")
 R = TypeVar("R")
@@ -227,6 +233,10 @@ class LoggingTransaction:
             that have been added by `call_after` which should be run on
             successful completion of the transaction. None indicates that no
             callbacks should be allowed to be scheduled to run.
+        async_after_callbacks: A list that asynchronous callbacks will be appended
+            to by `async_call_after` which should run, before after_callbacks, on
+            successful completion of the transaction. None indicates that no
+            callbacks should be allowed to be scheduled to run.
         exception_callbacks: A list that callbacks will be appended
             to that have been added by `call_on_exception` which should be run
             if transaction ends with an error. None indicates that no callbacks
@@ -238,6 +248,7 @@ class LoggingTransaction:
         "name",
         "database_engine",
         "after_callbacks",
+        "async_after_callbacks",
         "exception_callbacks",
     ]
 
@@ -247,12 +258,14 @@ class LoggingTransaction:
         name: str,
         database_engine: BaseDatabaseEngine,
         after_callbacks: Optional[List[_CallbackListEntry]] = None,
+        async_after_callbacks: Optional[List[_AsyncCallbackListEntry]] = None,
         exception_callbacks: Optional[List[_CallbackListEntry]] = None,
     ):
         self.txn = txn
         self.name = name
         self.database_engine = database_engine
         self.after_callbacks = after_callbacks
+        self.async_after_callbacks = async_after_callbacks
         self.exception_callbacks = exception_callbacks
 
     def call_after(
@@ -277,6 +290,28 @@ class LoggingTransaction:
         # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
         self.after_callbacks.append((callback, args, kwargs))  # type: ignore[arg-type]
 
+    def async_call_after(
+        self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
+    ) -> None:
+        """Call the given asynchronous callback on the main twisted thread after
+        the transaction has finished (but before those added in `call_after`).
+
+        Mostly used to invalidate remote caches after transactions.
+
+        Note that transactions may be retried a few times if they encounter database
+        errors such as serialization failures. Callbacks given to `async_call_after`
+        will accumulate across transaction attempts and will _all_ be called once a
+        transaction attempt succeeds, regardless of whether previous transaction
+        attempts failed. Otherwise, if all transaction attempts fail, all
+        `call_on_exception` callbacks will be run instead.
+        """
+        # if self.async_after_callbacks is None, that means that whatever constructed the
+        # LoggingTransaction isn't expecting there to be any callbacks; assert that
+        # is not the case.
+        assert self.async_after_callbacks is not None
+        # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
+        self.async_after_callbacks.append((callback, args, kwargs))  # type: ignore[arg-type]
+
     def call_on_exception(
         self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
     ) -> None:
@@ -574,6 +609,7 @@ class DatabasePool:
         conn: LoggingDatabaseConnection,
         desc: str,
         after_callbacks: List[_CallbackListEntry],
+        async_after_callbacks: List[_AsyncCallbackListEntry],
         exception_callbacks: List[_CallbackListEntry],
         func: Callable[Concatenate[LoggingTransaction, P], R],
         *args: P.args,
@@ -597,6 +633,7 @@ class DatabasePool:
             conn
             desc
             after_callbacks
+            async_after_callbacks
             exception_callbacks
             func
             *args
@@ -659,6 +696,7 @@ class DatabasePool:
                 cursor = conn.cursor(
                     txn_name=name,
                     after_callbacks=after_callbacks,
+                    async_after_callbacks=async_after_callbacks,
                     exception_callbacks=exception_callbacks,
                 )
                 try:
@@ -798,6 +836,7 @@ class DatabasePool:
 
         async def _runInteraction() -> R:
             after_callbacks: List[_CallbackListEntry] = []
+            async_after_callbacks: List[_AsyncCallbackListEntry] = []
             exception_callbacks: List[_CallbackListEntry] = []
 
             if not current_context():
@@ -809,6 +848,7 @@ class DatabasePool:
                         self.new_transaction,
                         desc,
                         after_callbacks,
+                        async_after_callbacks,
                         exception_callbacks,
                         func,
                         *args,
@@ -817,15 +857,17 @@ class DatabasePool:
                         **kwargs,
                     )
 
+                # We order these assuming that async functions call out to external
+                # systems (e.g. to invalidate a cache) and the sync functions make these
+                # changes on any local in-memory caches/similar, and thus must be second.
+                for async_callback, async_args, async_kwargs in async_after_callbacks:
+                    await async_callback(*async_args, **async_kwargs)
                 for after_callback, after_args, after_kwargs in after_callbacks:
-                    await maybe_awaitable(after_callback(*after_args, **after_kwargs))
-
+                    after_callback(*after_args, **after_kwargs)
                 return cast(R, result)
             except Exception:
                 for exception_callback, after_args, after_kwargs in exception_callbacks:
-                    await maybe_awaitable(
-                        exception_callback(*after_args, **after_kwargs)
-                    )
+                    exception_callback(*after_args, **after_kwargs)
                 raise
 
         # To handle cancellation, we ensure that `after_callback`s and