diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 9d867aaf4d..a8aa84dd5f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -552,7 +552,7 @@ class FederationHandler(BaseHandler):
destination: str,
room_id: str,
event_id: str,
- ) -> Tuple[List[EventBase], List[EventBase]]:
+ ) -> List[EventBase]:
"""Requests all of the room state at a given event from a remote homeserver.
Args:
@@ -573,11 +573,10 @@ class FederationHandler(BaseHandler):
desired_events = set(state_event_ids + auth_event_ids)
- event_map = await self._get_events_from_store_or_dest(
+ failed_to_fetch = await self._get_events_from_store_or_dest(
destination, room_id, desired_events
)
- failed_to_fetch = desired_events - event_map.keys()
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state/auth events for %s %s",
@@ -585,18 +584,44 @@ class FederationHandler(BaseHandler):
failed_to_fetch,
)
+ event_map = await self.store.get_events(state_event_ids, allow_rejected=True)
+
remote_state = [
event_map[e_id] for e_id in state_event_ids if e_id in event_map
]
- auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
- auth_chain.sort(key=lambda e: e.depth)
+ # check for events which were in the wrong room.
+ #
+ # this can happen if a remote server claims that the state or
+ # auth_events at an event in room A are actually events in room B
+
+ bad_events = [
+ (event_id, event.room_id)
+ for idx, event in enumerate(remote_state)
+ if event.room_id != room_id
+ ]
+
+ for bad_event_id, bad_room_id in bad_events:
+ # This is a bogus situation, but since we may only discover it a long time
+ # after it happened, we try our best to carry on, by just omitting the
+ # bad events from the returned auth/state set.
+ logger.warning(
+ "Remote server %s claims event %s in room %s is an auth/state "
+ "event in room %s",
+ destination,
+ bad_event_id,
+ bad_room_id,
+ room_id,
+ )
+
+ if bad_events:
+ remote_state = [e for e in remote_state if e.room_id == room_id]
- return remote_state, auth_chain
+ return remote_state
async def _get_events_from_store_or_dest(
self, destination: str, room_id: str, event_ids: Iterable[str]
- ) -> Dict[str, EventBase]:
+ ) -> Set[str]:
"""Fetch events from a remote destination, checking if we already have them.
Persists any events we don't already have as outliers.
@@ -613,54 +638,25 @@ class FederationHandler(BaseHandler):
Returns:
map from event_id to event
"""
- fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
+ have_events = await self.store.have_seen_events(event_ids)
- missing_events = set(event_ids) - fetched_events.keys()
-
- if missing_events:
- logger.debug(
- "Fetching unknown state/auth events %s for room %s",
- missing_events,
- room_id,
- )
+ missing_events = set(event_ids) - have_events
- await self._get_events_and_persist(
- destination=destination, room_id=room_id, events=missing_events
- )
+ if not missing_events:
+ return set()
- # we need to make sure we re-load from the database to get the rejected
- # state correct.
- fetched_events.update(
- (await self.store.get_events(missing_events, allow_rejected=True))
- )
-
- # check for events which were in the wrong room.
- #
- # this can happen if a remote server claims that the state or
- # auth_events at an event in room A are actually events in room B
-
- bad_events = [
- (event_id, event.room_id)
- for event_id, event in fetched_events.items()
- if event.room_id != room_id
- ]
-
- for bad_event_id, bad_room_id in bad_events:
- # This is a bogus situation, but since we may only discover it a long time
- # after it happened, we try our best to carry on, by just omitting the
- # bad events from the returned auth/state set.
- logger.warning(
- "Remote server %s claims event %s in room %s is an auth/state "
- "event in room %s",
- destination,
- bad_event_id,
- bad_room_id,
- room_id,
- )
+ logger.debug(
+ "Fetching unknown state/auth events %s for room %s",
+ missing_events,
+ room_id,
+ )
- del fetched_events[bad_event_id]
+ await self._get_events_and_persist(
+ destination=destination, room_id=room_id, events=missing_events
+ )
- return fetched_events
+ new_events = await self.store.have_seen_events(missing_events)
+ return missing_events - new_events
async def _get_state_after_missing_prev_event(
self,
@@ -963,27 +959,23 @@ class FederationHandler(BaseHandler):
# For each edge get the current state.
- auth_events = {}
state_events = {}
events_to_state = {}
for e_id in edges:
- state, auth = await self._get_state_for_room(
+ state = await self._get_state_for_room(
destination=dest,
room_id=room_id,
event_id=e_id,
)
- auth_events.update({a.event_id: a for a in auth})
- auth_events.update({s.event_id: s for s in state})
state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state
required_auth = {
a_id
- for event in events
- + list(state_events.values())
- + list(auth_events.values())
+ for event in events + list(state_events.values())
for a_id in event.auth_event_ids()
}
+ auth_events = await self.store.get_events(required_auth, allow_rejected=True)
auth_events.update(
{e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
)
|