summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/_base.py9
-rw-r--r--synapse/storage/database.py54
-rw-r--r--synapse/storage/databases/main/censor_events.py2
-rw-r--r--synapse/storage/databases/main/events.py6
-rw-r--r--synapse/storage/databases/main/events_worker.py48
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py1
-rw-r--r--synapse/storage/databases/main/purge_events.py2
7 files changed, 101 insertions, 21 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b8c8dcd76b..a2f8310388 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -96,6 +96,10 @@ class SQLBaseStore(metaclass=ABCMeta):
         cache doesn't exist. Mainly used for invalidating caches on workers,
         where they may not have the cache.
 
+        Note that this function does not invalidate any remote caches, only the
+        local in-memory ones. Any remote invalidation must be performed before
+        calling this.
+
         Args:
             cache_name
             key: Entry to invalidate. If None then invalidates the entire
@@ -112,7 +116,10 @@ class SQLBaseStore(metaclass=ABCMeta):
         if key is None:
             cache.invalidate_all()
         else:
-            cache.invalidate(tuple(key))
+            # Prefer any local-only invalidation method. Invalidating any non-local
+            # cache must be be done before this.
+            invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
+            invalidate_method(tuple(key))
 
 
 def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 6a6d0dcd73..ea672ff89e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -23,6 +23,7 @@ from time import monotonic as monotonic_time
 from typing import (
     TYPE_CHECKING,
     Any,
+    Awaitable,
     Callable,
     Collection,
     Dict,
@@ -57,7 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
-from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
+from synapse.util.async_helpers import delay_cancellation
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -168,6 +169,7 @@ class LoggingDatabaseConnection:
         *,
         txn_name: Optional[str] = None,
         after_callbacks: Optional[List["_CallbackListEntry"]] = None,
+        async_after_callbacks: Optional[List["_AsyncCallbackListEntry"]] = None,
         exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
     ) -> "LoggingTransaction":
         if not txn_name:
@@ -178,6 +180,7 @@ class LoggingDatabaseConnection:
             name=txn_name,
             database_engine=self.engine,
             after_callbacks=after_callbacks,
+            async_after_callbacks=async_after_callbacks,
             exception_callbacks=exception_callbacks,
         )
 
@@ -209,6 +212,9 @@ class LoggingDatabaseConnection:
 
 # The type of entry which goes on our after_callbacks and exception_callbacks lists.
 _CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
+_AsyncCallbackListEntry = Tuple[
+    Callable[..., Awaitable], Tuple[object, ...], Dict[str, object]
+]
 
 P = ParamSpec("P")
 R = TypeVar("R")
@@ -227,6 +233,10 @@ class LoggingTransaction:
             that have been added by `call_after` which should be run on
             successful completion of the transaction. None indicates that no
             callbacks should be allowed to be scheduled to run.
+        async_after_callbacks: A list that asynchronous callbacks will be appended
+            to by `async_call_after` which should run, before after_callbacks, on
+            successful completion of the transaction. None indicates that no
+            callbacks should be allowed to be scheduled to run.
         exception_callbacks: A list that callbacks will be appended
             to that have been added by `call_on_exception` which should be run
             if transaction ends with an error. None indicates that no callbacks
@@ -238,6 +248,7 @@ class LoggingTransaction:
         "name",
         "database_engine",
         "after_callbacks",
+        "async_after_callbacks",
         "exception_callbacks",
     ]
 
@@ -247,12 +258,14 @@ class LoggingTransaction:
         name: str,
         database_engine: BaseDatabaseEngine,
         after_callbacks: Optional[List[_CallbackListEntry]] = None,
+        async_after_callbacks: Optional[List[_AsyncCallbackListEntry]] = None,
         exception_callbacks: Optional[List[_CallbackListEntry]] = None,
     ):
         self.txn = txn
         self.name = name
         self.database_engine = database_engine
         self.after_callbacks = after_callbacks
+        self.async_after_callbacks = async_after_callbacks
         self.exception_callbacks = exception_callbacks
 
     def call_after(
@@ -277,6 +290,28 @@ class LoggingTransaction:
         # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
         self.after_callbacks.append((callback, args, kwargs))  # type: ignore[arg-type]
 
+    def async_call_after(
+        self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
+    ) -> None:
+        """Call the given asynchronous callback on the main twisted thread after
+        the transaction has finished (but before those added in `call_after`).
+
+        Mostly used to invalidate remote caches after transactions.
+
+        Note that transactions may be retried a few times if they encounter database
+        errors such as serialization failures. Callbacks given to `async_call_after`
+        will accumulate across transaction attempts and will _all_ be called once a
+        transaction attempt succeeds, regardless of whether previous transaction
+        attempts failed. Otherwise, if all transaction attempts fail, all
+        `call_on_exception` callbacks will be run instead.
+        """
+        # if self.async_after_callbacks is None, that means that whatever constructed the
+        # LoggingTransaction isn't expecting there to be any callbacks; assert that
+        # is not the case.
+        assert self.async_after_callbacks is not None
+        # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
+        self.async_after_callbacks.append((callback, args, kwargs))  # type: ignore[arg-type]
+
     def call_on_exception(
         self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
     ) -> None:
@@ -574,6 +609,7 @@ class DatabasePool:
         conn: LoggingDatabaseConnection,
         desc: str,
         after_callbacks: List[_CallbackListEntry],
+        async_after_callbacks: List[_AsyncCallbackListEntry],
         exception_callbacks: List[_CallbackListEntry],
         func: Callable[Concatenate[LoggingTransaction, P], R],
         *args: P.args,
@@ -597,6 +633,7 @@ class DatabasePool:
             conn
             desc
             after_callbacks
+            async_after_callbacks
             exception_callbacks
             func
             *args
@@ -659,6 +696,7 @@ class DatabasePool:
                 cursor = conn.cursor(
                     txn_name=name,
                     after_callbacks=after_callbacks,
+                    async_after_callbacks=async_after_callbacks,
                     exception_callbacks=exception_callbacks,
                 )
                 try:
@@ -798,6 +836,7 @@ class DatabasePool:
 
         async def _runInteraction() -> R:
             after_callbacks: List[_CallbackListEntry] = []
+            async_after_callbacks: List[_AsyncCallbackListEntry] = []
             exception_callbacks: List[_CallbackListEntry] = []
 
             if not current_context():
@@ -809,6 +848,7 @@ class DatabasePool:
                         self.new_transaction,
                         desc,
                         after_callbacks,
+                        async_after_callbacks,
                         exception_callbacks,
                         func,
                         *args,
@@ -817,15 +857,17 @@ class DatabasePool:
                         **kwargs,
                     )
 
+                # We order these assuming that async functions call out to external
+                # systems (e.g. to invalidate a cache) and the sync functions make these
+                # changes on any local in-memory caches/similar, and thus must be second.
+                for async_callback, async_args, async_kwargs in async_after_callbacks:
+                    await async_callback(*async_args, **async_kwargs)
                 for after_callback, after_args, after_kwargs in after_callbacks:
-                    await maybe_awaitable(after_callback(*after_args, **after_kwargs))
-
+                    after_callback(*after_args, **after_kwargs)
                 return cast(R, result)
             except Exception:
                 for exception_callback, after_args, after_kwargs in exception_callbacks:
-                    await maybe_awaitable(
-                        exception_callback(*after_args, **after_kwargs)
-                    )
+                    exception_callback(*after_args, **after_kwargs)
                 raise
 
         # To handle cancellation, we ensure that `after_callback`s and
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index fd3fc298b3..58177ecec1 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -194,7 +194,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
             # changed its content in the database. We can't call
             # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
             # right type.
-            txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+            self.invalidate_get_event_cache_after_txn(txn, event.event_id)
             # Send that invalidation to replication so that other workers also invalidate
             # the event cache.
             self._send_invalidation_to_replication(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index fa2266ba20..156e1bd5ab 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1293,7 +1293,7 @@ class PersistEventsStore:
         depth_updates: Dict[str, int] = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
-            txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
+            self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
             # Then update the `stream_ordering` position to mark the latest
             # event as the front of the room. This should not be done for
             # backfilled events because backfilled events have negative
@@ -1675,7 +1675,7 @@ class PersistEventsStore:
                     (cache_entry.event.event_id,), cache_entry
                 )
 
-        txn.call_after(prefill)
+        txn.async_call_after(prefill)
 
     def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
         """Invalidate the caches for the redacted event.
@@ -1684,7 +1684,7 @@ class PersistEventsStore:
         _invalidate_caches_for_event.
         """
         assert event.redacts is not None
-        txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
+        self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
         txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
         txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index f3935bfead..4435373146 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -712,17 +712,41 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_entry_map
 
-    async def _invalidate_get_event_cache(self, event_id: str) -> None:
-        # First we invalidate the asynchronous cache instance. This may include
-        # out-of-process caches such as Redis/memcache. Once complete we can
-        # invalidate any in memory cache. The ordering is important here to
-        # ensure we don't pull in any remote invalid value after we invalidate
-        # the in-memory cache.
+    def invalidate_get_event_cache_after_txn(
+        self, txn: LoggingTransaction, event_id: str
+    ) -> None:
+        """
+        Prepares a database transaction to invalidate the get event cache for a given
+        event ID when executed successfully. This is achieved by attaching two callbacks
+        to the transaction, one to invalidate the async cache and one for the in memory
+        sync cache (importantly called in that order).
+
+        Arguments:
+            txn: the database transaction to attach the callbacks to
+            event_id: the event ID to be invalidated from caches
+        """
+
+        txn.async_call_after(self._invalidate_async_get_event_cache, event_id)
+        txn.call_after(self._invalidate_local_get_event_cache, event_id)
+
+    async def _invalidate_async_get_event_cache(self, event_id: str) -> None:
+        """
+        Invalidates an event in the asyncronous get event cache, which may be remote.
+
+        Arguments:
+            event_id: the event ID to invalidate
+        """
+
         await self._get_event_cache.invalidate((event_id,))
-        self._event_ref.pop(event_id, None)
-        self._current_event_fetches.pop(event_id, None)
 
     def _invalidate_local_get_event_cache(self, event_id: str) -> None:
+        """
+        Invalidates an event in local in-memory get event caches.
+
+        Arguments:
+            event_id: the event ID to invalidate
+        """
+
         self._get_event_cache.invalidate_local((event_id,))
         self._event_ref.pop(event_id, None)
         self._current_event_fetches.pop(event_id, None)
@@ -958,7 +982,13 @@ class EventsWorkerStore(SQLBaseStore):
                 }
 
                 row_dict = self.db_pool.new_transaction(
-                    conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
+                    conn,
+                    "do_fetch",
+                    [],
+                    [],
+                    [],
+                    self._fetch_event_rows,
+                    events_to_fetch,
                 )
 
                 # We only want to resolve deferreds from the main thread
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 9a63f953fb..efd136a864 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -66,6 +66,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
                 "initialise_mau_threepids",
                 [],
                 [],
+                [],
                 self._initialise_reserved_users,
                 hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
             )
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 6d42276503..f6822707e4 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -304,7 +304,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
                 self._invalidate_cache_and_stream(
                     txn, self.have_seen_event, (room_id, event_id)
                 )
-                txn.call_after(self._invalidate_get_event_cache, event_id)
+                self.invalidate_get_event_cache_after_txn(txn, event_id)
 
         logger.info("[purge] done")