diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py
index d08517a953..2c377533c0 100644
--- a/scripts-dev/mypy_synapse_plugin.py
+++ b/scripts-dev/mypy_synapse_plugin.py
@@ -29,7 +29,7 @@ class SynapsePlugin(Plugin):
self, fullname: str
) -> Optional[Callable[[MethodSigContext], CallableType]]:
if fullname.startswith(
- "synapse.util.caches.descriptors._CachedFunction.__call__"
+ "synapse.util.caches.descriptors.CachedFunction.__call__"
) or fullname.startswith(
"synapse.util.caches.descriptors._LruCachedFunction.__call__"
):
@@ -38,7 +38,7 @@ class SynapsePlugin(Plugin):
def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
- """Fixes the `_CachedFunction.__call__` signature to be correct.
+ """Fixes the `CachedFunction.__call__` signature to be correct.
It already has *almost* the correct signature, except:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 87ba154cb7..19b501c9e3 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -125,7 +125,7 @@ from synapse.types import (
)
from synapse.util import Clock
from synapse.util.async_helpers import maybe_awaitable
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import CachedFunction, cached
from synapse.util.frozenutils import freeze
if TYPE_CHECKING:
@@ -836,6 +836,20 @@ class ModuleApi:
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type]
)
+ async def invalidate_cache(
+ self, cached_func: CachedFunction, keys: Tuple[Any, ...]
+ ) -> None:
+ cached_func.invalidate(keys)
+ await self._store.send_invalidation_to_replication(
+ cached_func.__qualname__,
+ keys,
+ )
+
+ def register_cached_function(self, cached_func: CachedFunction) -> None:
+ self._store.register_external_cached_function(
+ cached_func.__qualname__, cached_func
+ )
+
def complete_sso_login(
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
) -> None:
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index e30f9c76d4..3f85a33344 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -95,7 +95,7 @@ class SQLBaseStore(metaclass=ABCMeta):
def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
- ) -> None:
+ ) -> bool:
"""Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache.
@@ -115,7 +115,7 @@ class SQLBaseStore(metaclass=ABCMeta):
except AttributeError:
# We probably haven't pulled in the cache in this worker,
# which is fine.
- return
+ return False
if key is None:
cache.invalidate_all()
@@ -125,6 +125,8 @@ class SQLBaseStore(metaclass=ABCMeta):
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
invalidate_method(tuple(key))
+ return True
+
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
"""
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 12e9a42382..efe9f3ad88 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -33,7 +33,7 @@ from synapse.storage.database import (
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.util.caches.descriptors import _CachedFunction
+from synapse.util.caches.descriptors import CachedFunction
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -91,6 +91,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
else:
self._cache_id_gen = None
+ self.external_cached_functions = {}
+
+ def register_external_cached_function(self, cache_name, func):
+ self.external_cached_functions[cache_name] = func
+
async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
@@ -178,7 +183,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
members_changed = set(row.keys[1:])
self._invalidate_state_caches(room_id, members_changed)
else:
- self._attempt_to_invalidate_cache(row.cache_func, row.keys)
+ res = self._attempt_to_invalidate_cache(row.cache_func, row.keys)
+ if not res:
+ external_func = self.external_cached_functions[row.cache_func]
+ if external_func:
+ external_func.invalidate(row.keys)
super().process_replication_rows(stream_name, instance_name, token, rows)
@@ -269,9 +278,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
return
cache_func.invalidate(keys)
- await self.db_pool.runInteraction(
- "invalidate_cache_and_stream",
- self._send_invalidation_to_replication,
+ await self.send_invalidation_to_replication(
cache_func.__name__,
keys,
)
@@ -279,7 +286,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
def _invalidate_cache_and_stream(
self,
txn: LoggingTransaction,
- cache_func: _CachedFunction,
+ cache_func: CachedFunction,
keys: Tuple[Any, ...],
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
@@ -293,7 +300,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
def _invalidate_all_cache_and_stream(
- self, txn: LoggingTransaction, cache_func: _CachedFunction
+ self, txn: LoggingTransaction, cache_func: CachedFunction
) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -334,6 +341,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn, CURRENT_STATE_CACHE_NAME, [room_id]
)
+ async def send_invalidation_to_replication(self, cache_name, keys):
+ await self.db_pool.runInteraction(
+ "send_invalidation_to_replication",
+ self._send_invalidation_to_replication,
+ cache_name,
+ keys,
+ )
+
def _send_invalidation_to_replication(
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
) -> None:
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 867f315b2a..c3c20a2339 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -52,7 +52,7 @@ CacheKey = Union[Tuple, Any]
F = TypeVar("F", bound=Callable[..., Any])
-class _CachedFunction(Generic[F]):
+class CachedFunction(Generic[F]):
invalidate: Any = None
invalidate_all: Any = None
prefill: Any = None
@@ -239,7 +239,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
return ret2
- wrapped = cast(_CachedFunction, _wrapped)
+ wrapped = cast(CachedFunction, _wrapped)
wrapped.cache = cache
obj.__dict__[self.orig.__name__] = wrapped
@@ -358,7 +358,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
return make_deferred_yieldable(ret)
- wrapped = cast(_CachedFunction, _wrapped)
+ wrapped = cast(CachedFunction, _wrapped)
if self.num_args == 1:
assert not self.tree
@@ -577,7 +577,7 @@ def cached(
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
-) -> Callable[[F], _CachedFunction[F]]:
+) -> Callable[[F], CachedFunction[F]]:
func = lambda orig: DeferredCacheDescriptor(
orig,
max_entries=max_entries,
@@ -589,12 +589,12 @@ def cached(
prune_unread_entries=prune_unread_entries,
)
- return cast(Callable[[F], _CachedFunction[F]], func)
+ return cast(Callable[[F], CachedFunction[F]], func)
def cachedList(
*, cached_method_name: str, list_name: str, num_args: Optional[int] = None
-) -> Callable[[F], _CachedFunction[F]]:
+) -> Callable[[F], CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
Used to do batch lookups for an already created cache. One of the arguments
@@ -630,7 +630,7 @@ def cachedList(
num_args=num_args,
)
- return cast(Callable[[F], _CachedFunction[F]], func)
+ return cast(Callable[[F], CachedFunction[F]], func)
def _get_cache_key_builder(
diff --git a/tests/replication/test_account_validity.py b/tests/replication/test_account_validity.py
index 408eb266b9..7353bbce96 100644
--- a/tests/replication/test_account_validity.py
+++ b/tests/replication/test_account_validity.py
@@ -33,6 +33,8 @@ class MockAccountValidityStore:
):
self._api = api
+ api.register_cached_function(self.is_user_expired)
+
async def create_db(self):
def create_table_txn(txn: LoggingTransaction):
txn.execute(
@@ -88,13 +90,13 @@ class MockAccountValidityStore:
),
)
- txn.call_after(self.is_user_expired.invalidate, (user_id,))
-
await self._api.run_db_interaction(
"account_validity_set_expired_user",
set_expired_user_txn,
)
+ await self._api.invalidate_cache(self.is_user_expired, (user_id,))
+
class MockAccountValidity:
def __init__(
diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py
new file mode 100644
index 0000000000..fa52e11630
--- /dev/null
+++ b/tests/replication/test_module_cache_invalidation.py
@@ -0,0 +1,88 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging, time
+
+import synapse
+from synapse.module_api import ModuleApi, cached
+from synapse.server import HomeServer
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+logger = logging.getLogger(__name__)
+
+FIRST_VALUE = "one"
+SECOND_VALUE = "two"
+
+KEY = "mykey"
+
+class TestCache:
+ current_value = FIRST_VALUE
+
+ @cached()
+ async def cached_function(self, user_id: str) -> str:
+ print(self.current_value)
+ return self.current_value
+
+class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
+
+ def test_module_cache_full_invalidation(self):
+ main_cache = TestCache()
+ self.hs.get_module_api().register_cached_function(main_cache.cached_function)
+
+ worker_hs = self.make_worker_hs("synapse.app.generic_worker")
+
+ worker_cache = TestCache()
+ worker_hs.get_module_api().register_cached_function(worker_cache.cached_function)
+
+ self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+ self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
+
+ main_cache.current_value = SECOND_VALUE
+ worker_cache.current_value = SECOND_VALUE
+ # No invalidation yet, should return the cached value on both the main process and the worker
+ self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+ self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
+
+ self.reactor.advance(1)
+
+ # Full invalidation on the main process, should be replicated on the worker that
+ # should returned the updated value too
+ self.get_success(
+ self.hs.get_module_api().invalidate_cache(main_cache.cached_function, (KEY,))
+ )
+
+ self.reactor.advance(1)
+
+ self.assertEqual(SECOND_VALUE, self.get_success(main_cache.cached_function(KEY)))
+ self.assertEqual(SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY)))
+
+ # def test_module_cache_local_invalidation_only(self):
+ # main_cache = TestCache()
+ # worker_cache = TestCache()
+
+ # worker_hs = self.make_worker_hs("synapse.app.generic_worker")
+
+ # self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+ # self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
+
+ # main_cache.current_value = SECOND_VALUE
+ # worker_cache.current_value = SECOND_VALUE
+ # # No local invalidation yet, should return the cached value on both the main process and the worker
+ # self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+ # self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
+
+ # # local invalidation on the main process, worker should still return the cached value
+ # main_cache.cached_function.invalidate((KEY,))
+ # self.assertEqual(SECOND_VALUE, self.get_success(main_cache.cached_function(KEY)))
+ # self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
|