summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py23
-rw-r--r--synapse/storage/databases/main/cache.py20
2 files changed, 32 insertions, 11 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index e30f9c76d4..303a5d5298 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,12 +15,13 @@
 # limitations under the License.
 import logging
 from abc import ABCMeta
-from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
+from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union
 
 from synapse.storage.database import make_in_list_sql_clause  # noqa: F401; noqa: F401
 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.types import get_domain_from_id
 from synapse.util import json_decoder
+from synapse.util.caches.descriptors import CachedFunction
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -47,6 +48,8 @@ class SQLBaseStore(metaclass=ABCMeta):
         self.database_engine = database.engine
         self.db_pool = database
 
+        self.external_cached_functions: Dict[str, CachedFunction] = {}
+
     def process_replication_rows(
         self,
         stream_name: str,
@@ -95,7 +98,7 @@ class SQLBaseStore(metaclass=ABCMeta):
 
     def _attempt_to_invalidate_cache(
         self, cache_name: str, key: Optional[Collection[Any]]
-    ) -> None:
+    ) -> bool:
         """Attempts to invalidate the cache of the given name, ignoring if the
         cache doesn't exist. Mainly used for invalidating caches on workers,
         where they may not have the cache.
@@ -113,9 +116,12 @@ class SQLBaseStore(metaclass=ABCMeta):
         try:
             cache = getattr(self, cache_name)
         except AttributeError:
-            # We probably haven't pulled in the cache in this worker,
-            # which is fine.
-            return
+            # Check if an externally defined module cache has been registered
+            cache = self.external_cached_functions.get(cache_name)
+            if not cache:
+                # We probably haven't pulled in the cache in this worker,
+                # which is fine.
+                return False
 
         if key is None:
             cache.invalidate_all()
@@ -125,6 +131,13 @@ class SQLBaseStore(metaclass=ABCMeta):
             invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
             invalidate_method(tuple(key))
 
+        return True
+
+    def register_external_cached_function(
+        self, cache_name: str, func: CachedFunction
+    ) -> None:
+        self.external_cached_functions[cache_name] = func
+
 
 def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
     """
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 12e9a42382..2c421151c1 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -33,7 +33,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.util.caches.descriptors import _CachedFunction
+from synapse.util.caches.descriptors import CachedFunction
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -269,9 +269,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             return
 
         cache_func.invalidate(keys)
-        await self.db_pool.runInteraction(
-            "invalidate_cache_and_stream",
-            self._send_invalidation_to_replication,
+        await self.send_invalidation_to_replication(
             cache_func.__name__,
             keys,
         )
@@ -279,7 +277,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
     def _invalidate_cache_and_stream(
         self,
         txn: LoggingTransaction,
-        cache_func: _CachedFunction,
+        cache_func: CachedFunction,
         keys: Tuple[Any, ...],
     ) -> None:
         """Invalidates the cache and adds it to the cache stream so slaves
@@ -293,7 +291,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
 
     def _invalidate_all_cache_and_stream(
-        self, txn: LoggingTransaction, cache_func: _CachedFunction
+        self, txn: LoggingTransaction, cache_func: CachedFunction
     ) -> None:
         """Invalidates the entire cache and adds it to the cache stream so slaves
         will know to invalidate their caches.
@@ -334,6 +332,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 txn, CURRENT_STATE_CACHE_NAME, [room_id]
             )
 
+    async def send_invalidation_to_replication(
+        self, cache_name: str, keys: Optional[Collection[Any]]
+    ) -> None:
+        await self.db_pool.runInteraction(
+            "send_invalidation_to_replication",
+            self._send_invalidation_to_replication,
+            cache_name,
+            keys,
+        )
+
     def _send_invalidation_to_replication(
         self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
     ) -> None: