summary refs log tree commit diff
path: root/tests/replication
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication')
-rw-r--r--tests/replication/test_module_cache_invalidation.py79
1 files changed, 79 insertions, 0 deletions
diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py
new file mode 100644
index 0000000000..b93cae67d3
--- /dev/null
+++ b/tests/replication/test_module_cache_invalidation.py
@@ -0,0 +1,79 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import synapse
+from synapse.module_api import cached
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+logger = logging.getLogger(__name__)
+
+FIRST_VALUE = "one"
+SECOND_VALUE = "two"
+
+KEY = "mykey"
+
+
+class TestCache:
+    current_value = FIRST_VALUE
+
+    @cached()
+    async def cached_function(self, user_id: str) -> str:
+        return self.current_value
+
+
+class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+    ]
+
+    def test_module_cache_full_invalidation(self):
+        main_cache = TestCache()
+        self.hs.get_module_api().register_cached_function(main_cache.cached_function)
+
+        worker_hs = self.make_worker_hs("synapse.app.generic_worker")
+
+        worker_cache = TestCache()
+        worker_hs.get_module_api().register_cached_function(
+            worker_cache.cached_function
+        )
+
+        self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+        self.assertEqual(
+            FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
+        )
+
+        main_cache.current_value = SECOND_VALUE
+        worker_cache.current_value = SECOND_VALUE
+        # No invalidation yet, should return the cached value on both the main process and the worker
+        self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+        self.assertEqual(
+            FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
+        )
+
+        # Full invalidation on the main process, should be replicated on the worker that
+        # should returned the updated value too
+        self.get_success(
+            self.hs.get_module_api().invalidate_cache(
+                main_cache.cached_function, (KEY,)
+            )
+        )
+
+        self.assertEqual(
+            SECOND_VALUE, self.get_success(main_cache.cached_function(KEY))
+        )
+        self.assertEqual(
+            SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY))
+        )