diff options
Diffstat (limited to 'synapse/storage/databases/main/cache.py')
-rw-r--r-- | synapse/storage/databases/main/cache.py | 29 |
1 files changed, 22 insertions, 7 deletions
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 12e9a42382..efe9f3ad88 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -33,7 +33,7 @@ from synapse.storage.database import ( ) from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator -from synapse.util.caches.descriptors import _CachedFunction +from synapse.util.caches.descriptors import CachedFunction from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -91,6 +91,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore): else: self._cache_id_gen = None + self.external_cached_functions = {} + + def register_external_cached_function(self, cache_name, func): + self.external_cached_functions[cache_name] = func + async def get_all_updated_caches( self, instance_name: str, last_id: int, current_id: int, limit: int ) -> Tuple[List[Tuple[int, tuple]], int, bool]: @@ -178,7 +183,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore): members_changed = set(row.keys[1:]) self._invalidate_state_caches(room_id, members_changed) else: - self._attempt_to_invalidate_cache(row.cache_func, row.keys) + res = self._attempt_to_invalidate_cache(row.cache_func, row.keys) + if not res: + external_func = self.external_cached_functions[row.cache_func] + if external_func: + external_func.invalidate(row.keys) super().process_replication_rows(stream_name, instance_name, token, rows) @@ -269,9 +278,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): return cache_func.invalidate(keys) - await self.db_pool.runInteraction( - "invalidate_cache_and_stream", - self._send_invalidation_to_replication, + await self.send_invalidation_to_replication( cache_func.__name__, keys, ) @@ -279,7 +286,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): def _invalidate_cache_and_stream( self, txn: LoggingTransaction, - cache_func: _CachedFunction, + cache_func: CachedFunction, keys: Tuple[Any, ...], ) -> None: """Invalidates the cache and adds it to the cache stream so slaves @@ -293,7 +300,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._send_invalidation_to_replication(txn, cache_func.__name__, keys) def _invalidate_all_cache_and_stream( - self, txn: LoggingTransaction, cache_func: _CachedFunction + self, txn: LoggingTransaction, cache_func: CachedFunction ) -> None: """Invalidates the entire cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -334,6 +341,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore): txn, CURRENT_STATE_CACHE_NAME, [room_id] ) + async def send_invalidation_to_replication(self, cache_name, keys): + await self.db_pool.runInteraction( + "send_invalidation_to_replication", + self._send_invalidation_to_replication, + cache_name, + keys, + ) + def _send_invalidation_to_replication( self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]] ) -> None: |