diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/replication/test_account_validity.py | 6 | ||||
-rw-r--r-- | tests/replication/test_module_cache_invalidation.py | 88 |
2 files changed, 92 insertions, 2 deletions
diff --git a/tests/replication/test_account_validity.py b/tests/replication/test_account_validity.py index 408eb266b9..7353bbce96 100644 --- a/tests/replication/test_account_validity.py +++ b/tests/replication/test_account_validity.py @@ -33,6 +33,8 @@ class MockAccountValidityStore: ): self._api = api + api.register_cached_function(self.is_user_expired) + async def create_db(self): def create_table_txn(txn: LoggingTransaction): txn.execute( @@ -88,13 +90,13 @@ class MockAccountValidityStore: ), ) - txn.call_after(self.is_user_expired.invalidate, (user_id,)) - await self._api.run_db_interaction( "account_validity_set_expired_user", set_expired_user_txn, ) + await self._api.invalidate_cache(self.is_user_expired, (user_id,)) + class MockAccountValidity: def __init__( diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py new file mode 100644 index 0000000000..fa52e11630 --- /dev/null +++ b/tests/replication/test_module_cache_invalidation.py @@ -0,0 +1,88 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging, time + +import synapse +from synapse.module_api import ModuleApi, cached +from synapse.server import HomeServer + +from tests.replication._base import BaseMultiWorkerStreamTestCase + +logger = logging.getLogger(__name__) + +FIRST_VALUE = "one" +SECOND_VALUE = "two" + +KEY = "mykey" + +class TestCache: + current_value = FIRST_VALUE + + @cached() + async def cached_function(self, user_id: str) -> str: + print(self.current_value) + return self.current_value + +class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase): + + def test_module_cache_full_invalidation(self): + main_cache = TestCache() + self.hs.get_module_api().register_cached_function(main_cache.cached_function) + + worker_hs = self.make_worker_hs("synapse.app.generic_worker") + + worker_cache = TestCache() + worker_hs.get_module_api().register_cached_function(worker_cache.cached_function) + + self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) + self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))) + + main_cache.current_value = SECOND_VALUE + worker_cache.current_value = SECOND_VALUE + # No invalidation yet, should return the cached value on both the main process and the worker + self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) + self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))) + + self.reactor.advance(1) + + # Full invalidation on the main process, should be replicated on the worker that + # should returned the updated value too + self.get_success( + self.hs.get_module_api().invalidate_cache(main_cache.cached_function, (KEY,)) + ) + + self.reactor.advance(1) + + self.assertEqual(SECOND_VALUE, self.get_success(main_cache.cached_function(KEY))) + self.assertEqual(SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY))) + + # def test_module_cache_local_invalidation_only(self): + # main_cache = TestCache() + # worker_cache = TestCache() + + # worker_hs = self.make_worker_hs("synapse.app.generic_worker") + + # self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) + # self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))) + + # main_cache.current_value = SECOND_VALUE + # worker_cache.current_value = SECOND_VALUE + # # No local invalidation yet, should return the cached value on both the main process and the worker + # self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) + # self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))) + + # # local invalidation on the main process, worker should still return the cached value + # main_cache.cached_function.invalidate((KEY,)) + # self.assertEqual(SECOND_VALUE, self.get_success(main_cache.cached_function(KEY))) + # self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))) |