summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/state/__init__.py78
1 files changed, 44 insertions, 34 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 0b181a2c5a..30bbc0437a 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -422,49 +422,59 @@ class StateHandler:
 
     async def compute_event_context_for_batched(
         self,
-        event: EventBase,
-        state_ids_before_event: StateMap[str],
-    ) -> EventContext:
+        events_and_context: List[Tuple[EventBase, EventContext]],
+        prev_group: int,
+        state_ids_before_event: StateMap,
+    ) -> List[Tuple[EventBase, EventContext]]:
         """
         Generate an event context for an event that has not yet been persisted to the
         database. Intended for use with events that are created to be persisted in a batch.
         Args:
-            event: the event the context is being computed for
-            state_ids_before_event: a state map consisting of the state ids of the events
-            created prior to this event.
-            current_state_group: the current state group before the event.
+            events_and_context: a list of events and their associated contexts
+            prev_group: the state group of the last event persisted before the batched events
+            were created
+            state_ids_before_event: a state map consisting of current state ids
         """
-        state_group_before_event_prev_group = None
-        deltas_to_state_group_before_event = None
-
-        # if the event is not state, we are set
-        if not event.is_state():
-            return EventContext.without_state_group(
-                storage=self._storage_controllers,
-                state_delta_due_to_event={},
-                prev_group=state_group_before_event_prev_group,
-                delta_ids=deltas_to_state_group_before_event,
-                partial_state=False,
-            )
+        # separate out state and non-state contexts
+        state_events = []
+        for event, context in events_and_context:
+            if event.is_state():
+                state_events.append((event, context))
+
+        # get state groups for state events
+        room_id = events_and_context[0][0].room_id
+        assert self.hs.datastores is not None
+        await self.hs.datastores.state.store_state_deltas_for_batched(
+            state_events, room_id, prev_group=prev_group
+        )
 
-        # otherwise, we'll need to create a new state group for after the event
-        key = (event.type, event.state_key)
+        # iterate through all contexts and update everything
+        current_state_group = prev_group
+        for event, context in events_and_context:
+
+            # if the event is not state, we need to update it
+            if not event.is_state():
+                context._state_group = current_state_group
+                context.state_group_before_event = current_state_group
+                context._state_delta_due_to_event = {}
+                context.prev_group = None
+                context.delta_ids = None
+                context.partial_state = False
+
+            # the context should have been updated when storing the state groups but let's
+            # be sure - if it does not have a state group there is a problem
+            if context._state_group is None:
+                raise RuntimeError(
+                    f"Event {event.event_id} is missing a state group."
+                )
+            current_state_group = context._state_group
 
-        if state_ids_before_event is not None:
+            key = (event.type, event.state_key)
             replaces = state_ids_before_event.get(key)
+            if replaces and replaces != event.event_id:
+                event.unsigned["replaces_state"] = replaces
 
-        if replaces and replaces != event.event_id:
-            event.unsigned["replaces_state"] = replaces
-
-        delta_ids = {key: event.event_id}
-
-        context = EventContext.without_state_group(
-            storage=self._storage_controllers,
-            state_delta_due_to_event=delta_ids,
-            delta_ids=delta_ids,
-            partial_state=False,
-        )
-        return context
+        return events_and_context
 
     @measure_func()
     async def resolve_state_groups_for_events(