diff options
Diffstat (limited to 'synapse/state')
-rw-r--r-- | synapse/state/__init__.py | 71 |
1 files changed, 70 insertions, 1 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 3787d35b24..833ffec3de 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -190,6 +190,7 @@ class StateHandler: room_id: str, event_ids: Collection[str], state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> StateMap[str]: """Fetch the state after each of the given event IDs. Resolve them and return. @@ -200,13 +201,18 @@ class StateHandler: Args: room_id: the room_id containing the given events. event_ids: the events whose state should be fetched and resolved. + await_full_state: if `True`, will block if we do not yet have complete state + at the given `event_id`s, regardless of whether `state_filter` is + satisfied by partial state. Returns: the state dict (a mapping from (event_type, state_key) -> event_id) which holds the resolution of the states after the given event IDs. """ logger.debug("calling resolve_state_groups from compute_state_after_events") - ret = await self.resolve_state_groups_for_events(room_id, event_ids) + ret = await self.resolve_state_groups_for_events( + room_id, event_ids, await_full_state + ) return await ret.get_state(self._state_storage_controller, state_filter) async def get_current_user_ids_in_room( @@ -420,6 +426,69 @@ class StateHandler: partial_state=partial_state, ) + async def compute_event_context_for_batched( + self, + event: EventBase, + state_ids_before_event: StateMap[str], + current_state_group: int, + ) -> EventContext: + """ + Generate an event context for an event that has not yet been persisted to the + database. Intended for use with events that are created to be persisted in a batch. + Args: + event: the event the context is being computed for + state_ids_before_event: a state map consisting of the state ids of the events + created prior to this event. + current_state_group: the current state group before the event. + """ + state_group_before_event_prev_group = None + deltas_to_state_group_before_event = None + + state_group_before_event = current_state_group + + # if the event is not state, we are set + if not event.is_state(): + return EventContext.with_state( + storage=self._storage_controllers, + state_group_before_event=state_group_before_event, + state_group=state_group_before_event, + state_delta_due_to_event={}, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + partial_state=False, + ) + + # otherwise, we'll need to create a new state group for after the event + key = (event.type, event.state_key) + + if state_ids_before_event is not None: + replaces = state_ids_before_event.get(key) + + if replaces and replaces != event.event_id: + event.unsigned["replaces_state"] = replaces + + delta_ids = {key: event.event_id} + + state_group_after_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event, + delta_ids=delta_ids, + current_state_ids=None, + ) + ) + + return EventContext.with_state( + storage=self._storage_controllers, + state_group=state_group_after_event, + state_group_before_event=state_group_before_event, + state_delta_due_to_event=delta_ids, + prev_group=state_group_before_event, + delta_ids=delta_ids, + partial_state=False, + ) + @measure_func() async def resolve_state_groups_for_events( self, room_id: str, event_ids: Collection[str], await_full_state: bool = True |