diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py
new file mode 100644
index 0000000000..fa52e11630
--- /dev/null
+++ b/tests/replication/test_module_cache_invalidation.py
@@ -0,0 +1,88 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging, time
+
+import synapse
+from synapse.module_api import ModuleApi, cached
+from synapse.server import HomeServer
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+logger = logging.getLogger(__name__)
+
+FIRST_VALUE = "one"
+SECOND_VALUE = "two"
+
+KEY = "mykey"
+
+class TestCache:
+ current_value = FIRST_VALUE
+
+ @cached()
+ async def cached_function(self, user_id: str) -> str:
+ print(self.current_value)
+ return self.current_value
+
+class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
+
+ def test_module_cache_full_invalidation(self):
+ main_cache = TestCache()
+ self.hs.get_module_api().register_cached_function(main_cache.cached_function)
+
+ worker_hs = self.make_worker_hs("synapse.app.generic_worker")
+
+ worker_cache = TestCache()
+ worker_hs.get_module_api().register_cached_function(worker_cache.cached_function)
+
+ self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+ self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
+
+ main_cache.current_value = SECOND_VALUE
+ worker_cache.current_value = SECOND_VALUE
+ # No invalidation yet, should return the cached value on both the main process and the worker
+ self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+ self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
+
+ self.reactor.advance(1)
+
+ # Full invalidation on the main process, should be replicated on the worker that
+ # should returned the updated value too
+ self.get_success(
+ self.hs.get_module_api().invalidate_cache(main_cache.cached_function, (KEY,))
+ )
+
+ self.reactor.advance(1)
+
+ self.assertEqual(SECOND_VALUE, self.get_success(main_cache.cached_function(KEY)))
+ self.assertEqual(SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY)))
+
+ # def test_module_cache_local_invalidation_only(self):
+ # main_cache = TestCache()
+ # worker_cache = TestCache()
+
+ # worker_hs = self.make_worker_hs("synapse.app.generic_worker")
+
+ # self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+ # self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
+
+ # main_cache.current_value = SECOND_VALUE
+ # worker_cache.current_value = SECOND_VALUE
+ # # No local invalidation yet, should return the cached value on both the main process and the worker
+ # self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+ # self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
+
+ # # local invalidation on the main process, worker should still return the cached value
+ # main_cache.cached_function.invalidate((KEY,))
+ # self.assertEqual(SECOND_VALUE, self.get_success(main_cache.cached_function(KEY)))
+ # self.assertEqual(FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)))
|