diff --git a/changelog.d/10896.misc b/changelog.d/10896.misc
index 41de995842..9a765435db 100644
--- a/changelog.d/10896.misc
+++ b/changelog.d/10896.misc
@@ -1 +1 @@
- Clean up some of the federation event authentication code for clarity.
+Clean up some of the federation event authentication code for clarity.
diff --git a/changelog.d/10926.misc b/changelog.d/10926.misc
new file mode 100644
index 0000000000..9a765435db
--- /dev/null
+++ b/changelog.d/10926.misc
@@ -0,0 +1 @@
+Clean up some of the federation event authentication code for clarity.
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 01fd841122..2c4644b4a3 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -68,11 +68,7 @@ from synapse.types import (
UserID,
get_domain_from_id,
)
-from synapse.util.async_helpers import (
- Linearizer,
- concurrently_execute,
- yieldable_gather_results,
-)
+from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.iterutils import batch_iter
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import shortstr
@@ -1189,7 +1185,10 @@ class FederationEventHandler:
allow_rejected=True,
)
- async def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
+ room_version = await self._store.get_room_version_id(room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+ def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
with nested_logging_context(suffix=event.event_id):
auth = {}
for auth_event_id in event.auth_event_ids():
@@ -1207,17 +1206,15 @@ class FederationEventHandler:
auth[(ae.type, ae.state_key)] = ae
context = EventContext.for_outlier()
- context = await self._check_event_auth(
- origin,
- event,
- context,
- claimed_auth_event_map=auth,
- )
+ try:
+ event_auth.check(room_version_obj, event, auth_events=auth)
+ except AuthError as e:
+ logger.warning("Rejecting %r because %s", event, e)
+ context.rejected = RejectedReason.AUTH_ERROR
+
return event, context
- events_to_persist = (
- x for x in await yieldable_gather_results(prep, fetched_events) if x
- )
+ events_to_persist = (x for x in (prep(event) for event in fetched_events) if x)
await self.persist_events_and_notify(room_id, tuple(events_to_persist))
async def _check_event_auth(
@@ -1226,7 +1223,6 @@ class FederationEventHandler:
event: EventBase,
context: EventContext,
state: Optional[Iterable[EventBase]] = None,
- claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
backfilled: bool = False,
) -> EventContext:
"""
@@ -1242,42 +1238,36 @@ class FederationEventHandler:
The state events used to check the event for soft-fail. If this is
not provided the current state events will be used.
- claimed_auth_event_map:
- A map of (type, state_key) => event for the event's claimed auth_events.
- Possibly including events that were rejected, or are in the wrong room.
-
- Only populated when populating outliers.
-
backfilled: True if the event was backfilled.
Returns:
The updated context object.
"""
- # claimed_auth_event_map should be given iff the event is an outlier
- assert bool(claimed_auth_event_map) == event.internal_metadata.outlier
+ # This method should only be used for non-outliers
+ assert not event.internal_metadata.outlier
room_version = await self._store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
- if claimed_auth_event_map:
- # if we have a copy of the auth events from the event, use that as the
- # basis for auth.
- auth_events = claimed_auth_event_map
- else:
- # otherwise, we calculate what the auth events *should* be, and use that
- prev_state_ids = await context.get_prev_state_ids()
- auth_events_ids = self._event_auth_handler.compute_auth_events(
- event, prev_state_ids, for_verification=True
- )
- auth_events_x = await self._store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
+ # calculate what the auth events *should* be, to use as a basis for auth.
+ prev_state_ids = await context.get_prev_state_ids()
+ auth_events_ids = self._event_auth_handler.compute_auth_events(
+ event, prev_state_ids, for_verification=True
+ )
+ auth_events_x = await self._store.get_events(auth_events_ids)
+ calculated_auth_event_map = {
+ (e.type, e.state_key): e for e in auth_events_x.values()
+ }
try:
(
context,
auth_events_for_auth,
) = await self._update_auth_events_and_context_for_auth(
- origin, event, context, auth_events
+ origin,
+ event,
+ context,
+ calculated_auth_event_map=calculated_auth_event_map,
)
except Exception:
# We don't really mind if the above fails, so lets not fail
@@ -1289,7 +1279,7 @@ class FederationEventHandler:
"Ignoring failure and continuing processing of event.",
event.event_id,
)
- auth_events_for_auth = auth_events
+ auth_events_for_auth = calculated_auth_event_map
try:
event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth)
@@ -1425,7 +1415,7 @@ class FederationEventHandler:
origin: str,
event: EventBase,
context: EventContext,
- input_auth_events: StateMap[EventBase],
+ calculated_auth_event_map: StateMap[EventBase],
) -> Tuple[EventContext, StateMap[EventBase]]:
"""Helper for _check_event_auth. See there for docs.
@@ -1443,19 +1433,17 @@ class FederationEventHandler:
event:
context:
- input_auth_events:
- Map from (event_type, state_key) to event
-
- Normally, our calculated auth_events based on the state of the room
- at the event's position in the DAG, though occasionally (eg if the
- event is an outlier), may be the auth events claimed by the remote
- server.
+ calculated_auth_event_map:
+ Our calculated auth_events based on the state of the room
+ at the event's position in the DAG.
Returns:
updated context, updated auth event map
"""
- # take a copy of input_auth_events before we modify it.
- auth_events: MutableStateMap[EventBase] = dict(input_auth_events)
+ assert not event.internal_metadata.outlier
+
+ # take a copy of calculated_auth_event_map before we modify it.
+ auth_events: MutableStateMap[EventBase] = dict(calculated_auth_event_map)
event_auth_events = set(event.auth_event_ids())
@@ -1496,15 +1484,6 @@ class FederationEventHandler:
}
)
- if event.internal_metadata.is_outlier():
- # XXX: given that, for an outlier, we'll be working with the
- # event's *claimed* auth events rather than those we calculated:
- # (a) is there any point in this test, since different_auth below will
- # obviously be empty
- # (b) alternatively, why don't we do it earlier?
- logger.info("Skipping auth_event fetch for outlier")
- return context, auth_events
-
different_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index c51e018da1..24fc77d7a7 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -82,7 +82,6 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
event,
context,
state=None,
- claimed_auth_event_map=None,
backfilled=False,
):
return context
|