summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/6521.misc1
-rw-r--r--synapse/handlers/federation.py39
2 files changed, 22 insertions, 18 deletions
diff --git a/changelog.d/6521.misc b/changelog.d/6521.misc
new file mode 100644
index 0000000000..d9a44389b9
--- /dev/null
+++ b/changelog.d/6521.misc
@@ -0,0 +1 @@
+Refactor some code in the event authentication path for clarity.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c0dcf9abf8..31c9132ae9 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -379,22 +379,10 @@ class FederationHandler(BaseHandler):
                             (
                                 remote_state,
                                 got_auth_chain,
-                            ) = yield self._get_state_for_room(origin, room_id, p)
-
-                            # we want the state *after* p; _get_state_for_room returns the
-                            # state *before* p.
-                            remote_event = yield self.federation_client.get_pdu(
-                                [origin], p, room_version, outlier=True
+                            ) = yield self._get_state_for_room(
+                                origin, room_id, p, include_event_in_state=True
                             )
 
-                            if remote_event is None:
-                                raise Exception(
-                                    "Unable to get missing prev_event %s" % (p,)
-                                )
-
-                            if remote_event.is_state():
-                                remote_state.append(remote_event)
-
                             # XXX hrm I'm not convinced that duplicate events will compare
                             # for equality, so I'm not sure this does what the author
                             # hoped.
@@ -583,13 +571,15 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     @log_function
-    def _get_state_for_room(self, destination, room_id, event_id):
+    def _get_state_for_room(self, destination, room_id, event_id, include_event_in_state):
         """Requests all of the room state at a given event from a remote homeserver.
 
         Args:
             destination (str): The remote homeserver to query for the state.
             room_id (str): The id of the room we're interested in.
             event_id (str): The id of the event we want the state at.
+            include_event_in_state: if true, the event itself will be included in the
+                returned state event list.
 
         Returns:
             Deferred[Tuple[List[EventBase], List[EventBase]]]:
@@ -604,6 +594,10 @@ class FederationHandler(BaseHandler):
         )
 
         desired_events = set(state_event_ids + auth_event_ids)
+
+        if include_event_in_state:
+            desired_events.add(event_id)
+
         event_map = yield self._get_events_from_store_or_dest(
             destination, room_id, desired_events
         )
@@ -616,12 +610,21 @@ class FederationHandler(BaseHandler):
                 failed_to_fetch,
             )
 
-        pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
-        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+        remote_state = [
+            event_map[e_id] for e_id in state_event_ids if e_id in event_map
+        ]
 
+        if include_event_in_state:
+            remote_event = event_map.get(event_id)
+            if not remote_event:
+                raise Exception("Unable to get missing prev_event %s" % (event_id,))
+            if remote_event.is_state():
+                remote_state.append(remote_event)
+
+        auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
         auth_chain.sort(key=lambda e: e.depth)
 
-        return pdus, auth_chain
+        return remote_state, auth_chain
 
     @defer.inlineCallbacks
     def _get_events_from_store_or_dest(self, destination, room_id, event_ids):