diff --git a/changelog.d/10781.misc b/changelog.d/10781.misc
new file mode 100644
index 0000000000..9a765435db
--- /dev/null
+++ b/changelog.d/10781.misc
@@ -0,0 +1 @@
+Clean up some of the federation event authentication code for clarity.
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index ccbfce0219..9ec90ac8c1 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -909,12 +909,18 @@ class FederationEventHandler:
context = await self._state_handler.compute_event_context(
event, old_state=state
)
- await self._auth_and_persist_event(
- origin, event, context, state=state, backfilled=backfilled
+ context = await self._check_event_auth(
+ origin,
+ event,
+ context,
+ state=state,
+ backfilled=backfilled,
)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
+ await self._run_push_actions_and_persist_event(event, context, backfilled)
+
if backfilled:
return
@@ -1239,51 +1245,6 @@ class FederationEventHandler:
],
)
- async def _auth_and_persist_event(
- self,
- origin: str,
- event: EventBase,
- context: EventContext,
- state: Optional[Iterable[EventBase]] = None,
- claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
- backfilled: bool = False,
- ) -> None:
- """
- Process an event by performing auth checks and then persisting to the database.
-
- Args:
- origin: The host the event originates from.
- event: The event itself.
- context:
- The event context.
-
- state:
- The state events used to check the event for soft-fail. If this is
- not provided the current state events will be used.
-
- claimed_auth_event_map:
- A map of (type, state_key) => event for the event's claimed auth_events.
- Possibly incomplete, and possibly including events that are not yet
- persisted, or authed, or in the right room.
-
- Only populated when populating outliers.
-
- backfilled: True if the event was backfilled.
- """
- # claimed_auth_event_map should be given iff the event is an outlier
- assert bool(claimed_auth_event_map) == event.internal_metadata.outlier
-
- context = await self._check_event_auth(
- origin,
- event,
- context,
- state=state,
- claimed_auth_event_map=claimed_auth_event_map,
- backfilled=backfilled,
- )
-
- await self._run_push_actions_and_persist_event(event, context, backfilled)
-
async def _check_event_auth(
self,
origin: str,
@@ -1558,39 +1519,45 @@ class FederationEventHandler:
event.room_id, [e.event_id for e in remote_auth_chain]
)
- for e in remote_auth_chain:
- if e.event_id in seen_remotes:
+ for auth_event in remote_auth_chain:
+ if auth_event.event_id in seen_remotes:
continue
- if e.event_id == event.event_id:
+ if auth_event.event_id == event.event_id:
continue
try:
- auth_ids = e.auth_event_ids()
+ auth_ids = auth_event.auth_event_ids()
auth = {
(e.type, e.state_key): e
for e in remote_auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
- e.internal_metadata.outlier = True
+ auth_event.internal_metadata.outlier = True
logger.debug(
"_check_event_auth %s missing_auth: %s",
event.event_id,
- e.event_id,
+ auth_event.event_id,
)
missing_auth_event_context = (
- await self._state_handler.compute_event_context(e)
+ await self._state_handler.compute_event_context(auth_event)
)
- await self._auth_and_persist_event(
+
+ missing_auth_event_context = await self._check_event_auth(
origin,
- e,
+ auth_event,
missing_auth_event_context,
claimed_auth_event_map=auth,
)
+ await self.persist_events_and_notify(
+ event.room_id, [(auth_event, missing_auth_event_context)]
+ )
- if e.event_id in event_auth_events:
- auth_events[(e.type, e.state_key)] = e
+ if auth_event.event_id in event_auth_events:
+ auth_events[
+ (auth_event.type, auth_event.state_key)
+ ] = auth_event
except AuthError:
pass
@@ -1733,10 +1700,13 @@ class FederationEventHandler:
context: The event context.
backfilled: True if the event was backfilled.
"""
+ # this method should not be called on outliers (those code paths call
+ # persist_events_and_notify directly.)
+ assert not event.internal_metadata.outlier
+
try:
if (
- not event.internal_metadata.is_outlier()
- and not backfilled
+ not backfilled
and not context.rejected
and (await self._store.get_min_depth(event.room_id)) <= event.depth
):
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 61c9d7c2ef..c51e018da1 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -76,9 +76,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.handler = self.homeserver.get_federation_handler()
federation_event_handler = self.homeserver.get_federation_event_handler()
- federation_event_handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed(
- context
- )
+
+ async def _check_event_auth(
+ origin,
+ event,
+ context,
+ state=None,
+ claimed_auth_event_map=None,
+ backfilled=False,
+ ):
+ return context
+
+ federation_event_handler._check_event_auth = _check_event_auth
self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
pdus
|