diff options
Diffstat (limited to 'synapse/state')
-rw-r--r-- | synapse/state/__init__.py | 36 |
1 files changed, 21 insertions, 15 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 56606e9afb..fcb7e829d4 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -298,12 +298,18 @@ class StateHandler: state_group_before_event_prev_group = entry.prev_group deltas_to_state_group_before_event = entry.delta_ids + state_ids_before_event = None # We make sure that we have a state group assigned to the state. if entry.state_group is None: - state_ids_before_event = await entry.get_state( - self._state_storage_controller, StateFilter.all() - ) + # store_state_group requires us to have either a previous state group + # (with deltas) or the complete state map. So, if we don't have a + # previous state group, load the complete state map now. + if state_group_before_event_prev_group is None: + state_ids_before_event = await entry.get_state( + self._state_storage_controller, StateFilter.all() + ) + state_group_before_event = ( await self._state_storage_controller.store_state_group( event.event_id, @@ -316,7 +322,6 @@ class StateHandler: entry.state_group = state_group_before_event else: state_group_before_event = entry.state_group - state_ids_before_event = None # # now if it's not a state event, we're done @@ -336,19 +341,20 @@ class StateHandler: # # otherwise, we'll need to create a new state group for after the event # - if state_ids_before_event is None: - state_ids_before_event = await entry.get_state( - self._state_storage_controller, StateFilter.all() - ) key = (event.type, event.state_key) - if key in state_ids_before_event: - replaces = state_ids_before_event[key] - if replaces != event.event_id: - event.unsigned["replaces_state"] = replaces - state_ids_after_event = dict(state_ids_before_event) - state_ids_after_event[key] = event.event_id + if state_ids_before_event is not None: + replaces = state_ids_before_event.get(key) + else: + replaces_state_map = await entry.get_state( + self._state_storage_controller, StateFilter.from_types([key]) + ) + replaces = replaces_state_map.get(key) + + if replaces and replaces != event.event_id: + event.unsigned["replaces_state"] = replaces + delta_ids = {key: event.event_id} state_group_after_event = ( @@ -357,7 +363,7 @@ class StateHandler: event.room_id, prev_group=state_group_before_event, delta_ids=delta_ids, - current_state_ids=state_ids_after_event, + current_state_ids=None, ) ) |