summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/10901.misc1
-rw-r--r--synapse/handlers/federation_event.py91
2 files changed, 23 insertions, 69 deletions
diff --git a/changelog.d/10901.misc b/changelog.d/10901.misc
new file mode 100644
index 0000000000..9a765435db
--- /dev/null
+++ b/changelog.d/10901.misc
@@ -0,0 +1 @@
+Clean up some of the federation event authentication code for clarity.
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 7d468bd2df..4eefcc36d8 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -27,11 +27,8 @@ from typing import (
     Tuple,
 )
 
-import attr
 from prometheus_client import Counter
 
-from twisted.internet import defer
-
 from synapse import event_auth
 from synapse.api.constants import (
     EventContentFields,
@@ -54,11 +51,7 @@ from synapse.event_auth import auth_types_for_event
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.federation.federation_client import InvalidResponseError
-from synapse.logging.context import (
-    make_deferred_yieldable,
-    nested_logging_context,
-    run_in_background,
-)
+from synapse.logging.context import nested_logging_context, run_in_background
 from synapse.logging.utils import log_function
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
@@ -75,7 +68,11 @@ from synapse.types import (
     UserID,
     get_domain_from_id,
 )
-from synapse.util.async_helpers import Linearizer, concurrently_execute
+from synapse.util.async_helpers import (
+    Linearizer,
+    concurrently_execute,
+    yieldable_gather_results,
+)
 from synapse.util.iterutils import batch_iter
 from synapse.util.retryutils import NotRetryingDestination
 from synapse.util.stringutils import shortstr
@@ -92,30 +89,6 @@ soft_failed_event_counter = Counter(
 )
 
 
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class _NewEventInfo:
-    """Holds information about a received event, ready for passing to _auth_and_persist_events
-
-    Attributes:
-        event: the received event
-
-        claimed_auth_event_map: a map of (type, state_key) => event for the event's
-            claimed auth_events.
-
-            This can include events which have not yet been persisted, in the case that
-            we are backfilling a batch of events.
-
-            Note: May be incomplete: if we were unable to find all of the claimed auth
-            events. Also, treat the contents with caution: the events might also have
-            been rejected, might not yet have been authorized themselves, or they might
-            be in the wrong room.
-
-    """
-
-    event: EventBase
-    claimed_auth_event_map: StateMap[EventBase]
-
-
 class FederationEventHandler:
     """Handles events that originated from federation.
 
@@ -1203,47 +1176,27 @@ class FederationEventHandler:
             allow_rejected=True,
         )
 
-        event_infos = []
-        for event in fetched_events:
-            auth = {}
-            for auth_event_id in event.auth_event_ids():
-                ae = persisted_events.get(auth_event_id)
-                if ae:
-                    auth[(ae.type, ae.state_key)] = ae
-                else:
-                    logger.info("Missing auth event %s", auth_event_id)
-
-            event_infos.append(_NewEventInfo(event, auth))
-
-        if not event_infos:
-            return
-
-        async def prep(ev_info: _NewEventInfo) -> EventContext:
-            event = ev_info.event
+        async def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
             with nested_logging_context(suffix=event.event_id):
-                res = EventContext.for_outlier()
-                res = await self._check_event_auth(
+                auth = {}
+                for auth_event_id in event.auth_event_ids():
+                    ae = persisted_events.get(auth_event_id)
+                    if ae:
+                        auth[(ae.type, ae.state_key)] = ae
+                    else:
+                        logger.info("Missing auth event %s", auth_event_id)
+
+                context = EventContext.for_outlier()
+                context = await self._check_event_auth(
                     origin,
                     event,
-                    res,
-                    claimed_auth_event_map=ev_info.claimed_auth_event_map,
+                    context,
+                    claimed_auth_event_map=auth,
                 )
-            return res
-
-        contexts = await make_deferred_yieldable(
-            defer.gatherResults(
-                [run_in_background(prep, ev_info) for ev_info in event_infos],
-                consumeErrors=True,
-            )
-        )
+            return event, context
 
-        await self.persist_events_and_notify(
-            room_id,
-            [
-                (ev_info.event, context)
-                for ev_info, context in zip(event_infos, contexts)
-            ],
-        )
+        events_to_persist = await yieldable_gather_results(prep, fetched_events)
+        await self.persist_events_and_notify(room_id, events_to_persist)
 
     async def _check_event_auth(
         self,