summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2019-12-11 16:37:51 +0000
committerGitHub <noreply@github.com>2019-12-11 16:37:51 +0000
commit20453565176cfd358212a23cf89dfd2deab1d690 (patch)
tree83d70909cc0b03d00dd089a1408d6d25e6fb6d6f /synapse/handlers
parentMerge pull request #6517 from matrix-org/rav/event_auth/13 (diff)
downloadsynapse-20453565176cfd358212a23cf89dfd2deab1d690.tar.xz
Add `include_event_in_state` to _get_state_for_room (#6521)
Make it return the state *after* the requested event, rather than the one
before it. This is a bit easier and requires fewer calls to
get_events_from_store_or_dest.
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/federation.py50
1 files changed, 28 insertions, 22 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bcd3b422aa..62985bab9f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -378,22 +378,10 @@ class FederationHandler(BaseHandler):
                             (
                                 remote_state,
                                 got_auth_chain,
-                            ) = await self._get_state_for_room(origin, room_id, p)
-
-                            # we want the state *after* p; _get_state_for_room returns the
-                            # state *before* p.
-                            remote_event = await self.federation_client.get_pdu(
-                                [origin], p, room_version, outlier=True
+                            ) = await self._get_state_for_room(
+                                origin, room_id, p, include_event_in_state=True
                             )
 
-                            if remote_event is None:
-                                raise Exception(
-                                    "Unable to get missing prev_event %s" % (p,)
-                                )
-
-                            if remote_event.is_state():
-                                remote_state.append(remote_event)
-
                             # XXX hrm I'm not convinced that duplicate events will compare
                             # for equality, so I'm not sure this does what the author
                             # hoped.
@@ -579,20 +567,25 @@ class FederationHandler(BaseHandler):
                     else:
                         raise
 
-    @log_function
     async def _get_state_for_room(
-        self, destination: str, room_id: str, event_id: str
+        self,
+        destination: str,
+        room_id: str,
+        event_id: str,
+        include_event_in_state: bool = False,
     ) -> Tuple[List[EventBase], List[EventBase]]:
         """Requests all of the room state at a given event from a remote homeserver.
 
         Args:
-            destination:: The remote homeserver to query for the state.
+            destination: The remote homeserver to query for the state.
             room_id: The id of the room we're interested in.
             event_id: The id of the event we want the state at.
+            include_event_in_state: if true, the event itself will be included in the
+                returned state event list.
 
         Returns:
-            A list of events in the state, and a list of events in the auth chain
-            for the given event.
+            A list of events in the state, possibly including the event itself, and
+            a list of events in the auth chain for the given event.
         """
         (
             state_event_ids,
@@ -602,6 +595,10 @@ class FederationHandler(BaseHandler):
         )
 
         desired_events = set(state_event_ids + auth_event_ids)
+
+        if include_event_in_state:
+            desired_events.add(event_id)
+
         event_map = await self._get_events_from_store_or_dest(
             destination, room_id, desired_events
         )
@@ -614,12 +611,21 @@ class FederationHandler(BaseHandler):
                 failed_to_fetch,
             )
 
-        pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
-        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+        remote_state = [
+            event_map[e_id] for e_id in state_event_ids if e_id in event_map
+        ]
+
+        if include_event_in_state:
+            remote_event = event_map.get(event_id)
+            if not remote_event:
+                raise Exception("Unable to get missing prev_event %s" % (event_id,))
+            if remote_event.is_state():
+                remote_state.append(remote_event)
 
+        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
         auth_chain.sort(key=lambda e: e.depth)
 
-        return pdus, auth_chain
+        return remote_state, auth_chain
 
     async def _get_events_from_store_or_dest(
         self, destination: str, room_id: str, event_ids: Iterable[str]