diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bcd3b422aa..62985bab9f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -378,22 +378,10 @@ class FederationHandler(BaseHandler):
(
remote_state,
got_auth_chain,
- ) = await self._get_state_for_room(origin, room_id, p)
-
- # we want the state *after* p; _get_state_for_room returns the
- # state *before* p.
- remote_event = await self.federation_client.get_pdu(
- [origin], p, room_version, outlier=True
+ ) = await self._get_state_for_room(
+ origin, room_id, p, include_event_in_state=True
)
- if remote_event is None:
- raise Exception(
- "Unable to get missing prev_event %s" % (p,)
- )
-
- if remote_event.is_state():
- remote_state.append(remote_event)
-
# XXX hrm I'm not convinced that duplicate events will compare
# for equality, so I'm not sure this does what the author
# hoped.
@@ -579,20 +567,25 @@ class FederationHandler(BaseHandler):
else:
raise
- @log_function
async def _get_state_for_room(
- self, destination: str, room_id: str, event_id: str
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ include_event_in_state: bool = False,
) -> Tuple[List[EventBase], List[EventBase]]:
"""Requests all of the room state at a given event from a remote homeserver.
Args:
- destination:: The remote homeserver to query for the state.
+ destination: The remote homeserver to query for the state.
room_id: The id of the room we're interested in.
event_id: The id of the event we want the state at.
+ include_event_in_state: if true, the event itself will be included in the
+ returned state event list.
Returns:
- A list of events in the state, and a list of events in the auth chain
- for the given event.
+ A list of events in the state, possibly including the event itself, and
+ a list of events in the auth chain for the given event.
"""
(
state_event_ids,
@@ -602,6 +595,10 @@ class FederationHandler(BaseHandler):
)
desired_events = set(state_event_ids + auth_event_ids)
+
+ if include_event_in_state:
+ desired_events.add(event_id)
+
event_map = await self._get_events_from_store_or_dest(
destination, room_id, desired_events
)
@@ -614,12 +611,21 @@ class FederationHandler(BaseHandler):
failed_to_fetch,
)
- pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
- auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+ remote_state = [
+ event_map[e_id] for e_id in state_event_ids if e_id in event_map
+ ]
+
+ if include_event_in_state:
+ remote_event = event_map.get(event_id)
+ if not remote_event:
+ raise Exception("Unable to get missing prev_event %s" % (event_id,))
+ if remote_event.is_state():
+ remote_state.append(remote_event)
+ auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
auth_chain.sort(key=lambda e: e.depth)
- return pdus, auth_chain
+ return remote_state, auth_chain
async def _get_events_from_store_or_dest(
self, destination: str, room_id: str, event_ids: Iterable[str]
|