summary refs log tree commit diff
path: root/synapse/handlers/federation_event.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation_event.py')
-rw-r--r--synapse/handlers/federation_event.py277
1 files changed, 82 insertions, 195 deletions
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index c74117c19a..b1dab57447 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import collections
 import itertools
 import logging
 from http import HTTPStatus
@@ -347,7 +348,7 @@ class FederationEventHandler:
         event.internal_metadata.send_on_behalf_of = origin
 
         context = await self._state_handler.compute_event_context(event)
-        context = await self._check_event_auth(origin, event, context)
+        await self._check_event_auth(origin, event, context)
         if context.rejected:
             raise SynapseError(
                 403, f"{event.membership} event was rejected", Codes.FORBIDDEN
@@ -485,7 +486,7 @@ class FederationEventHandler:
                 partial_state=partial_state,
             )
 
-            context = await self._check_event_auth(origin, event, context)
+            await self._check_event_auth(origin, event, context)
             if context.rejected:
                 raise SynapseError(400, "Join event was rejected")
 
@@ -1116,11 +1117,7 @@ class FederationEventHandler:
             state_ids_before_event=state_ids,
         )
         try:
-            context = await self._check_event_auth(
-                origin,
-                event,
-                context,
-            )
+            await self._check_event_auth(origin, event, context)
         except AuthError as e:
             # This happens only if we couldn't find the auth events. We'll already have
             # logged a warning, so now we just convert to a FederationError.
@@ -1495,11 +1492,8 @@ class FederationEventHandler:
         )
 
     async def _check_event_auth(
-        self,
-        origin: str,
-        event: EventBase,
-        context: EventContext,
-    ) -> EventContext:
+        self, origin: str, event: EventBase, context: EventContext
+    ) -> None:
         """
         Checks whether an event should be rejected (for failing auth checks).
 
@@ -1509,9 +1503,6 @@ class FederationEventHandler:
             context:
                 The event context.
 
-        Returns:
-            The updated context object.
-
         Raises:
             AuthError if we were unable to find copies of the event's auth events.
                (Most other failures just cause us to set `context.rejected`.)
@@ -1526,7 +1517,7 @@ class FederationEventHandler:
             logger.warning("While validating received event %r: %s", event, e)
             # TODO: use a different rejected reason here?
             context.rejected = RejectedReason.AUTH_ERROR
-            return context
+            return
 
         # next, check that we have all of the event's auth events.
         #
@@ -1538,6 +1529,9 @@ class FederationEventHandler:
         )
 
         # ... and check that the event passes auth at those auth events.
+        # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
+        #   4. Passes authorization rules based on the event’s auth events,
+        #      otherwise it is rejected.
         try:
             await check_state_independent_auth_rules(self._store, event)
             check_state_dependent_auth_rules(event, claimed_auth_events)
@@ -1546,55 +1540,90 @@ class FederationEventHandler:
                 "While checking auth of %r against auth_events: %s", event, e
             )
             context.rejected = RejectedReason.AUTH_ERROR
-            return context
+            return
+
+        # now check the auth rules pass against the room state before the event
+        # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
+        #   5. Passes authorization rules based on the state before the event,
+        #      otherwise it is rejected.
+        #
+        # ... however, if we only have partial state for the room, then there is a good
+        # chance that we'll be missing some of the state needed to auth the new event.
+        # So, we state-resolve the auth events that we are given against the state that
+        # we know about, which ensures things like bans are applied. (Note that we'll
+        # already have checked we have all the auth events, in
+        # _load_or_fetch_auth_events_for_event above)
+        if context.partial_state:
+            room_version = await self._store.get_room_version_id(event.room_id)
+
+            local_state_id_map = await context.get_prev_state_ids()
+            claimed_auth_events_id_map = {
+                (ev.type, ev.state_key): ev.event_id for ev in claimed_auth_events
+            }
+
+            state_for_auth_id_map = (
+                await self._state_resolution_handler.resolve_events_with_store(
+                    event.room_id,
+                    room_version,
+                    [local_state_id_map, claimed_auth_events_id_map],
+                    event_map=None,
+                    state_res_store=StateResolutionStore(self._store),
+                )
+            )
+        else:
+            event_types = event_auth.auth_types_for_event(event.room_version, event)
+            state_for_auth_id_map = await context.get_prev_state_ids(
+                StateFilter.from_types(event_types)
+            )
 
-        # now check auth against what we think the auth events *should* be.
-        event_types = event_auth.auth_types_for_event(event.room_version, event)
-        prev_state_ids = await context.get_prev_state_ids(
-            StateFilter.from_types(event_types)
+        calculated_auth_event_ids = self._event_auth_handler.compute_auth_events(
+            event, state_for_auth_id_map, for_verification=True
         )
 
-        auth_events_ids = self._event_auth_handler.compute_auth_events(
-            event, prev_state_ids, for_verification=True
+        # if those are the same, we're done here.
+        if collections.Counter(event.auth_event_ids()) == collections.Counter(
+            calculated_auth_event_ids
+        ):
+            return
+
+        # otherwise, re-run the auth checks based on what we calculated.
+        calculated_auth_events = await self._store.get_events_as_list(
+            calculated_auth_event_ids
         )
-        auth_events_x = await self._store.get_events(auth_events_ids)
+
+        # log the differences
+
+        claimed_auth_event_map = {(e.type, e.state_key): e for e in claimed_auth_events}
         calculated_auth_event_map = {
-            (e.type, e.state_key): e for e in auth_events_x.values()
+            (e.type, e.state_key): e for e in calculated_auth_events
         }
+        logger.info(
+            "event's auth_events are different to our calculated auth_events. "
+            "Claimed but not calculated: %s. Calculated but not claimed: %s",
+            [
+                ev
+                for k, ev in claimed_auth_event_map.items()
+                if k not in calculated_auth_event_map
+                or calculated_auth_event_map[k].event_id != ev.event_id
+            ],
+            [
+                ev
+                for k, ev in calculated_auth_event_map.items()
+                if k not in claimed_auth_event_map
+                or claimed_auth_event_map[k].event_id != ev.event_id
+            ],
+        )
 
         try:
-            updated_auth_events = await self._update_auth_events_for_auth(
+            check_state_dependent_auth_rules(event, calculated_auth_events)
+        except AuthError as e:
+            logger.warning(
+                "While checking auth of %r against room state before the event: %s",
                 event,
-                calculated_auth_event_map=calculated_auth_event_map,
-            )
-        except Exception:
-            # We don't really mind if the above fails, so lets not fail
-            # processing if it does. However, it really shouldn't fail so
-            # let's still log as an exception since we'll still want to fix
-            # any bugs.
-            logger.exception(
-                "Failed to double check auth events for %s with remote. "
-                "Ignoring failure and continuing processing of event.",
-                event.event_id,
-            )
-            updated_auth_events = None
-
-        if updated_auth_events:
-            context = await self._update_context_for_auth_events(
-                event, context, updated_auth_events
+                e,
             )
-            auth_events_for_auth = updated_auth_events
-        else:
-            auth_events_for_auth = calculated_auth_event_map
-
-        try:
-            check_state_dependent_auth_rules(event, auth_events_for_auth.values())
-        except AuthError as e:
-            logger.warning("Failed auth resolution for %r because %s", event, e)
             context.rejected = RejectedReason.AUTH_ERROR
 
-        return context
-
     async def _maybe_kick_guest_users(self, event: EventBase) -> None:
         if event.type != EventTypes.GuestAccess:
             return
@@ -1704,93 +1733,6 @@ class FederationEventHandler:
             soft_failed_event_counter.inc()
             event.internal_metadata.soft_failed = True
 
-    async def _update_auth_events_for_auth(
-        self,
-        event: EventBase,
-        calculated_auth_event_map: StateMap[EventBase],
-    ) -> Optional[StateMap[EventBase]]:
-        """Helper for _check_event_auth. See there for docs.
-
-        Checks whether a given event has the expected auth events. If it
-        doesn't then we talk to the remote server to compare state to see if
-        we can come to a consensus (e.g. if one server missed some valid
-        state).
-
-        This attempts to resolve any potential divergence of state between
-        servers, but is not essential and so failures should not block further
-        processing of the event.
-
-        Args:
-            event:
-
-            calculated_auth_event_map:
-                Our calculated auth_events based on the state of the room
-                at the event's position in the DAG.
-
-        Returns:
-            updated auth event map, or None if no changes are needed.
-
-        """
-        assert not event.internal_metadata.outlier
-
-        # check for events which are in the event's claimed auth_events, but not
-        # in our calculated event map.
-        event_auth_events = set(event.auth_event_ids())
-        different_auth = event_auth_events.difference(
-            e.event_id for e in calculated_auth_event_map.values()
-        )
-
-        if not different_auth:
-            return None
-
-        logger.info(
-            "auth_events refers to events which are not in our calculated auth "
-            "chain: %s",
-            different_auth,
-        )
-
-        # XXX: currently this checks for redactions but I'm not convinced that is
-        # necessary?
-        different_events = await self._store.get_events_as_list(different_auth)
-
-        # double-check they're all in the same room - we should already have checked
-        # this but it doesn't hurt to check again.
-        for d in different_events:
-            assert (
-                d.room_id == event.room_id
-            ), f"Event {event.event_id} refers to auth_event {d.event_id} which is in a different room"
-
-        # now we state-resolve between our own idea of the auth events, and the remote's
-        # idea of them.
-
-        local_state = calculated_auth_event_map.values()
-        remote_auth_events = dict(calculated_auth_event_map)
-        remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
-        remote_state = remote_auth_events.values()
-
-        room_version = await self._store.get_room_version_id(event.room_id)
-        new_state = await self._state_handler.resolve_events(
-            room_version, (local_state, remote_state), event
-        )
-        different_state = {
-            (d.type, d.state_key): d
-            for d in new_state.values()
-            if calculated_auth_event_map.get((d.type, d.state_key)) != d
-        }
-        if not different_state:
-            logger.info("State res returned no new state")
-            return None
-
-        logger.info(
-            "After state res: updating auth_events with new state %s",
-            different_state.values(),
-        )
-
-        # take a copy of calculated_auth_event_map before we modify it.
-        auth_events = dict(calculated_auth_event_map)
-        auth_events.update(different_state)
-        return auth_events
-
     async def _load_or_fetch_auth_events_for_event(
         self, destination: str, event: EventBase
     ) -> Collection[EventBase]:
@@ -1888,61 +1830,6 @@ class FederationEventHandler:
 
         await self._auth_and_persist_outliers(room_id, remote_auth_events)
 
-    async def _update_context_for_auth_events(
-        self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
-    ) -> EventContext:
-        """Update the state_ids in an event context after auth event resolution,
-        storing the changes as a new state group.
-
-        Args:
-            event: The event we're handling the context for
-
-            context: initial event context
-
-            auth_events: Events to update in the event context.
-
-        Returns:
-            new event context
-        """
-        # exclude the state key of the new event from the current_state in the context.
-        if event.is_state():
-            event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
-        else:
-            event_key = None
-        state_updates = {
-            k: a.event_id for k, a in auth_events.items() if k != event_key
-        }
-
-        current_state_ids = await context.get_current_state_ids()
-        current_state_ids = dict(current_state_ids)  # type: ignore
-
-        current_state_ids.update(state_updates)
-
-        prev_state_ids = await context.get_prev_state_ids()
-        prev_state_ids = dict(prev_state_ids)
-
-        prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})
-
-        # create a new state group as a delta from the existing one.
-        prev_group = context.state_group
-        state_group = await self._state_storage_controller.store_state_group(
-            event.event_id,
-            event.room_id,
-            prev_group=prev_group,
-            delta_ids=state_updates,
-            current_state_ids=current_state_ids,
-        )
-
-        return EventContext.with_state(
-            storage=self._storage_controllers,
-            state_group=state_group,
-            state_group_before_event=context.state_group_before_event,
-            state_delta_due_to_event=state_updates,
-            prev_group=prev_group,
-            delta_ids=state_updates,
-            partial_state=context.partial_state,
-        )
-
     async def _run_push_actions_and_persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
     ) -> None: