summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-08-19 18:05:12 +0100
committerGitHub <noreply@github.com>2021-08-19 17:05:12 +0000
commite81d62009ee091d69d56bca702ba0edd39becb1c (patch)
treeb20911f0ffe6e7fee3a022ff658144bf73cdce58
parentExtract `_resolve_state_at_missing_prevs` (#10624) (diff)
downloadsynapse-e81d62009ee091d69d56bca702ba0edd39becb1c.tar.xz
Split `on_receive_pdu` in half (#10640)
Here we split on_receive_pdu into two functions (on_receive_pdu and process_pulled_event), rather than having both cases in the same method. There's a tiny bit of overlap, but not that much.
Diffstat (limited to '')
-rw-r--r--changelog.d/10640.misc1
-rw-r--r--synapse/federation/federation_server.py4
-rw-r--r--synapse/handlers/federation.py236
-rw-r--r--tests/test_federation.py10
4 files changed, 142 insertions, 109 deletions
diff --git a/changelog.d/10640.misc b/changelog.d/10640.misc
new file mode 100644
index 0000000000..9a765435db
--- /dev/null
+++ b/changelog.d/10640.misc
@@ -0,0 +1 @@
+Clean up some of the federation event authentication code for clarity.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index afd8f8580a..e1b58d40c5 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1005,9 +1005,7 @@ class FederationServer(FederationBase):
             async with lock:
                 logger.info("handling received PDU: %s", event)
                 try:
-                    await self.handler.on_receive_pdu(
-                        origin, event, sent_to_us_directly=True
-                    )
+                    await self.handler.on_receive_pdu(origin, event)
                 except FederationError as e:
                     # XXX: Ideally we'd inform the remote we failed to process
                     # the event, but we can't return an error in the transaction
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index de86918b7b..246df43501 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -203,18 +203,13 @@ class FederationHandler(BaseHandler):
 
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
 
-    async def on_receive_pdu(
-        self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
-    ) -> None:
-        """Process a PDU received via a federation /send/ transaction, or
-        via backfill of missing prev_events
+    async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
+        """Process a PDU received via a federation /send/ transaction
 
         Args:
             origin: server which initiated the /send/ transaction. Will
                 be used to fetch missing events or state.
             pdu: received PDU
-            sent_to_us_directly: True if this event was pushed to us; False if
-                we pulled it as the result of a missing prev_event.
         """
 
         room_id = pdu.room_id
@@ -276,8 +271,6 @@ class FederationHandler(BaseHandler):
             )
             return None
 
-        state = None
-
         # Check that the event passes auth based on the state at the event. This is
         # done for events that are to be added to the timeline (non-outliers).
         #
@@ -285,83 +278,72 @@ class FederationHandler(BaseHandler):
         #  - Fetching any missing prev events to fill in gaps in the graph
         #  - Fetching state if we have a hole in the graph
         if not pdu.internal_metadata.is_outlier():
-            if sent_to_us_directly:
-                prevs = set(pdu.prev_event_ids())
-                seen = await self.store.have_events_in_timeline(prevs)
-                missing_prevs = prevs - seen
-
-                if missing_prevs:
-                    # We only backfill backwards to the min depth.
-                    min_depth = await self.get_min_depth_for_context(pdu.room_id)
-                    logger.debug("min_depth: %d", min_depth)
-
-                    if min_depth is not None and pdu.depth > min_depth:
-                        # If we're missing stuff, ensure we only fetch stuff one
-                        # at a time.
+            prevs = set(pdu.prev_event_ids())
+            seen = await self.store.have_events_in_timeline(prevs)
+            missing_prevs = prevs - seen
+
+            if missing_prevs:
+                # We only backfill backwards to the min depth.
+                min_depth = await self.get_min_depth_for_context(pdu.room_id)
+                logger.debug("min_depth: %d", min_depth)
+
+                if min_depth is not None and pdu.depth > min_depth:
+                    # If we're missing stuff, ensure we only fetch stuff one
+                    # at a time.
+                    logger.info(
+                        "Acquiring room lock to fetch %d missing prev_events: %s",
+                        len(missing_prevs),
+                        shortstr(missing_prevs),
+                    )
+                    with (await self._room_pdu_linearizer.queue(pdu.room_id)):
                         logger.info(
-                            "Acquiring room lock to fetch %d missing prev_events: %s",
+                            "Acquired room lock to fetch %d missing prev_events",
                             len(missing_prevs),
-                            shortstr(missing_prevs),
                         )
-                        with (await self._room_pdu_linearizer.queue(pdu.room_id)):
-                            logger.info(
-                                "Acquired room lock to fetch %d missing prev_events",
-                                len(missing_prevs),
+
+                        try:
+                            await self._get_missing_events_for_pdu(
+                                origin, pdu, prevs, min_depth
                             )
+                        except Exception as e:
+                            raise Exception(
+                                "Error fetching missing prev_events for %s: %s"
+                                % (event_id, e)
+                            ) from e
 
-                            try:
-                                await self._get_missing_events_for_pdu(
-                                    origin, pdu, prevs, min_depth
-                                )
-                            except Exception as e:
-                                raise Exception(
-                                    "Error fetching missing prev_events for %s: %s"
-                                    % (event_id, e)
-                                ) from e
-
-                        # Update the set of things we've seen after trying to
-                        # fetch the missing stuff
-                        seen = await self.store.have_events_in_timeline(prevs)
-                        missing_prevs = prevs - seen
-
-                        if not missing_prevs:
-                            logger.info("Found all missing prev_events")
-
-                    if missing_prevs:
-                        # since this event was pushed to us, it is possible for it to
-                        # become the only forward-extremity in the room, and we would then
-                        # trust its state to be the state for the whole room. This is very
-                        # bad. Further, if the event was pushed to us, there is no excuse
-                        # for us not to have all the prev_events. (XXX: apart from
-                        # min_depth?)
-                        #
-                        # We therefore reject any such events.
-                        logger.warning(
-                            "Rejecting: failed to fetch %d prev events: %s",
-                            len(missing_prevs),
-                            shortstr(missing_prevs),
-                        )
-                        raise FederationError(
-                            "ERROR",
-                            403,
-                            (
-                                "Your server isn't divulging details about prev_events "
-                                "referenced in this event."
-                            ),
-                            affected=pdu.event_id,
-                        )
+                    # Update the set of things we've seen after trying to
+                    # fetch the missing stuff
+                    seen = await self.store.have_events_in_timeline(prevs)
+                    missing_prevs = prevs - seen
 
-            else:
-                state = await self._resolve_state_at_missing_prevs(origin, pdu)
+                    if not missing_prevs:
+                        logger.info("Found all missing prev_events")
+
+                if missing_prevs:
+                    # since this event was pushed to us, it is possible for it to
+                    # become the only forward-extremity in the room, and we would then
+                    # trust its state to be the state for the whole room. This is very
+                    # bad. Further, if the event was pushed to us, there is no excuse
+                    # for us not to have all the prev_events. (XXX: apart from
+                    # min_depth?)
+                    #
+                    # We therefore reject any such events.
+                    logger.warning(
+                        "Rejecting: failed to fetch %d prev events: %s",
+                        len(missing_prevs),
+                        shortstr(missing_prevs),
+                    )
+                    raise FederationError(
+                        "ERROR",
+                        403,
+                        (
+                            "Your server isn't divulging details about prev_events "
+                            "referenced in this event."
+                        ),
+                        affected=pdu.event_id,
+                    )
 
-        # A second round of checks for all events. Check that the event passes auth
-        # based on `auth_events`, this allows us to assert that the event would
-        # have been allowed at some point. If an event passes this check its OK
-        # for it to be used as part of a returned `/state` request, as either
-        # a) we received the event as part of the original join and so trust it, or
-        # b) we'll do a state resolution with existing state before it becomes
-        # part of the "current state", which adds more protection.
-        await self._process_received_pdu(origin, pdu, state=state)
+        await self._process_received_pdu(origin, pdu, state=None)
 
     async def _get_missing_events_for_pdu(
         self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
@@ -461,24 +443,7 @@ class FederationHandler(BaseHandler):
             return
 
         logger.info("Got %d prev_events", len(missing_events))
-
-        # We want to sort these by depth so we process them and
-        # tell clients about them in order.
-        missing_events.sort(key=lambda x: x.depth)
-
-        for ev in missing_events:
-            logger.info("Handling received prev_event %s", ev)
-            with nested_logging_context(ev.event_id):
-                try:
-                    await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
-                except FederationError as e:
-                    if e.code == 403:
-                        logger.warning(
-                            "Received prev_event %s failed history check.",
-                            ev.event_id,
-                        )
-                    else:
-                        raise
+        await self._process_pulled_events(origin, missing_events)
 
     async def _get_state_for_room(
         self,
@@ -1395,6 +1360,81 @@ class FederationHandler(BaseHandler):
                 event_infos,
             )
 
+    async def _process_pulled_events(
+        self, origin: str, events: Iterable[EventBase]
+    ) -> None:
+        """Process a batch of events we have pulled from a remote server
+
+        Pulls in any events required to auth the events, persists the received events,
+        and notifies clients, if appropriate.
+
+        Assumes the events have already had their signatures and hashes checked.
+
+        Params:
+            origin: The server we received these events from
+            events: The received events.
+        """
+
+        # We want to sort these by depth so we process them and
+        # tell clients about them in order.
+        sorted_events = sorted(events, key=lambda x: x.depth)
+
+        for ev in sorted_events:
+            with nested_logging_context(ev.event_id):
+                await self._process_pulled_event(origin, ev)
+
+    async def _process_pulled_event(self, origin: str, event: EventBase) -> None:
+        """Process a single event that we have pulled from a remote server
+
+        Pulls in any events required to auth the event, persists the received event,
+        and notifies clients, if appropriate.
+
+        Assumes the event has already had its signatures and hashes checked.
+
+        This is somewhat equivalent to on_receive_pdu, but applies somewhat different
+        logic in the case that we are missing prev_events (in particular, it just
+        requests the state at that point, rather than triggering a get_missing_events) -
+        so is appropriate when we have pulled the event from a remote server, rather
+        than having it pushed to us.
+
+        Params:
+            origin: The server we received this event from
+            events: The received event
+        """
+        logger.info("Processing pulled event %s", event)
+
+        # these should not be outliers.
+        assert not event.internal_metadata.is_outlier()
+
+        event_id = event.event_id
+
+        existing = await self.store.get_event(
+            event_id, allow_none=True, allow_rejected=True
+        )
+        if existing:
+            if not existing.internal_metadata.is_outlier():
+                logger.info(
+                    "Ignoring received event %s which we have already seen",
+                    event_id,
+                )
+                return
+            logger.info("De-outliering event %s", event_id)
+
+        try:
+            self._sanity_check_event(event)
+        except SynapseError as err:
+            logger.warning("Event %s failed sanity check: %s", event_id, err)
+            return
+
+        try:
+            state = await self._resolve_state_at_missing_prevs(origin, event)
+            await self._process_received_pdu(origin, event, state=state)
+        except FederationError as e:
+            if e.code == 403:
+                logger.warning("Pulled event %s failed history check.", event_id)
+            else:
+                raise
+
     async def _resolve_state_at_missing_prevs(
         self, dest: str, event: EventBase
     ) -> Optional[Iterable[EventBase]]:
@@ -1780,7 +1820,7 @@ class FederationHandler(BaseHandler):
                     p,
                 )
                 with nested_logging_context(p.event_id):
-                    await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+                    await self.on_receive_pdu(origin, p)
             except Exception as e:
                 logger.warning(
                     "Error handling queued PDU %s from %s: %s", p.event_id, origin, e
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 3785799f46..348fcb72a7 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -85,11 +85,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         # Send the join, it should return None (which is not an error)
         self.assertEqual(
-            self.get_success(
-                self.handler.on_receive_pdu(
-                    "test.serv", join_event, sent_to_us_directly=True
-                )
-            ),
+            self.get_success(self.handler.on_receive_pdu("test.serv", join_event)),
             None,
         )
 
@@ -135,9 +131,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         with LoggingContext("test-context"):
             failure = self.get_failure(
-                self.handler.on_receive_pdu(
-                    "test.serv", lying_event, sent_to_us_directly=True
-                ),
+                self.handler.on_receive_pdu("test.serv", lying_event),
                 FederationError,
             )