summary refs log tree commit diff
path: root/synapse/replication/slave/storage/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/slave/storage/events.py')
-rw-r--r--synapse/replication/slave/storage/events.py31
1 files changed, 22 insertions, 9 deletions
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index c57385d92f..b457c5563f 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -16,7 +16,10 @@
 import logging
 
 from synapse.api.constants import EventTypes
-from synapse.replication.tcp.streams.events import EventsStreamEventRow
+from synapse.replication.tcp.streams.events import (
+    EventsStreamCurrentStateRow,
+    EventsStreamEventRow,
+)
 from synapse.storage.event_federation import EventFederationWorkerStore
 from synapse.storage.event_push_actions import EventPushActionsWorkerStore
 from synapse.storage.events_worker import EventsWorkerStore
@@ -80,14 +83,7 @@ class SlavedEventStore(EventFederationWorkerStore,
         if stream_name == "events":
             self._stream_id_gen.advance(token)
             for row in rows:
-                if row.type != EventsStreamEventRow.TypeId:
-                    continue
-                data = row.data
-                self.invalidate_caches_for_event(
-                    token, data.event_id, data.room_id, data.type, data.state_key,
-                    data.redacts,
-                    backfilled=False,
-                )
+                self._process_event_stream_row(token, row)
         elif stream_name == "backfill":
             self._backfill_id_gen.advance(-token)
             for row in rows:
@@ -100,6 +96,23 @@ class SlavedEventStore(EventFederationWorkerStore,
             stream_name, token, rows
         )
 
+    def _process_event_stream_row(self, token, row):
+        data = row.data
+
+        if row.type == EventsStreamEventRow.TypeId:
+            self.invalidate_caches_for_event(
+                token, data.event_id, data.room_id, data.type, data.state_key,
+                data.redacts,
+                backfilled=False,
+            )
+        elif row.type == EventsStreamCurrentStateRow.TypeId:
+            if data.type == EventTypes.Member:
+                self.get_rooms_for_user_with_stream_ordering.invalidate(
+                    (data.state_key, ),
+                )
+        else:
+            raise Exception("Unknown events stream row type %s" % (row.type, ))
+
     def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
                                     etype, state_key, redacts, backfilled):
         self._invalidate_get_event_cache(event_id)