summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/10624.misc1
-rw-r--r--synapse/handlers/federation.py229
2 files changed, 125 insertions, 105 deletions
diff --git a/changelog.d/10624.misc b/changelog.d/10624.misc
new file mode 100644
index 0000000000..9a765435db
--- /dev/null
+++ b/changelog.d/10624.misc
@@ -0,0 +1 @@
+Clean up some of the federation event authentication code for clarity.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 529d025c39..de86918b7b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -285,12 +285,12 @@ class FederationHandler(BaseHandler):
         #  - Fetching any missing prev events to fill in gaps in the graph
         #  - Fetching state if we have a hole in the graph
         if not pdu.internal_metadata.is_outlier():
-            prevs = set(pdu.prev_event_ids())
-            seen = await self.store.have_events_in_timeline(prevs)
-            missing_prevs = prevs - seen
+            if sent_to_us_directly:
+                prevs = set(pdu.prev_event_ids())
+                seen = await self.store.have_events_in_timeline(prevs)
+                missing_prevs = prevs - seen
 
-            if missing_prevs:
-                if sent_to_us_directly:
+                if missing_prevs:
                     # We only backfill backwards to the min depth.
                     min_depth = await self.get_min_depth_for_context(pdu.room_id)
                     logger.debug("min_depth: %d", min_depth)
@@ -351,106 +351,8 @@ class FederationHandler(BaseHandler):
                             affected=pdu.event_id,
                         )
 
-                else:
-                    # We don't have all of the prev_events for this event.
-                    #
-                    # In this case, we need to fall back to asking another server in the
-                    # federation for the state at this event. That's ok provided we then
-                    # resolve the state against other bits of the DAG before using it (which
-                    # will ensure that you can't just take over a room by sending an event,
-                    # withholding its prev_events, and declaring yourself to be an admin in
-                    # the subsequent state request).
-                    #
-                    # Since we're pulling this event as a missing prev_event, then clearly
-                    # this event is not going to become the only forward-extremity and we are
-                    # guaranteed to resolve its state against our existing forward
-                    # extremities, so that should be fine.
-                    #
-                    # XXX this really feels like it could/should be merged with the above,
-                    # but there is an interaction with min_depth that I'm not really
-                    # following.
-                    logger.info(
-                        "Event %s is missing prev_events %s: calculating state for a "
-                        "backwards extremity",
-                        event_id,
-                        shortstr(missing_prevs),
-                    )
-
-                    # Calculate the state after each of the previous events, and
-                    # resolve them to find the correct state at the current event.
-                    event_map = {event_id: pdu}
-                    try:
-                        # Get the state of the events we know about
-                        ours = await self.state_store.get_state_groups_ids(
-                            room_id, seen
-                        )
-
-                        # state_maps is a list of mappings from (type, state_key) to event_id
-                        state_maps: List[StateMap[str]] = list(ours.values())
-
-                        # we don't need this any more, let's delete it.
-                        del ours
-
-                        # Ask the remote server for the states we don't
-                        # know about
-                        for p in missing_prevs:
-                            logger.info(
-                                "Requesting state after missing prev_event %s", p
-                            )
-
-                            with nested_logging_context(p):
-                                # note that if any of the missing prevs share missing state or
-                                # auth events, the requests to fetch those events are deduped
-                                # by the get_pdu_cache in federation_client.
-                                remote_state = (
-                                    await self._get_state_after_missing_prev_event(
-                                        origin, room_id, p
-                                    )
-                                )
-
-                                remote_state_map = {
-                                    (x.type, x.state_key): x.event_id
-                                    for x in remote_state
-                                }
-                                state_maps.append(remote_state_map)
-
-                                for x in remote_state:
-                                    event_map[x.event_id] = x
-
-                        room_version = await self.store.get_room_version_id(room_id)
-                        state_map = await self._state_resolution_handler.resolve_events_with_store(
-                            room_id,
-                            room_version,
-                            state_maps,
-                            event_map,
-                            state_res_store=StateResolutionStore(self.store),
-                        )
-
-                        # We need to give _process_received_pdu the actual state events
-                        # rather than event ids, so generate that now.
-
-                        # First though we need to fetch all the events that are in
-                        # state_map, so we can build up the state below.
-                        evs = await self.store.get_events(
-                            list(state_map.values()),
-                            get_prev_content=False,
-                            redact_behaviour=EventRedactBehaviour.AS_IS,
-                        )
-                        event_map.update(evs)
-
-                        state = [event_map[e] for e in state_map.values()]
-                    except Exception:
-                        logger.warning(
-                            "Error attempting to resolve state at missing "
-                            "prev_events",
-                            exc_info=True,
-                        )
-                        raise FederationError(
-                            "ERROR",
-                            403,
-                            "We can't get valid state history.",
-                            affected=event_id,
-                        )
+            else:
+                state = await self._resolve_state_at_missing_prevs(origin, pdu)
 
         # A second round of checks for all events. Check that the event passes auth
         # based on `auth_events`, this allows us to assert that the event would
@@ -1493,6 +1395,123 @@ class FederationHandler(BaseHandler):
                 event_infos,
             )
 
+    async def _resolve_state_at_missing_prevs(
+        self, dest: str, event: EventBase
+    ) -> Optional[Iterable[EventBase]]:
+        """Calculate the state at an event with missing prev_events.
+
+        This is used when we have pulled a batch of events from a remote server, and
+        still don't have all the prev_events.
+
+        If we already have all the prev_events for `event`, this method does nothing.
+
+        Otherwise, the missing prevs become new backwards extremities, and we fall back
+        to asking the remote server for the state after each missing `prev_event`,
+        and resolving across them.
+
+        That's ok provided we then resolve the state against other bits of the DAG
+        before using it - in other words, that the received event `event` is not going
+        to become the only forwards_extremity in the room (which will ensure that you
+        can't just take over a room by sending an event, withholding its prev_events,
+        and declaring yourself to be an admin in the subsequent state request).
+
+        In other words: we should only call this method if `event` has been *pulled*
+        as part of a batch of missing prev events, or similar.
+
+        Params:
+            dest: the remote server to ask for state at the missing prevs. Typically,
+                this will be the server we got `event` from.
+            event: an event to check for missing prevs.
+
+        Returns:
+            if we already had all the prev events, `None`. Otherwise, returns a list of
+            the events in the state at `event`.
+        """
+        room_id = event.room_id
+        event_id = event.event_id
+
+        prevs = set(event.prev_event_ids())
+        seen = await self.store.have_events_in_timeline(prevs)
+        missing_prevs = prevs - seen
+
+        if not missing_prevs:
+            return None
+
+        logger.info(
+            "Event %s is missing prev_events %s: calculating state for a "
+            "backwards extremity",
+            event_id,
+            shortstr(missing_prevs),
+        )
+        # Calculate the state after each of the previous events, and
+        # resolve them to find the correct state at the current event.
+        event_map = {event_id: event}
+        try:
+            # Get the state of the events we know about
+            ours = await self.state_store.get_state_groups_ids(room_id, seen)
+
+            # state_maps is a list of mappings from (type, state_key) to event_id
+            state_maps: List[StateMap[str]] = list(ours.values())
+
+            # we don't need this any more, let's delete it.
+            del ours
+
+            # Ask the remote server for the states we don't
+            # know about
+            for p in missing_prevs:
+                logger.info("Requesting state after missing prev_event %s", p)
+
+                with nested_logging_context(p):
+                    # note that if any of the missing prevs share missing state or
+                    # auth events, the requests to fetch those events are deduped
+                    # by the get_pdu_cache in federation_client.
+                    remote_state = await self._get_state_after_missing_prev_event(
+                        dest, room_id, p
+                    )
+
+                    remote_state_map = {
+                        (x.type, x.state_key): x.event_id for x in remote_state
+                    }
+                    state_maps.append(remote_state_map)
+
+                    for x in remote_state:
+                        event_map[x.event_id] = x
+
+            room_version = await self.store.get_room_version_id(room_id)
+            state_map = await self._state_resolution_handler.resolve_events_with_store(
+                room_id,
+                room_version,
+                state_maps,
+                event_map,
+                state_res_store=StateResolutionStore(self.store),
+            )
+
+            # We need to give _process_received_pdu the actual state events
+            # rather than event ids, so generate that now.
+
+            # First though we need to fetch all the events that are in
+            # state_map, so we can build up the state below.
+            evs = await self.store.get_events(
+                list(state_map.values()),
+                get_prev_content=False,
+                redact_behaviour=EventRedactBehaviour.AS_IS,
+            )
+            event_map.update(evs)
+
+            state = [event_map[e] for e in state_map.values()]
+        except Exception:
+            logger.warning(
+                "Error attempting to resolve state at missing prev_events",
+                exc_info=True,
+            )
+            raise FederationError(
+                "ERROR",
+                403,
+                "We can't get valid state history.",
+                affected=event_id,
+            )
+        return state
+
     def _sanity_check_event(self, ev: EventBase) -> None:
         """
         Do some early sanity checks of a received event