diff options
Diffstat (limited to 'synapse/state/__init__.py')
-rw-r--r-- | synapse/state/__init__.py | 31 |
1 files changed, 29 insertions, 2 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index fcc24ad129..6babd5963c 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -258,7 +258,10 @@ class StateHandler: return await self.store.get_joined_hosts(room_id, entry) async def compute_event_context( - self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None + self, + event: EventBase, + old_state: Optional[Iterable[EventBase]] = None, + partial_state: bool = False, ) -> EventContext: """Build an EventContext structure for a non-outlier event. @@ -273,6 +276,8 @@ class StateHandler: calculated from existing events. This is normally only specified when receiving an event from federation where we don't have the prev events for, e.g. when backfilling. + partial_state: True if `old_state` is partial and omits non-critical + membership events Returns: The event context. """ @@ -295,8 +300,28 @@ class StateHandler: else: # otherwise, we'll need to resolve the state across the prev_events. - logger.debug("calling resolve_state_groups from compute_event_context") + # partial_state should not be set explicitly in this case: + # we work it out dynamically + assert not partial_state + + # if any of the prev-events have partial state, so do we. + # (This is slightly racy - the prev-events might get fixed up before we use + # their states - but I don't think that really matters; it just means we + # might redundantly recalculate the state for this event later.) + prev_event_ids = event.prev_event_ids() + incomplete_prev_events = await self.store.get_partial_state_events( + prev_event_ids + ) + if any(incomplete_prev_events.values()): + logger.debug( + "New/incoming event %s refers to prev_events %s with partial state", + event.event_id, + [k for (k, v) in incomplete_prev_events.items() if v], + ) + partial_state = True + + logger.debug("calling resolve_state_groups from compute_event_context") entry = await self.resolve_state_groups_for_events( event.room_id, event.prev_event_ids() ) @@ -342,6 +367,7 @@ class StateHandler: prev_state_ids=state_ids_before_event, prev_group=state_group_before_event_prev_group, delta_ids=deltas_to_state_group_before_event, + partial_state=partial_state, ) # @@ -373,6 +399,7 @@ class StateHandler: prev_state_ids=state_ids_before_event, prev_group=state_group_before_event, delta_ids=delta_ids, + partial_state=partial_state, ) @measure_func() |