diff options
-rw-r--r-- | synapse/state/__init__.py | 78 |
1 files changed, 44 insertions, 34 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 0b181a2c5a..30bbc0437a 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -422,49 +422,59 @@ class StateHandler: async def compute_event_context_for_batched( self, - event: EventBase, - state_ids_before_event: StateMap[str], - ) -> EventContext: + events_and_context: List[Tuple[EventBase, EventContext]], + prev_group: int, + state_ids_before_event: StateMap, + ) -> List[Tuple[EventBase, EventContext]]: """ Generate an event context for an event that has not yet been persisted to the database. Intended for use with events that are created to be persisted in a batch. Args: - event: the event the context is being computed for - state_ids_before_event: a state map consisting of the state ids of the events - created prior to this event. - current_state_group: the current state group before the event. + events_and_context: a list of events and their associated contexts + prev_group: the state group of the last event persisted before the batched events + were created + state_ids_before_event: a state map consisting of current state ids """ - state_group_before_event_prev_group = None - deltas_to_state_group_before_event = None - - # if the event is not state, we are set - if not event.is_state(): - return EventContext.without_state_group( - storage=self._storage_controllers, - state_delta_due_to_event={}, - prev_group=state_group_before_event_prev_group, - delta_ids=deltas_to_state_group_before_event, - partial_state=False, - ) + # separate out state and non-state contexts + state_events = [] + for event, context in events_and_context: + if event.is_state(): + state_events.append((event, context)) + + # get state groups for state events + room_id = events_and_context[0][0].room_id + assert self.hs.datastores is not None + await self.hs.datastores.state.store_state_deltas_for_batched( + state_events, room_id, prev_group=prev_group + ) - # otherwise, we'll need to create a new state group for after the event - key = (event.type, event.state_key) + # iterate through all contexts and update everything + current_state_group = prev_group + for event, context in events_and_context: + + # if the event is not state, we need to update it + if not event.is_state(): + context._state_group = current_state_group + context.state_group_before_event = current_state_group + context._state_delta_due_to_event = {} + context.prev_group = None + context.delta_ids = None + context.partial_state = False + + # the context should have been updated when storing the state groups but let's + # be sure - if it does not have a state group there is a problem + if context._state_group is None: + raise RuntimeError( + f"Event {event.event_id} is missing a state group." + ) + current_state_group = context._state_group - if state_ids_before_event is not None: + key = (event.type, event.state_key) replaces = state_ids_before_event.get(key) + if replaces and replaces != event.event_id: + event.unsigned["replaces_state"] = replaces - if replaces and replaces != event.event_id: - event.unsigned["replaces_state"] = replaces - - delta_ids = {key: event.event_id} - - context = EventContext.without_state_group( - storage=self._storage_controllers, - state_delta_due_to_event=delta_ids, - delta_ids=delta_ids, - partial_state=False, - ) - return context + return events_and_context @measure_func() async def resolve_state_groups_for_events( |