diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 56606e9afb..fcb7e829d4 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -298,12 +298,18 @@ class StateHandler:
state_group_before_event_prev_group = entry.prev_group
deltas_to_state_group_before_event = entry.delta_ids
+ state_ids_before_event = None
# We make sure that we have a state group assigned to the state.
if entry.state_group is None:
- state_ids_before_event = await entry.get_state(
- self._state_storage_controller, StateFilter.all()
- )
+ # store_state_group requires us to have either a previous state group
+ # (with deltas) or the complete state map. So, if we don't have a
+ # previous state group, load the complete state map now.
+ if state_group_before_event_prev_group is None:
+ state_ids_before_event = await entry.get_state(
+ self._state_storage_controller, StateFilter.all()
+ )
+
state_group_before_event = (
await self._state_storage_controller.store_state_group(
event.event_id,
@@ -316,7 +322,6 @@ class StateHandler:
entry.state_group = state_group_before_event
else:
state_group_before_event = entry.state_group
- state_ids_before_event = None
#
# now if it's not a state event, we're done
@@ -336,19 +341,20 @@ class StateHandler:
#
# otherwise, we'll need to create a new state group for after the event
#
- if state_ids_before_event is None:
- state_ids_before_event = await entry.get_state(
- self._state_storage_controller, StateFilter.all()
- )
key = (event.type, event.state_key)
- if key in state_ids_before_event:
- replaces = state_ids_before_event[key]
- if replaces != event.event_id:
- event.unsigned["replaces_state"] = replaces
- state_ids_after_event = dict(state_ids_before_event)
- state_ids_after_event[key] = event.event_id
+ if state_ids_before_event is not None:
+ replaces = state_ids_before_event.get(key)
+ else:
+ replaces_state_map = await entry.get_state(
+ self._state_storage_controller, StateFilter.from_types([key])
+ )
+ replaces = replaces_state_map.get(key)
+
+ if replaces and replaces != event.event_id:
+ event.unsigned["replaces_state"] = replaces
+
delta_ids = {key: event.event_id}
state_group_after_event = (
@@ -357,7 +363,7 @@ class StateHandler:
event.room_id,
prev_group=state_group_before_event,
delta_ids=delta_ids,
- current_state_ids=state_ids_after_event,
+ current_state_ids=None,
)
)
|