diff options
Diffstat (limited to '')
-rw-r--r-- | changelog.d/6521.misc | 1 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 39 |
2 files changed, 22 insertions, 18 deletions
diff --git a/changelog.d/6521.misc b/changelog.d/6521.misc new file mode 100644 index 0000000000..d9a44389b9 --- /dev/null +++ b/changelog.d/6521.misc @@ -0,0 +1 @@ +Refactor some code in the event authentication path for clarity. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c0dcf9abf8..31c9132ae9 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -379,22 +379,10 @@ class FederationHandler(BaseHandler): ( remote_state, got_auth_chain, - ) = yield self._get_state_for_room(origin, room_id, p) - - # we want the state *after* p; _get_state_for_room returns the - # state *before* p. - remote_event = yield self.federation_client.get_pdu( - [origin], p, room_version, outlier=True + ) = yield self._get_state_for_room( + origin, room_id, p, include_event_in_state=True ) - if remote_event is None: - raise Exception( - "Unable to get missing prev_event %s" % (p,) - ) - - if remote_event.is_state(): - remote_state.append(remote_event) - # XXX hrm I'm not convinced that duplicate events will compare # for equality, so I'm not sure this does what the author # hoped. @@ -583,13 +571,15 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function - def _get_state_for_room(self, destination, room_id, event_id): + def _get_state_for_room(self, destination, room_id, event_id, include_event_in_state): """Requests all of the room state at a given event from a remote homeserver. Args: destination (str): The remote homeserver to query for the state. room_id (str): The id of the room we're interested in. event_id (str): The id of the event we want the state at. + include_event_in_state: if true, the event itself will be included in the + returned state event list. Returns: Deferred[Tuple[List[EventBase], List[EventBase]]]: @@ -604,6 +594,10 @@ class FederationHandler(BaseHandler): ) desired_events = set(state_event_ids + auth_event_ids) + + if include_event_in_state: + desired_events.add(event_id) + event_map = yield self._get_events_from_store_or_dest( destination, room_id, desired_events ) @@ -616,12 +610,21 @@ class FederationHandler(BaseHandler): failed_to_fetch, ) - pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] - auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] + remote_state = [ + event_map[e_id] for e_id in state_event_ids if e_id in event_map + ] + if include_event_in_state: + remote_event = event_map.get(event_id) + if not remote_event: + raise Exception("Unable to get missing prev_event %s" % (event_id,)) + if remote_event.is_state(): + remote_state.append(remote_event) + + auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] auth_chain.sort(key=lambda e: e.depth) - return pdus, auth_chain + return remote_state, auth_chain @defer.inlineCallbacks def _get_events_from_store_or_dest(self, destination, room_id, event_ids): |