diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index e30f9c76d4..3f85a33344 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -95,7 +95,7 @@ class SQLBaseStore(metaclass=ABCMeta):
def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
- ) -> None:
+ ) -> bool:
"""Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache.
@@ -115,7 +115,7 @@ class SQLBaseStore(metaclass=ABCMeta):
except AttributeError:
# We probably haven't pulled in the cache in this worker,
# which is fine.
- return
+ return False
if key is None:
cache.invalidate_all()
@@ -125,6 +125,8 @@ class SQLBaseStore(metaclass=ABCMeta):
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
invalidate_method(tuple(key))
+ return True
+
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
"""
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 12e9a42382..efe9f3ad88 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -33,7 +33,7 @@ from synapse.storage.database import (
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.util.caches.descriptors import _CachedFunction
+from synapse.util.caches.descriptors import CachedFunction
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -91,6 +91,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
else:
self._cache_id_gen = None
+ self.external_cached_functions = {}
+
+ def register_external_cached_function(self, cache_name, func):
+ self.external_cached_functions[cache_name] = func
+
async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
@@ -178,7 +183,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
members_changed = set(row.keys[1:])
self._invalidate_state_caches(room_id, members_changed)
else:
- self._attempt_to_invalidate_cache(row.cache_func, row.keys)
+ res = self._attempt_to_invalidate_cache(row.cache_func, row.keys)
+ if not res:
+ external_func = self.external_cached_functions[row.cache_func]
+ if external_func:
+ external_func.invalidate(row.keys)
super().process_replication_rows(stream_name, instance_name, token, rows)
@@ -269,9 +278,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
return
cache_func.invalidate(keys)
- await self.db_pool.runInteraction(
- "invalidate_cache_and_stream",
- self._send_invalidation_to_replication,
+ await self.send_invalidation_to_replication(
cache_func.__name__,
keys,
)
@@ -279,7 +286,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
def _invalidate_cache_and_stream(
self,
txn: LoggingTransaction,
- cache_func: _CachedFunction,
+ cache_func: CachedFunction,
keys: Tuple[Any, ...],
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
@@ -293,7 +300,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
def _invalidate_all_cache_and_stream(
- self, txn: LoggingTransaction, cache_func: _CachedFunction
+ self, txn: LoggingTransaction, cache_func: CachedFunction
) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -334,6 +341,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn, CURRENT_STATE_CACHE_NAME, [room_id]
)
+ async def send_invalidation_to_replication(self, cache_name, keys):
+ await self.db_pool.runInteraction(
+ "send_invalidation_to_replication",
+ self._send_invalidation_to_replication,
+ cache_name,
+ keys,
+ )
+
def _send_invalidation_to_replication(
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
) -> None:
|