summary refs log tree commit diff
path: root/synapse/storage/data_stores/main/cache.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-05-07 13:51:08 +0100
committerGitHub <noreply@github.com>2020-05-07 13:51:08 +0100
commitd7983b63a6746d92225295f1e9d521f847cf8ba7 (patch)
tree13d581210d94c26bd75036592996f6f53f7d4bb2 /synapse/storage/data_stores/main/cache.py
parentMerge pull request #7398 from Starbix/alpine-3.11 (diff)
downloadsynapse-d7983b63a6746d92225295f1e9d521f847cf8ba7.tar.xz
Support any process writing to cache invalidation stream. (#7436)
Diffstat (limited to 'synapse/storage/data_stores/main/cache.py')
-rw-r--r--synapse/storage/data_stores/main/cache.py84
1 files changed, 46 insertions, 38 deletions
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py

index 4dc5da3fe8..342a87a46b 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py
@@ -16,11 +16,10 @@ import itertools import logging -from typing import Any, Iterable, Optional, Tuple - -from twisted.internet import defer +from typing import Any, Iterable, Optional from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.util.iterutils import batch_iter @@ -33,47 +32,58 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake" class CacheInvalidationWorkerStore(SQLBaseStore): - def get_all_updated_caches(self, last_id, current_id, limit): + def __init__(self, database: Database, db_conn, hs): + super().__init__(database, db_conn, hs) + + self._instance_name = hs.get_instance_name() + + async def get_all_updated_caches( + self, instance_name: str, last_id: int, current_id: int, limit: int + ): + """Fetches cache invalidation rows between the two given IDs written + by the given instance. Returns at most `limit` rows. + """ + if last_id == current_id: - return defer.succeed([]) + return [] def get_all_updated_caches_txn(txn): # We purposefully don't bound by the current token, as we want to # send across cache invalidations as quickly as possible. Cache # invalidations are idempotent, so duplicates are fine. - sql = ( - "SELECT stream_id, cache_func, keys, invalidation_ts" - " FROM cache_invalidation_stream" - " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, limit)) + sql = """ + SELECT stream_id, cache_func, keys, invalidation_ts + FROM cache_invalidation_stream_by_instance + WHERE stream_id > ? AND instance_name = ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, instance_name, limit)) return txn.fetchall() - return self.db.runInteraction( + return await self.db.runInteraction( "get_all_updated_caches", get_all_updated_caches_txn ) + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == "caches": + if self._cache_id_gen: + self._cache_id_gen.advance(instance_name, token) -class CacheInvalidationStore(CacheInvalidationWorkerStore): - async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): - """Invalidates the cache and adds it to the cache stream so slaves - will know to invalidate their caches. + for row in rows: + if row.cache_func == CURRENT_STATE_CACHE_NAME: + if row.keys is None: + raise Exception( + "Can't send an 'invalidate all' for current state cache" + ) - This should only be used to invalidate caches where slaves won't - otherwise know from other replication streams that the cache should - be invalidated. - """ - cache_func = getattr(self, cache_name, None) - if not cache_func: - return - - cache_func.invalidate(keys) - await self.runInteraction( - "invalidate_cache_and_stream", - self._send_invalidation_to_replication, - cache_func.__name__, - keys, - ) + room_id = row.keys[0] + members_changed = set(row.keys[1:]) + self._invalidate_state_caches(room_id, members_changed) + else: + self._attempt_to_invalidate_cache(row.cache_func, row.keys) + + super().process_replication_rows(stream_name, instance_name, token, rows) def _invalidate_cache_and_stream(self, txn, cache_func, keys): """Invalidates the cache and adds it to the cache stream so slaves @@ -147,10 +157,7 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore): # the transaction. However, we want to only get an ID when we want # to use it, here, so we need to call __enter__ manually, and have # __exit__ called after the transaction finishes. - ctx = self._cache_id_gen.get_next() - stream_id = ctx.__enter__() - txn.call_on_exception(ctx.__exit__, None, None, None) - txn.call_after(ctx.__exit__, None, None, None) + stream_id = self._cache_id_gen.get_next_txn(txn) txn.call_after(self.hs.get_notifier().on_new_replication_data) if keys is not None: @@ -158,17 +165,18 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore): self.db.simple_insert_txn( txn, - table="cache_invalidation_stream", + table="cache_invalidation_stream_by_instance", values={ "stream_id": stream_id, + "instance_name": self._instance_name, "cache_func": cache_name, "keys": keys, "invalidation_ts": self.clock.time_msec(), }, ) - def get_cache_stream_token(self): + def get_cache_stream_token(self, instance_name): if self._cache_id_gen: - return self._cache_id_gen.get_current_token() + return self._cache_id_gen.get_current_token(instance_name) else: return 0