diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c0dcf9abf8..31c9132ae9 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -379,22 +379,10 @@ class FederationHandler(BaseHandler):
(
remote_state,
got_auth_chain,
- ) = yield self._get_state_for_room(origin, room_id, p)
-
- # we want the state *after* p; _get_state_for_room returns the
- # state *before* p.
- remote_event = yield self.federation_client.get_pdu(
- [origin], p, room_version, outlier=True
+ ) = yield self._get_state_for_room(
+ origin, room_id, p, include_event_in_state=True
)
- if remote_event is None:
- raise Exception(
- "Unable to get missing prev_event %s" % (p,)
- )
-
- if remote_event.is_state():
- remote_state.append(remote_event)
-
# XXX hrm I'm not convinced that duplicate events will compare
# for equality, so I'm not sure this does what the author
# hoped.
@@ -583,13 +571,15 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
- def _get_state_for_room(self, destination, room_id, event_id):
+ def _get_state_for_room(self, destination, room_id, event_id, include_event_in_state):
"""Requests all of the room state at a given event from a remote homeserver.
Args:
destination (str): The remote homeserver to query for the state.
room_id (str): The id of the room we're interested in.
event_id (str): The id of the event we want the state at.
+ include_event_in_state: if true, the event itself will be included in the
+ returned state event list.
Returns:
Deferred[Tuple[List[EventBase], List[EventBase]]]:
@@ -604,6 +594,10 @@ class FederationHandler(BaseHandler):
)
desired_events = set(state_event_ids + auth_event_ids)
+
+ if include_event_in_state:
+ desired_events.add(event_id)
+
event_map = yield self._get_events_from_store_or_dest(
destination, room_id, desired_events
)
@@ -616,12 +610,21 @@ class FederationHandler(BaseHandler):
failed_to_fetch,
)
- pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
- auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+ remote_state = [
+ event_map[e_id] for e_id in state_event_ids if e_id in event_map
+ ]
+ if include_event_in_state:
+ remote_event = event_map.get(event_id)
+ if not remote_event:
+ raise Exception("Unable to get missing prev_event %s" % (event_id,))
+ if remote_event.is_state():
+ remote_state.append(remote_event)
+
+ auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
auth_chain.sort(key=lambda e: e.depth)
- return pdus, auth_chain
+ return remote_state, auth_chain
@defer.inlineCallbacks
def _get_events_from_store_or_dest(self, destination, room_id, event_ids):
|