diff options
-rw-r--r-- | scripts-dev/mypy_synapse_plugin.py | 4 | ||||
-rw-r--r-- | synapse/module_api/__init__.py | 16 | ||||
-rw-r--r-- | synapse/storage/_base.py | 6 | ||||
-rw-r--r-- | synapse/storage/databases/main/cache.py | 29 | ||||
-rw-r--r-- | synapse/util/caches/descriptors.py | 14 | ||||
-rw-r--r-- | tests/replication/test_account_validity.py | 6 | ||||
-rw-r--r-- | tests/replication/test_module_cache_invalidation.py | 88 |
7 files changed, 142 insertions, 21 deletions
diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index d08517a953..2c377533c0 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -29,7 +29,7 @@ class SynapsePlugin(Plugin): self, fullname: str ) -> Optional[Callable[[MethodSigContext], CallableType]]: if fullname.startswith( - "synapse.util.caches.descriptors._CachedFunction.__call__" + "synapse.util.caches.descriptors.CachedFunction.__call__" ) or fullname.startswith( "synapse.util.caches.descriptors._LruCachedFunction.__call__" ): @@ -38,7 +38,7 @@ class SynapsePlugin(Plugin): def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: - """Fixes the `_CachedFunction.__call__` signature to be correct. + """Fixes the `CachedFunction.__call__` signature to be correct. It already has *almost* the correct signature, except: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 87ba154cb7..19b501c9e3 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -125,7 +125,7 @@ from synapse.types import ( ) from synapse.util import Clock from synapse.util.async_helpers import maybe_awaitable -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import CachedFunction, cached from synapse.util.frozenutils import freeze if TYPE_CHECKING: @@ -836,6 +836,20 @@ class ModuleApi: self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type] ) + async def invalidate_cache( + self, cached_func: CachedFunction, keys: Tuple[Any, ...] + ) -> None: + cached_func.invalidate(keys) + await self._store.send_invalidation_to_replication( + cached_func.__qualname__, + keys, + ) + + def register_cached_function(self, cached_func: CachedFunction) -> None: + self._store.register_external_cached_function( + cached_func.__qualname__, cached_func + ) + def complete_sso_login( self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str ) -> None: diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index e30f9c76d4..3f85a33344 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -95,7 +95,7 @@ class SQLBaseStore(metaclass=ABCMeta): def _attempt_to_invalidate_cache( self, cache_name: str, key: Optional[Collection[Any]] - ) -> None: + ) -> bool: """Attempts to invalidate the cache of the given name, ignoring if the cache doesn't exist. Mainly used for invalidating caches on workers, where they may not have the cache. @@ -115,7 +115,7 @@ class SQLBaseStore(metaclass=ABCMeta): except AttributeError: # We probably haven't pulled in the cache in this worker, # which is fine. - return + return False if key is None: cache.invalidate_all() @@ -125,6 +125,8 @@ class SQLBaseStore(metaclass=ABCMeta): invalidate_method = getattr(cache, "invalidate_local", cache.invalidate) invalidate_method(tuple(key)) + return True + def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any: """ diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 12e9a42382..efe9f3ad88 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -33,7 +33,7 @@ from synapse.storage.database import ( ) from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator -from synapse.util.caches.descriptors import _CachedFunction +from synapse.util.caches.descriptors import CachedFunction from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -91,6 +91,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore): else: self._cache_id_gen = None + self.external_cached_functions = {} + + def register_external_cached_function(self, cache_name, func): + self.external_cached_functions[cache_name] = func + async def get_all_updated_caches( self, instance_name: str, last_id: int, current_id: int, limit: int ) -> Tuple[List[Tuple[int, tuple]], int, bool]: @@ -178,7 +183,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore): members_changed = set(row.keys[1:]) self._invalidate_state_caches(room_id, members_changed) else: - self._attempt_to_invalidate_cache(row.cache_func, row.keys) + res = self._attempt_to_invalidate_cache(row.cache_func, row.keys) + if not res: + external_func = self.external_cached_functions[row.cache_func] + if external_func: + external_func.invalidate(row.keys) super().process_replication_rows(stream_name, instance_name, token, rows) @@ -269,9 +278,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): return cache_func.invalidate(keys) - await self.db_pool.runInteraction( - "invalidate_cache_and_stream", - self._send_invalidation_to_replication, + await self.send_invalidation_to_replication( cache_func.__name__, keys, ) @@ -279,7 +286,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): def _invalidate_cache_and_stream( self, txn: LoggingTransaction, - cache_func: _CachedFunction, + cache_func: CachedFunction, keys: Tuple[Any, ...], ) -> None: """Invalidates the cache and adds it to the cache stream so slaves @@ -293,7 +300,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._send_invalidation_to_replication(txn, cache_func.__name__, keys) def _invalidate_all_cache_and_stream( - self, txn: LoggingTransaction, cache_func: _CachedFunction + self, txn: LoggingTransaction, cache_func: CachedFunction ) -> None: """Invalidates the entire cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -334,6 +341,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore): txn, CURRENT_STATE_CACHE_NAME, [room_id] ) + async def send_invalidation_to_replication(self, cache_name, keys): + await self.db_pool.runInteraction( + "send_invalidation_to_replication", + self._send_invalidation_to_replication, + cache_name, + keys, + ) + def _send_invalidation_to_replication( self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]] ) -> None: diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 867f315b2a..c3c20a2339 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -52,7 +52,7 @@ CacheKey = Union[Tuple, Any] F = TypeVar("F", bound=Callable[..., Any]) -class _CachedFunction(Generic[F]): +class CachedFunction(Generic[F]): invalidate: Any = None invalidate_all: Any = None prefill: Any = None @@ -239,7 +239,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): return ret2 - wrapped = cast(_CachedFunction, _wrapped) + wrapped = cast(CachedFunction, _wrapped) wrapped.cache = cache obj.__dict__[self.orig.__name__] = wrapped @@ -358,7 +358,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): return make_deferred_yieldable(ret) - wrapped = cast(_CachedFunction, _wrapped) + wrapped = cast(CachedFunction, _wrapped) if self.num_args == 1: assert not self.tree @@ -577,7 +577,7 @@ def cached( cache_context: bool = False, iterable: bool = False, prune_unread_entries: bool = True, -) -> Callable[[F], _CachedFunction[F]]: +) -> Callable[[F], CachedFunction[F]]: func = lambda orig: DeferredCacheDescriptor( orig, max_entries=max_entries, @@ -589,12 +589,12 @@ def cached( prune_unread_entries=prune_unread_entries, ) - return cast(Callable[[F], _CachedFunction[F]], func) + return cast(Callable[[F], CachedFunction[F]], func) def cachedList( *, cached_method_name: str, list_name: str, num_args: Optional[int] = None -) -> Callable[[F], _CachedFunction[F]]: +) -> Callable[[F], CachedFunction[F]]: """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. Used to do batch lookups for an already created cache. One of the arguments @@ -630,7 +630,7 @@ def cachedList( num_args=num_args, ) - return cast(Callable[[F], _CachedFunction[F]], func) + return cast(Callable[[F], CachedFunction[F]], func) def _get_cache_key_builder( diff --git a/tests/replication/test_account_validity.py b/tests/replication/test_account_validity.py index 408eb266b9..7353bbce96 100644 --- a/tests/replication/test_account_validity.py +++ b/tests/replication/test_account_validity.py @@ -33,6 +33,8 @@ class MockAccountValidityStore: ): self._api = api + api.register_cached_function(self.is_user_expired) + async def create_db(self): def create_table_txn(txn: LoggingTransaction): txn.execute( @@ -88,13 +90,13 @@ class MockAccountValidityStore: ), ) - txn.call_after(self.is_user_expired.invalidate, (user_id,)) - await self._api.run_db_interaction( "account_validity_set_expired_user", set_expired_user_txn, ) + await self._api.invalidate_cache(self.is_user_expired, (user_id,)) + class MockAccountValidity: def __init__( diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py new file mode 100644 index 0000000000..fa52e11630 --- /dev/null +++ b/tests/replication/test_module_cache_invalidation.py @@ -0,0 +1,88 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging, time + +import synapse +from synapse.module_api import ModuleApi, cached +from synapse.server import HomeServer + +from tests.replication._base import BaseMultiWorkerStreamTestCase + +logger = logging.getLogger(__name__) + +FIRST_VALUE = "one" +SECOND_VALUE = "two" + +KEY = "mykey" + +class TestCache: + current_value = FIRST_VALUE + + @cached() + async def cached_function(self, user_id: str) -> str: + print(self.current_value) + return self.current_value + +class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase): + + def test_module_cache_full_invalidation(self): + main_cache = TestCache() + self.hs.get_module_api().register_cached_function(main_cache.cached_function) + + worker_hs = self.make_worker_hs("synapse.app.generic_worker") + + worker_cache = TestCache() + worker_hs.get_module_api().register_cached_function(worker_cache.cached_function) + + self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) + self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))) + + main_cache.current_value = SECOND_VALUE + worker_cache.current_value = SECOND_VALUE + # No invalidation yet, should return the cached value on both the main process and the worker + self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) + self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))) + + self.reactor.advance(1) + + # Full invalidation on the main process, should be replicated on the worker that + # should returned the updated value too + self.get_success( + self.hs.get_module_api().invalidate_cache(main_cache.cached_function, (KEY,)) + ) + + self.reactor.advance(1) + + self.assertEqual(SECOND_VALUE, self.get_success(main_cache.cached_function(KEY))) + self.assertEqual(SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY))) + + # def test_module_cache_local_invalidation_only(self): + # main_cache = TestCache() + # worker_cache = TestCache() + + # worker_hs = self.make_worker_hs("synapse.app.generic_worker") + + # self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) + # self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))) + + # main_cache.current_value = SECOND_VALUE + # worker_cache.current_value = SECOND_VALUE + # # No local invalidation yet, should return the cached value on both the main process and the worker + # self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) + # self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))) + + # # local invalidation on the main process, worker should still return the cached value + # main_cache.cached_function.invalidate((KEY,)) + # self.assertEqual(SECOND_VALUE, self.get_success(main_cache.cached_function(KEY))) + # self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))) |