diff options
author | Richard van der Hoff <1389908+richvdh@users.noreply.github.com> | 2019-12-11 16:37:51 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-12-11 16:37:51 +0000 |
commit | 20453565176cfd358212a23cf89dfd2deab1d690 (patch) | |
tree | 83d70909cc0b03d00dd089a1408d6d25e6fb6d6f /synapse/handlers/federation.py | |
parent | Merge pull request #6517 from matrix-org/rav/event_auth/13 (diff) | |
download | synapse-20453565176cfd358212a23cf89dfd2deab1d690.tar.xz |
Add `include_event_in_state` to _get_state_for_room (#6521)
Make it return the state *after* the requested event, rather than the one before it. This is a bit easier and requires fewer calls to get_events_from_store_or_dest.
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r-- | synapse/handlers/federation.py | 50 |
1 files changed, 28 insertions, 22 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index bcd3b422aa..62985bab9f 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -378,22 +378,10 @@ class FederationHandler(BaseHandler): ( remote_state, got_auth_chain, - ) = await self._get_state_for_room(origin, room_id, p) - - # we want the state *after* p; _get_state_for_room returns the - # state *before* p. - remote_event = await self.federation_client.get_pdu( - [origin], p, room_version, outlier=True + ) = await self._get_state_for_room( + origin, room_id, p, include_event_in_state=True ) - if remote_event is None: - raise Exception( - "Unable to get missing prev_event %s" % (p,) - ) - - if remote_event.is_state(): - remote_state.append(remote_event) - # XXX hrm I'm not convinced that duplicate events will compare # for equality, so I'm not sure this does what the author # hoped. @@ -579,20 +567,25 @@ class FederationHandler(BaseHandler): else: raise - @log_function async def _get_state_for_room( - self, destination: str, room_id: str, event_id: str + self, + destination: str, + room_id: str, + event_id: str, + include_event_in_state: bool = False, ) -> Tuple[List[EventBase], List[EventBase]]: """Requests all of the room state at a given event from a remote homeserver. Args: - destination:: The remote homeserver to query for the state. + destination: The remote homeserver to query for the state. room_id: The id of the room we're interested in. event_id: The id of the event we want the state at. + include_event_in_state: if true, the event itself will be included in the + returned state event list. Returns: - A list of events in the state, and a list of events in the auth chain - for the given event. + A list of events in the state, possibly including the event itself, and + a list of events in the auth chain for the given event. """ ( state_event_ids, @@ -602,6 +595,10 @@ class FederationHandler(BaseHandler): ) desired_events = set(state_event_ids + auth_event_ids) + + if include_event_in_state: + desired_events.add(event_id) + event_map = await self._get_events_from_store_or_dest( destination, room_id, desired_events ) @@ -614,12 +611,21 @@ class FederationHandler(BaseHandler): failed_to_fetch, ) - pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] - auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] + remote_state = [ + event_map[e_id] for e_id in state_event_ids if e_id in event_map + ] + + if include_event_in_state: + remote_event = event_map.get(event_id) + if not remote_event: + raise Exception("Unable to get missing prev_event %s" % (event_id,)) + if remote_event.is_state(): + remote_state.append(remote_event) + auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] auth_chain.sort(key=lambda e: e.depth) - return pdus, auth_chain + return remote_state, auth_chain async def _get_events_from_store_or_dest( self, destination: str, room_id: str, event_ids: Iterable[str] |