diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index 4dc5da3fe8..342a87a46b 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -16,11 +16,10 @@
import itertools
import logging
-from typing import Any, Iterable, Optional, Tuple
-
-from twisted.internet import defer
+from typing import Any, Iterable, Optional
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
@@ -33,47 +32,58 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationWorkerStore(SQLBaseStore):
- def get_all_updated_caches(self, last_id, current_id, limit):
+ def __init__(self, database: Database, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self._instance_name = hs.get_instance_name()
+
+ async def get_all_updated_caches(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ):
+ """Fetches cache invalidation rows between the two given IDs written
+ by the given instance. Returns at most `limit` rows.
+ """
+
if last_id == current_id:
- return defer.succeed([])
+ return []
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
- sql = (
- "SELECT stream_id, cache_func, keys, invalidation_ts"
- " FROM cache_invalidation_stream"
- " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, limit))
+ sql = """
+ SELECT stream_id, cache_func, keys, invalidation_ts
+ FROM cache_invalidation_stream_by_instance
+ WHERE stream_id > ? AND instance_name = ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_id, instance_name, limit))
return txn.fetchall()
- return self.db.runInteraction(
+ return await self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
+ if stream_name == "caches":
+ if self._cache_id_gen:
+ self._cache_id_gen.advance(instance_name, token)
-class CacheInvalidationStore(CacheInvalidationWorkerStore):
- async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
- """Invalidates the cache and adds it to the cache stream so slaves
- will know to invalidate their caches.
+ for row in rows:
+ if row.cache_func == CURRENT_STATE_CACHE_NAME:
+ if row.keys is None:
+ raise Exception(
+ "Can't send an 'invalidate all' for current state cache"
+ )
- This should only be used to invalidate caches where slaves won't
- otherwise know from other replication streams that the cache should
- be invalidated.
- """
- cache_func = getattr(self, cache_name, None)
- if not cache_func:
- return
-
- cache_func.invalidate(keys)
- await self.runInteraction(
- "invalidate_cache_and_stream",
- self._send_invalidation_to_replication,
- cache_func.__name__,
- keys,
- )
+ room_id = row.keys[0]
+ members_changed = set(row.keys[1:])
+ self._invalidate_state_caches(room_id, members_changed)
+ else:
+ self._attempt_to_invalidate_cache(row.cache_func, row.keys)
+
+ super().process_replication_rows(stream_name, instance_name, token, rows)
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
"""Invalidates the cache and adds it to the cache stream so slaves
@@ -147,10 +157,7 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
- ctx = self._cache_id_gen.get_next()
- stream_id = ctx.__enter__()
- txn.call_on_exception(ctx.__exit__, None, None, None)
- txn.call_after(ctx.__exit__, None, None, None)
+ stream_id = self._cache_id_gen.get_next_txn(txn)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
if keys is not None:
@@ -158,17 +165,18 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
self.db.simple_insert_txn(
txn,
- table="cache_invalidation_stream",
+ table="cache_invalidation_stream_by_instance",
values={
"stream_id": stream_id,
+ "instance_name": self._instance_name,
"cache_func": cache_name,
"keys": keys,
"invalidation_ts": self.clock.time_msec(),
},
)
- def get_cache_stream_token(self):
+ def get_cache_stream_token(self, instance_name):
if self._cache_id_gen:
- return self._cache_id_gen.get_current_token()
+ return self._cache_id_gen.get_current_token(instance_name)
else:
return 0
|