summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorMathieu Velten <mathieuv@matrix.org>2022-08-11 14:26:40 +0200
committerMathieu Velten <mathieuv@matrix.org>2022-08-12 16:20:26 +0200
commit941ab4b483e24997106e2478e37057450c4f9c14 (patch)
treeb583812cce5c928935e234e52e5910ea36de0a68 /synapse
parentAdd test for account validity feature (diff)
downloadsynapse-github/mv/test-account-validity.tar.xz
Add cache invalidation to module API github/mv/test-account-validity mv/test-account-validity
Diffstat (limited to 'synapse')
-rw-r--r--synapse/module_api/__init__.py16
-rw-r--r--synapse/storage/_base.py6
-rw-r--r--synapse/storage/databases/main/cache.py29
-rw-r--r--synapse/util/caches/descriptors.py14
4 files changed, 48 insertions, 17 deletions
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 87ba154cb7..19b501c9e3 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -125,7 +125,7 @@ from synapse.types import (
 )
 from synapse.util import Clock
 from synapse.util.async_helpers import maybe_awaitable
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import CachedFunction, cached
 from synapse.util.frozenutils import freeze
 
 if TYPE_CHECKING:
@@ -836,6 +836,20 @@ class ModuleApi:
             self._store.db_pool.runInteraction(desc, func, *args, **kwargs)  # type: ignore[arg-type]
         )
 
+    async def invalidate_cache(
+        self, cached_func: CachedFunction, keys: Tuple[Any, ...]
+    ) -> None:
+        cached_func.invalidate(keys)
+        await self._store.send_invalidation_to_replication(
+            cached_func.__qualname__,
+            keys,
+        )
+
+    def register_cached_function(self, cached_func: CachedFunction) -> None:
+        self._store.register_external_cached_function(
+            cached_func.__qualname__, cached_func
+        )
+
     def complete_sso_login(
         self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
     ) -> None:
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index e30f9c76d4..3f85a33344 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -95,7 +95,7 @@ class SQLBaseStore(metaclass=ABCMeta):
 
     def _attempt_to_invalidate_cache(
         self, cache_name: str, key: Optional[Collection[Any]]
-    ) -> None:
+    ) -> bool:
         """Attempts to invalidate the cache of the given name, ignoring if the
         cache doesn't exist. Mainly used for invalidating caches on workers,
         where they may not have the cache.
@@ -115,7 +115,7 @@ class SQLBaseStore(metaclass=ABCMeta):
         except AttributeError:
             # We probably haven't pulled in the cache in this worker,
             # which is fine.
-            return
+            return False
 
         if key is None:
             cache.invalidate_all()
@@ -125,6 +125,8 @@ class SQLBaseStore(metaclass=ABCMeta):
             invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
             invalidate_method(tuple(key))
 
+        return True
+
 
 def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
     """
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 12e9a42382..efe9f3ad88 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -33,7 +33,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.util.caches.descriptors import _CachedFunction
+from synapse.util.caches.descriptors import CachedFunction
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -91,6 +91,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         else:
             self._cache_id_gen = None
 
+        self.external_cached_functions = {}
+
+    def register_external_cached_function(self, cache_name, func):
+        self.external_cached_functions[cache_name] = func
+
     async def get_all_updated_caches(
         self, instance_name: str, last_id: int, current_id: int, limit: int
     ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
@@ -178,7 +183,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                     members_changed = set(row.keys[1:])
                     self._invalidate_state_caches(room_id, members_changed)
                 else:
-                    self._attempt_to_invalidate_cache(row.cache_func, row.keys)
+                    res = self._attempt_to_invalidate_cache(row.cache_func, row.keys)
+                    if not res:
+                        external_func = self.external_cached_functions[row.cache_func]
+                        if external_func:
+                            external_func.invalidate(row.keys)
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
@@ -269,9 +278,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             return
 
         cache_func.invalidate(keys)
-        await self.db_pool.runInteraction(
-            "invalidate_cache_and_stream",
-            self._send_invalidation_to_replication,
+        await self.send_invalidation_to_replication(
             cache_func.__name__,
             keys,
         )
@@ -279,7 +286,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
     def _invalidate_cache_and_stream(
         self,
         txn: LoggingTransaction,
-        cache_func: _CachedFunction,
+        cache_func: CachedFunction,
         keys: Tuple[Any, ...],
     ) -> None:
         """Invalidates the cache and adds it to the cache stream so slaves
@@ -293,7 +300,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
 
     def _invalidate_all_cache_and_stream(
-        self, txn: LoggingTransaction, cache_func: _CachedFunction
+        self, txn: LoggingTransaction, cache_func: CachedFunction
     ) -> None:
         """Invalidates the entire cache and adds it to the cache stream so slaves
         will know to invalidate their caches.
@@ -334,6 +341,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 txn, CURRENT_STATE_CACHE_NAME, [room_id]
             )
 
+    async def send_invalidation_to_replication(self, cache_name, keys):
+        await self.db_pool.runInteraction(
+            "send_invalidation_to_replication",
+            self._send_invalidation_to_replication,
+            cache_name,
+            keys,
+        )
+
     def _send_invalidation_to_replication(
         self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
     ) -> None:
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 867f315b2a..c3c20a2339 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -52,7 +52,7 @@ CacheKey = Union[Tuple, Any]
 F = TypeVar("F", bound=Callable[..., Any])
 
 
-class _CachedFunction(Generic[F]):
+class CachedFunction(Generic[F]):
     invalidate: Any = None
     invalidate_all: Any = None
     prefill: Any = None
@@ -239,7 +239,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
             return ret2
 
-        wrapped = cast(_CachedFunction, _wrapped)
+        wrapped = cast(CachedFunction, _wrapped)
         wrapped.cache = cache
         obj.__dict__[self.orig.__name__] = wrapped
 
@@ -358,7 +358,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
 
             return make_deferred_yieldable(ret)
 
-        wrapped = cast(_CachedFunction, _wrapped)
+        wrapped = cast(CachedFunction, _wrapped)
 
         if self.num_args == 1:
             assert not self.tree
@@ -577,7 +577,7 @@ def cached(
     cache_context: bool = False,
     iterable: bool = False,
     prune_unread_entries: bool = True,
-) -> Callable[[F], _CachedFunction[F]]:
+) -> Callable[[F], CachedFunction[F]]:
     func = lambda orig: DeferredCacheDescriptor(
         orig,
         max_entries=max_entries,
@@ -589,12 +589,12 @@ def cached(
         prune_unread_entries=prune_unread_entries,
     )
 
-    return cast(Callable[[F], _CachedFunction[F]], func)
+    return cast(Callable[[F], CachedFunction[F]], func)
 
 
 def cachedList(
     *, cached_method_name: str, list_name: str, num_args: Optional[int] = None
-) -> Callable[[F], _CachedFunction[F]]:
+) -> Callable[[F], CachedFunction[F]]:
     """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
 
     Used to do batch lookups for an already created cache. One of the arguments
@@ -630,7 +630,7 @@ def cachedList(
         num_args=num_args,
     )
 
-    return cast(Callable[[F], _CachedFunction[F]], func)
+    return cast(Callable[[F], CachedFunction[F]], func)
 
 
 def _get_cache_key_builder(