diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 9287c0fb8d..59755bff6d 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -125,7 +125,7 @@ from synapse.types import (
)
from synapse.util import Clock
from synapse.util.async_helpers import maybe_awaitable
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import CachedFunction, cached
from synapse.util.frozenutils import freeze
if TYPE_CHECKING:
@@ -836,6 +836,37 @@ class ModuleApi:
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type]
)
+ def register_cached_function(self, cached_func: CachedFunction) -> None:
+ """Register a cached function that should be invalidated across workers.
+ Invalidation local to a worker can be done directly using `cached_func.invalidate`,
+ however invalidation that needs to go to other workers needs to call `invalidate_cache`
+ on the module API instead.
+
+ Args:
+ cached_function: The cached function that will be registered to receive invalidation
+ locally and from other workers.
+ """
+ self._store.register_external_cached_function(
+ f"{cached_func.__module__}.{cached_func.__name__}", cached_func
+ )
+
+ async def invalidate_cache(
+ self, cached_func: CachedFunction, keys: Tuple[Any, ...]
+ ) -> None:
+ """Invalidate a cache entry of a cached function across workers. The cached function
+ needs to be registered on all workers first with `register_cached_function`.
+
+ Args:
+ cached_function: The cached function that needs an invalidation
+ keys: keys of the entry to invalidate, usually matching the arguments of the
+ cached function.
+ """
+ cached_func.invalidate(keys)
+ await self._store.send_invalidation_to_replication(
+ f"{cached_func.__module__}.{cached_func.__name__}",
+ keys,
+ )
+
async def complete_sso_login_async(
self,
registered_user_id: str,
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index e30f9c76d4..303a5d5298 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,12 +15,13 @@
# limitations under the License.
import logging
from abc import ABCMeta
-from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
+from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union
from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.types import get_domain_from_id
from synapse.util import json_decoder
+from synapse.util.caches.descriptors import CachedFunction
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -47,6 +48,8 @@ class SQLBaseStore(metaclass=ABCMeta):
self.database_engine = database.engine
self.db_pool = database
+ self.external_cached_functions: Dict[str, CachedFunction] = {}
+
def process_replication_rows(
self,
stream_name: str,
@@ -95,7 +98,7 @@ class SQLBaseStore(metaclass=ABCMeta):
def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
- ) -> None:
+ ) -> bool:
"""Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache.
@@ -113,9 +116,12 @@ class SQLBaseStore(metaclass=ABCMeta):
try:
cache = getattr(self, cache_name)
except AttributeError:
- # We probably haven't pulled in the cache in this worker,
- # which is fine.
- return
+ # Check if an externally defined module cache has been registered
+ cache = self.external_cached_functions.get(cache_name)
+ if not cache:
+ # We probably haven't pulled in the cache in this worker,
+ # which is fine.
+ return False
if key is None:
cache.invalidate_all()
@@ -125,6 +131,13 @@ class SQLBaseStore(metaclass=ABCMeta):
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
invalidate_method(tuple(key))
+ return True
+
+ def register_external_cached_function(
+ self, cache_name: str, func: CachedFunction
+ ) -> None:
+ self.external_cached_functions[cache_name] = func
+
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
"""
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 12e9a42382..2c421151c1 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -33,7 +33,7 @@ from synapse.storage.database import (
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.util.caches.descriptors import _CachedFunction
+from synapse.util.caches.descriptors import CachedFunction
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -269,9 +269,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
return
cache_func.invalidate(keys)
- await self.db_pool.runInteraction(
- "invalidate_cache_and_stream",
- self._send_invalidation_to_replication,
+ await self.send_invalidation_to_replication(
cache_func.__name__,
keys,
)
@@ -279,7 +277,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
def _invalidate_cache_and_stream(
self,
txn: LoggingTransaction,
- cache_func: _CachedFunction,
+ cache_func: CachedFunction,
keys: Tuple[Any, ...],
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
@@ -293,7 +291,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
def _invalidate_all_cache_and_stream(
- self, txn: LoggingTransaction, cache_func: _CachedFunction
+ self, txn: LoggingTransaction, cache_func: CachedFunction
) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -334,6 +332,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn, CURRENT_STATE_CACHE_NAME, [room_id]
)
+ async def send_invalidation_to_replication(
+ self, cache_name: str, keys: Optional[Collection[Any]]
+ ) -> None:
+ await self.db_pool.runInteraction(
+ "send_invalidation_to_replication",
+ self._send_invalidation_to_replication,
+ cache_name,
+ keys,
+ )
+
def _send_invalidation_to_replication(
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
) -> None:
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 10aff4d04a..3909f1caea 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -53,7 +53,7 @@ CacheKey = Union[Tuple, Any]
F = TypeVar("F", bound=Callable[..., Any])
-class _CachedFunction(Generic[F]):
+class CachedFunction(Generic[F]):
invalidate: Any = None
invalidate_all: Any = None
prefill: Any = None
@@ -242,7 +242,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
return ret2
- wrapped = cast(_CachedFunction, _wrapped)
+ wrapped = cast(CachedFunction, _wrapped)
wrapped.cache = cache
obj.__dict__[self.name] = wrapped
@@ -363,7 +363,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
return make_deferred_yieldable(ret)
- wrapped = cast(_CachedFunction, _wrapped)
+ wrapped = cast(CachedFunction, _wrapped)
if self.num_args == 1:
assert not self.tree
@@ -572,7 +572,7 @@ def cached(
iterable: bool = False,
prune_unread_entries: bool = True,
name: Optional[str] = None,
-) -> Callable[[F], _CachedFunction[F]]:
+) -> Callable[[F], CachedFunction[F]]:
func = lambda orig: DeferredCacheDescriptor(
orig,
max_entries=max_entries,
@@ -585,7 +585,7 @@ def cached(
name=name,
)
- return cast(Callable[[F], _CachedFunction[F]], func)
+ return cast(Callable[[F], CachedFunction[F]], func)
def cachedList(
@@ -594,7 +594,7 @@ def cachedList(
list_name: str,
num_args: Optional[int] = None,
name: Optional[str] = None,
-) -> Callable[[F], _CachedFunction[F]]:
+) -> Callable[[F], CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
Used to do batch lookups for an already created cache. One of the arguments
@@ -631,7 +631,7 @@ def cachedList(
name=name,
)
- return cast(Callable[[F], _CachedFunction[F]], func)
+ return cast(Callable[[F], CachedFunction[F]], func)
def _get_cache_key_builder(
|