summary refs log tree commit diff
path: root/synapse/state
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state')
-rw-r--r--synapse/state/__init__.py71
1 files changed, 70 insertions, 1 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 3787d35b24..833ffec3de 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -190,6 +190,7 @@ class StateHandler:
         room_id: str,
         event_ids: Collection[str],
         state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> StateMap[str]:
         """Fetch the state after each of the given event IDs. Resolve them and return.
 
@@ -200,13 +201,18 @@ class StateHandler:
         Args:
             room_id: the room_id containing the given events.
             event_ids: the events whose state should be fetched and resolved.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at the given `event_id`s, regardless of whether `state_filter` is
+                satisfied by partial state.
 
         Returns:
             the state dict (a mapping from (event_type, state_key) -> event_id) which
             holds the resolution of the states after the given event IDs.
         """
         logger.debug("calling resolve_state_groups from compute_state_after_events")
-        ret = await self.resolve_state_groups_for_events(room_id, event_ids)
+        ret = await self.resolve_state_groups_for_events(
+            room_id, event_ids, await_full_state
+        )
         return await ret.get_state(self._state_storage_controller, state_filter)
 
     async def get_current_user_ids_in_room(
@@ -420,6 +426,69 @@ class StateHandler:
             partial_state=partial_state,
         )
 
+    async def compute_event_context_for_batched(
+        self,
+        event: EventBase,
+        state_ids_before_event: StateMap[str],
+        current_state_group: int,
+    ) -> EventContext:
+        """
+        Generate an event context for an event that has not yet been persisted to the
+        database. Intended for use with events that are created to be persisted in a batch.
+        Args:
+            event: the event the context is being computed for
+            state_ids_before_event: a state map consisting of the state ids of the events
+            created prior to this event.
+            current_state_group: the current state group before the event.
+        """
+        state_group_before_event_prev_group = None
+        deltas_to_state_group_before_event = None
+
+        state_group_before_event = current_state_group
+
+        # if the event is not state, we are set
+        if not event.is_state():
+            return EventContext.with_state(
+                storage=self._storage_controllers,
+                state_group_before_event=state_group_before_event,
+                state_group=state_group_before_event,
+                state_delta_due_to_event={},
+                prev_group=state_group_before_event_prev_group,
+                delta_ids=deltas_to_state_group_before_event,
+                partial_state=False,
+            )
+
+        # otherwise, we'll need to create a new state group for after the event
+        key = (event.type, event.state_key)
+
+        if state_ids_before_event is not None:
+            replaces = state_ids_before_event.get(key)
+
+        if replaces and replaces != event.event_id:
+            event.unsigned["replaces_state"] = replaces
+
+        delta_ids = {key: event.event_id}
+
+        state_group_after_event = (
+            await self._state_storage_controller.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=state_group_before_event,
+                delta_ids=delta_ids,
+                current_state_ids=None,
+            )
+        )
+
+        return EventContext.with_state(
+            storage=self._storage_controllers,
+            state_group=state_group_after_event,
+            state_group_before_event=state_group_before_event,
+            state_delta_due_to_event=delta_ids,
+            prev_group=state_group_before_event,
+            delta_ids=delta_ids,
+            partial_state=False,
+        )
+
     @measure_func()
     async def resolve_state_groups_for_events(
         self, room_id: str, event_ids: Collection[str], await_full_state: bool = True