summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13667.feature1
-rw-r--r--scripts-dev/mypy_synapse_plugin.py4
-rw-r--r--synapse/module_api/__init__.py33
-rw-r--r--synapse/storage/_base.py23
-rw-r--r--synapse/storage/databases/main/cache.py20
-rw-r--r--synapse/util/caches/descriptors.py14
-rw-r--r--tests/replication/test_module_cache_invalidation.py79
7 files changed, 153 insertions, 21 deletions
diff --git a/changelog.d/13667.feature b/changelog.d/13667.feature
new file mode 100644
index 0000000000..a0b3cfe18c
--- /dev/null
+++ b/changelog.d/13667.feature
@@ -0,0 +1 @@
+Add cache invalidation across workers to module API.
diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py
index d08517a953..2c377533c0 100644
--- a/scripts-dev/mypy_synapse_plugin.py
+++ b/scripts-dev/mypy_synapse_plugin.py
@@ -29,7 +29,7 @@ class SynapsePlugin(Plugin):
         self, fullname: str
     ) -> Optional[Callable[[MethodSigContext], CallableType]]:
         if fullname.startswith(
-            "synapse.util.caches.descriptors._CachedFunction.__call__"
+            "synapse.util.caches.descriptors.CachedFunction.__call__"
         ) or fullname.startswith(
             "synapse.util.caches.descriptors._LruCachedFunction.__call__"
         ):
@@ -38,7 +38,7 @@ class SynapsePlugin(Plugin):
 
 
 def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
-    """Fixes the `_CachedFunction.__call__` signature to be correct.
+    """Fixes the `CachedFunction.__call__` signature to be correct.
 
     It already has *almost* the correct signature, except:
 
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 9287c0fb8d..59755bff6d 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -125,7 +125,7 @@ from synapse.types import (
 )
 from synapse.util import Clock
 from synapse.util.async_helpers import maybe_awaitable
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import CachedFunction, cached
 from synapse.util.frozenutils import freeze
 
 if TYPE_CHECKING:
@@ -836,6 +836,37 @@ class ModuleApi:
             self._store.db_pool.runInteraction(desc, func, *args, **kwargs)  # type: ignore[arg-type]
         )
 
+    def register_cached_function(self, cached_func: CachedFunction) -> None:
+        """Register a cached function that should be invalidated across workers.
+        Invalidation local to a worker can be done directly using `cached_func.invalidate`,
+        however invalidation that needs to go to other workers needs to call `invalidate_cache`
+        on the module API instead.
+
+        Args:
+            cached_function: The cached function that will be registered to receive invalidation
+            locally and from other workers.
+        """
+        self._store.register_external_cached_function(
+            f"{cached_func.__module__}.{cached_func.__name__}", cached_func
+        )
+
+    async def invalidate_cache(
+        self, cached_func: CachedFunction, keys: Tuple[Any, ...]
+    ) -> None:
+        """Invalidate a cache entry of a cached function across workers. The cached function
+        needs to be registered on all workers first with `register_cached_function`.
+
+        Args:
+            cached_function: The cached function that needs an invalidation
+            keys: keys of the entry to invalidate, usually matching the arguments of the
+            cached function.
+        """
+        cached_func.invalidate(keys)
+        await self._store.send_invalidation_to_replication(
+            f"{cached_func.__module__}.{cached_func.__name__}",
+            keys,
+        )
+
     async def complete_sso_login_async(
         self,
         registered_user_id: str,
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index e30f9c76d4..303a5d5298 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,12 +15,13 @@
 # limitations under the License.
 import logging
 from abc import ABCMeta
-from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
+from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union
 
 from synapse.storage.database import make_in_list_sql_clause  # noqa: F401; noqa: F401
 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.types import get_domain_from_id
 from synapse.util import json_decoder
+from synapse.util.caches.descriptors import CachedFunction
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -47,6 +48,8 @@ class SQLBaseStore(metaclass=ABCMeta):
         self.database_engine = database.engine
         self.db_pool = database
 
+        self.external_cached_functions: Dict[str, CachedFunction] = {}
+
     def process_replication_rows(
         self,
         stream_name: str,
@@ -95,7 +98,7 @@ class SQLBaseStore(metaclass=ABCMeta):
 
     def _attempt_to_invalidate_cache(
         self, cache_name: str, key: Optional[Collection[Any]]
-    ) -> None:
+    ) -> bool:
         """Attempts to invalidate the cache of the given name, ignoring if the
         cache doesn't exist. Mainly used for invalidating caches on workers,
         where they may not have the cache.
@@ -113,9 +116,12 @@ class SQLBaseStore(metaclass=ABCMeta):
         try:
             cache = getattr(self, cache_name)
         except AttributeError:
-            # We probably haven't pulled in the cache in this worker,
-            # which is fine.
-            return
+            # Check if an externally defined module cache has been registered
+            cache = self.external_cached_functions.get(cache_name)
+            if not cache:
+                # We probably haven't pulled in the cache in this worker,
+                # which is fine.
+                return False
 
         if key is None:
             cache.invalidate_all()
@@ -125,6 +131,13 @@ class SQLBaseStore(metaclass=ABCMeta):
             invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
             invalidate_method(tuple(key))
 
+        return True
+
+    def register_external_cached_function(
+        self, cache_name: str, func: CachedFunction
+    ) -> None:
+        self.external_cached_functions[cache_name] = func
+
 
 def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
     """
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 12e9a42382..2c421151c1 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -33,7 +33,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.util.caches.descriptors import _CachedFunction
+from synapse.util.caches.descriptors import CachedFunction
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -269,9 +269,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             return
 
         cache_func.invalidate(keys)
-        await self.db_pool.runInteraction(
-            "invalidate_cache_and_stream",
-            self._send_invalidation_to_replication,
+        await self.send_invalidation_to_replication(
             cache_func.__name__,
             keys,
         )
@@ -279,7 +277,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
     def _invalidate_cache_and_stream(
         self,
         txn: LoggingTransaction,
-        cache_func: _CachedFunction,
+        cache_func: CachedFunction,
         keys: Tuple[Any, ...],
     ) -> None:
         """Invalidates the cache and adds it to the cache stream so slaves
@@ -293,7 +291,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
 
     def _invalidate_all_cache_and_stream(
-        self, txn: LoggingTransaction, cache_func: _CachedFunction
+        self, txn: LoggingTransaction, cache_func: CachedFunction
     ) -> None:
         """Invalidates the entire cache and adds it to the cache stream so slaves
         will know to invalidate their caches.
@@ -334,6 +332,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 txn, CURRENT_STATE_CACHE_NAME, [room_id]
             )
 
+    async def send_invalidation_to_replication(
+        self, cache_name: str, keys: Optional[Collection[Any]]
+    ) -> None:
+        await self.db_pool.runInteraction(
+            "send_invalidation_to_replication",
+            self._send_invalidation_to_replication,
+            cache_name,
+            keys,
+        )
+
     def _send_invalidation_to_replication(
         self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
     ) -> None:
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 10aff4d04a..3909f1caea 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -53,7 +53,7 @@ CacheKey = Union[Tuple, Any]
 F = TypeVar("F", bound=Callable[..., Any])
 
 
-class _CachedFunction(Generic[F]):
+class CachedFunction(Generic[F]):
     invalidate: Any = None
     invalidate_all: Any = None
     prefill: Any = None
@@ -242,7 +242,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
             return ret2
 
-        wrapped = cast(_CachedFunction, _wrapped)
+        wrapped = cast(CachedFunction, _wrapped)
         wrapped.cache = cache
         obj.__dict__[self.name] = wrapped
 
@@ -363,7 +363,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
 
             return make_deferred_yieldable(ret)
 
-        wrapped = cast(_CachedFunction, _wrapped)
+        wrapped = cast(CachedFunction, _wrapped)
 
         if self.num_args == 1:
             assert not self.tree
@@ -572,7 +572,7 @@ def cached(
     iterable: bool = False,
     prune_unread_entries: bool = True,
     name: Optional[str] = None,
-) -> Callable[[F], _CachedFunction[F]]:
+) -> Callable[[F], CachedFunction[F]]:
     func = lambda orig: DeferredCacheDescriptor(
         orig,
         max_entries=max_entries,
@@ -585,7 +585,7 @@ def cached(
         name=name,
     )
 
-    return cast(Callable[[F], _CachedFunction[F]], func)
+    return cast(Callable[[F], CachedFunction[F]], func)
 
 
 def cachedList(
@@ -594,7 +594,7 @@ def cachedList(
     list_name: str,
     num_args: Optional[int] = None,
     name: Optional[str] = None,
-) -> Callable[[F], _CachedFunction[F]]:
+) -> Callable[[F], CachedFunction[F]]:
     """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
 
     Used to do batch lookups for an already created cache. One of the arguments
@@ -631,7 +631,7 @@ def cachedList(
         name=name,
     )
 
-    return cast(Callable[[F], _CachedFunction[F]], func)
+    return cast(Callable[[F], CachedFunction[F]], func)
 
 
 def _get_cache_key_builder(
diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py
new file mode 100644
index 0000000000..b93cae67d3
--- /dev/null
+++ b/tests/replication/test_module_cache_invalidation.py
@@ -0,0 +1,79 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import synapse
+from synapse.module_api import cached
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+logger = logging.getLogger(__name__)
+
+FIRST_VALUE = "one"
+SECOND_VALUE = "two"
+
+KEY = "mykey"
+
+
+class TestCache:
+    current_value = FIRST_VALUE
+
+    @cached()
+    async def cached_function(self, user_id: str) -> str:
+        return self.current_value
+
+
+class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+    ]
+
+    def test_module_cache_full_invalidation(self):
+        main_cache = TestCache()
+        self.hs.get_module_api().register_cached_function(main_cache.cached_function)
+
+        worker_hs = self.make_worker_hs("synapse.app.generic_worker")
+
+        worker_cache = TestCache()
+        worker_hs.get_module_api().register_cached_function(
+            worker_cache.cached_function
+        )
+
+        self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+        self.assertEqual(
+            FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
+        )
+
+        main_cache.current_value = SECOND_VALUE
+        worker_cache.current_value = SECOND_VALUE
+        # No invalidation yet, should return the cached value on both the main process and the worker
+        self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+        self.assertEqual(
+            FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
+        )
+
+        # Full invalidation on the main process, should be replicated on the worker that
+        # should returned the updated value too
+        self.get_success(
+            self.hs.get_module_api().invalidate_cache(
+                main_cache.cached_function, (KEY,)
+            )
+        )
+
+        self.assertEqual(
+            SECOND_VALUE, self.get_success(main_cache.cached_function(KEY))
+        )
+        self.assertEqual(
+            SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY))
+        )