diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index fcc24ad129..6babd5963c 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -258,7 +258,10 @@ class StateHandler:
return await self.store.get_joined_hosts(room_id, entry)
async def compute_event_context(
- self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
+ self,
+ event: EventBase,
+ old_state: Optional[Iterable[EventBase]] = None,
+ partial_state: bool = False,
) -> EventContext:
"""Build an EventContext structure for a non-outlier event.
@@ -273,6 +276,8 @@ class StateHandler:
calculated from existing events. This is normally only specified
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
+ partial_state: True if `old_state` is partial and omits non-critical
+ membership events
Returns:
The event context.
"""
@@ -295,8 +300,28 @@ class StateHandler:
else:
# otherwise, we'll need to resolve the state across the prev_events.
- logger.debug("calling resolve_state_groups from compute_event_context")
+ # partial_state should not be set explicitly in this case:
+ # we work it out dynamically
+ assert not partial_state
+
+ # if any of the prev-events have partial state, so do we.
+ # (This is slightly racy - the prev-events might get fixed up before we use
+ # their states - but I don't think that really matters; it just means we
+ # might redundantly recalculate the state for this event later.)
+ prev_event_ids = event.prev_event_ids()
+ incomplete_prev_events = await self.store.get_partial_state_events(
+ prev_event_ids
+ )
+ if any(incomplete_prev_events.values()):
+ logger.debug(
+ "New/incoming event %s refers to prev_events %s with partial state",
+ event.event_id,
+ [k for (k, v) in incomplete_prev_events.items() if v],
+ )
+ partial_state = True
+
+ logger.debug("calling resolve_state_groups from compute_event_context")
entry = await self.resolve_state_groups_for_events(
event.room_id, event.prev_event_ids()
)
@@ -342,6 +367,7 @@ class StateHandler:
prev_state_ids=state_ids_before_event,
prev_group=state_group_before_event_prev_group,
delta_ids=deltas_to_state_group_before_event,
+ partial_state=partial_state,
)
#
@@ -373,6 +399,7 @@ class StateHandler:
prev_state_ids=state_ids_before_event,
prev_group=state_group_before_event,
delta_ids=delta_ids,
+ partial_state=partial_state,
)
@measure_func()
|