summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2022-07-14 22:52:26 +0100
committerGitHub <noreply@github.com>2022-07-14 21:52:26 +0000
commitfe15a865a5a5bc8e89d770e43dae702aa2a809cb (patch)
tree86a36b867132fd6294cc6358450fe80d506efbca
parentCHANGES.md: fix link to upgrade notes (diff)
downloadsynapse-fe15a865a5a5bc8e89d770e43dae702aa2a809cb.tar.xz
Rip out auth-event reconciliation code (#12943)
There is a corner in `_check_event_auth` (long known as "the weird corner") where, if we get an event with auth_events which don't match those we were expecting, we attempt to resolve the diffence between our state and the remote's with a state resolution.

This isn't specced, and there's general agreement we shouldn't be doing it.

However, it turns out that the faster-joins code was relying on it, so we need to introduce something similar (but rather simpler) for that.
-rw-r--r--changelog.d/12943.misc1
-rw-r--r--synapse/handlers/federation_event.py277
-rw-r--r--synapse/state/__init__.py26
-rw-r--r--tests/handlers/test_federation.py140
-rw-r--r--tests/rest/client/test_third_party_rules.py11
-rw-r--r--tests/test_federation.py8
6 files changed, 88 insertions, 375 deletions
diff --git a/changelog.d/12943.misc b/changelog.d/12943.misc
new file mode 100644
index 0000000000..f66bb3ec32
--- /dev/null
+++ b/changelog.d/12943.misc
@@ -0,0 +1 @@
+Remove code which incorrectly attempted to reconcile state with remote servers when processing incoming events.
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index c74117c19a..b1dab57447 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import collections
 import itertools
 import logging
 from http import HTTPStatus
@@ -347,7 +348,7 @@ class FederationEventHandler:
         event.internal_metadata.send_on_behalf_of = origin
 
         context = await self._state_handler.compute_event_context(event)
-        context = await self._check_event_auth(origin, event, context)
+        await self._check_event_auth(origin, event, context)
         if context.rejected:
             raise SynapseError(
                 403, f"{event.membership} event was rejected", Codes.FORBIDDEN
@@ -485,7 +486,7 @@ class FederationEventHandler:
                 partial_state=partial_state,
             )
 
-            context = await self._check_event_auth(origin, event, context)
+            await self._check_event_auth(origin, event, context)
             if context.rejected:
                 raise SynapseError(400, "Join event was rejected")
 
@@ -1116,11 +1117,7 @@ class FederationEventHandler:
             state_ids_before_event=state_ids,
         )
         try:
-            context = await self._check_event_auth(
-                origin,
-                event,
-                context,
-            )
+            await self._check_event_auth(origin, event, context)
         except AuthError as e:
             # This happens only if we couldn't find the auth events. We'll already have
             # logged a warning, so now we just convert to a FederationError.
@@ -1495,11 +1492,8 @@ class FederationEventHandler:
         )
 
     async def _check_event_auth(
-        self,
-        origin: str,
-        event: EventBase,
-        context: EventContext,
-    ) -> EventContext:
+        self, origin: str, event: EventBase, context: EventContext
+    ) -> None:
         """
         Checks whether an event should be rejected (for failing auth checks).
 
@@ -1509,9 +1503,6 @@ class FederationEventHandler:
             context:
                 The event context.
 
-        Returns:
-            The updated context object.
-
         Raises:
             AuthError if we were unable to find copies of the event's auth events.
                (Most other failures just cause us to set `context.rejected`.)
@@ -1526,7 +1517,7 @@ class FederationEventHandler:
             logger.warning("While validating received event %r: %s", event, e)
             # TODO: use a different rejected reason here?
             context.rejected = RejectedReason.AUTH_ERROR
-            return context
+            return
 
         # next, check that we have all of the event's auth events.
         #
@@ -1538,6 +1529,9 @@ class FederationEventHandler:
         )
 
         # ... and check that the event passes auth at those auth events.
+        # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
+        #   4. Passes authorization rules based on the event’s auth events,
+        #      otherwise it is rejected.
         try:
             await check_state_independent_auth_rules(self._store, event)
             check_state_dependent_auth_rules(event, claimed_auth_events)
@@ -1546,55 +1540,90 @@ class FederationEventHandler:
                 "While checking auth of %r against auth_events: %s", event, e
             )
             context.rejected = RejectedReason.AUTH_ERROR
-            return context
+            return
+
+        # now check the auth rules pass against the room state before the event
+        # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
+        #   5. Passes authorization rules based on the state before the event,
+        #      otherwise it is rejected.
+        #
+        # ... however, if we only have partial state for the room, then there is a good
+        # chance that we'll be missing some of the state needed to auth the new event.
+        # So, we state-resolve the auth events that we are given against the state that
+        # we know about, which ensures things like bans are applied. (Note that we'll
+        # already have checked we have all the auth events, in
+        # _load_or_fetch_auth_events_for_event above)
+        if context.partial_state:
+            room_version = await self._store.get_room_version_id(event.room_id)
+
+            local_state_id_map = await context.get_prev_state_ids()
+            claimed_auth_events_id_map = {
+                (ev.type, ev.state_key): ev.event_id for ev in claimed_auth_events
+            }
+
+            state_for_auth_id_map = (
+                await self._state_resolution_handler.resolve_events_with_store(
+                    event.room_id,
+                    room_version,
+                    [local_state_id_map, claimed_auth_events_id_map],
+                    event_map=None,
+                    state_res_store=StateResolutionStore(self._store),
+                )
+            )
+        else:
+            event_types = event_auth.auth_types_for_event(event.room_version, event)
+            state_for_auth_id_map = await context.get_prev_state_ids(
+                StateFilter.from_types(event_types)
+            )
 
-        # now check auth against what we think the auth events *should* be.
-        event_types = event_auth.auth_types_for_event(event.room_version, event)
-        prev_state_ids = await context.get_prev_state_ids(
-            StateFilter.from_types(event_types)
+        calculated_auth_event_ids = self._event_auth_handler.compute_auth_events(
+            event, state_for_auth_id_map, for_verification=True
         )
 
-        auth_events_ids = self._event_auth_handler.compute_auth_events(
-            event, prev_state_ids, for_verification=True
+        # if those are the same, we're done here.
+        if collections.Counter(event.auth_event_ids()) == collections.Counter(
+            calculated_auth_event_ids
+        ):
+            return
+
+        # otherwise, re-run the auth checks based on what we calculated.
+        calculated_auth_events = await self._store.get_events_as_list(
+            calculated_auth_event_ids
         )
-        auth_events_x = await self._store.get_events(auth_events_ids)
+
+        # log the differences
+
+        claimed_auth_event_map = {(e.type, e.state_key): e for e in claimed_auth_events}
         calculated_auth_event_map = {
-            (e.type, e.state_key): e for e in auth_events_x.values()
+            (e.type, e.state_key): e for e in calculated_auth_events
         }
+        logger.info(
+            "event's auth_events are different to our calculated auth_events. "
+            "Claimed but not calculated: %s. Calculated but not claimed: %s",
+            [
+                ev
+                for k, ev in claimed_auth_event_map.items()
+                if k not in calculated_auth_event_map
+                or calculated_auth_event_map[k].event_id != ev.event_id
+            ],
+            [
+                ev
+                for k, ev in calculated_auth_event_map.items()
+                if k not in claimed_auth_event_map
+                or claimed_auth_event_map[k].event_id != ev.event_id
+            ],
+        )
 
         try:
-            updated_auth_events = await self._update_auth_events_for_auth(
+            check_state_dependent_auth_rules(event, calculated_auth_events)
+        except AuthError as e:
+            logger.warning(
+                "While checking auth of %r against room state before the event: %s",
                 event,
-                calculated_auth_event_map=calculated_auth_event_map,
-            )
-        except Exception:
-            # We don't really mind if the above fails, so lets not fail
-            # processing if it does. However, it really shouldn't fail so
-            # let's still log as an exception since we'll still want to fix
-            # any bugs.
-            logger.exception(
-                "Failed to double check auth events for %s with remote. "
-                "Ignoring failure and continuing processing of event.",
-                event.event_id,
-            )
-            updated_auth_events = None
-
-        if updated_auth_events:
-            context = await self._update_context_for_auth_events(
-                event, context, updated_auth_events
+                e,
             )
-            auth_events_for_auth = updated_auth_events
-        else:
-            auth_events_for_auth = calculated_auth_event_map
-
-        try:
-            check_state_dependent_auth_rules(event, auth_events_for_auth.values())
-        except AuthError as e:
-            logger.warning("Failed auth resolution for %r because %s", event, e)
             context.rejected = RejectedReason.AUTH_ERROR
 
-        return context
-
     async def _maybe_kick_guest_users(self, event: EventBase) -> None:
         if event.type != EventTypes.GuestAccess:
             return
@@ -1704,93 +1733,6 @@ class FederationEventHandler:
             soft_failed_event_counter.inc()
             event.internal_metadata.soft_failed = True
 
-    async def _update_auth_events_for_auth(
-        self,
-        event: EventBase,
-        calculated_auth_event_map: StateMap[EventBase],
-    ) -> Optional[StateMap[EventBase]]:
-        """Helper for _check_event_auth. See there for docs.
-
-        Checks whether a given event has the expected auth events. If it
-        doesn't then we talk to the remote server to compare state to see if
-        we can come to a consensus (e.g. if one server missed some valid
-        state).
-
-        This attempts to resolve any potential divergence of state between
-        servers, but is not essential and so failures should not block further
-        processing of the event.
-
-        Args:
-            event:
-
-            calculated_auth_event_map:
-                Our calculated auth_events based on the state of the room
-                at the event's position in the DAG.
-
-        Returns:
-            updated auth event map, or None if no changes are needed.
-
-        """
-        assert not event.internal_metadata.outlier
-
-        # check for events which are in the event's claimed auth_events, but not
-        # in our calculated event map.
-        event_auth_events = set(event.auth_event_ids())
-        different_auth = event_auth_events.difference(
-            e.event_id for e in calculated_auth_event_map.values()
-        )
-
-        if not different_auth:
-            return None
-
-        logger.info(
-            "auth_events refers to events which are not in our calculated auth "
-            "chain: %s",
-            different_auth,
-        )
-
-        # XXX: currently this checks for redactions but I'm not convinced that is
-        # necessary?
-        different_events = await self._store.get_events_as_list(different_auth)
-
-        # double-check they're all in the same room - we should already have checked
-        # this but it doesn't hurt to check again.
-        for d in different_events:
-            assert (
-                d.room_id == event.room_id
-            ), f"Event {event.event_id} refers to auth_event {d.event_id} which is in a different room"
-
-        # now we state-resolve between our own idea of the auth events, and the remote's
-        # idea of them.
-
-        local_state = calculated_auth_event_map.values()
-        remote_auth_events = dict(calculated_auth_event_map)
-        remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
-        remote_state = remote_auth_events.values()
-
-        room_version = await self._store.get_room_version_id(event.room_id)
-        new_state = await self._state_handler.resolve_events(
-            room_version, (local_state, remote_state), event
-        )
-        different_state = {
-            (d.type, d.state_key): d
-            for d in new_state.values()
-            if calculated_auth_event_map.get((d.type, d.state_key)) != d
-        }
-        if not different_state:
-            logger.info("State res returned no new state")
-            return None
-
-        logger.info(
-            "After state res: updating auth_events with new state %s",
-            different_state.values(),
-        )
-
-        # take a copy of calculated_auth_event_map before we modify it.
-        auth_events = dict(calculated_auth_event_map)
-        auth_events.update(different_state)
-        return auth_events
-
     async def _load_or_fetch_auth_events_for_event(
         self, destination: str, event: EventBase
     ) -> Collection[EventBase]:
@@ -1888,61 +1830,6 @@ class FederationEventHandler:
 
         await self._auth_and_persist_outliers(room_id, remote_auth_events)
 
-    async def _update_context_for_auth_events(
-        self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
-    ) -> EventContext:
-        """Update the state_ids in an event context after auth event resolution,
-        storing the changes as a new state group.
-
-        Args:
-            event: The event we're handling the context for
-
-            context: initial event context
-
-            auth_events: Events to update in the event context.
-
-        Returns:
-            new event context
-        """
-        # exclude the state key of the new event from the current_state in the context.
-        if event.is_state():
-            event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
-        else:
-            event_key = None
-        state_updates = {
-            k: a.event_id for k, a in auth_events.items() if k != event_key
-        }
-
-        current_state_ids = await context.get_current_state_ids()
-        current_state_ids = dict(current_state_ids)  # type: ignore
-
-        current_state_ids.update(state_updates)
-
-        prev_state_ids = await context.get_prev_state_ids()
-        prev_state_ids = dict(prev_state_ids)
-
-        prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})
-
-        # create a new state group as a delta from the existing one.
-        prev_group = context.state_group
-        state_group = await self._state_storage_controller.store_state_group(
-            event.event_id,
-            event.room_id,
-            prev_group=prev_group,
-            delta_ids=state_updates,
-            current_state_ids=current_state_ids,
-        )
-
-        return EventContext.with_state(
-            storage=self._storage_controllers,
-            state_group=state_group,
-            state_group_before_event=context.state_group_before_event,
-            state_delta_due_to_event=state_updates,
-            prev_group=prev_group,
-            delta_ids=state_updates,
-            partial_state=context.partial_state,
-        )
-
     async def _run_push_actions_and_persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
     ) -> None:
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 9f0a36652c..dcd272034d 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -24,7 +24,6 @@ from typing import (
     DefaultDict,
     Dict,
     FrozenSet,
-    Iterable,
     List,
     Mapping,
     Optional,
@@ -422,31 +421,6 @@ class StateHandler:
         )
         return result
 
-    async def resolve_events(
-        self,
-        room_version: str,
-        state_sets: Collection[Iterable[EventBase]],
-        event: EventBase,
-    ) -> StateMap[EventBase]:
-        logger.info(
-            "Resolving state for %s with %d groups", event.room_id, len(state_sets)
-        )
-        state_set_ids = [
-            {(ev.type, ev.state_key): ev.event_id for ev in st} for st in state_sets
-        ]
-
-        state_map = {ev.event_id: ev for st in state_sets for ev in st}
-
-        new_state = await self._state_resolution_handler.resolve_events_with_store(
-            event.room_id,
-            room_version,
-            state_set_ids,
-            event_map=state_map,
-            state_res_store=StateResolutionStore(self.store),
-        )
-
-        return {key: state_map[ev_id] for key, ev_id in new_state.items()}
-
     async def update_current_state(self, room_id: str) -> None:
         """Recalculates the current state for a room, and persists it.
 
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 712933f9ca..8a0bb91f40 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import List, cast
+from typing import cast
 from unittest import TestCase
 
 from twisted.test.proto_helpers import MemoryReactor
@@ -50,8 +50,6 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
         hs = self.setup_test_homeserver(federation_http_client=None)
         self.handler = hs.get_federation_handler()
         self.store = hs.get_datastores().main
-        self.state_storage_controller = hs.get_storage_controllers().state
-        self._event_auth_handler = hs.get_event_auth_handler()
         return hs
 
     def test_exchange_revoked_invite(self) -> None:
@@ -314,142 +312,6 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
             )
         self.get_success(d)
 
-    def test_backfill_floating_outlier_membership_auth(self) -> None:
-        """
-        As the local homeserver, check that we can properly process a federated
-        event from the OTHER_SERVER with auth_events that include a floating
-        membership event from the OTHER_SERVER.
-
-        Regression test, see #10439.
-        """
-        OTHER_SERVER = "otherserver"
-        OTHER_USER = "@otheruser:" + OTHER_SERVER
-
-        # create the room
-        user_id = self.register_user("kermit", "test")
-        tok = self.login("kermit", "test")
-        room_id = self.helper.create_room_as(
-            room_creator=user_id,
-            is_public=True,
-            tok=tok,
-            extra_content={
-                "preset": "public_chat",
-            },
-        )
-        room_version = self.get_success(self.store.get_room_version(room_id))
-
-        prev_event_ids = self.get_success(self.store.get_prev_events_for_room(room_id))
-        (
-            most_recent_prev_event_id,
-            most_recent_prev_event_depth,
-        ) = self.get_success(self.store.get_max_depth_of(prev_event_ids))
-        # mapping from (type, state_key) -> state_event_id
-        assert most_recent_prev_event_id is not None
-        prev_state_map = self.get_success(
-            self.state_storage_controller.get_state_ids_for_event(
-                most_recent_prev_event_id
-            )
-        )
-        # List of state event ID's
-        prev_state_ids = list(prev_state_map.values())
-        auth_event_ids = prev_state_ids
-        auth_events = list(
-            self.get_success(self.store.get_events(auth_event_ids)).values()
-        )
-
-        # build a floating outlier member state event
-        fake_prev_event_id = "$" + random_string(43)
-        member_event_dict = {
-            "type": EventTypes.Member,
-            "content": {
-                "membership": "join",
-            },
-            "state_key": OTHER_USER,
-            "room_id": room_id,
-            "sender": OTHER_USER,
-            "depth": most_recent_prev_event_depth,
-            "prev_events": [fake_prev_event_id],
-            "origin_server_ts": self.clock.time_msec(),
-            "signatures": {OTHER_SERVER: {"ed25519:key_version": "SomeSignatureHere"}},
-        }
-        builder = self.hs.get_event_builder_factory().for_room_version(
-            room_version, member_event_dict
-        )
-        member_event = self.get_success(
-            builder.build(
-                prev_event_ids=member_event_dict["prev_events"],
-                auth_event_ids=self._event_auth_handler.compute_auth_events(
-                    builder,
-                    prev_state_map,
-                    for_verification=False,
-                ),
-                depth=member_event_dict["depth"],
-            )
-        )
-        # Override the signature added from "test" homeserver that we created the event with
-        member_event.signatures = member_event_dict["signatures"]
-
-        # Add the new member_event to the StateMap
-        updated_state_map = dict(prev_state_map)
-        updated_state_map[
-            (member_event.type, member_event.state_key)
-        ] = member_event.event_id
-        auth_events.append(member_event)
-
-        # build and send an event authed based on the member event
-        message_event_dict = {
-            "type": EventTypes.Message,
-            "content": {},
-            "room_id": room_id,
-            "sender": OTHER_USER,
-            "depth": most_recent_prev_event_depth,
-            "prev_events": prev_event_ids.copy(),
-            "origin_server_ts": self.clock.time_msec(),
-            "signatures": {OTHER_SERVER: {"ed25519:key_version": "SomeSignatureHere"}},
-        }
-        builder = self.hs.get_event_builder_factory().for_room_version(
-            room_version, message_event_dict
-        )
-        message_event = self.get_success(
-            builder.build(
-                prev_event_ids=message_event_dict["prev_events"],
-                auth_event_ids=self._event_auth_handler.compute_auth_events(
-                    builder,
-                    updated_state_map,
-                    for_verification=False,
-                ),
-                depth=message_event_dict["depth"],
-            )
-        )
-        # Override the signature added from "test" homeserver that we created the event with
-        message_event.signatures = message_event_dict["signatures"]
-
-        # Stub the /event_auth response from the OTHER_SERVER
-        async def get_event_auth(
-            destination: str, room_id: str, event_id: str
-        ) -> List[EventBase]:
-            return [
-                event_from_pdu_json(ae.get_pdu_json(), room_version=room_version)
-                for ae in auth_events
-            ]
-
-        self.handler.federation_client.get_event_auth = get_event_auth  # type: ignore[assignment]
-
-        with LoggingContext("receive_pdu"):
-            # Fake the OTHER_SERVER federating the message event over to our local homeserver
-            d = run_in_background(
-                self.hs.get_federation_event_handler().on_receive_pdu,
-                OTHER_SERVER,
-                message_event,
-            )
-        self.get_success(d)
-
-        # Now try and get the events on our local homeserver
-        stored_event = self.get_success(
-            self.store.get_event(message_event.event_id, allow_none=True)
-        )
-        self.assertTrue(stored_event is not None)
-
     @unittest.override_config(
         {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
     )
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 5eb0f243f7..9a48e9286f 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -21,7 +21,6 @@ from synapse.api.constants import EventTypes, LoginType, Membership
 from synapse.api.errors import SynapseError
 from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
 from synapse.rest import admin
 from synapse.rest.client import account, login, profile, room
@@ -113,14 +112,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
 
         # Have this homeserver skip event auth checks. This is necessary due to
         # event auth checks ensuring that events were signed by the sender's homeserver.
-        async def _check_event_auth(
-            origin: str,
-            event: EventBase,
-            context: EventContext,
-            *args: Any,
-            **kwargs: Any,
-        ) -> EventContext:
-            return context
+        async def _check_event_auth(origin: Any, event: Any, context: Any) -> None:
+            pass
 
         hs.get_federation_event_handler()._check_event_auth = _check_event_auth  # type: ignore[assignment]
 
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 0cbef70bfa..779fad1f63 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -81,12 +81,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         self.handler = self.homeserver.get_federation_handler()
         federation_event_handler = self.homeserver.get_federation_event_handler()
 
-        async def _check_event_auth(
-            origin,
-            event,
-            context,
-        ):
-            return context
+        async def _check_event_auth(origin, event, context):
+            pass
 
         federation_event_handler._check_event_auth = _check_event_auth
         self.client = self.homeserver.get_federation_client()