summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/10781.misc1
-rw-r--r--synapse/handlers/federation_event.py90
-rw-r--r--tests/test_federation.py15
3 files changed, 43 insertions, 63 deletions
diff --git a/changelog.d/10781.misc b/changelog.d/10781.misc
new file mode 100644
index 0000000000..9a765435db
--- /dev/null
+++ b/changelog.d/10781.misc
@@ -0,0 +1 @@
+Clean up some of the federation event authentication code for clarity.
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index ccbfce0219..9ec90ac8c1 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -909,12 +909,18 @@ class FederationEventHandler:
             context = await self._state_handler.compute_event_context(
                 event, old_state=state
             )
-            await self._auth_and_persist_event(
-                origin, event, context, state=state, backfilled=backfilled
+            context = await self._check_event_auth(
+                origin,
+                event,
+                context,
+                state=state,
+                backfilled=backfilled,
             )
         except AuthError as e:
             raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
 
+        await self._run_push_actions_and_persist_event(event, context, backfilled)
+
         if backfilled:
             return
 
@@ -1239,51 +1245,6 @@ class FederationEventHandler:
             ],
         )
 
-    async def _auth_and_persist_event(
-        self,
-        origin: str,
-        event: EventBase,
-        context: EventContext,
-        state: Optional[Iterable[EventBase]] = None,
-        claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
-        backfilled: bool = False,
-    ) -> None:
-        """
-        Process an event by performing auth checks and then persisting to the database.
-
-        Args:
-            origin: The host the event originates from.
-            event: The event itself.
-            context:
-                The event context.
-
-            state:
-                The state events used to check the event for soft-fail. If this is
-                not provided the current state events will be used.
-
-            claimed_auth_event_map:
-                A map of (type, state_key) => event for the event's claimed auth_events.
-                Possibly incomplete, and possibly including events that are not yet
-                persisted, or authed, or in the right room.
-
-                Only populated when populating outliers.
-
-            backfilled: True if the event was backfilled.
-        """
-        # claimed_auth_event_map should be given iff the event is an outlier
-        assert bool(claimed_auth_event_map) == event.internal_metadata.outlier
-
-        context = await self._check_event_auth(
-            origin,
-            event,
-            context,
-            state=state,
-            claimed_auth_event_map=claimed_auth_event_map,
-            backfilled=backfilled,
-        )
-
-        await self._run_push_actions_and_persist_event(event, context, backfilled)
-
     async def _check_event_auth(
         self,
         origin: str,
@@ -1558,39 +1519,45 @@ class FederationEventHandler:
                     event.room_id, [e.event_id for e in remote_auth_chain]
                 )
 
-                for e in remote_auth_chain:
-                    if e.event_id in seen_remotes:
+                for auth_event in remote_auth_chain:
+                    if auth_event.event_id in seen_remotes:
                         continue
 
-                    if e.event_id == event.event_id:
+                    if auth_event.event_id == event.event_id:
                         continue
 
                     try:
-                        auth_ids = e.auth_event_ids()
+                        auth_ids = auth_event.auth_event_ids()
                         auth = {
                             (e.type, e.state_key): e
                             for e in remote_auth_chain
                             if e.event_id in auth_ids or e.type == EventTypes.Create
                         }
-                        e.internal_metadata.outlier = True
+                        auth_event.internal_metadata.outlier = True
 
                         logger.debug(
                             "_check_event_auth %s missing_auth: %s",
                             event.event_id,
-                            e.event_id,
+                            auth_event.event_id,
                         )
                         missing_auth_event_context = (
-                            await self._state_handler.compute_event_context(e)
+                            await self._state_handler.compute_event_context(auth_event)
                         )
-                        await self._auth_and_persist_event(
+
+                        missing_auth_event_context = await self._check_event_auth(
                             origin,
-                            e,
+                            auth_event,
                             missing_auth_event_context,
                             claimed_auth_event_map=auth,
                         )
+                        await self.persist_events_and_notify(
+                            event.room_id, [(auth_event, missing_auth_event_context)]
+                        )
 
-                        if e.event_id in event_auth_events:
-                            auth_events[(e.type, e.state_key)] = e
+                        if auth_event.event_id in event_auth_events:
+                            auth_events[
+                                (auth_event.type, auth_event.state_key)
+                            ] = auth_event
                     except AuthError:
                         pass
 
@@ -1733,10 +1700,13 @@ class FederationEventHandler:
             context: The event context.
             backfilled: True if the event was backfilled.
         """
+        # this method should not be called on outliers (those code paths call
+        # persist_events_and_notify directly.)
+        assert not event.internal_metadata.outlier
+
         try:
             if (
-                not event.internal_metadata.is_outlier()
-                and not backfilled
+                not backfilled
                 and not context.rejected
                 and (await self._store.get_min_depth(event.room_id)) <= event.depth
             ):
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 61c9d7c2ef..c51e018da1 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -76,9 +76,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         self.handler = self.homeserver.get_federation_handler()
         federation_event_handler = self.homeserver.get_federation_event_handler()
-        federation_event_handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed(
-            context
-        )
+
+        async def _check_event_auth(
+            origin,
+            event,
+            context,
+            state=None,
+            claimed_auth_event_map=None,
+            backfilled=False,
+        ):
+            return context
+
+        federation_event_handler._check_event_auth = _check_event_auth
         self.client = self.homeserver.get_federation_client()
         self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
             pdus