diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 1d20c441f3..598a66f74c 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -353,17 +353,16 @@ class FederationHandler(BaseHandler):
# Ask the remote server for the states we don't
# know about
for p in prevs - seen:
- logger.info(
- "Requesting state at missing prev_event %s",
- event_id,
- )
+ logger.info("Requesting state after missing prev_event %s", p)
with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
- (remote_state, _,) = await self._get_state_for_room(
- origin, room_id, p, include_event_in_state=True
+ remote_state = (
+ await self._get_state_after_missing_prev_event(
+ origin, room_id, p
+ )
)
remote_state_map = {
@@ -539,7 +538,6 @@ class FederationHandler(BaseHandler):
destination: str,
room_id: str,
event_id: str,
- include_event_in_state: bool = False,
) -> Tuple[List[EventBase], List[EventBase]]:
"""Requests all of the room state at a given event from a remote homeserver.
@@ -547,11 +545,9 @@ class FederationHandler(BaseHandler):
destination: The remote homeserver to query for the state.
room_id: The id of the room we're interested in.
event_id: The id of the event we want the state at.
- include_event_in_state: if true, the event itself will be included in the
- returned state event list.
Returns:
- A list of events in the state, possibly including the event itself, and
+ A list of events in the state, not including the event itself, and
a list of events in the auth chain for the given event.
"""
(
@@ -563,9 +559,6 @@ class FederationHandler(BaseHandler):
desired_events = set(state_event_ids + auth_event_ids)
- if include_event_in_state:
- desired_events.add(event_id)
-
event_map = await self._get_events_from_store_or_dest(
destination, room_id, desired_events
)
@@ -582,13 +575,6 @@ class FederationHandler(BaseHandler):
event_map[e_id] for e_id in state_event_ids if e_id in event_map
]
- if include_event_in_state:
- remote_event = event_map.get(event_id)
- if not remote_event:
- raise Exception("Unable to get missing prev_event %s" % (event_id,))
- if remote_event.is_state() and remote_event.rejected_reason is None:
- remote_state.append(remote_event)
-
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
auth_chain.sort(key=lambda e: e.depth)
@@ -662,6 +648,131 @@ class FederationHandler(BaseHandler):
return fetched_events
+ async def _get_state_after_missing_prev_event(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ ) -> List[EventBase]:
+ """Requests all of the room state at a given event from a remote homeserver.
+
+ Args:
+ destination: The remote homeserver to query for the state.
+ room_id: The id of the room we're interested in.
+ event_id: The id of the event we want the state at.
+
+ Returns:
+ A list of events in the state, including the event itself
+ """
+ # TODO: This function is basically the same as _get_state_for_room. Can
+ # we make backfill() use it, rather than having two code paths? I think the
+ # only difference is that backfill() persists the prev events separately.
+
+ (
+ state_event_ids,
+ auth_event_ids,
+ ) = await self.federation_client.get_room_state_ids(
+ destination, room_id, event_id=event_id
+ )
+
+ logger.debug(
+ "state_ids returned %i state events, %i auth events",
+ len(state_event_ids),
+ len(auth_event_ids),
+ )
+
+ # start by just trying to fetch the events from the store
+ desired_events = set(state_event_ids)
+ desired_events.add(event_id)
+ logger.debug("Fetching %i events from cache/store", len(desired_events))
+ fetched_events = await self.store.get_events(
+ desired_events, allow_rejected=True
+ )
+
+ missing_desired_events = desired_events - fetched_events.keys()
+ logger.debug(
+ "We are missing %i events (got %i)",
+ len(missing_desired_events),
+ len(fetched_events),
+ )
+
+ # We probably won't need most of the auth events, so let's just check which
+ # we have for now, rather than thrashing the event cache with them all
+ # unnecessarily.
+
+ # TODO: we probably won't actually need all of the auth events, since we
+ # already have a bunch of the state events. It would be nice if the
+ # federation api gave us a way of finding out which we actually need.
+
+ missing_auth_events = set(auth_event_ids) - fetched_events.keys()
+ missing_auth_events.difference_update(
+ await self.store.have_seen_events(missing_auth_events)
+ )
+ logger.debug("We are also missing %i auth events", len(missing_auth_events))
+
+ missing_events = missing_desired_events | missing_auth_events
+ logger.debug("Fetching %i events from remote", len(missing_events))
+ await self._get_events_and_persist(
+ destination=destination, room_id=room_id, events=missing_events
+ )
+
+ # we need to make sure we re-load from the database to get the rejected
+ # state correct.
+ fetched_events.update(
+ (await self.store.get_events(missing_desired_events, allow_rejected=True))
+ )
+
+ # check for events which were in the wrong room.
+ #
+ # this can happen if a remote server claims that the state or
+ # auth_events at an event in room A are actually events in room B
+
+ bad_events = [
+ (event_id, event.room_id)
+ for event_id, event in fetched_events.items()
+ if event.room_id != room_id
+ ]
+
+ for bad_event_id, bad_room_id in bad_events:
+ # This is a bogus situation, but since we may only discover it a long time
+ # after it happened, we try our best to carry on, by just omitting the
+ # bad events from the returned state set.
+ logger.warning(
+ "Remote server %s claims event %s in room %s is an auth/state "
+ "event in room %s",
+ destination,
+ bad_event_id,
+ bad_room_id,
+ room_id,
+ )
+
+ del fetched_events[bad_event_id]
+
+ # if we couldn't get the prev event in question, that's a problem.
+ remote_event = fetched_events.get(event_id)
+ if not remote_event:
+ raise Exception("Unable to get missing prev_event %s" % (event_id,))
+
+ # missing state at that event is a warning, not a blocker
+ # XXX: this doesn't sound right? it means that we'll end up with incomplete
+ # state.
+ failed_to_fetch = desired_events - fetched_events.keys()
+ if failed_to_fetch:
+ logger.warning(
+ "Failed to fetch missing state events for %s %s",
+ event_id,
+ failed_to_fetch,
+ )
+
+ remote_state = [
+ fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
+ ]
+
+ if remote_event.is_state() and remote_event.rejected_reason is None:
+ remote_state.append(remote_event)
+
+ return remote_state
+
async def _process_received_pdu(
self,
origin: str,
@@ -841,7 +952,6 @@ class FederationHandler(BaseHandler):
destination=dest,
room_id=room_id,
event_id=e_id,
- include_event_in_state=False,
)
auth_events.update({a.event_id: a for a in auth})
auth_events.update({s.event_id: s for s in state})
|