summary refs log tree commit diff
path: root/synapse/module_api
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/module_api')
-rw-r--r--synapse/module_api/__init__.py33
1 files changed, 32 insertions, 1 deletions
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 9287c0fb8d..59755bff6d 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -125,7 +125,7 @@ from synapse.types import (
 )
 from synapse.util import Clock
 from synapse.util.async_helpers import maybe_awaitable
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import CachedFunction, cached
 from synapse.util.frozenutils import freeze
 
 if TYPE_CHECKING:
@@ -836,6 +836,37 @@ class ModuleApi:
             self._store.db_pool.runInteraction(desc, func, *args, **kwargs)  # type: ignore[arg-type]
         )
 
+    def register_cached_function(self, cached_func: CachedFunction) -> None:
+        """Register a cached function that should be invalidated across workers.
+        Invalidation local to a worker can be done directly using `cached_func.invalidate`,
+        however invalidation that needs to go to other workers needs to call `invalidate_cache`
+        on the module API instead.
+
+        Args:
+            cached_function: The cached function that will be registered to receive invalidation
+            locally and from other workers.
+        """
+        self._store.register_external_cached_function(
+            f"{cached_func.__module__}.{cached_func.__name__}", cached_func
+        )
+
+    async def invalidate_cache(
+        self, cached_func: CachedFunction, keys: Tuple[Any, ...]
+    ) -> None:
+        """Invalidate a cache entry of a cached function across workers. The cached function
+        needs to be registered on all workers first with `register_cached_function`.
+
+        Args:
+            cached_function: The cached function that needs an invalidation
+            keys: keys of the entry to invalidate, usually matching the arguments of the
+            cached function.
+        """
+        cached_func.invalidate(keys)
+        await self._store.send_invalidation_to_replication(
+            f"{cached_func.__module__}.{cached_func.__name__}",
+            keys,
+        )
+
     async def complete_sso_login_async(
         self,
         registered_user_id: str,