summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/9763.feature1
-rw-r--r--changelog.d/9800.feature1
-rw-r--r--synapse/handlers/event_auth.py82
-rw-r--r--synapse/handlers/federation.py212
-rw-r--r--synapse/handlers/room_member.py62
-rw-r--r--synapse/server.py5
-rw-r--r--tests/test_federation.py6
7 files changed, 131 insertions, 238 deletions
diff --git a/changelog.d/9763.feature b/changelog.d/9763.feature
deleted file mode 100644
index 9404ad2fc0..0000000000
--- a/changelog.d/9763.feature
+++ /dev/null
@@ -1 +0,0 @@
-Update experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership.
diff --git a/changelog.d/9800.feature b/changelog.d/9800.feature
deleted file mode 100644
index 9404ad2fc0..0000000000
--- a/changelog.d/9800.feature
+++ /dev/null
@@ -1 +0,0 @@
-Update experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership.
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
deleted file mode 100644
index 06da1a93d9..0000000000
--- a/synapse/handlers/event_auth.py
+++ /dev/null
@@ -1,82 +0,0 @@
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import TYPE_CHECKING
-
-from synapse.api.constants import EventTypes, JoinRules
-from synapse.api.room_versions import RoomVersion
-from synapse.types import StateMap
-
-if TYPE_CHECKING:
-    from synapse.server import HomeServer
-
-
-class EventAuthHandler:
-    def __init__(self, hs: "HomeServer"):
-        self._store = hs.get_datastore()
-
-    async def can_join_without_invite(
-        self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str
-    ) -> bool:
-        """
-        Check whether a user can join a room without an invite.
-
-        When joining a room with restricted joined rules (as defined in MSC3083),
-        the membership of spaces must be checked during join.
-
-        Args:
-            state_ids: The state of the room as it currently is.
-            room_version: The room version of the room being joined.
-            user_id: The user joining the room.
-
-        Returns:
-            True if the user can join the room, false otherwise.
-        """
-        # This only applies to room versions which support the new join rule.
-        if not room_version.msc3083_join_rules:
-            return True
-
-        # If there's no join rule, then it defaults to public (so this doesn't apply).
-        join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
-        if not join_rules_event_id:
-            return True
-
-        # If the join rule is not restricted, this doesn't apply.
-        join_rules_event = await self._store.get_event(join_rules_event_id)
-        if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED:
-            return True
-
-        # If allowed is of the wrong form, then only allow invited users.
-        allowed_spaces = join_rules_event.content.get("allow", [])
-        if not isinstance(allowed_spaces, list):
-            return False
-
-        # Get the list of joined rooms and see if there's an overlap.
-        joined_rooms = await self._store.get_rooms_for_user(user_id)
-
-        # Pull out the other room IDs, invalid data gets filtered.
-        for space in allowed_spaces:
-            if not isinstance(space, dict):
-                continue
-
-            space_id = space.get("space")
-            if not isinstance(space_id, str):
-                continue
-
-            # The user was joined to one of the spaces specified, they can join
-            # this room!
-            if space_id in joined_rooms:
-                return True
-
-        # The user was not in any of the required spaces.
-        return False
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0c9bdf51a4..fe1d83f6b8 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -103,7 +103,7 @@ logger = logging.getLogger(__name__)
 
 @attr.s(slots=True)
 class _NewEventInfo:
-    """Holds information about a received event, ready for passing to _auth_and_persist_events
+    """Holds information about a received event, ready for passing to _handle_new_events
 
     Attributes:
         event: the received event
@@ -146,7 +146,6 @@ class FederationHandler(BaseHandler):
         self.is_mine_id = hs.is_mine_id
         self.spam_checker = hs.get_spam_checker()
         self.event_creation_handler = hs.get_event_creation_handler()
-        self.event_auth_handler = hs.get_event_auth_handler()
         self._message_handler = hs.get_message_handler()
         self._server_notices_mxid = hs.config.server_notices_mxid
         self.config = hs.config
@@ -808,10 +807,7 @@ class FederationHandler(BaseHandler):
         logger.debug("Processing event: %s", event)
 
         try:
-            context = await self.state_handler.compute_event_context(
-                event, old_state=state
-            )
-            await self._auth_and_persist_event(origin, event, context, state=state)
+            await self._handle_new_event(origin, event, state=state)
         except AuthError as e:
             raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
 
@@ -1014,9 +1010,7 @@ class FederationHandler(BaseHandler):
             )
 
         if ev_infos:
-            await self._auth_and_persist_events(
-                dest, room_id, ev_infos, backfilled=True
-            )
+            await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
 
         # Step 2: Persist the rest of the events in the chunk one by one
         events.sort(key=lambda e: e.depth)
@@ -1029,12 +1023,10 @@ class FederationHandler(BaseHandler):
             # non-outliers
             assert not event.internal_metadata.is_outlier()
 
-            context = await self.state_handler.compute_event_context(event)
-
             # We store these one at a time since each event depends on the
             # previous to work out the state.
             # TODO: We can probably do something more clever here.
-            await self._auth_and_persist_event(dest, event, context, backfilled=True)
+            await self._handle_new_event(dest, event, backfilled=True)
 
         return events
 
@@ -1368,7 +1360,7 @@ class FederationHandler(BaseHandler):
 
             event_infos.append(_NewEventInfo(event, None, auth))
 
-        await self._auth_and_persist_events(
+        await self._handle_new_events(
             destination,
             room_id,
             event_infos,
@@ -1674,47 +1666,16 @@ class FederationHandler(BaseHandler):
         # would introduce the danger of backwards-compatibility problems.
         event.internal_metadata.send_on_behalf_of = origin
 
-        # Calculate the event context.
-        context = await self.state_handler.compute_event_context(event)
-
-        # Get the state before the new event.
-        prev_state_ids = await context.get_prev_state_ids()
-
-        # Check if the user is already in the room or invited to the room.
-        user_id = event.state_key
-        prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
-        newly_joined = True
-        is_invite = False
-        if prev_member_event_id:
-            prev_member_event = await self.store.get_event(prev_member_event_id)
-            newly_joined = prev_member_event.membership != Membership.JOIN
-            is_invite = prev_member_event.membership == Membership.INVITE
-
-        # If the member is not already in the room, and not invited, check if
-        # they should be allowed access via membership in a space.
-        if (
-            newly_joined
-            and not is_invite
-            and not await self.event_auth_handler.can_join_without_invite(
-                prev_state_ids,
-                event.room_version,
-                user_id,
-            )
-        ):
-            raise SynapseError(
-                400,
-                "You do not belong to any of the required spaces to join this room.",
-            )
-
-        # Persist the event.
-        await self._auth_and_persist_event(origin, event, context)
+        context = await self._handle_new_event(origin, event)
 
         logger.debug(
-            "on_send_join_request: After _auth_and_persist_event: %s, sigs: %s",
+            "on_send_join_request: After _handle_new_event: %s, sigs: %s",
             event.event_id,
             event.signatures,
         )
 
+        prev_state_ids = await context.get_prev_state_ids()
+
         state_ids = list(prev_state_ids.values())
         auth_chain = await self.store.get_auth_chain(event.room_id, state_ids)
 
@@ -1917,11 +1878,10 @@ class FederationHandler(BaseHandler):
 
         event.internal_metadata.outlier = False
 
-        context = await self.state_handler.compute_event_context(event)
-        await self._auth_and_persist_event(origin, event, context)
+        await self._handle_new_event(origin, event)
 
         logger.debug(
-            "on_send_leave_request: After _auth_and_persist_event: %s, sigs: %s",
+            "on_send_leave_request: After _handle_new_event: %s, sigs: %s",
             event.event_id,
             event.signatures,
         )
@@ -2029,44 +1989,16 @@ class FederationHandler(BaseHandler):
     async def get_min_depth_for_context(self, context: str) -> int:
         return await self.store.get_min_depth(context)
 
-    async def _auth_and_persist_event(
+    async def _handle_new_event(
         self,
         origin: str,
         event: EventBase,
-        context: EventContext,
         state: Optional[Iterable[EventBase]] = None,
         auth_events: Optional[MutableStateMap[EventBase]] = None,
         backfilled: bool = False,
-    ) -> None:
-        """
-        Process an event by performing auth checks and then persisting to the database.
-
-        Args:
-            origin: The host the event originates from.
-            event: The event itself.
-            context:
-                The event context.
-
-                NB that this function potentially modifies it.
-            state:
-                The state events used to auth the event. If this is not provided
-                the current state events will be used.
-            auth_events:
-                Map from (event_type, state_key) to event
-
-                Normally, our calculated auth_events based on the state of the room
-                at the event's position in the DAG, though occasionally (eg if the
-                event is an outlier), may be the auth events claimed by the remote
-                server.
-            backfilled: True if the event was backfilled.
-        """
-        context = await self._check_event_auth(
-            origin,
-            event,
-            context,
-            state=state,
-            auth_events=auth_events,
-            backfilled=backfilled,
+    ) -> EventContext:
+        context = await self._prep_event(
+            origin, event, state=state, auth_events=auth_events, backfilled=backfilled
         )
 
         try:
@@ -2088,7 +2020,9 @@ class FederationHandler(BaseHandler):
             )
             raise
 
-    async def _auth_and_persist_events(
+        return context
+
+    async def _handle_new_events(
         self,
         origin: str,
         room_id: str,
@@ -2106,13 +2040,9 @@ class FederationHandler(BaseHandler):
         async def prep(ev_info: _NewEventInfo):
             event = ev_info.event
             with nested_logging_context(suffix=event.event_id):
-                res = await self.state_handler.compute_event_context(
-                    event, old_state=ev_info.state
-                )
-                res = await self._check_event_auth(
+                res = await self._prep_event(
                     origin,
                     event,
-                    res,
                     state=ev_info.state,
                     auth_events=ev_info.auth_events,
                     backfilled=backfilled,
@@ -2247,6 +2177,49 @@ class FederationHandler(BaseHandler):
             room_id, [(event, new_event_context)]
         )
 
+    async def _prep_event(
+        self,
+        origin: str,
+        event: EventBase,
+        state: Optional[Iterable[EventBase]],
+        auth_events: Optional[MutableStateMap[EventBase]],
+        backfilled: bool,
+    ) -> EventContext:
+        context = await self.state_handler.compute_event_context(event, old_state=state)
+
+        if not auth_events:
+            prev_state_ids = await context.get_prev_state_ids()
+            auth_events_ids = self.auth.compute_auth_events(
+                event, prev_state_ids, for_verification=True
+            )
+            auth_events_x = await self.store.get_events(auth_events_ids)
+            auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
+
+        # This is a hack to fix some old rooms where the initial join event
+        # didn't reference the create event in its auth events.
+        if event.type == EventTypes.Member and not event.auth_event_ids():
+            if len(event.prev_event_ids()) == 1 and event.depth < 5:
+                c = await self.store.get_event(
+                    event.prev_event_ids()[0], allow_none=True
+                )
+                if c and c.type == EventTypes.Create:
+                    auth_events[(c.type, c.state_key)] = c
+
+        context = await self.do_auth(origin, event, context, auth_events=auth_events)
+
+        if not context.rejected:
+            await self._check_for_soft_fail(event, state, backfilled)
+
+        if event.type == EventTypes.GuestAccess and not context.rejected:
+            await self.maybe_kick_guest_users(event)
+
+        # If we are going to send this event over federation we precaclculate
+        # the joined hosts.
+        if event.internal_metadata.get_send_on_behalf_of():
+            await self.event_creation_handler.cache_joined_hosts_for_event(event)
+
+        return context
+
     async def _check_for_soft_fail(
         self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
     ) -> None:
@@ -2357,28 +2330,19 @@ class FederationHandler(BaseHandler):
 
         return missing_events
 
-    async def _check_event_auth(
+    async def do_auth(
         self,
         origin: str,
         event: EventBase,
         context: EventContext,
-        state: Optional[Iterable[EventBase]],
-        auth_events: Optional[MutableStateMap[EventBase]],
-        backfilled: bool,
+        auth_events: MutableStateMap[EventBase],
     ) -> EventContext:
         """
-        Checks whether an event should be rejected (for failing auth checks).
 
         Args:
-            origin: The host the event originates from.
-            event: The event itself.
+            origin:
+            event:
             context:
-                The event context.
-
-                NB that this function potentially modifies it.
-            state:
-                The state events used to auth the event. If this is not provided
-                the current state events will be used.
             auth_events:
                 Map from (event_type, state_key) to event
 
@@ -2388,32 +2352,12 @@ class FederationHandler(BaseHandler):
                 server.
 
                 Also NB that this function adds entries to it.
-            backfilled: True if the event was backfilled.
-
         Returns:
-            The updated context object.
+            updated context object
         """
         room_version = await self.store.get_room_version_id(event.room_id)
         room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
 
-        if not auth_events:
-            prev_state_ids = await context.get_prev_state_ids()
-            auth_events_ids = self.auth.compute_auth_events(
-                event, prev_state_ids, for_verification=True
-            )
-            auth_events_x = await self.store.get_events(auth_events_ids)
-            auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
-
-        # This is a hack to fix some old rooms where the initial join event
-        # didn't reference the create event in its auth events.
-        if event.type == EventTypes.Member and not event.auth_event_ids():
-            if len(event.prev_event_ids()) == 1 and event.depth < 5:
-                c = await self.store.get_event(
-                    event.prev_event_ids()[0], allow_none=True
-                )
-                if c and c.type == EventTypes.Create:
-                    auth_events[(c.type, c.state_key)] = c
-
         try:
             context = await self._update_auth_events_and_context_for_auth(
                 origin, event, context, auth_events
@@ -2435,17 +2379,6 @@ class FederationHandler(BaseHandler):
             logger.warning("Failed auth resolution for %r because %s", event, e)
             context.rejected = RejectedReason.AUTH_ERROR
 
-        if not context.rejected:
-            await self._check_for_soft_fail(event, state, backfilled)
-
-        if event.type == EventTypes.GuestAccess and not context.rejected:
-            await self.maybe_kick_guest_users(event)
-
-        # If we are going to send this event over federation we precaclculate
-        # the joined hosts.
-        if event.internal_metadata.get_send_on_behalf_of():
-            await self.event_creation_handler.cache_joined_hosts_for_event(event)
-
         return context
 
     async def _update_auth_events_and_context_for_auth(
@@ -2455,7 +2388,7 @@ class FederationHandler(BaseHandler):
         context: EventContext,
         auth_events: MutableStateMap[EventBase],
     ) -> EventContext:
-        """Helper for _check_event_auth. See there for docs.
+        """Helper for do_auth. See there for docs.
 
         Checks whether a given event has the expected auth events. If it
         doesn't then we talk to the remote server to compare state to see if
@@ -2535,14 +2468,9 @@ class FederationHandler(BaseHandler):
                         e.internal_metadata.outlier = True
 
                         logger.debug(
-                            "_check_event_auth %s missing_auth: %s",
-                            event.event_id,
-                            e.event_id,
-                        )
-                        context = await self.state_handler.compute_event_context(e)
-                        await self._auth_and_persist_event(
-                            origin, e, context, auth_events=auth
+                            "do_auth %s missing_auth: %s", event.event_id, e.event_id
                         )
+                        await self._handle_new_event(origin, e, auth_events=auth)
 
                         if e.event_id in event_auth_events:
                             auth_events[(e.type, e.state_key)] = e
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 2c5bada1d8..2bbfac6471 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -19,7 +19,7 @@ from http import HTTPStatus
 from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
 from synapse import types
-from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership
 from synapse.api.errors import (
     AuthError,
     Codes,
@@ -28,6 +28,7 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.api.ratelimiting import Ratelimiter
+from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
@@ -63,7 +64,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         self.profile_handler = hs.get_profile_handler()
         self.event_creation_handler = hs.get_event_creation_handler()
         self.account_data_handler = hs.get_account_data_handler()
-        self.event_auth_handler = hs.get_event_auth_handler()
 
         self.member_linearizer = Linearizer(name="member")
 
@@ -178,6 +178,62 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id)
 
+    async def _can_join_without_invite(
+        self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str
+    ) -> bool:
+        """
+        Check whether a user can join a room without an invite.
+
+        When joining a room with restricted joined rules (as defined in MSC3083),
+        the membership of spaces must be checked during join.
+
+        Args:
+            state_ids: The state of the room as it currently is.
+            room_version: The room version of the room being joined.
+            user_id: The user joining the room.
+
+        Returns:
+            True if the user can join the room, false otherwise.
+        """
+        # This only applies to room versions which support the new join rule.
+        if not room_version.msc3083_join_rules:
+            return True
+
+        # If there's no join rule, then it defaults to public (so this doesn't apply).
+        join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
+        if not join_rules_event_id:
+            return True
+
+        # If the join rule is not restricted, this doesn't apply.
+        join_rules_event = await self.store.get_event(join_rules_event_id)
+        if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED:
+            return True
+
+        # If allowed is of the wrong form, then only allow invited users.
+        allowed_spaces = join_rules_event.content.get("allow", [])
+        if not isinstance(allowed_spaces, list):
+            return False
+
+        # Get the list of joined rooms and see if there's an overlap.
+        joined_rooms = await self.store.get_rooms_for_user(user_id)
+
+        # Pull out the other room IDs, invalid data gets filtered.
+        for space in allowed_spaces:
+            if not isinstance(space, dict):
+                continue
+
+            space_id = space.get("space")
+            if not isinstance(space_id, str):
+                continue
+
+            # The user was joined to one of the spaces specified, they can join
+            # this room!
+            if space_id in joined_rooms:
+                return True
+
+        # The user was not in any of the required spaces.
+        return False
+
     async def _local_membership_update(
         self,
         requester: Requester,
@@ -246,7 +302,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             if (
                 newly_joined
                 and not user_is_invited
-                and not await self.event_auth_handler.can_join_without_invite(
+                and not await self._can_join_without_invite(
                     prev_state_ids, event.room_version, user_id
                 )
             ):
diff --git a/synapse/server.py b/synapse/server.py
index 045b8f3fca..95a2cd2e5d 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -77,7 +77,6 @@ from synapse.handlers.devicemessage import DeviceMessageHandler
 from synapse.handlers.directory import DirectoryHandler
 from synapse.handlers.e2e_keys import E2eKeysHandler
 from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler
-from synapse.handlers.event_auth import EventAuthHandler
 from synapse.handlers.events import EventHandler, EventStreamHandler
 from synapse.handlers.federation import FederationHandler
 from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler
@@ -751,10 +750,6 @@ class HomeServer(metaclass=abc.ABCMeta):
         return SpaceSummaryHandler(self)
 
     @cache_in_self
-    def get_event_auth_handler(self) -> EventAuthHandler:
-        return EventAuthHandler(self)
-
-    @cache_in_self
     def get_external_cache(self) -> ExternalCache:
         return ExternalCache(self)
 
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 0a3a996ec1..86a44a13da 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -75,10 +75,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         )
 
         self.handler = self.homeserver.get_federation_handler()
-        self.handler._check_event_auth = (
-            lambda origin, event, context, state, auth_events, backfilled: succeed(
-                context
-            )
+        self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
+            context
         )
         self.client = self.homeserver.get_federation_client()
         self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(