diff options
Diffstat (limited to 'synapse/handlers/federation_event.py')
-rw-r--r-- | synapse/handlers/federation_event.py | 140 |
1 files changed, 65 insertions, 75 deletions
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 3c7ae2ea39..b65bbaff69 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -463,7 +463,9 @@ class FederationEventHandler: with nested_logging_context(suffix=event.event_id): context = await self._state_handler.compute_event_context( event, - old_state=state, + state_ids_before_event={ + (e.type, e.state_key): e.event_id for e in state + }, partial_state=partial_state, ) @@ -501,7 +503,7 @@ class FederationEventHandler: # build a new state group for it if need be context = await self._state_handler.compute_event_context( event, - old_state=state, + state_ids_before_event=state, ) if context.partial_state: # this can happen if some or all of the event's prev_events still have @@ -765,7 +767,7 @@ class FederationEventHandler: async def _resolve_state_at_missing_prevs( self, dest: str, event: EventBase - ) -> Optional[Iterable[EventBase]]: + ) -> Optional[StateMap[str]]: """Calculate the state at an event with missing prev_events. This is used when we have pulled a batch of events from a remote server, and @@ -792,8 +794,8 @@ class FederationEventHandler: event: an event to check for missing prevs. Returns: - if we already had all the prev events, `None`. Otherwise, returns a list of - the events in the state at `event`. + if we already had all the prev events, `None`. Otherwise, returns + the state at `event`. """ room_id = event.room_id event_id = event.event_id @@ -837,13 +839,7 @@ class FederationEventHandler: dest, room_id, p ) - remote_state_map = { - (x.type, x.state_key): x.event_id for x in remote_state - } - state_maps.append(remote_state_map) - - for x in remote_state: - event_map[x.event_id] = x + state_maps.append(remote_state) room_version = await self._store.get_room_version_id(room_id) state_map = await self._state_resolution_handler.resolve_events_with_store( @@ -854,19 +850,6 @@ class FederationEventHandler: state_res_store=StateResolutionStore(self._store), ) - # We need to give _process_received_pdu the actual state events - # rather than event ids, so generate that now. - - # First though we need to fetch all the events that are in - # state_map, so we can build up the state below. - evs = await self._store.get_events( - list(state_map.values()), - get_prev_content=False, - redact_behaviour=EventRedactBehaviour.as_is, - ) - event_map.update(evs) - - state = [event_map[e] for e in state_map.values()] except Exception: logger.warning( "Error attempting to resolve state at missing prev_events", @@ -878,14 +861,14 @@ class FederationEventHandler: "We can't get valid state history.", affected=event_id, ) - return state + return state_map async def _get_state_after_missing_prev_event( self, destination: str, room_id: str, event_id: str, - ) -> List[EventBase]: + ) -> StateMap[str]: """Requests all of the room state at a given event from a remote homeserver. Args: @@ -894,7 +877,7 @@ class FederationEventHandler: event_id: The id of the event we want the state at. Returns: - A list of events in the state, including the event itself + The state *after* the given event. """ ( state_event_ids, @@ -913,15 +896,13 @@ class FederationEventHandler: desired_events = set(state_event_ids) desired_events.add(event_id) logger.debug("Fetching %i events from cache/store", len(desired_events)) - fetched_events = await self._store.get_events( - desired_events, allow_rejected=True - ) + have_events = await self._store.have_seen_events(room_id, desired_events) - missing_desired_events = desired_events - fetched_events.keys() + missing_desired_events = desired_events - have_events logger.debug( "We are missing %i events (got %i)", len(missing_desired_events), - len(fetched_events), + len(have_events), ) # We probably won't need most of the auth events, so let's just check which @@ -932,7 +913,7 @@ class FederationEventHandler: # already have a bunch of the state events. It would be nice if the # federation api gave us a way of finding out which we actually need. - missing_auth_events = set(auth_event_ids) - fetched_events.keys() + missing_auth_events = set(auth_event_ids) - have_events missing_auth_events.difference_update( await self._store.have_seen_events(room_id, missing_auth_events) ) @@ -958,47 +939,54 @@ class FederationEventHandler: destination=destination, room_id=room_id, event_ids=missing_events ) - # we need to make sure we re-load from the database to get the rejected - # state correct. - fetched_events.update( - await self._store.get_events(missing_desired_events, allow_rejected=True) - ) + event_metadata = await self._store.get_metadata_for_events(state_event_ids) # check for events which were in the wrong room. # # this can happen if a remote server claims that the state or # auth_events at an event in room A are actually events in room B - bad_events = [ - (event_id, event.room_id) - for event_id, event in fetched_events.items() - if event.room_id != room_id - ] + event_metadata = await self._store.get_metadata_for_events(state_event_ids) - for bad_event_id, bad_room_id in bad_events: - # This is a bogus situation, but since we may only discover it a long time - # after it happened, we try our best to carry on, by just omitting the - # bad events from the returned state set. - logger.warning( - "Remote server %s claims event %s in room %s is an auth/state " - "event in room %s", - destination, - bad_event_id, - bad_room_id, - room_id, - ) + state_map = {} + + for state_event_id, metadata in event_metadata.items(): + if metadata.room_id != room_id: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned state set. + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + state_event_id, + metadata.room_id, + room_id, + ) + continue + + if metadata.state_key is None: + logger.warning( + "Remote server gave us non-state event in state: %s", state_event_id + ) + continue - del fetched_events[bad_event_id] + state_map[(metadata.event_type, metadata.state_key)] = state_event_id # if we couldn't get the prev event in question, that's a problem. - remote_event = fetched_events.get(event_id) + remote_event = await self._store.get_event( + event_id, + allow_none=True, + allow_rejected=True, + redact_behaviour=EventRedactBehaviour.as_is, + ) if not remote_event: raise Exception("Unable to get missing prev_event %s" % (event_id,)) # missing state at that event is a warning, not a blocker # XXX: this doesn't sound right? it means that we'll end up with incomplete # state. - failed_to_fetch = desired_events - fetched_events.keys() + failed_to_fetch = desired_events - event_metadata.keys() if failed_to_fetch: logger.warning( "Failed to fetch missing state events for %s %s", @@ -1006,14 +994,12 @@ class FederationEventHandler: failed_to_fetch, ) - remote_state = [ - fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events - ] - if remote_event.is_state() and remote_event.rejected_reason is None: - remote_state.append(remote_event) + state_map[ + (remote_event.type, remote_event.state_key) + ] = remote_event.event_id - return remote_state + return state_map async def _get_state_and_persist( self, destination: str, room_id: str, event_id: str @@ -1040,7 +1026,7 @@ class FederationEventHandler: self, origin: str, event: EventBase, - state: Optional[Iterable[EventBase]], + state: Optional[StateMap[str]], backfilled: bool = False, ) -> None: """Called when we have a new non-outlier event. @@ -1074,7 +1060,7 @@ class FederationEventHandler: try: context = await self._state_handler.compute_event_context( - event, old_state=state + event, state_ids_before_event=state ) context = await self._check_event_auth( origin, @@ -1565,7 +1551,7 @@ class FederationEventHandler: async def _check_for_soft_fail( self, event: EventBase, - state: Optional[Iterable[EventBase]], + state: Optional[StateMap[str]], origin: str, ) -> None: """Checks if we should soft fail the event; if so, marks the event as @@ -1605,17 +1591,21 @@ class FederationEventHandler: # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets_d = await self._state_store.get_state_groups( + state_sets_d = await self._state_store.get_state_groups_ids( event.room_id, extrem_ids ) - state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) + state_sets: List[StateMap[str]] = list(state_sets_d.values()) state_sets.append(state) - current_states = await self._state_handler.resolve_events( - room_version, state_sets, event + + current_state_ids = ( + await self._state_resolution_handler.resolve_events_with_store( + event.room_id, + room_version, + state_sets, + event_map={}, + state_res_store=StateResolutionStore(self._store), + ) ) - current_state_ids: StateMap[str] = { - k: e.event_id for k, e in current_states.items() - } else: current_state_ids = await self._store.get_filtered_current_state_ids( event.room_id, StateFilter.from_types(auth_types) |