diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index 342a87a46b..eac5a4e55b 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -16,8 +16,13 @@
import itertools
import logging
-from typing import Any, Iterable, Optional
+from typing import Any, Iterable, Optional, Tuple
+from synapse.api.constants import EventTypes
+from synapse.replication.tcp.streams.events import (
+ EventsStreamCurrentStateRow,
+ EventsStreamEventRow,
+)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
@@ -66,7 +71,22 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "caches":
+ if stream_name == "events":
+ for row in rows:
+ self._process_event_stream_row(token, row)
+ elif stream_name == "backfill":
+ for row in rows:
+ self._invalidate_caches_for_event(
+ -token,
+ row.event_id,
+ row.room_id,
+ row.type,
+ row.state_key,
+ row.redacts,
+ row.relates_to,
+ backfilled=True,
+ )
+ elif stream_name == "caches":
if self._cache_id_gen:
self._cache_id_gen.advance(instance_name, token)
@@ -85,6 +105,84 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows)
+ def _process_event_stream_row(self, token, row):
+ data = row.data
+
+ if row.type == EventsStreamEventRow.TypeId:
+ self._invalidate_caches_for_event(
+ token,
+ data.event_id,
+ data.room_id,
+ data.type,
+ data.state_key,
+ data.redacts,
+ data.relates_to,
+ backfilled=False,
+ )
+ elif row.type == EventsStreamCurrentStateRow.TypeId:
+ self._curr_state_delta_stream_cache.entity_has_changed(
+ row.data.room_id, token
+ )
+
+ if data.type == EventTypes.Member:
+ self.get_rooms_for_user_with_stream_ordering.invalidate(
+ (data.state_key,)
+ )
+ else:
+ raise Exception("Unknown events stream row type %s" % (row.type,))
+
+ def _invalidate_caches_for_event(
+ self,
+ stream_ordering,
+ event_id,
+ room_id,
+ etype,
+ state_key,
+ redacts,
+ relates_to,
+ backfilled,
+ ):
+ self._invalidate_get_event_cache(event_id)
+
+ self.get_latest_event_ids_in_room.invalidate((room_id,))
+
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
+
+ if not backfilled:
+ self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
+
+ if redacts:
+ self._invalidate_get_event_cache(redacts)
+
+ if etype == EventTypes.Member:
+ self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
+ self.get_invited_rooms_for_local_user.invalidate((state_key,))
+
+ if relates_to:
+ self.get_relations_for_event.invalidate_many((relates_to,))
+ self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
+ self.get_applicable_edit.invalidate((relates_to,))
+
+ async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
+ """Invalidates the cache and adds it to the cache stream so slaves
+ will know to invalidate their caches.
+
+ This should only be used to invalidate caches where slaves won't
+ otherwise know from other replication streams that the cache should
+ be invalidated.
+ """
+ cache_func = getattr(self, cache_name, None)
+ if not cache_func:
+ return
+
+ cache_func.invalidate(keys)
+ await self.db.runInteraction(
+ "invalidate_cache_and_stream",
+ self._send_invalidation_to_replication,
+ cache_func.__name__,
+ keys,
+ )
+
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
|