From 4b965c862dc66c0da5d3240add70e9b5f0aa720b Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Wed, 14 Apr 2021 16:34:27 +0200 Subject: Remove redundant "coding: utf-8" lines (#9786) Part of #9744 Removes all redundant `# -*- coding: utf-8 -*-` lines from files, as python 3 automatically reads source code as utf-8 now. `Signed-off-by: Jonathan de Jong ` --- tests/test_federation.py | 1 - 1 file changed, 1 deletion(-) (limited to 'tests/test_federation.py') diff --git a/tests/test_federation.py b/tests/test_federation.py index 8928597d17..86a44a13da 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); -- cgit 1.5.1 From cc51aaaa7adb0ec2235e027b5184ebda9b660ec4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 14 Apr 2021 12:32:20 -0400 Subject: Check for space membership during a remote join of a restricted room. (#9763) When receiving a /send_join request for a room with join rules set to 'restricted', check if the user is a member of the spaces defined in the 'allow' key of the join rules. This only applies to an experimental room version, as defined in MSC3083. --- changelog.d/9763.feature | 1 + changelog.d/9800.feature | 1 + synapse/handlers/event_auth.py | 82 ++++++++++++++++ synapse/handlers/federation.py | 212 +++++++++++++++++++++++++++------------- synapse/handlers/room_member.py | 62 +----------- synapse/server.py | 5 + tests/test_federation.py | 6 +- 7 files changed, 238 insertions(+), 131 deletions(-) create mode 100644 changelog.d/9763.feature create mode 100644 changelog.d/9800.feature create mode 100644 synapse/handlers/event_auth.py (limited to 'tests/test_federation.py') diff --git a/changelog.d/9763.feature b/changelog.d/9763.feature new file mode 100644 index 0000000000..9404ad2fc0 --- /dev/null +++ b/changelog.d/9763.feature @@ -0,0 +1 @@ +Update experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership. diff --git a/changelog.d/9800.feature b/changelog.d/9800.feature new file mode 100644 index 0000000000..9404ad2fc0 --- /dev/null +++ b/changelog.d/9800.feature @@ -0,0 +1 @@ +Update experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership. diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py new file mode 100644 index 0000000000..06da1a93d9 --- /dev/null +++ b/synapse/handlers/event_auth.py @@ -0,0 +1,82 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from synapse.api.constants import EventTypes, JoinRules +from synapse.api.room_versions import RoomVersion +from synapse.types import StateMap + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class EventAuthHandler: + def __init__(self, hs: "HomeServer"): + self._store = hs.get_datastore() + + async def can_join_without_invite( + self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str + ) -> bool: + """ + Check whether a user can join a room without an invite. + + When joining a room with restricted joined rules (as defined in MSC3083), + the membership of spaces must be checked during join. + + Args: + state_ids: The state of the room as it currently is. + room_version: The room version of the room being joined. + user_id: The user joining the room. + + Returns: + True if the user can join the room, false otherwise. + """ + # This only applies to room versions which support the new join rule. + if not room_version.msc3083_join_rules: + return True + + # If there's no join rule, then it defaults to public (so this doesn't apply). + join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None) + if not join_rules_event_id: + return True + + # If the join rule is not restricted, this doesn't apply. + join_rules_event = await self._store.get_event(join_rules_event_id) + if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED: + return True + + # If allowed is of the wrong form, then only allow invited users. + allowed_spaces = join_rules_event.content.get("allow", []) + if not isinstance(allowed_spaces, list): + return False + + # Get the list of joined rooms and see if there's an overlap. + joined_rooms = await self._store.get_rooms_for_user(user_id) + + # Pull out the other room IDs, invalid data gets filtered. + for space in allowed_spaces: + if not isinstance(space, dict): + continue + + space_id = space.get("space") + if not isinstance(space_id, str): + continue + + # The user was joined to one of the spaces specified, they can join + # this room! + if space_id in joined_rooms: + return True + + # The user was not in any of the required spaces. + return False diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index fe1d83f6b8..0c9bdf51a4 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -103,7 +103,7 @@ logger = logging.getLogger(__name__) @attr.s(slots=True) class _NewEventInfo: - """Holds information about a received event, ready for passing to _handle_new_events + """Holds information about a received event, ready for passing to _auth_and_persist_events Attributes: event: the received event @@ -146,6 +146,7 @@ class FederationHandler(BaseHandler): self.is_mine_id = hs.is_mine_id self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() + self.event_auth_handler = hs.get_event_auth_handler() self._message_handler = hs.get_message_handler() self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config @@ -807,7 +808,10 @@ class FederationHandler(BaseHandler): logger.debug("Processing event: %s", event) try: - await self._handle_new_event(origin, event, state=state) + context = await self.state_handler.compute_event_context( + event, old_state=state + ) + await self._auth_and_persist_event(origin, event, context, state=state) except AuthError as e: raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) @@ -1010,7 +1014,9 @@ class FederationHandler(BaseHandler): ) if ev_infos: - await self._handle_new_events(dest, room_id, ev_infos, backfilled=True) + await self._auth_and_persist_events( + dest, room_id, ev_infos, backfilled=True + ) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -1023,10 +1029,12 @@ class FederationHandler(BaseHandler): # non-outliers assert not event.internal_metadata.is_outlier() + context = await self.state_handler.compute_event_context(event) + # We store these one at a time since each event depends on the # previous to work out the state. # TODO: We can probably do something more clever here. - await self._handle_new_event(dest, event, backfilled=True) + await self._auth_and_persist_event(dest, event, context, backfilled=True) return events @@ -1360,7 +1368,7 @@ class FederationHandler(BaseHandler): event_infos.append(_NewEventInfo(event, None, auth)) - await self._handle_new_events( + await self._auth_and_persist_events( destination, room_id, event_infos, @@ -1666,16 +1674,47 @@ class FederationHandler(BaseHandler): # would introduce the danger of backwards-compatibility problems. event.internal_metadata.send_on_behalf_of = origin - context = await self._handle_new_event(origin, event) + # Calculate the event context. + context = await self.state_handler.compute_event_context(event) + + # Get the state before the new event. + prev_state_ids = await context.get_prev_state_ids() + + # Check if the user is already in the room or invited to the room. + user_id = event.state_key + prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) + newly_joined = True + is_invite = False + if prev_member_event_id: + prev_member_event = await self.store.get_event(prev_member_event_id) + newly_joined = prev_member_event.membership != Membership.JOIN + is_invite = prev_member_event.membership == Membership.INVITE + + # If the member is not already in the room, and not invited, check if + # they should be allowed access via membership in a space. + if ( + newly_joined + and not is_invite + and not await self.event_auth_handler.can_join_without_invite( + prev_state_ids, + event.room_version, + user_id, + ) + ): + raise SynapseError( + 400, + "You do not belong to any of the required spaces to join this room.", + ) + + # Persist the event. + await self._auth_and_persist_event(origin, event, context) logger.debug( - "on_send_join_request: After _handle_new_event: %s, sigs: %s", + "on_send_join_request: After _auth_and_persist_event: %s, sigs: %s", event.event_id, event.signatures, ) - prev_state_ids = await context.get_prev_state_ids() - state_ids = list(prev_state_ids.values()) auth_chain = await self.store.get_auth_chain(event.room_id, state_ids) @@ -1878,10 +1917,11 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = False - await self._handle_new_event(origin, event) + context = await self.state_handler.compute_event_context(event) + await self._auth_and_persist_event(origin, event, context) logger.debug( - "on_send_leave_request: After _handle_new_event: %s, sigs: %s", + "on_send_leave_request: After _auth_and_persist_event: %s, sigs: %s", event.event_id, event.signatures, ) @@ -1989,16 +2029,44 @@ class FederationHandler(BaseHandler): async def get_min_depth_for_context(self, context: str) -> int: return await self.store.get_min_depth(context) - async def _handle_new_event( + async def _auth_and_persist_event( self, origin: str, event: EventBase, + context: EventContext, state: Optional[Iterable[EventBase]] = None, auth_events: Optional[MutableStateMap[EventBase]] = None, backfilled: bool = False, - ) -> EventContext: - context = await self._prep_event( - origin, event, state=state, auth_events=auth_events, backfilled=backfilled + ) -> None: + """ + Process an event by performing auth checks and then persisting to the database. + + Args: + origin: The host the event originates from. + event: The event itself. + context: + The event context. + + NB that this function potentially modifies it. + state: + The state events used to auth the event. If this is not provided + the current state events will be used. + auth_events: + Map from (event_type, state_key) to event + + Normally, our calculated auth_events based on the state of the room + at the event's position in the DAG, though occasionally (eg if the + event is an outlier), may be the auth events claimed by the remote + server. + backfilled: True if the event was backfilled. + """ + context = await self._check_event_auth( + origin, + event, + context, + state=state, + auth_events=auth_events, + backfilled=backfilled, ) try: @@ -2020,9 +2088,7 @@ class FederationHandler(BaseHandler): ) raise - return context - - async def _handle_new_events( + async def _auth_and_persist_events( self, origin: str, room_id: str, @@ -2040,9 +2106,13 @@ class FederationHandler(BaseHandler): async def prep(ev_info: _NewEventInfo): event = ev_info.event with nested_logging_context(suffix=event.event_id): - res = await self._prep_event( + res = await self.state_handler.compute_event_context( + event, old_state=ev_info.state + ) + res = await self._check_event_auth( origin, event, + res, state=ev_info.state, auth_events=ev_info.auth_events, backfilled=backfilled, @@ -2177,49 +2247,6 @@ class FederationHandler(BaseHandler): room_id, [(event, new_event_context)] ) - async def _prep_event( - self, - origin: str, - event: EventBase, - state: Optional[Iterable[EventBase]], - auth_events: Optional[MutableStateMap[EventBase]], - backfilled: bool, - ) -> EventContext: - context = await self.state_handler.compute_event_context(event, old_state=state) - - if not auth_events: - prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = self.auth.compute_auth_events( - event, prev_state_ids, for_verification=True - ) - auth_events_x = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} - - # This is a hack to fix some old rooms where the initial join event - # didn't reference the create event in its auth events. - if event.type == EventTypes.Member and not event.auth_event_ids(): - if len(event.prev_event_ids()) == 1 and event.depth < 5: - c = await self.store.get_event( - event.prev_event_ids()[0], allow_none=True - ) - if c and c.type == EventTypes.Create: - auth_events[(c.type, c.state_key)] = c - - context = await self.do_auth(origin, event, context, auth_events=auth_events) - - if not context.rejected: - await self._check_for_soft_fail(event, state, backfilled) - - if event.type == EventTypes.GuestAccess and not context.rejected: - await self.maybe_kick_guest_users(event) - - # If we are going to send this event over federation we precaclculate - # the joined hosts. - if event.internal_metadata.get_send_on_behalf_of(): - await self.event_creation_handler.cache_joined_hosts_for_event(event) - - return context - async def _check_for_soft_fail( self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool ) -> None: @@ -2330,19 +2357,28 @@ class FederationHandler(BaseHandler): return missing_events - async def do_auth( + async def _check_event_auth( self, origin: str, event: EventBase, context: EventContext, - auth_events: MutableStateMap[EventBase], + state: Optional[Iterable[EventBase]], + auth_events: Optional[MutableStateMap[EventBase]], + backfilled: bool, ) -> EventContext: """ + Checks whether an event should be rejected (for failing auth checks). Args: - origin: - event: + origin: The host the event originates from. + event: The event itself. context: + The event context. + + NB that this function potentially modifies it. + state: + The state events used to auth the event. If this is not provided + the current state events will be used. auth_events: Map from (event_type, state_key) to event @@ -2352,12 +2388,32 @@ class FederationHandler(BaseHandler): server. Also NB that this function adds entries to it. + backfilled: True if the event was backfilled. + Returns: - updated context object + The updated context object. """ room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + if not auth_events: + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = self.auth.compute_auth_events( + event, prev_state_ids, for_verification=True + ) + auth_events_x = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} + + # This is a hack to fix some old rooms where the initial join event + # didn't reference the create event in its auth events. + if event.type == EventTypes.Member and not event.auth_event_ids(): + if len(event.prev_event_ids()) == 1 and event.depth < 5: + c = await self.store.get_event( + event.prev_event_ids()[0], allow_none=True + ) + if c and c.type == EventTypes.Create: + auth_events[(c.type, c.state_key)] = c + try: context = await self._update_auth_events_and_context_for_auth( origin, event, context, auth_events @@ -2379,6 +2435,17 @@ class FederationHandler(BaseHandler): logger.warning("Failed auth resolution for %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR + if not context.rejected: + await self._check_for_soft_fail(event, state, backfilled) + + if event.type == EventTypes.GuestAccess and not context.rejected: + await self.maybe_kick_guest_users(event) + + # If we are going to send this event over federation we precaclculate + # the joined hosts. + if event.internal_metadata.get_send_on_behalf_of(): + await self.event_creation_handler.cache_joined_hosts_for_event(event) + return context async def _update_auth_events_and_context_for_auth( @@ -2388,7 +2455,7 @@ class FederationHandler(BaseHandler): context: EventContext, auth_events: MutableStateMap[EventBase], ) -> EventContext: - """Helper for do_auth. See there for docs. + """Helper for _check_event_auth. See there for docs. Checks whether a given event has the expected auth events. If it doesn't then we talk to the remote server to compare state to see if @@ -2468,9 +2535,14 @@ class FederationHandler(BaseHandler): e.internal_metadata.outlier = True logger.debug( - "do_auth %s missing_auth: %s", event.event_id, e.event_id + "_check_event_auth %s missing_auth: %s", + event.event_id, + e.event_id, + ) + context = await self.state_handler.compute_event_context(e) + await self._auth_and_persist_event( + origin, e, context, auth_events=auth ) - await self._handle_new_event(origin, e, auth_events=auth) if e.event_id in event_auth_events: auth_events[(e.type, e.state_key)] = e diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 2bbfac6471..2c5bada1d8 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -19,7 +19,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from synapse import types -from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -28,7 +28,6 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.ratelimiting import Ratelimiter -from synapse.api.room_versions import RoomVersion from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID @@ -64,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.profile_handler = hs.get_profile_handler() self.event_creation_handler = hs.get_event_creation_handler() self.account_data_handler = hs.get_account_data_handler() + self.event_auth_handler = hs.get_event_auth_handler() self.member_linearizer = Linearizer(name="member") @@ -178,62 +178,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id) - async def _can_join_without_invite( - self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str - ) -> bool: - """ - Check whether a user can join a room without an invite. - - When joining a room with restricted joined rules (as defined in MSC3083), - the membership of spaces must be checked during join. - - Args: - state_ids: The state of the room as it currently is. - room_version: The room version of the room being joined. - user_id: The user joining the room. - - Returns: - True if the user can join the room, false otherwise. - """ - # This only applies to room versions which support the new join rule. - if not room_version.msc3083_join_rules: - return True - - # If there's no join rule, then it defaults to public (so this doesn't apply). - join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None) - if not join_rules_event_id: - return True - - # If the join rule is not restricted, this doesn't apply. - join_rules_event = await self.store.get_event(join_rules_event_id) - if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED: - return True - - # If allowed is of the wrong form, then only allow invited users. - allowed_spaces = join_rules_event.content.get("allow", []) - if not isinstance(allowed_spaces, list): - return False - - # Get the list of joined rooms and see if there's an overlap. - joined_rooms = await self.store.get_rooms_for_user(user_id) - - # Pull out the other room IDs, invalid data gets filtered. - for space in allowed_spaces: - if not isinstance(space, dict): - continue - - space_id = space.get("space") - if not isinstance(space_id, str): - continue - - # The user was joined to one of the spaces specified, they can join - # this room! - if space_id in joined_rooms: - return True - - # The user was not in any of the required spaces. - return False - async def _local_membership_update( self, requester: Requester, @@ -302,7 +246,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if ( newly_joined and not user_is_invited - and not await self._can_join_without_invite( + and not await self.event_auth_handler.can_join_without_invite( prev_state_ids, event.room_version, user_id ) ): diff --git a/synapse/server.py b/synapse/server.py index 95a2cd2e5d..045b8f3fca 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -77,6 +77,7 @@ from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.directory import DirectoryHandler from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler +from synapse.handlers.event_auth import EventAuthHandler from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.federation import FederationHandler from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler @@ -749,6 +750,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_space_summary_handler(self) -> SpaceSummaryHandler: return SpaceSummaryHandler(self) + @cache_in_self + def get_event_auth_handler(self) -> EventAuthHandler: + return EventAuthHandler(self) + @cache_in_self def get_external_cache(self) -> ExternalCache: return ExternalCache(self) diff --git a/tests/test_federation.py b/tests/test_federation.py index 86a44a13da..0a3a996ec1 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -75,8 +75,10 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) self.handler = self.homeserver.get_federation_handler() - self.handler.do_auth = lambda origin, event, context, auth_events: succeed( - context + self.handler._check_event_auth = ( + lambda origin, event, context, state, auth_events, backfilled: succeed( + context + ) ) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( -- cgit 1.5.1 From e8816c6aced208cdf1310393a212b5109892ffbf Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 14 Apr 2021 12:33:37 -0400 Subject: Revert "Check for space membership during a remote join of a restricted room. (#9763)" This reverts commit cc51aaaa7adb0ec2235e027b5184ebda9b660ec4. The PR was prematurely merged and not yet approved. --- changelog.d/9763.feature | 1 - changelog.d/9800.feature | 1 - synapse/handlers/event_auth.py | 82 ---------------- synapse/handlers/federation.py | 212 +++++++++++++--------------------------- synapse/handlers/room_member.py | 62 +++++++++++- synapse/server.py | 5 - tests/test_federation.py | 6 +- 7 files changed, 131 insertions(+), 238 deletions(-) delete mode 100644 changelog.d/9763.feature delete mode 100644 changelog.d/9800.feature delete mode 100644 synapse/handlers/event_auth.py (limited to 'tests/test_federation.py') diff --git a/changelog.d/9763.feature b/changelog.d/9763.feature deleted file mode 100644 index 9404ad2fc0..0000000000 --- a/changelog.d/9763.feature +++ /dev/null @@ -1 +0,0 @@ -Update experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership. diff --git a/changelog.d/9800.feature b/changelog.d/9800.feature deleted file mode 100644 index 9404ad2fc0..0000000000 --- a/changelog.d/9800.feature +++ /dev/null @@ -1 +0,0 @@ -Update experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership. diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py deleted file mode 100644 index 06da1a93d9..0000000000 --- a/synapse/handlers/event_auth.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import TYPE_CHECKING - -from synapse.api.constants import EventTypes, JoinRules -from synapse.api.room_versions import RoomVersion -from synapse.types import StateMap - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class EventAuthHandler: - def __init__(self, hs: "HomeServer"): - self._store = hs.get_datastore() - - async def can_join_without_invite( - self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str - ) -> bool: - """ - Check whether a user can join a room without an invite. - - When joining a room with restricted joined rules (as defined in MSC3083), - the membership of spaces must be checked during join. - - Args: - state_ids: The state of the room as it currently is. - room_version: The room version of the room being joined. - user_id: The user joining the room. - - Returns: - True if the user can join the room, false otherwise. - """ - # This only applies to room versions which support the new join rule. - if not room_version.msc3083_join_rules: - return True - - # If there's no join rule, then it defaults to public (so this doesn't apply). - join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None) - if not join_rules_event_id: - return True - - # If the join rule is not restricted, this doesn't apply. - join_rules_event = await self._store.get_event(join_rules_event_id) - if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED: - return True - - # If allowed is of the wrong form, then only allow invited users. - allowed_spaces = join_rules_event.content.get("allow", []) - if not isinstance(allowed_spaces, list): - return False - - # Get the list of joined rooms and see if there's an overlap. - joined_rooms = await self._store.get_rooms_for_user(user_id) - - # Pull out the other room IDs, invalid data gets filtered. - for space in allowed_spaces: - if not isinstance(space, dict): - continue - - space_id = space.get("space") - if not isinstance(space_id, str): - continue - - # The user was joined to one of the spaces specified, they can join - # this room! - if space_id in joined_rooms: - return True - - # The user was not in any of the required spaces. - return False diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0c9bdf51a4..fe1d83f6b8 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -103,7 +103,7 @@ logger = logging.getLogger(__name__) @attr.s(slots=True) class _NewEventInfo: - """Holds information about a received event, ready for passing to _auth_and_persist_events + """Holds information about a received event, ready for passing to _handle_new_events Attributes: event: the received event @@ -146,7 +146,6 @@ class FederationHandler(BaseHandler): self.is_mine_id = hs.is_mine_id self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() - self.event_auth_handler = hs.get_event_auth_handler() self._message_handler = hs.get_message_handler() self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config @@ -808,10 +807,7 @@ class FederationHandler(BaseHandler): logger.debug("Processing event: %s", event) try: - context = await self.state_handler.compute_event_context( - event, old_state=state - ) - await self._auth_and_persist_event(origin, event, context, state=state) + await self._handle_new_event(origin, event, state=state) except AuthError as e: raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) @@ -1014,9 +1010,7 @@ class FederationHandler(BaseHandler): ) if ev_infos: - await self._auth_and_persist_events( - dest, room_id, ev_infos, backfilled=True - ) + await self._handle_new_events(dest, room_id, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -1029,12 +1023,10 @@ class FederationHandler(BaseHandler): # non-outliers assert not event.internal_metadata.is_outlier() - context = await self.state_handler.compute_event_context(event) - # We store these one at a time since each event depends on the # previous to work out the state. # TODO: We can probably do something more clever here. - await self._auth_and_persist_event(dest, event, context, backfilled=True) + await self._handle_new_event(dest, event, backfilled=True) return events @@ -1368,7 +1360,7 @@ class FederationHandler(BaseHandler): event_infos.append(_NewEventInfo(event, None, auth)) - await self._auth_and_persist_events( + await self._handle_new_events( destination, room_id, event_infos, @@ -1674,47 +1666,16 @@ class FederationHandler(BaseHandler): # would introduce the danger of backwards-compatibility problems. event.internal_metadata.send_on_behalf_of = origin - # Calculate the event context. - context = await self.state_handler.compute_event_context(event) - - # Get the state before the new event. - prev_state_ids = await context.get_prev_state_ids() - - # Check if the user is already in the room or invited to the room. - user_id = event.state_key - prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) - newly_joined = True - is_invite = False - if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) - newly_joined = prev_member_event.membership != Membership.JOIN - is_invite = prev_member_event.membership == Membership.INVITE - - # If the member is not already in the room, and not invited, check if - # they should be allowed access via membership in a space. - if ( - newly_joined - and not is_invite - and not await self.event_auth_handler.can_join_without_invite( - prev_state_ids, - event.room_version, - user_id, - ) - ): - raise SynapseError( - 400, - "You do not belong to any of the required spaces to join this room.", - ) - - # Persist the event. - await self._auth_and_persist_event(origin, event, context) + context = await self._handle_new_event(origin, event) logger.debug( - "on_send_join_request: After _auth_and_persist_event: %s, sigs: %s", + "on_send_join_request: After _handle_new_event: %s, sigs: %s", event.event_id, event.signatures, ) + prev_state_ids = await context.get_prev_state_ids() + state_ids = list(prev_state_ids.values()) auth_chain = await self.store.get_auth_chain(event.room_id, state_ids) @@ -1917,11 +1878,10 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = False - context = await self.state_handler.compute_event_context(event) - await self._auth_and_persist_event(origin, event, context) + await self._handle_new_event(origin, event) logger.debug( - "on_send_leave_request: After _auth_and_persist_event: %s, sigs: %s", + "on_send_leave_request: After _handle_new_event: %s, sigs: %s", event.event_id, event.signatures, ) @@ -2029,44 +1989,16 @@ class FederationHandler(BaseHandler): async def get_min_depth_for_context(self, context: str) -> int: return await self.store.get_min_depth(context) - async def _auth_and_persist_event( + async def _handle_new_event( self, origin: str, event: EventBase, - context: EventContext, state: Optional[Iterable[EventBase]] = None, auth_events: Optional[MutableStateMap[EventBase]] = None, backfilled: bool = False, - ) -> None: - """ - Process an event by performing auth checks and then persisting to the database. - - Args: - origin: The host the event originates from. - event: The event itself. - context: - The event context. - - NB that this function potentially modifies it. - state: - The state events used to auth the event. If this is not provided - the current state events will be used. - auth_events: - Map from (event_type, state_key) to event - - Normally, our calculated auth_events based on the state of the room - at the event's position in the DAG, though occasionally (eg if the - event is an outlier), may be the auth events claimed by the remote - server. - backfilled: True if the event was backfilled. - """ - context = await self._check_event_auth( - origin, - event, - context, - state=state, - auth_events=auth_events, - backfilled=backfilled, + ) -> EventContext: + context = await self._prep_event( + origin, event, state=state, auth_events=auth_events, backfilled=backfilled ) try: @@ -2088,7 +2020,9 @@ class FederationHandler(BaseHandler): ) raise - async def _auth_and_persist_events( + return context + + async def _handle_new_events( self, origin: str, room_id: str, @@ -2106,13 +2040,9 @@ class FederationHandler(BaseHandler): async def prep(ev_info: _NewEventInfo): event = ev_info.event with nested_logging_context(suffix=event.event_id): - res = await self.state_handler.compute_event_context( - event, old_state=ev_info.state - ) - res = await self._check_event_auth( + res = await self._prep_event( origin, event, - res, state=ev_info.state, auth_events=ev_info.auth_events, backfilled=backfilled, @@ -2247,6 +2177,49 @@ class FederationHandler(BaseHandler): room_id, [(event, new_event_context)] ) + async def _prep_event( + self, + origin: str, + event: EventBase, + state: Optional[Iterable[EventBase]], + auth_events: Optional[MutableStateMap[EventBase]], + backfilled: bool, + ) -> EventContext: + context = await self.state_handler.compute_event_context(event, old_state=state) + + if not auth_events: + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = self.auth.compute_auth_events( + event, prev_state_ids, for_verification=True + ) + auth_events_x = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} + + # This is a hack to fix some old rooms where the initial join event + # didn't reference the create event in its auth events. + if event.type == EventTypes.Member and not event.auth_event_ids(): + if len(event.prev_event_ids()) == 1 and event.depth < 5: + c = await self.store.get_event( + event.prev_event_ids()[0], allow_none=True + ) + if c and c.type == EventTypes.Create: + auth_events[(c.type, c.state_key)] = c + + context = await self.do_auth(origin, event, context, auth_events=auth_events) + + if not context.rejected: + await self._check_for_soft_fail(event, state, backfilled) + + if event.type == EventTypes.GuestAccess and not context.rejected: + await self.maybe_kick_guest_users(event) + + # If we are going to send this event over federation we precaclculate + # the joined hosts. + if event.internal_metadata.get_send_on_behalf_of(): + await self.event_creation_handler.cache_joined_hosts_for_event(event) + + return context + async def _check_for_soft_fail( self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool ) -> None: @@ -2357,28 +2330,19 @@ class FederationHandler(BaseHandler): return missing_events - async def _check_event_auth( + async def do_auth( self, origin: str, event: EventBase, context: EventContext, - state: Optional[Iterable[EventBase]], - auth_events: Optional[MutableStateMap[EventBase]], - backfilled: bool, + auth_events: MutableStateMap[EventBase], ) -> EventContext: """ - Checks whether an event should be rejected (for failing auth checks). Args: - origin: The host the event originates from. - event: The event itself. + origin: + event: context: - The event context. - - NB that this function potentially modifies it. - state: - The state events used to auth the event. If this is not provided - the current state events will be used. auth_events: Map from (event_type, state_key) to event @@ -2388,32 +2352,12 @@ class FederationHandler(BaseHandler): server. Also NB that this function adds entries to it. - backfilled: True if the event was backfilled. - Returns: - The updated context object. + updated context object """ room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - if not auth_events: - prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = self.auth.compute_auth_events( - event, prev_state_ids, for_verification=True - ) - auth_events_x = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} - - # This is a hack to fix some old rooms where the initial join event - # didn't reference the create event in its auth events. - if event.type == EventTypes.Member and not event.auth_event_ids(): - if len(event.prev_event_ids()) == 1 and event.depth < 5: - c = await self.store.get_event( - event.prev_event_ids()[0], allow_none=True - ) - if c and c.type == EventTypes.Create: - auth_events[(c.type, c.state_key)] = c - try: context = await self._update_auth_events_and_context_for_auth( origin, event, context, auth_events @@ -2435,17 +2379,6 @@ class FederationHandler(BaseHandler): logger.warning("Failed auth resolution for %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR - if not context.rejected: - await self._check_for_soft_fail(event, state, backfilled) - - if event.type == EventTypes.GuestAccess and not context.rejected: - await self.maybe_kick_guest_users(event) - - # If we are going to send this event over federation we precaclculate - # the joined hosts. - if event.internal_metadata.get_send_on_behalf_of(): - await self.event_creation_handler.cache_joined_hosts_for_event(event) - return context async def _update_auth_events_and_context_for_auth( @@ -2455,7 +2388,7 @@ class FederationHandler(BaseHandler): context: EventContext, auth_events: MutableStateMap[EventBase], ) -> EventContext: - """Helper for _check_event_auth. See there for docs. + """Helper for do_auth. See there for docs. Checks whether a given event has the expected auth events. If it doesn't then we talk to the remote server to compare state to see if @@ -2535,14 +2468,9 @@ class FederationHandler(BaseHandler): e.internal_metadata.outlier = True logger.debug( - "_check_event_auth %s missing_auth: %s", - event.event_id, - e.event_id, - ) - context = await self.state_handler.compute_event_context(e) - await self._auth_and_persist_event( - origin, e, context, auth_events=auth + "do_auth %s missing_auth: %s", event.event_id, e.event_id ) + await self._handle_new_event(origin, e, auth_events=auth) if e.event_id in event_auth_events: auth_events[(e.type, e.state_key)] = e diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 2c5bada1d8..2bbfac6471 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -19,7 +19,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from synapse import types -from synapse.api.constants import AccountDataTypes, EventTypes, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership from synapse.api.errors import ( AuthError, Codes, @@ -28,6 +28,7 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.ratelimiting import Ratelimiter +from synapse.api.room_versions import RoomVersion from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID @@ -63,7 +64,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.profile_handler = hs.get_profile_handler() self.event_creation_handler = hs.get_event_creation_handler() self.account_data_handler = hs.get_account_data_handler() - self.event_auth_handler = hs.get_event_auth_handler() self.member_linearizer = Linearizer(name="member") @@ -178,6 +178,62 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id) + async def _can_join_without_invite( + self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str + ) -> bool: + """ + Check whether a user can join a room without an invite. + + When joining a room with restricted joined rules (as defined in MSC3083), + the membership of spaces must be checked during join. + + Args: + state_ids: The state of the room as it currently is. + room_version: The room version of the room being joined. + user_id: The user joining the room. + + Returns: + True if the user can join the room, false otherwise. + """ + # This only applies to room versions which support the new join rule. + if not room_version.msc3083_join_rules: + return True + + # If there's no join rule, then it defaults to public (so this doesn't apply). + join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None) + if not join_rules_event_id: + return True + + # If the join rule is not restricted, this doesn't apply. + join_rules_event = await self.store.get_event(join_rules_event_id) + if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED: + return True + + # If allowed is of the wrong form, then only allow invited users. + allowed_spaces = join_rules_event.content.get("allow", []) + if not isinstance(allowed_spaces, list): + return False + + # Get the list of joined rooms and see if there's an overlap. + joined_rooms = await self.store.get_rooms_for_user(user_id) + + # Pull out the other room IDs, invalid data gets filtered. + for space in allowed_spaces: + if not isinstance(space, dict): + continue + + space_id = space.get("space") + if not isinstance(space_id, str): + continue + + # The user was joined to one of the spaces specified, they can join + # this room! + if space_id in joined_rooms: + return True + + # The user was not in any of the required spaces. + return False + async def _local_membership_update( self, requester: Requester, @@ -246,7 +302,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if ( newly_joined and not user_is_invited - and not await self.event_auth_handler.can_join_without_invite( + and not await self._can_join_without_invite( prev_state_ids, event.room_version, user_id ) ): diff --git a/synapse/server.py b/synapse/server.py index 045b8f3fca..95a2cd2e5d 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -77,7 +77,6 @@ from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.directory import DirectoryHandler from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler -from synapse.handlers.event_auth import EventAuthHandler from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.federation import FederationHandler from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler @@ -750,10 +749,6 @@ class HomeServer(metaclass=abc.ABCMeta): def get_space_summary_handler(self) -> SpaceSummaryHandler: return SpaceSummaryHandler(self) - @cache_in_self - def get_event_auth_handler(self) -> EventAuthHandler: - return EventAuthHandler(self) - @cache_in_self def get_external_cache(self) -> ExternalCache: return ExternalCache(self) diff --git a/tests/test_federation.py b/tests/test_federation.py index 0a3a996ec1..86a44a13da 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -75,10 +75,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) self.handler = self.homeserver.get_federation_handler() - self.handler._check_event_auth = ( - lambda origin, event, context, state, auth_events, backfilled: succeed( - context - ) + self.handler.do_auth = lambda origin, event, context, auth_events: succeed( + context ) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( -- cgit 1.5.1 From 936e69825ab684ca580edb6038e86fdb2e561776 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 14 Apr 2021 12:35:28 -0400 Subject: Separate creating an event context from persisting it in the federation handler (#9800) This refactoring allows adding logic that uses the event context before persisting it. --- changelog.d/9800.feature | 1 + synapse/handlers/federation.py | 178 ++++++++++++++++++++++++++--------------- tests/test_federation.py | 6 +- 3 files changed, 118 insertions(+), 67 deletions(-) create mode 100644 changelog.d/9800.feature (limited to 'tests/test_federation.py') diff --git a/changelog.d/9800.feature b/changelog.d/9800.feature new file mode 100644 index 0000000000..9404ad2fc0 --- /dev/null +++ b/changelog.d/9800.feature @@ -0,0 +1 @@ +Update experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index fe1d83f6b8..4b3730aa3b 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -103,7 +103,7 @@ logger = logging.getLogger(__name__) @attr.s(slots=True) class _NewEventInfo: - """Holds information about a received event, ready for passing to _handle_new_events + """Holds information about a received event, ready for passing to _auth_and_persist_events Attributes: event: the received event @@ -807,7 +807,10 @@ class FederationHandler(BaseHandler): logger.debug("Processing event: %s", event) try: - await self._handle_new_event(origin, event, state=state) + context = await self.state_handler.compute_event_context( + event, old_state=state + ) + await self._auth_and_persist_event(origin, event, context, state=state) except AuthError as e: raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) @@ -1010,7 +1013,9 @@ class FederationHandler(BaseHandler): ) if ev_infos: - await self._handle_new_events(dest, room_id, ev_infos, backfilled=True) + await self._auth_and_persist_events( + dest, room_id, ev_infos, backfilled=True + ) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -1023,10 +1028,12 @@ class FederationHandler(BaseHandler): # non-outliers assert not event.internal_metadata.is_outlier() + context = await self.state_handler.compute_event_context(event) + # We store these one at a time since each event depends on the # previous to work out the state. # TODO: We can probably do something more clever here. - await self._handle_new_event(dest, event, backfilled=True) + await self._auth_and_persist_event(dest, event, context, backfilled=True) return events @@ -1360,7 +1367,7 @@ class FederationHandler(BaseHandler): event_infos.append(_NewEventInfo(event, None, auth)) - await self._handle_new_events( + await self._auth_and_persist_events( destination, room_id, event_infos, @@ -1666,10 +1673,11 @@ class FederationHandler(BaseHandler): # would introduce the danger of backwards-compatibility problems. event.internal_metadata.send_on_behalf_of = origin - context = await self._handle_new_event(origin, event) + context = await self.state_handler.compute_event_context(event) + context = await self._auth_and_persist_event(origin, event, context) logger.debug( - "on_send_join_request: After _handle_new_event: %s, sigs: %s", + "on_send_join_request: After _auth_and_persist_event: %s, sigs: %s", event.event_id, event.signatures, ) @@ -1878,10 +1886,11 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = False - await self._handle_new_event(origin, event) + context = await self.state_handler.compute_event_context(event) + await self._auth_and_persist_event(origin, event, context) logger.debug( - "on_send_leave_request: After _handle_new_event: %s, sigs: %s", + "on_send_leave_request: After _auth_and_persist_event: %s, sigs: %s", event.event_id, event.signatures, ) @@ -1989,16 +1998,47 @@ class FederationHandler(BaseHandler): async def get_min_depth_for_context(self, context: str) -> int: return await self.store.get_min_depth(context) - async def _handle_new_event( + async def _auth_and_persist_event( self, origin: str, event: EventBase, + context: EventContext, state: Optional[Iterable[EventBase]] = None, auth_events: Optional[MutableStateMap[EventBase]] = None, backfilled: bool = False, ) -> EventContext: - context = await self._prep_event( - origin, event, state=state, auth_events=auth_events, backfilled=backfilled + """ + Process an event by performing auth checks and then persisting to the database. + + Args: + origin: The host the event originates from. + event: The event itself. + context: + The event context. + + NB that this function potentially modifies it. + state: + The state events used to check the event for soft-fail. If this is + not provided the current state events will be used. + auth_events: + Map from (event_type, state_key) to event + + Normally, our calculated auth_events based on the state of the room + at the event's position in the DAG, though occasionally (eg if the + event is an outlier), may be the auth events claimed by the remote + server. + backfilled: True if the event was backfilled. + + Returns: + The event context. + """ + context = await self._check_event_auth( + origin, + event, + context, + state=state, + auth_events=auth_events, + backfilled=backfilled, ) try: @@ -2022,7 +2062,7 @@ class FederationHandler(BaseHandler): return context - async def _handle_new_events( + async def _auth_and_persist_events( self, origin: str, room_id: str, @@ -2040,9 +2080,13 @@ class FederationHandler(BaseHandler): async def prep(ev_info: _NewEventInfo): event = ev_info.event with nested_logging_context(suffix=event.event_id): - res = await self._prep_event( + res = await self.state_handler.compute_event_context( + event, old_state=ev_info.state + ) + res = await self._check_event_auth( origin, event, + res, state=ev_info.state, auth_events=ev_info.auth_events, backfilled=backfilled, @@ -2177,49 +2221,6 @@ class FederationHandler(BaseHandler): room_id, [(event, new_event_context)] ) - async def _prep_event( - self, - origin: str, - event: EventBase, - state: Optional[Iterable[EventBase]], - auth_events: Optional[MutableStateMap[EventBase]], - backfilled: bool, - ) -> EventContext: - context = await self.state_handler.compute_event_context(event, old_state=state) - - if not auth_events: - prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = self.auth.compute_auth_events( - event, prev_state_ids, for_verification=True - ) - auth_events_x = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} - - # This is a hack to fix some old rooms where the initial join event - # didn't reference the create event in its auth events. - if event.type == EventTypes.Member and not event.auth_event_ids(): - if len(event.prev_event_ids()) == 1 and event.depth < 5: - c = await self.store.get_event( - event.prev_event_ids()[0], allow_none=True - ) - if c and c.type == EventTypes.Create: - auth_events[(c.type, c.state_key)] = c - - context = await self.do_auth(origin, event, context, auth_events=auth_events) - - if not context.rejected: - await self._check_for_soft_fail(event, state, backfilled) - - if event.type == EventTypes.GuestAccess and not context.rejected: - await self.maybe_kick_guest_users(event) - - # If we are going to send this event over federation we precaclculate - # the joined hosts. - if event.internal_metadata.get_send_on_behalf_of(): - await self.event_creation_handler.cache_joined_hosts_for_event(event) - - return context - async def _check_for_soft_fail( self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool ) -> None: @@ -2330,19 +2331,28 @@ class FederationHandler(BaseHandler): return missing_events - async def do_auth( + async def _check_event_auth( self, origin: str, event: EventBase, context: EventContext, - auth_events: MutableStateMap[EventBase], + state: Optional[Iterable[EventBase]], + auth_events: Optional[MutableStateMap[EventBase]], + backfilled: bool, ) -> EventContext: """ + Checks whether an event should be rejected (for failing auth checks). Args: - origin: - event: + origin: The host the event originates from. + event: The event itself. context: + The event context. + + NB that this function potentially modifies it. + state: + The state events used to check the event for soft-fail. If this is + not provided the current state events will be used. auth_events: Map from (event_type, state_key) to event @@ -2352,12 +2362,34 @@ class FederationHandler(BaseHandler): server. Also NB that this function adds entries to it. + + If this is not provided, it is calculated from the previous state IDs. + backfilled: True if the event was backfilled. + Returns: - updated context object + The updated context object. """ room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + if not auth_events: + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = self.auth.compute_auth_events( + event, prev_state_ids, for_verification=True + ) + auth_events_x = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} + + # This is a hack to fix some old rooms where the initial join event + # didn't reference the create event in its auth events. + if event.type == EventTypes.Member and not event.auth_event_ids(): + if len(event.prev_event_ids()) == 1 and event.depth < 5: + c = await self.store.get_event( + event.prev_event_ids()[0], allow_none=True + ) + if c and c.type == EventTypes.Create: + auth_events[(c.type, c.state_key)] = c + try: context = await self._update_auth_events_and_context_for_auth( origin, event, context, auth_events @@ -2379,6 +2411,17 @@ class FederationHandler(BaseHandler): logger.warning("Failed auth resolution for %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR + if not context.rejected: + await self._check_for_soft_fail(event, state, backfilled) + + if event.type == EventTypes.GuestAccess and not context.rejected: + await self.maybe_kick_guest_users(event) + + # If we are going to send this event over federation we precaclculate + # the joined hosts. + if event.internal_metadata.get_send_on_behalf_of(): + await self.event_creation_handler.cache_joined_hosts_for_event(event) + return context async def _update_auth_events_and_context_for_auth( @@ -2388,7 +2431,7 @@ class FederationHandler(BaseHandler): context: EventContext, auth_events: MutableStateMap[EventBase], ) -> EventContext: - """Helper for do_auth. See there for docs. + """Helper for _check_event_auth. See there for docs. Checks whether a given event has the expected auth events. If it doesn't then we talk to the remote server to compare state to see if @@ -2468,9 +2511,14 @@ class FederationHandler(BaseHandler): e.internal_metadata.outlier = True logger.debug( - "do_auth %s missing_auth: %s", event.event_id, e.event_id + "_check_event_auth %s missing_auth: %s", + event.event_id, + e.event_id, + ) + context = await self.state_handler.compute_event_context(e) + await self._auth_and_persist_event( + origin, e, context, auth_events=auth ) - await self._handle_new_event(origin, e, auth_events=auth) if e.event_id in event_auth_events: auth_events[(e.type, e.state_key)] = e diff --git a/tests/test_federation.py b/tests/test_federation.py index 86a44a13da..0a3a996ec1 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -75,8 +75,10 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) self.handler = self.homeserver.get_federation_handler() - self.handler.do_auth = lambda origin, event, context, auth_events: succeed( - context + self.handler._check_event_auth = ( + lambda origin, event, context, state, auth_events, backfilled: succeed( + context + ) ) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( -- cgit 1.5.1