diff --git a/changelog.d/7769.misc b/changelog.d/7769.misc
new file mode 100644
index 0000000000..2e200286ce
--- /dev/null
+++ b/changelog.d/7769.misc
@@ -0,0 +1 @@
+Add early returns to `_check_for_soft_fail`.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 8f0b9be791..fa5854578d 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -2061,76 +2061,67 @@ class FederationHandler(BaseHandler):
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
# "soft-fail" the event.
- do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier()
- if do_soft_fail_check:
- extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
-
- extrem_ids = set(extrem_ids)
- prev_event_ids = set(event.prev_event_ids())
-
- if extrem_ids == prev_event_ids:
- # If they're the same then the current state is the same as the
- # state at the event, so no point rechecking auth for soft fail.
- do_soft_fail_check = False
-
- if do_soft_fail_check:
- room_version = await self.store.get_room_version_id(event.room_id)
- room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
-
- # Calculate the "current state".
- if state is not None:
- # If we're explicitly given the state then we won't have all the
- # prev events, and so we have a gap in the graph. In this case
- # we want to be a little careful as we might have been down for
- # a while and have an incorrect view of the current state,
- # however we still want to do checks as gaps are easy to
- # maliciously manufacture.
- #
- # So we use a "current state" that is actually a state
- # resolution across the current forward extremities and the
- # given state at the event. This should correctly handle cases
- # like bans, especially with state res v2.
+ if backfilled or event.internal_metadata.is_outlier():
+ return
- state_sets = await self.state_store.get_state_groups(
- event.room_id, extrem_ids
- )
- state_sets = list(state_sets.values())
- state_sets.append(state)
- current_state_ids = await self.state_handler.resolve_events(
- room_version, state_sets, event
- )
- current_state_ids = {
- k: e.event_id for k, e in current_state_ids.items()
- }
- else:
- current_state_ids = await self.state_handler.get_current_state_ids(
- event.room_id, latest_event_ids=extrem_ids
- )
+ extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
+ extrem_ids = set(extrem_ids)
+ prev_event_ids = set(event.prev_event_ids())
- logger.debug(
- "Doing soft-fail check for %s: state %s",
- event.event_id,
- current_state_ids,
+ if extrem_ids == prev_event_ids:
+ # If they're the same then the current state is the same as the
+ # state at the event, so no point rechecking auth for soft fail.
+ return
+
+ room_version = await self.store.get_room_version_id(event.room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+ # Calculate the "current state".
+ if state is not None:
+ # If we're explicitly given the state then we won't have all the
+ # prev events, and so we have a gap in the graph. In this case
+ # we want to be a little careful as we might have been down for
+ # a while and have an incorrect view of the current state,
+ # however we still want to do checks as gaps are easy to
+ # maliciously manufacture.
+ #
+ # So we use a "current state" that is actually a state
+ # resolution across the current forward extremities and the
+ # given state at the event. This should correctly handle cases
+ # like bans, especially with state res v2.
+
+ state_sets = await self.state_store.get_state_groups(
+ event.room_id, extrem_ids
+ )
+ state_sets = list(state_sets.values())
+ state_sets.append(state)
+ current_state_ids = await self.state_handler.resolve_events(
+ room_version, state_sets, event
+ )
+ current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
+ else:
+ current_state_ids = await self.state_handler.get_current_state_ids(
+ event.room_id, latest_event_ids=extrem_ids
)
- # Now check if event pass auth against said current state
- auth_types = auth_types_for_event(event)
- current_state_ids = [
- e for k, e in current_state_ids.items() if k in auth_types
- ]
+ logger.debug(
+ "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids,
+ )
- current_auth_events = await self.store.get_events(current_state_ids)
- current_auth_events = {
- (e.type, e.state_key): e for e in current_auth_events.values()
- }
+ # Now check if event pass auth against said current state
+ auth_types = auth_types_for_event(event)
+ current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
- try:
- event_auth.check(
- room_version_obj, event, auth_events=current_auth_events
- )
- except AuthError as e:
- logger.warning("Soft-failing %r because %s", event, e)
- event.internal_metadata.soft_failed = True
+ current_auth_events = await self.store.get_events(current_state_ids)
+ current_auth_events = {
+ (e.type, e.state_key): e for e in current_auth_events.values()
+ }
+
+ try:
+ event_auth.check(room_version_obj, event, auth_events=current_auth_events)
+ except AuthError as e:
+ logger.warning("Soft-failing %r because %s", event, e)
+ event.internal_metadata.soft_failed = True
async def on_query_auth(
self, origin, event_id, room_id, remote_auth_chain, rejects, missing
|