summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/replication/test_account_validity.py6
-rw-r--r--tests/replication/test_module_cache_invalidation.py88
2 files changed, 92 insertions, 2 deletions
diff --git a/tests/replication/test_account_validity.py b/tests/replication/test_account_validity.py
index 408eb266b9..7353bbce96 100644
--- a/tests/replication/test_account_validity.py
+++ b/tests/replication/test_account_validity.py
@@ -33,6 +33,8 @@ class MockAccountValidityStore:
     ):
         self._api = api
 
+        api.register_cached_function(self.is_user_expired)
+
     async def create_db(self):
         def create_table_txn(txn: LoggingTransaction):
             txn.execute(
@@ -88,13 +90,13 @@ class MockAccountValidityStore:
                 ),
             )
 
-            txn.call_after(self.is_user_expired.invalidate, (user_id,))
-
         await self._api.run_db_interaction(
             "account_validity_set_expired_user",
             set_expired_user_txn,
         )
 
+        await self._api.invalidate_cache(self.is_user_expired, (user_id,))
+
 
 class MockAccountValidity:
     def __init__(
diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py
new file mode 100644
index 0000000000..fa52e11630
--- /dev/null
+++ b/tests/replication/test_module_cache_invalidation.py
@@ -0,0 +1,88 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging, time
+
+import synapse
+from synapse.module_api import ModuleApi, cached
+from synapse.server import HomeServer
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+logger = logging.getLogger(__name__)
+
+FIRST_VALUE = "one"
+SECOND_VALUE = "two"
+
+KEY = "mykey"
+
+class TestCache:
+    current_value = FIRST_VALUE
+
+    @cached()
+    async def cached_function(self, user_id: str) -> str:
+        print(self.current_value)
+        return self.current_value
+
+class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
+
+    def test_module_cache_full_invalidation(self):
+        main_cache = TestCache()
+        self.hs.get_module_api().register_cached_function(main_cache.cached_function)
+
+        worker_hs = self.make_worker_hs("synapse.app.generic_worker")
+
+        worker_cache = TestCache()
+        worker_hs.get_module_api().register_cached_function(worker_cache.cached_function)
+
+        self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+        self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
+
+        main_cache.current_value = SECOND_VALUE
+        worker_cache.current_value = SECOND_VALUE
+        # No invalidation yet, should return the cached value on both the main process and the worker
+        self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+        self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
+
+        self.reactor.advance(1)
+
+        # Full invalidation on the main process, should be replicated on the worker that
+        # should returned the updated value too
+        self.get_success(
+            self.hs.get_module_api().invalidate_cache(main_cache.cached_function, (KEY,))
+        )
+
+        self.reactor.advance(1)
+        
+        self.assertEqual(SECOND_VALUE, self.get_success(main_cache.cached_function(KEY)))
+        self.assertEqual(SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY)))
+
+    # def test_module_cache_local_invalidation_only(self):
+    #     main_cache = TestCache()
+    #     worker_cache = TestCache()
+
+    #     worker_hs = self.make_worker_hs("synapse.app.generic_worker")
+
+    #     self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+    #     self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
+
+    #     main_cache.current_value = SECOND_VALUE
+    #     worker_cache.current_value = SECOND_VALUE
+    #     # No local invalidation yet, should return the cached value on both the main process and the worker
+    #     self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+    #     self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
+
+    #     # local invalidation on the main process, worker should still return the cached value
+    #     main_cache.cached_function.invalidate((KEY,))
+    #     self.assertEqual(SECOND_VALUE, self.get_success(main_cache.cached_function(KEY)))
+    #     self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))