summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16510.misc1
-rw-r--r--synapse/replication/tcp/streams/events.py45
-rw-r--r--synapse/storage/databases/main/cache.py8
-rw-r--r--tests/replication/tcp/streams/test_events.py91
4 files changed, 115 insertions, 30 deletions
diff --git a/changelog.d/16510.misc b/changelog.d/16510.misc
new file mode 100644
index 0000000000..5556b5d74c
--- /dev/null
+++ b/changelog.d/16510.misc
@@ -0,0 +1 @@
+Improve replication performance when purging rooms.
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index ad9b760713..da6d948e1b 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import heapq
+from collections import defaultdict
 from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
 
 import attr
@@ -51,8 +52,19 @@ data part are:
  * The state_key of the state which has changed
  * The event id of the new state
 
+A "state-all" row is sent whenever the "current state" in a room changes, but there are
+too many state updates for a particular room in the same update. This replaces any
+"state" rows on a per-room basis. The fields in the data part are:
+
+* The room id for the state changes
+
 """
 
+# Any room with more than _MAX_STATE_UPDATES_PER_ROOM will send a EventsStreamAllStateRow
+# instead of individual EventsStreamEventRow. This is predominantly useful when
+# purging large rooms.
+_MAX_STATE_UPDATES_PER_ROOM = 150
+
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class EventsStreamRow:
@@ -111,9 +123,17 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
     event_id: Optional[str]
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class EventsStreamAllStateRow(BaseEventsStreamRow):
+    TypeId = "state-all"
+
+    room_id: str
+
+
 _EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
     EventsStreamEventRow,
     EventsStreamCurrentStateRow,
+    EventsStreamAllStateRow,
 )
 
 TypeToRow = {Row.TypeId: Row for Row in _EventRows}
@@ -213,9 +233,28 @@ class EventsStream(Stream):
             if stream_id <= upper_limit
         )
 
+        # Separate out rooms that have many state updates, listeners should clear
+        # all state for those rooms.
+        state_updates_by_room = defaultdict(list)
+        for stream_id, room_id, _type, _state_key, _event_id in state_rows:
+            state_updates_by_room[room_id].append(stream_id)
+
+        state_all_rows = [
+            (stream_ids[-1], room_id)
+            for room_id, stream_ids in state_updates_by_room.items()
+            if len(stream_ids) >= _MAX_STATE_UPDATES_PER_ROOM
+        ]
+        state_all_updates: Iterable[Tuple[int, Tuple]] = (
+            (max_stream_id, (EventsStreamAllStateRow.TypeId, (room_id,)))
+            for (max_stream_id, room_id) in state_all_rows
+        )
+
+        # Any remaining state updates are sent individually.
+        state_all_rooms = {room_id for _, room_id in state_all_rows}
         state_updates: Iterable[Tuple[int, Tuple]] = (
             (stream_id, (EventsStreamCurrentStateRow.TypeId, rest))
             for (stream_id, *rest) in state_rows
+            if rest[0] not in state_all_rooms
         )
 
         ex_outliers_updates: Iterable[Tuple[int, Tuple]] = (
@@ -224,7 +263,11 @@ class EventsStream(Stream):
         )
 
         # we need to return a sorted list, so merge them together.
-        updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates))
+        updates = list(
+            heapq.merge(
+                event_updates, state_all_updates, state_updates, ex_outliers_updates
+            )
+        )
         return updates, upper_limit, limited
 
     @classmethod
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2fbd389c71..4d0470ffd9 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -23,6 +23,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
 from synapse.replication.tcp.streams import BackfillStream, CachesStream
 from synapse.replication.tcp.streams.events import (
     EventsStream,
+    EventsStreamAllStateRow,
     EventsStreamCurrentStateRow,
     EventsStreamEventRow,
     EventsStreamRow,
@@ -264,6 +265,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                     (data.state_key,)
                 )
                 self.get_rooms_for_user.invalidate((data.state_key,))  # type: ignore[attr-defined]
+        elif row.type == EventsStreamAllStateRow.TypeId:
+            assert isinstance(data, EventsStreamAllStateRow)
+            # Similar to the above, but the entire caches are invalidated. This is
+            # unfortunate for the membership caches, but should recover quickly.
+            self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)  # type: ignore[attr-defined]
+            self.get_rooms_for_user_with_stream_ordering.invalidate_all()  # type: ignore[attr-defined]
+            self.get_rooms_for_user.invalidate_all()  # type: ignore[attr-defined]
         else:
             raise Exception("Unknown events stream row type %s" % (row.type,))
 
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 128fc3e046..b8ab4ee54b 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -14,6 +14,8 @@
 
 from typing import Any, List, Optional
 
+from parameterized import parameterized
+
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes, Membership
@@ -21,6 +23,8 @@ from synapse.events import EventBase
 from synapse.replication.tcp.commands import RdataCommand
 from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
 from synapse.replication.tcp.streams.events import (
+    _MAX_STATE_UPDATES_PER_ROOM,
+    EventsStreamAllStateRow,
     EventsStreamCurrentStateRow,
     EventsStreamEventRow,
     EventsStreamRow,
@@ -106,11 +110,21 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
-    def test_update_function_huge_state_change(self) -> None:
+    @parameterized.expand(
+        [(_STREAM_UPDATE_TARGET_ROW_COUNT, False), (_MAX_STATE_UPDATES_PER_ROOM, True)]
+    )
+    def test_update_function_huge_state_change(
+        self, num_state_changes: int, collapse_state_changes: bool
+    ) -> None:
         """Test replication with many state events
 
         Ensures that all events are correctly replicated when there are lots of
         state change rows to be replicated.
+
+        Args:
+            num_state_changes: The number of state changes to create.
+            collapse_state_changes: Whether the state changes are expected to be
+                collapsed or not.
         """
 
         # we want to generate lots of state changes at a single stream ID.
@@ -145,7 +159,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         events = [
             self._inject_state_event(sender=OTHER_USER)
-            for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT)
+            for _ in range(num_state_changes)
         ]
 
         self.replicate()
@@ -202,8 +216,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
             row for row in self.test_handler.received_rdata_rows if row[0] == "events"
         ]
 
-        # first check the first two rows, which should be state1
-
+        # first check the first two rows, which should be the state1 event.
         stream_name, token, row = received_rows.pop(0)
         self.assertEqual("events", stream_name)
         self.assertIsInstance(row, EventsStreamRow)
@@ -217,7 +230,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
         self.assertEqual(row.data.event_id, state1.event_id)
 
-        # now the last two rows, which should be state2
+        # now the last two rows, which should be the state2 event.
         stream_name, token, row = received_rows.pop(-2)
         self.assertEqual("events", stream_name)
         self.assertIsInstance(row, EventsStreamRow)
@@ -231,34 +244,54 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
         self.assertEqual(row.data.event_id, state2.event_id)
 
-        # that should leave us with the rows for the PL event
-        self.assertEqual(len(received_rows), len(events) + 2)
+        # Based on the number of
+        if collapse_state_changes:
+            # that should leave us with the rows for the PL event, the state changes
+            # get collapsed into a single row.
+            self.assertEqual(len(received_rows), 2)
 
-        stream_name, token, row = received_rows.pop(0)
-        self.assertEqual("events", stream_name)
-        self.assertIsInstance(row, EventsStreamRow)
-        self.assertEqual(row.type, "ev")
-        self.assertIsInstance(row.data, EventsStreamEventRow)
-        self.assertEqual(row.data.event_id, pl_event.event_id)
+            stream_name, token, row = received_rows.pop(0)
+            self.assertEqual("events", stream_name)
+            self.assertIsInstance(row, EventsStreamRow)
+            self.assertEqual(row.type, "ev")
+            self.assertIsInstance(row.data, EventsStreamEventRow)
+            self.assertEqual(row.data.event_id, pl_event.event_id)
 
-        # the state rows are unsorted
-        state_rows: List[EventsStreamCurrentStateRow] = []
-        for stream_name, _, row in received_rows:
+            stream_name, token, row = received_rows.pop(0)
+            self.assertIsInstance(row, EventsStreamRow)
+            self.assertEqual(row.type, "state-all")
+            self.assertIsInstance(row.data, EventsStreamAllStateRow)
+            self.assertEqual(row.data.room_id, state2.room_id)
+
+        else:
+            # that should leave us with the rows for the PL event
+            self.assertEqual(len(received_rows), len(events) + 2)
+
+            stream_name, token, row = received_rows.pop(0)
             self.assertEqual("events", stream_name)
             self.assertIsInstance(row, EventsStreamRow)
-            self.assertEqual(row.type, "state")
-            self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
-            state_rows.append(row.data)
-
-        state_rows.sort(key=lambda r: r.state_key)
-
-        sr = state_rows.pop(0)
-        self.assertEqual(sr.type, EventTypes.PowerLevels)
-        self.assertEqual(sr.event_id, pl_event.event_id)
-        for sr in state_rows:
-            self.assertEqual(sr.type, "test_state_event")
-            # "None" indicates the state has been deleted
-            self.assertIsNone(sr.event_id)
+            self.assertEqual(row.type, "ev")
+            self.assertIsInstance(row.data, EventsStreamEventRow)
+            self.assertEqual(row.data.event_id, pl_event.event_id)
+
+            # the state rows are unsorted
+            state_rows: List[EventsStreamCurrentStateRow] = []
+            for stream_name, _, row in received_rows:
+                self.assertEqual("events", stream_name)
+                self.assertIsInstance(row, EventsStreamRow)
+                self.assertEqual(row.type, "state")
+                self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
+                state_rows.append(row.data)
+
+            state_rows.sort(key=lambda r: r.state_key)
+
+            sr = state_rows.pop(0)
+            self.assertEqual(sr.type, EventTypes.PowerLevels)
+            self.assertEqual(sr.event_id, pl_event.event_id)
+            for sr in state_rows:
+                self.assertEqual(sr.type, "test_state_event")
+                # "None" indicates the state has been deleted
+                self.assertIsNone(sr.event_id)
 
     def test_update_function_state_row_limit(self) -> None:
         """Test replication with many state events over several stream ids."""