diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b8c8dcd76b..a2f8310388 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -96,6 +96,10 @@ class SQLBaseStore(metaclass=ABCMeta):
cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache.
+ Note that this function does not invalidate any remote caches, only the
+ local in-memory ones. Any remote invalidation must be performed before
+ calling this.
+
Args:
cache_name
key: Entry to invalidate. If None then invalidates the entire
@@ -112,7 +116,10 @@ class SQLBaseStore(metaclass=ABCMeta):
if key is None:
cache.invalidate_all()
else:
- cache.invalidate(tuple(key))
+ # Prefer any local-only invalidation method. Invalidating any non-local
+ # cache must be be done before this.
+ invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
+ invalidate_method(tuple(key))
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 6a6d0dcd73..ea672ff89e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -23,6 +23,7 @@ from time import monotonic as monotonic_time
from typing import (
TYPE_CHECKING,
Any,
+ Awaitable,
Callable,
Collection,
Dict,
@@ -57,7 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
-from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
+from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -168,6 +169,7 @@ class LoggingDatabaseConnection:
*,
txn_name: Optional[str] = None,
after_callbacks: Optional[List["_CallbackListEntry"]] = None,
+ async_after_callbacks: Optional[List["_AsyncCallbackListEntry"]] = None,
exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
) -> "LoggingTransaction":
if not txn_name:
@@ -178,6 +180,7 @@ class LoggingDatabaseConnection:
name=txn_name,
database_engine=self.engine,
after_callbacks=after_callbacks,
+ async_after_callbacks=async_after_callbacks,
exception_callbacks=exception_callbacks,
)
@@ -209,6 +212,9 @@ class LoggingDatabaseConnection:
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
_CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
+_AsyncCallbackListEntry = Tuple[
+ Callable[..., Awaitable], Tuple[object, ...], Dict[str, object]
+]
P = ParamSpec("P")
R = TypeVar("R")
@@ -227,6 +233,10 @@ class LoggingTransaction:
that have been added by `call_after` which should be run on
successful completion of the transaction. None indicates that no
callbacks should be allowed to be scheduled to run.
+ async_after_callbacks: A list that asynchronous callbacks will be appended
+ to by `async_call_after` which should run, before after_callbacks, on
+ successful completion of the transaction. None indicates that no
+ callbacks should be allowed to be scheduled to run.
exception_callbacks: A list that callbacks will be appended
to that have been added by `call_on_exception` which should be run
if transaction ends with an error. None indicates that no callbacks
@@ -238,6 +248,7 @@ class LoggingTransaction:
"name",
"database_engine",
"after_callbacks",
+ "async_after_callbacks",
"exception_callbacks",
]
@@ -247,12 +258,14 @@ class LoggingTransaction:
name: str,
database_engine: BaseDatabaseEngine,
after_callbacks: Optional[List[_CallbackListEntry]] = None,
+ async_after_callbacks: Optional[List[_AsyncCallbackListEntry]] = None,
exception_callbacks: Optional[List[_CallbackListEntry]] = None,
):
self.txn = txn
self.name = name
self.database_engine = database_engine
self.after_callbacks = after_callbacks
+ self.async_after_callbacks = async_after_callbacks
self.exception_callbacks = exception_callbacks
def call_after(
@@ -277,6 +290,28 @@ class LoggingTransaction:
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
+ def async_call_after(
+ self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
+ ) -> None:
+ """Call the given asynchronous callback on the main twisted thread after
+ the transaction has finished (but before those added in `call_after`).
+
+ Mostly used to invalidate remote caches after transactions.
+
+ Note that transactions may be retried a few times if they encounter database
+ errors such as serialization failures. Callbacks given to `async_call_after`
+ will accumulate across transaction attempts and will _all_ be called once a
+ transaction attempt succeeds, regardless of whether previous transaction
+ attempts failed. Otherwise, if all transaction attempts fail, all
+ `call_on_exception` callbacks will be run instead.
+ """
+ # if self.async_after_callbacks is None, that means that whatever constructed the
+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
+ # is not the case.
+ assert self.async_after_callbacks is not None
+ # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
+ self.async_after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
+
def call_on_exception(
self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
) -> None:
@@ -574,6 +609,7 @@ class DatabasePool:
conn: LoggingDatabaseConnection,
desc: str,
after_callbacks: List[_CallbackListEntry],
+ async_after_callbacks: List[_AsyncCallbackListEntry],
exception_callbacks: List[_CallbackListEntry],
func: Callable[Concatenate[LoggingTransaction, P], R],
*args: P.args,
@@ -597,6 +633,7 @@ class DatabasePool:
conn
desc
after_callbacks
+ async_after_callbacks
exception_callbacks
func
*args
@@ -659,6 +696,7 @@ class DatabasePool:
cursor = conn.cursor(
txn_name=name,
after_callbacks=after_callbacks,
+ async_after_callbacks=async_after_callbacks,
exception_callbacks=exception_callbacks,
)
try:
@@ -798,6 +836,7 @@ class DatabasePool:
async def _runInteraction() -> R:
after_callbacks: List[_CallbackListEntry] = []
+ async_after_callbacks: List[_AsyncCallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = []
if not current_context():
@@ -809,6 +848,7 @@ class DatabasePool:
self.new_transaction,
desc,
after_callbacks,
+ async_after_callbacks,
exception_callbacks,
func,
*args,
@@ -817,15 +857,17 @@ class DatabasePool:
**kwargs,
)
+ # We order these assuming that async functions call out to external
+ # systems (e.g. to invalidate a cache) and the sync functions make these
+ # changes on any local in-memory caches/similar, and thus must be second.
+ for async_callback, async_args, async_kwargs in async_after_callbacks:
+ await async_callback(*async_args, **async_kwargs)
for after_callback, after_args, after_kwargs in after_callbacks:
- await maybe_awaitable(after_callback(*after_args, **after_kwargs))
-
+ after_callback(*after_args, **after_kwargs)
return cast(R, result)
except Exception:
for exception_callback, after_args, after_kwargs in exception_callbacks:
- await maybe_awaitable(
- exception_callback(*after_args, **after_kwargs)
- )
+ exception_callback(*after_args, **after_kwargs)
raise
# To handle cancellation, we ensure that `after_callback`s and
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index fd3fc298b3..58177ecec1 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -194,7 +194,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
# changed its content in the database. We can't call
# self._invalidate_cache_and_stream because self.get_event_cache isn't of the
# right type.
- txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+ self.invalidate_get_event_cache_after_txn(txn, event.event_id)
# Send that invalidation to replication so that other workers also invalidate
# the event cache.
self._send_invalidation_to_replication(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index fa2266ba20..156e1bd5ab 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1293,7 +1293,7 @@ class PersistEventsStore:
depth_updates: Dict[str, int] = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
- txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
+ self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
# Then update the `stream_ordering` position to mark the latest
# event as the front of the room. This should not be done for
# backfilled events because backfilled events have negative
@@ -1675,7 +1675,7 @@ class PersistEventsStore:
(cache_entry.event.event_id,), cache_entry
)
- txn.call_after(prefill)
+ txn.async_call_after(prefill)
def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
"""Invalidate the caches for the redacted event.
@@ -1684,7 +1684,7 @@ class PersistEventsStore:
_invalidate_caches_for_event.
"""
assert event.redacts is not None
- txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
+ self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index f3935bfead..4435373146 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -712,17 +712,41 @@ class EventsWorkerStore(SQLBaseStore):
return event_entry_map
- async def _invalidate_get_event_cache(self, event_id: str) -> None:
- # First we invalidate the asynchronous cache instance. This may include
- # out-of-process caches such as Redis/memcache. Once complete we can
- # invalidate any in memory cache. The ordering is important here to
- # ensure we don't pull in any remote invalid value after we invalidate
- # the in-memory cache.
+ def invalidate_get_event_cache_after_txn(
+ self, txn: LoggingTransaction, event_id: str
+ ) -> None:
+ """
+ Prepares a database transaction to invalidate the get event cache for a given
+ event ID when executed successfully. This is achieved by attaching two callbacks
+ to the transaction, one to invalidate the async cache and one for the in memory
+ sync cache (importantly called in that order).
+
+ Arguments:
+ txn: the database transaction to attach the callbacks to
+ event_id: the event ID to be invalidated from caches
+ """
+
+ txn.async_call_after(self._invalidate_async_get_event_cache, event_id)
+ txn.call_after(self._invalidate_local_get_event_cache, event_id)
+
+ async def _invalidate_async_get_event_cache(self, event_id: str) -> None:
+ """
+ Invalidates an event in the asyncronous get event cache, which may be remote.
+
+ Arguments:
+ event_id: the event ID to invalidate
+ """
+
await self._get_event_cache.invalidate((event_id,))
- self._event_ref.pop(event_id, None)
- self._current_event_fetches.pop(event_id, None)
def _invalidate_local_get_event_cache(self, event_id: str) -> None:
+ """
+ Invalidates an event in local in-memory get event caches.
+
+ Arguments:
+ event_id: the event ID to invalidate
+ """
+
self._get_event_cache.invalidate_local((event_id,))
self._event_ref.pop(event_id, None)
self._current_event_fetches.pop(event_id, None)
@@ -958,7 +982,13 @@ class EventsWorkerStore(SQLBaseStore):
}
row_dict = self.db_pool.new_transaction(
- conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
+ conn,
+ "do_fetch",
+ [],
+ [],
+ [],
+ self._fetch_event_rows,
+ events_to_fetch,
)
# We only want to resolve deferreds from the main thread
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 9a63f953fb..efd136a864 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -66,6 +66,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
"initialise_mau_threepids",
[],
[],
+ [],
self._initialise_reserved_users,
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
)
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 6d42276503..f6822707e4 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -304,7 +304,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
self._invalidate_cache_and_stream(
txn, self.have_seen_event, (room_id, event_id)
)
- txn.call_after(self._invalidate_get_event_cache, event_id)
+ self.invalidate_get_event_cache_after_txn(txn, event_id)
logger.info("[purge] done")
|