summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/15233.misc1
-rw-r--r--synapse/events/snapshot.py159
-rw-r--r--synapse/storage/controllers/persist_events.py5
-rw-r--r--tests/events/test_snapshot.py3
-rw-r--r--tests/storage/test_event_chain.py5
-rw-r--r--tests/test_state.py11
6 files changed, 126 insertions, 58 deletions
diff --git a/changelog.d/15233.misc b/changelog.d/15233.misc
new file mode 100644
index 0000000000..1dff00bf3c
--- /dev/null
+++ b/changelog.d/15233.misc
@@ -0,0 +1 @@
+Replace `EventContext` fields `prev_group` and `delta_ids` with field `state_group_deltas`.
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index e7e8225b8e..a43498ed4d 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
 
 import attr
 from immutabledict import immutabledict
@@ -107,33 +107,32 @@ class EventContext(UnpersistedEventContextBase):
         state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None
             then this is the delta of the state between the two groups.
 
-        prev_group: If it is known, ``state_group``'s prev_group. Note that this being
-            None does not necessarily mean that ``state_group`` does not have
-            a prev_group!
+        state_group_deltas: If not empty, this is a dict collecting a mapping of the state
+            difference between state groups.
 
-            If the event is a state event, this is normally the same as
-            ``state_group_before_event``.
+            The keys are a tuple of two integers: the initial group and final state group.
+            The corresponding value is a state map representing the state delta between
+            these state groups.
 
-            If ``state_group`` is None (ie, the event is an outlier), ``prev_group``
-            will always also be ``None``.
+            The dictionary is expected to have at most two entries with state groups of:
 
-            Note that this *not* (necessarily) the state group associated with
-            ``_prev_state_ids``.
+            1. The state group before the event and after the event.
+            2. The state group preceding the state group before the event and the
+               state group before the event.
 
-        delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group``
-            and ``state_group``.
+            This information is collected and stored as part of an optimization for persisting
+            events.
 
         partial_state: if True, we may be storing this event with a temporary,
             incomplete state.
     """
 
     _storage: "StorageControllers"
+    state_group_deltas: Dict[Tuple[int, int], StateMap[str]]
     rejected: Optional[str] = None
     _state_group: Optional[int] = None
     state_group_before_event: Optional[int] = None
     _state_delta_due_to_event: Optional[StateMap[str]] = None
-    prev_group: Optional[int] = None
-    delta_ids: Optional[StateMap[str]] = None
     app_service: Optional[ApplicationService] = None
 
     partial_state: bool = False
@@ -145,16 +144,14 @@ class EventContext(UnpersistedEventContextBase):
         state_group_before_event: Optional[int],
         state_delta_due_to_event: Optional[StateMap[str]],
         partial_state: bool,
-        prev_group: Optional[int] = None,
-        delta_ids: Optional[StateMap[str]] = None,
+        state_group_deltas: Dict[Tuple[int, int], StateMap[str]],
     ) -> "EventContext":
         return EventContext(
             storage=storage,
             state_group=state_group,
             state_group_before_event=state_group_before_event,
             state_delta_due_to_event=state_delta_due_to_event,
-            prev_group=prev_group,
-            delta_ids=delta_ids,
+            state_group_deltas=state_group_deltas,
             partial_state=partial_state,
         )
 
@@ -163,7 +160,7 @@ class EventContext(UnpersistedEventContextBase):
         storage: "StorageControllers",
     ) -> "EventContext":
         """Return an EventContext instance suitable for persisting an outlier event"""
-        return EventContext(storage=storage)
+        return EventContext(storage=storage, state_group_deltas={})
 
     async def persist(self, event: EventBase) -> "EventContext":
         return self
@@ -183,13 +180,15 @@ class EventContext(UnpersistedEventContextBase):
             "state_group": self._state_group,
             "state_group_before_event": self.state_group_before_event,
             "rejected": self.rejected,
-            "prev_group": self.prev_group,
+            "state_group_deltas": _encode_state_group_delta(self.state_group_deltas),
             "state_delta_due_to_event": _encode_state_dict(
                 self._state_delta_due_to_event
             ),
-            "delta_ids": _encode_state_dict(self.delta_ids),
             "app_service_id": self.app_service.id if self.app_service else None,
             "partial_state": self.partial_state,
+            # add dummy delta_ids and prev_group for backwards compatibility
+            "delta_ids": None,
+            "prev_group": None,
         }
 
     @staticmethod
@@ -204,17 +203,24 @@ class EventContext(UnpersistedEventContextBase):
         Returns:
             The event context.
         """
+        # workaround for backwards/forwards compatibility: if the input doesn't have a value
+        # for "state_group_deltas" just assign an empty dict
+        state_group_deltas = input.get("state_group_deltas", None)
+        if state_group_deltas:
+            state_group_deltas = _decode_state_group_delta(state_group_deltas)
+        else:
+            state_group_deltas = {}
+
         context = EventContext(
             # We use the state_group and prev_state_id stuff to pull the
             # current_state_ids out of the DB and construct prev_state_ids.
             storage=storage,
             state_group=input["state_group"],
             state_group_before_event=input["state_group_before_event"],
-            prev_group=input["prev_group"],
+            state_group_deltas=state_group_deltas,
             state_delta_due_to_event=_decode_state_dict(
                 input["state_delta_due_to_event"]
             ),
-            delta_ids=_decode_state_dict(input["delta_ids"]),
             rejected=input["rejected"],
             partial_state=input.get("partial_state", False),
         )
@@ -349,7 +355,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
     _storage: "StorageControllers"
     state_group_before_event: Optional[int]
     state_group_after_event: Optional[int]
-    state_delta_due_to_event: Optional[dict]
+    state_delta_due_to_event: Optional[StateMap[str]]
     prev_group_for_state_group_before_event: Optional[int]
     delta_ids_to_state_group_before_event: Optional[StateMap[str]]
     partial_state: bool
@@ -380,26 +386,16 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
 
         events_and_persisted_context = []
         for event, unpersisted_context in amended_events_and_context:
-            if event.is_state():
-                context = EventContext(
-                    storage=unpersisted_context._storage,
-                    state_group=unpersisted_context.state_group_after_event,
-                    state_group_before_event=unpersisted_context.state_group_before_event,
-                    state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
-                    partial_state=unpersisted_context.partial_state,
-                    prev_group=unpersisted_context.state_group_before_event,
-                    delta_ids=unpersisted_context.state_delta_due_to_event,
-                )
-            else:
-                context = EventContext(
-                    storage=unpersisted_context._storage,
-                    state_group=unpersisted_context.state_group_after_event,
-                    state_group_before_event=unpersisted_context.state_group_before_event,
-                    state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
-                    partial_state=unpersisted_context.partial_state,
-                    prev_group=unpersisted_context.prev_group_for_state_group_before_event,
-                    delta_ids=unpersisted_context.delta_ids_to_state_group_before_event,
-                )
+            state_group_deltas = unpersisted_context._build_state_group_deltas()
+
+            context = EventContext(
+                storage=unpersisted_context._storage,
+                state_group=unpersisted_context.state_group_after_event,
+                state_group_before_event=unpersisted_context.state_group_before_event,
+                state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
+                partial_state=unpersisted_context.partial_state,
+                state_group_deltas=state_group_deltas,
+            )
             events_and_persisted_context.append((event, context))
         return events_and_persisted_context
 
@@ -452,11 +448,11 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
 
         # if the event isn't a state event the state group doesn't change
         if not self.state_delta_due_to_event:
-            state_group_after_event = self.state_group_before_event
+            self.state_group_after_event = self.state_group_before_event
 
         # otherwise if it is a state event we need to get a state group for it
         else:
-            state_group_after_event = await self._storage.state.store_state_group(
+            self.state_group_after_event = await self._storage.state.store_state_group(
                 event.event_id,
                 event.room_id,
                 prev_group=self.state_group_before_event,
@@ -464,16 +460,81 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
                 current_state_ids=None,
             )
 
+        state_group_deltas = self._build_state_group_deltas()
+
         return EventContext.with_state(
             storage=self._storage,
-            state_group=state_group_after_event,
+            state_group=self.state_group_after_event,
             state_group_before_event=self.state_group_before_event,
             state_delta_due_to_event=self.state_delta_due_to_event,
+            state_group_deltas=state_group_deltas,
             partial_state=self.partial_state,
-            prev_group=self.state_group_before_event,
-            delta_ids=self.state_delta_due_to_event,
         )
 
+    def _build_state_group_deltas(self) -> Dict[Tuple[int, int], StateMap]:
+        """
+        Collect deltas between the state groups associated with this context
+        """
+        state_group_deltas = {}
+
+        # if we know the state group before the event and after the event, add them and the
+        # state delta between them to state_group_deltas
+        if self.state_group_before_event and self.state_group_after_event:
+            # if we have the state groups we should have the delta
+            assert self.state_delta_due_to_event is not None
+            state_group_deltas[
+                (
+                    self.state_group_before_event,
+                    self.state_group_after_event,
+                )
+            ] = self.state_delta_due_to_event
+
+        # the state group before the event may also have a state group which precedes it, if
+        # we have that and the state group before the event, add them and the state
+        # delta between them to state_group_deltas
+        if (
+            self.prev_group_for_state_group_before_event
+            and self.state_group_before_event
+        ):
+            # if we have both state groups we should have the delta between them
+            assert self.delta_ids_to_state_group_before_event is not None
+            state_group_deltas[
+                (
+                    self.prev_group_for_state_group_before_event,
+                    self.state_group_before_event,
+                )
+            ] = self.delta_ids_to_state_group_before_event
+
+        return state_group_deltas
+
+
+def _encode_state_group_delta(
+    state_group_delta: Dict[Tuple[int, int], StateMap[str]]
+) -> List[Tuple[int, int, Optional[List[Tuple[str, str, str]]]]]:
+    if not state_group_delta:
+        return []
+
+    state_group_delta_encoded = []
+    for key, value in state_group_delta.items():
+        state_group_delta_encoded.append((key[0], key[1], _encode_state_dict(value)))
+
+    return state_group_delta_encoded
+
+
+def _decode_state_group_delta(
+    input: List[Tuple[int, int, List[Tuple[str, str, str]]]]
+) -> Dict[Tuple[int, int], StateMap[str]]:
+    if not input:
+        return {}
+
+    state_group_deltas = {}
+    for state_group_1, state_group_2, state_dict in input:
+        state_map = _decode_state_dict(state_dict)
+        assert state_map is not None
+        state_group_deltas[(state_group_1, state_group_2)] = state_map
+
+    return state_group_deltas
+
 
 def _encode_state_dict(
     state_dict: Optional[StateMap[str]],
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index f1d2c71c91..35c0680365 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -839,9 +839,8 @@ class EventsPersistenceStorageController:
                         "group" % (ev.event_id,)
                     )
                 continue
-
-            if ctx.prev_group:
-                state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
+            if ctx.state_group_deltas:
+                state_group_deltas.update(ctx.state_group_deltas)
 
         # We need to map the event_ids to their state groups. First, let's
         # check if the event is one we're persisting, in which case we can
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 6687c28e8f..b5e42f9600 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -101,8 +101,7 @@ class TestEventContext(unittest.HomeserverTestCase):
         self.assertEqual(
             context.state_group_before_event, d_context.state_group_before_event
         )
-        self.assertEqual(context.prev_group, d_context.prev_group)
-        self.assertEqual(context.delta_ids, d_context.delta_ids)
+        self.assertEqual(context.state_group_deltas, d_context.state_group_deltas)
         self.assertEqual(context.app_service, d_context.app_service)
 
         self.assertEqual(
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index e39b63edac..48ebfadaab 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -401,7 +401,10 @@ class EventChainStoreTestCase(HomeserverTestCase):
             assert persist_events_store is not None
             persist_events_store._store_event_txn(
                 txn,
-                [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
+                [
+                    (e, EventContext(self.hs.get_storage_controllers(), {}))
+                    for e in events
+                ],
             )
 
             # Actually call the function that calculates the auth chain stuff.
diff --git a/tests/test_state.py b/tests/test_state.py
index 7a49b87953..eded38c766 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -555,10 +555,15 @@ class StateTestCase(unittest.TestCase):
             (e.event_id for e in old_state + [event]), current_state_ids.values()
         )
 
-        self.assertIsNotNone(context.state_group_before_event)
+        assert context.state_group_before_event is not None
+        assert context.state_group is not None
+        self.assertEqual(
+            context.state_group_deltas.get(
+                (context.state_group_before_event, context.state_group)
+            ),
+            {(event.type, event.state_key): event.event_id},
+        )
         self.assertNotEqual(context.state_group_before_event, context.state_group)
-        self.assertEqual(context.state_group_before_event, context.prev_group)
-        self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
 
     @defer.inlineCallbacks
     def test_trivial_annotate_message(