summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_client.py17
-rw-r--r--synapse/federation/federation_server.py42
-rw-r--r--synapse/federation/transport/client.py12
-rw-r--r--synapse/federation/transport/server.py255
4 files changed, 222 insertions, 104 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py

index 35b28b3ed2..ed09c6af1f 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py
@@ -1,6 +1,5 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyrignt 2020 Sorunome -# Copyrignt 2020 The Matrix.org Foundation C.I.C. +# Copyright 2015-2021 The Matrix.org Foundation C.I.C. +# Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -622,6 +621,7 @@ class FederationClient(FederationBase): no servers successfully handle the request. """ valid_memberships = {Membership.JOIN, Membership.LEAVE, Membership.KNOCK} + if membership not in valid_memberships: raise RuntimeError( "make_membership_event called with membership='%s', must be one of %s" @@ -640,6 +640,13 @@ class FederationClient(FederationBase): if not room_version: raise UnsupportedRoomVersionError() + if not room_version.msc2403_knocking and membership == Membership.KNOCK: + raise SynapseError( + 400, + "This room version does not support knocking", + errcode=Codes.FORBIDDEN, + ) + pdu_dict = ret.get("event", None) if not isinstance(pdu_dict, dict): raise InvalidResponseError("Bad 'event' field in response") @@ -977,7 +984,7 @@ class FederationClient(FederationBase): return await self._do_send_knock(destination, pdu) return await self._try_destination_list( - "xyz.amorgan.knock/send_knock", destinations, send_request + "send_knock", destinations, send_request ) async def _do_send_knock(self, destination: str, pdu: EventBase) -> JsonDict: @@ -997,7 +1004,7 @@ class FederationClient(FederationBase): """ time_now = self._clock.time_msec() - return await self.transport_layer.send_knock_v2( + return await self.transport_layer.send_knock_v1( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 60a26741a0..3466523c3c 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -130,7 +130,7 @@ class FederationServer(FederationBase): # come in waves. self._state_resp_cache = ResponseCache( hs.get_clock(), "state_resp", timeout_ms=30000 - ) # type: ResponseCache[Tuple[str, str]] + ) # type: ResponseCache[Tuple[str, Optional[str]]] self._state_ids_resp_cache = ResponseCache( hs.get_clock(), "state_ids_resp", timeout_ms=30000 ) # type: ResponseCache[Tuple[str, str]] @@ -139,6 +139,8 @@ class FederationServer(FederationBase): hs.config.federation.federation_metrics_domains ) + self._room_prejoin_state_types = hs.config.api.room_prejoin_state + async def on_backfill_request( self, origin: str, room_id: str, versions: List[str], limit: int ) -> Tuple[int, Dict[str, Any]]: @@ -407,7 +409,7 @@ class FederationServer(FederationBase): ) async def on_room_state_request( - self, origin: str, room_id: str, event_id: str + self, origin: str, room_id: str, event_id: Optional[str] ) -> Tuple[int, Dict[str, Any]]: origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -464,7 +466,7 @@ class FederationServer(FederationBase): return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} async def _on_context_state_request_compute( - self, room_id: str, event_id: str + self, room_id: str, event_id: Optional[str] ) -> Dict[str, list]: if event_id: pdus = await self.handler.get_state_for_pdu( @@ -607,16 +609,29 @@ class FederationServer(FederationBase): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) - room_version = await self.store.get_room_version_id(room_id) - if room_version not in supported_versions: + room_version = await self.store.get_room_version(room_id) + + # Check that this room version is supported by the remote homeserver + if room_version.identifier not in supported_versions: logger.warning( - "Room version %s not in %s", room_version, supported_versions + "Room version %s not in %s", room_version.identifier, supported_versions + ) + raise IncompatibleRoomVersionError(room_version=room_version.identifier) + + # Check that this room supports knocking as defined by its room version + if not room_version.msc2403_knocking: + raise SynapseError( + 403, + "This room version does not support knocking", + errcode=Codes.FORBIDDEN, ) - raise IncompatibleRoomVersionError(room_version=room_version) pdu = await self.handler.on_make_knock_request(origin, room_id, user_id) time_now = self._clock.time_msec() - return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} + return { + "event": pdu.get_pdu_json(time_now), + "room_version": room_version.identifier, + } async def on_send_knock_request( self, @@ -640,6 +655,15 @@ class FederationServer(FederationBase): logger.debug("on_send_knock_request: content: %s", content) room_version = await self.store.get_room_version(room_id) + + # Check that this room supports knocking as defined by its room version + if not room_version.msc2403_knocking: + raise SynapseError( + 403, + "This room version does not support knocking", + errcode=Codes.FORBIDDEN, + ) + pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) @@ -657,7 +681,7 @@ class FederationServer(FederationBase): # related to the room while the knock request is pending. stripped_room_state = ( await self.store.get_stripped_room_state_from_event_context( - event_context, DEFAULT_ROOM_STATE_TYPES + event_context, self._room_prejoin_state_types ) ) return {"knock_state_events": stripped_room_state} diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index bf5b541deb..e6c3cf9bb0 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py
@@ -1,5 +1,3 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd # Copyright 2020 Sorunome # Copyright 2020 The Matrix.org Foundation C.I.C. # @@ -223,6 +221,7 @@ class TransportLayerClient: is not in our federation whitelist """ valid_memberships = {Membership.JOIN, Membership.LEAVE, Membership.KNOCK} + if membership not in valid_memberships: raise RuntimeError( "make_membership_event called with membership='%s', must be one of %s" @@ -335,7 +334,7 @@ class TransportLayerClient: return response @log_function - async def send_knock_v2( + async def send_knock_v1( self, destination: str, room_id: str, @@ -362,12 +361,7 @@ class TransportLayerClient: The list of state events may be empty. """ - path = _create_path( - FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock", - "/send_knock/%s/%s", - room_id, - event_id, - ) + path = _create_v1_path("/send_knock/%s/%s", room_id, event_id) return await self.client.put_json( destination=destination, path=path, data=content diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index f1e659571a..a9942b41fb 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py
@@ -1,6 +1,4 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# Copyright 2019-2020 The Matrix.org Foundation C.I.C. +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,15 +26,17 @@ from synapse.api.urls import ( FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX, ) +from synapse.handlers.groups_local import GroupsLocalHandler from synapse.http.server import HttpServer, JsonResource from synapse.http.servlet import ( assert_params_in_dict, parse_boolean_from_args, parse_integer_from_args, parse_json_object_from_request, - parse_list_from_args, parse_string_from_args, + parse_strings_from_args, ) +from synapse.logging import opentracing from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( SynapseTags, @@ -277,10 +277,17 @@ class BaseFederationServlet: RATELIMIT = True # Whether to rate limit requests or not - def __init__(self, handler, authenticator, ratelimiter, server_name): - self.handler = handler + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + self.hs = hs self.authenticator = authenticator self.ratelimiter = ratelimiter + self.server_name = server_name def _wrap(self, func): authenticator = self.authenticator @@ -340,6 +347,8 @@ class BaseFederationServlet: ) with scope: + opentracing.inject_response_headers(request.responseHeaders) + if origin and self.RATELIMIT: with ratelimiter.ratelimit(origin) as d: await d @@ -377,17 +386,30 @@ class BaseFederationServlet: ) -class FederationSendServlet(BaseFederationServlet): +class BaseFederationServerServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a federation server handler. + + See BaseFederationServlet for more information. + """ + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_federation_server() + + +class FederationSendServlet(BaseFederationServerServlet): PATH = "/send/(?P<transaction_id>[^/]*)/?" # We ratelimit manually in the handler as we queue up the requests and we # don't want to fill up the ratelimiter with blocked requests. RATELIMIT = False - def __init__(self, handler, server_name, **kwargs): - super().__init__(handler, server_name=server_name, **kwargs) - self.server_name = server_name - # This is when someone is trying to send us a bunch of data. async def on_PUT(self, origin, content, query, transaction_id): """Called on PUT /send/<transaction_id>/ @@ -436,7 +458,7 @@ class FederationSendServlet(BaseFederationServlet): return code, response -class FederationEventServlet(BaseFederationServlet): +class FederationEventServlet(BaseFederationServerServlet): PATH = "/event/(?P<event_id>[^/]*)/?" # This is when someone asks for a data item for a given server data_id pair. @@ -444,7 +466,7 @@ class FederationEventServlet(BaseFederationServlet): return await self.handler.on_pdu_request(origin, event_id) -class FederationStateV1Servlet(BaseFederationServlet): +class FederationStateV1Servlet(BaseFederationServerServlet): PATH = "/state/(?P<room_id>[^/]*)/?" # This is when someone asks for all data for a given room. @@ -456,7 +478,7 @@ class FederationStateV1Servlet(BaseFederationServlet): ) -class FederationStateIdsServlet(BaseFederationServlet): +class FederationStateIdsServlet(BaseFederationServerServlet): PATH = "/state_ids/(?P<room_id>[^/]*)/?" async def on_GET(self, origin, content, query, room_id): @@ -467,7 +489,7 @@ class FederationStateIdsServlet(BaseFederationServlet): ) -class FederationBackfillServlet(BaseFederationServlet): +class FederationBackfillServlet(BaseFederationServerServlet): PATH = "/backfill/(?P<room_id>[^/]*)/?" async def on_GET(self, origin, content, query, room_id): @@ -480,7 +502,7 @@ class FederationBackfillServlet(BaseFederationServlet): return await self.handler.on_backfill_request(origin, room_id, versions, limit) -class FederationQueryServlet(BaseFederationServlet): +class FederationQueryServlet(BaseFederationServerServlet): PATH = "/query/(?P<query_type>[^/]*)" # This is when we receive a server-server Query @@ -490,7 +512,7 @@ class FederationQueryServlet(BaseFederationServlet): return await self.handler.on_query_request(query_type, args) -class FederationMakeJoinServlet(BaseFederationServlet): +class FederationMakeJoinServlet(BaseFederationServerServlet): PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)" async def on_GET(self, origin, _content, query, room_id, user_id): @@ -520,7 +542,7 @@ class FederationMakeJoinServlet(BaseFederationServlet): return 200, content -class FederationMakeLeaveServlet(BaseFederationServlet): +class FederationMakeLeaveServlet(BaseFederationServerServlet): PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)" async def on_GET(self, origin, content, query, room_id, user_id): @@ -528,7 +550,7 @@ class FederationMakeLeaveServlet(BaseFederationServlet): return 200, content -class FederationV1SendLeaveServlet(BaseFederationServlet): +class FederationV1SendLeaveServlet(BaseFederationServerServlet): PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): @@ -536,7 +558,7 @@ class FederationV1SendLeaveServlet(BaseFederationServlet): return 200, (200, content) -class FederationV2SendLeaveServlet(BaseFederationServlet): +class FederationV2SendLeaveServlet(BaseFederationServerServlet): PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PREFIX = FEDERATION_V2_PREFIX @@ -546,15 +568,13 @@ class FederationV2SendLeaveServlet(BaseFederationServlet): return 200, content -class FederationMakeKnockServlet(BaseFederationServlet): +class FederationMakeKnockServlet(BaseFederationServerServlet): PATH = "/make_knock/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)" - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock" - async def on_GET(self, origin, content, query, room_id, user_id): try: # Retrieve the room versions the remote homeserver claims to support - supported_versions = parse_list_from_args(query, "ver", encoding="utf-8") + supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8") except KeyError: raise SynapseError(400, "Missing required query parameter 'ver'") @@ -564,24 +584,22 @@ class FederationMakeKnockServlet(BaseFederationServlet): return 200, content -class FederationV2SendKnockServlet(BaseFederationServlet): +class FederationV1SendKnockServlet(BaseFederationServerServlet): PATH = "/send_knock/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock" - async def on_PUT(self, origin, content, query, room_id, event_id): content = await self.handler.on_send_knock_request(origin, content, room_id) return 200, content -class FederationEventAuthServlet(BaseFederationServlet): +class FederationEventAuthServlet(BaseFederationServerServlet): PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" async def on_GET(self, origin, content, query, room_id, event_id): return await self.handler.on_event_auth(origin, room_id, event_id) -class FederationV1SendJoinServlet(BaseFederationServlet): +class FederationV1SendJoinServlet(BaseFederationServerServlet): PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): @@ -591,7 +609,7 @@ class FederationV1SendJoinServlet(BaseFederationServlet): return 200, (200, content) -class FederationV2SendJoinServlet(BaseFederationServlet): +class FederationV2SendJoinServlet(BaseFederationServerServlet): PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PREFIX = FEDERATION_V2_PREFIX @@ -603,7 +621,7 @@ class FederationV2SendJoinServlet(BaseFederationServlet): return 200, content -class FederationV1InviteServlet(BaseFederationServlet): +class FederationV1InviteServlet(BaseFederationServerServlet): PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): @@ -620,7 +638,7 @@ class FederationV1InviteServlet(BaseFederationServlet): return 200, (200, content) -class FederationV2InviteServlet(BaseFederationServlet): +class FederationV2InviteServlet(BaseFederationServerServlet): PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PREFIX = FEDERATION_V2_PREFIX @@ -644,7 +662,7 @@ class FederationV2InviteServlet(BaseFederationServlet): return 200, content -class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet): +class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)" async def on_PUT(self, origin, content, query, room_id): @@ -652,21 +670,21 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet): return 200, {} -class FederationClientKeysQueryServlet(BaseFederationServlet): +class FederationClientKeysQueryServlet(BaseFederationServerServlet): PATH = "/user/keys/query" async def on_POST(self, origin, content, query): return await self.handler.on_query_client_keys(origin, content) -class FederationUserDevicesQueryServlet(BaseFederationServlet): +class FederationUserDevicesQueryServlet(BaseFederationServerServlet): PATH = "/user/devices/(?P<user_id>[^/]*)" async def on_GET(self, origin, content, query, user_id): return await self.handler.on_query_user_devices(origin, user_id) -class FederationClientKeysClaimServlet(BaseFederationServlet): +class FederationClientKeysClaimServlet(BaseFederationServerServlet): PATH = "/user/keys/claim" async def on_POST(self, origin, content, query): @@ -674,7 +692,7 @@ class FederationClientKeysClaimServlet(BaseFederationServlet): return 200, response -class FederationGetMissingEventsServlet(BaseFederationServlet): +class FederationGetMissingEventsServlet(BaseFederationServerServlet): # TODO(paul): Why does this path alone end with "/?" optional? PATH = "/get_missing_events/(?P<room_id>[^/]*)/?" @@ -694,7 +712,7 @@ class FederationGetMissingEventsServlet(BaseFederationServlet): return 200, content -class On3pidBindServlet(BaseFederationServlet): +class On3pidBindServlet(BaseFederationServerServlet): PATH = "/3pid/onbind" REQUIRE_AUTH = False @@ -724,7 +742,7 @@ class On3pidBindServlet(BaseFederationServlet): return 200, {} -class OpenIdUserInfo(BaseFederationServlet): +class OpenIdUserInfo(BaseFederationServerServlet): """ Exchange a bearer token for information about a user. @@ -800,8 +818,16 @@ class PublicRoomList(BaseFederationServlet): PATH = "/publicRooms" - def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access): - super().__init__(handler, authenticator, ratelimiter, server_name) + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + allow_access: bool, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_room_list_handler() self.allow_access = allow_access async def on_GET(self, origin, content, query): @@ -937,7 +963,24 @@ class FederationVersionServlet(BaseFederationServlet): ) -class FederationGroupsProfileServlet(BaseFederationServlet): +class BaseGroupsServerServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a groups server handler. + + See BaseFederationServlet for more information. + """ + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_server_handler() + + +class FederationGroupsProfileServlet(BaseGroupsServerServlet): """Get/set the basic profile of a group on behalf of a user""" PATH = "/groups/(?P<group_id>[^/]*)/profile" @@ -963,7 +1006,7 @@ class FederationGroupsProfileServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsSummaryServlet(BaseFederationServlet): +class FederationGroupsSummaryServlet(BaseGroupsServerServlet): PATH = "/groups/(?P<group_id>[^/]*)/summary" async def on_GET(self, origin, content, query, group_id): @@ -976,7 +1019,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsRoomsServlet(BaseFederationServlet): +class FederationGroupsRoomsServlet(BaseGroupsServerServlet): """Get the rooms in a group on behalf of a user""" PATH = "/groups/(?P<group_id>[^/]*)/rooms" @@ -991,7 +1034,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsAddRoomsServlet(BaseFederationServlet): +class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): """Add/remove room from group""" PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" @@ -1019,7 +1062,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): +class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet): """Update room config in group""" PATH = ( @@ -1039,7 +1082,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): return 200, result -class FederationGroupsUsersServlet(BaseFederationServlet): +class FederationGroupsUsersServlet(BaseGroupsServerServlet): """Get the users in a group on behalf of a user""" PATH = "/groups/(?P<group_id>[^/]*)/users" @@ -1054,7 +1097,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsInvitedUsersServlet(BaseFederationServlet): +class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet): """Get the users that have been invited to a group""" PATH = "/groups/(?P<group_id>[^/]*)/invited_users" @@ -1071,7 +1114,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsInviteServlet(BaseFederationServlet): +class FederationGroupsInviteServlet(BaseGroupsServerServlet): """Ask a group server to invite someone to the group""" PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" @@ -1088,7 +1131,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsAcceptInviteServlet(BaseFederationServlet): +class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet): """Accept an invitation from the group server""" PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite" @@ -1102,7 +1145,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsJoinServlet(BaseFederationServlet): +class FederationGroupsJoinServlet(BaseGroupsServerServlet): """Attempt to join a group""" PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join" @@ -1116,7 +1159,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsRemoveUserServlet(BaseFederationServlet): +class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet): """Leave or kick a user from the group""" PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" @@ -1133,7 +1176,24 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsLocalInviteServlet(BaseFederationServlet): +class BaseGroupsLocalServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a groups local handler. + + See BaseFederationServlet for more information. + """ + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_local_handler() + + +class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet): """A group server has invited a local user""" PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" @@ -1142,12 +1202,16 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet): if get_domain_from_id(group_id) != origin: raise SynapseError(403, "group_id doesn't match origin") + assert isinstance( + self.handler, GroupsLocalHandler + ), "Workers cannot handle group invites." + new_content = await self.handler.on_invite(group_id, user_id, content) return 200, new_content -class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): +class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): """A group server has removed a local user""" PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" @@ -1156,6 +1220,10 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): if get_domain_from_id(group_id) != origin: raise SynapseError(403, "user_id doesn't match origin") + assert isinstance( + self.handler, GroupsLocalHandler + ), "Workers cannot handle group removals." + new_content = await self.handler.user_removed_from_group( group_id, user_id, content ) @@ -1168,6 +1236,16 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)" + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_attestation_renewer() + async def on_POST(self, origin, content, query, group_id, user_id): # We don't need to check auth here as we check the attestation signatures @@ -1178,7 +1256,7 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): +class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): """Add/remove a room from the group summary, with optional category. Matches both: @@ -1235,7 +1313,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): return 200, resp -class FederationGroupsCategoriesServlet(BaseFederationServlet): +class FederationGroupsCategoriesServlet(BaseGroupsServerServlet): """Get all categories for a group""" PATH = "/groups/(?P<group_id>[^/]*)/categories/?" @@ -1250,7 +1328,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet): return 200, resp -class FederationGroupsCategoryServlet(BaseFederationServlet): +class FederationGroupsCategoryServlet(BaseGroupsServerServlet): """Add/remove/get a category in a group""" PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)" @@ -1303,7 +1381,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet): return 200, resp -class FederationGroupsRolesServlet(BaseFederationServlet): +class FederationGroupsRolesServlet(BaseGroupsServerServlet): """Get roles in a group""" PATH = "/groups/(?P<group_id>[^/]*)/roles/?" @@ -1318,7 +1396,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet): return 200, resp -class FederationGroupsRoleServlet(BaseFederationServlet): +class FederationGroupsRoleServlet(BaseGroupsServerServlet): """Add/remove/get a role in a group""" PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)" @@ -1371,7 +1449,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet): return 200, resp -class FederationGroupsSummaryUsersServlet(BaseFederationServlet): +class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): """Add/remove a user from the group summary, with optional role. Matches both: @@ -1426,7 +1504,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): return 200, resp -class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): +class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet): """Get roles in a group""" PATH = "/get_groups_publicised" @@ -1439,7 +1517,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): return 200, resp -class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): +class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet): """Sets whether a group is joinable without an invite or knock""" PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy" @@ -1460,6 +1538,16 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" PATH = "/spaces/(?P<room_id>[^/]*)" + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_space_summary_handler() + async def on_GET( self, origin: str, @@ -1525,16 +1613,25 @@ class RoomComplexityServlet(BaseFederationServlet): PATH = "/rooms/(?P<room_id>[^/]*)/complexity" PREFIX = FEDERATION_UNSTABLE_PREFIX - async def on_GET(self, origin, content, query, room_id): - - store = self.handler.hs.get_datastore() + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self._store = self.hs.get_datastore() - is_public = await store.is_room_world_readable_or_publicly_joinable(room_id) + async def on_GET(self, origin, content, query, room_id): + is_public = await self._store.is_room_world_readable_or_publicly_joinable( + room_id + ) if not is_public: raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM) - complexity = await store.get_room_complexity(room_id) + complexity = await self._store.get_room_complexity(room_id) return 200, complexity @@ -1553,7 +1650,6 @@ FEDERATION_SERVLET_CLASSES = ( FederationV2SendJoinServlet, FederationV1SendLeaveServlet, FederationV2SendLeaveServlet, - FederationV2SendKnockServlet, FederationV1InviteServlet, FederationV2InviteServlet, FederationGetMissingEventsServlet, @@ -1566,6 +1662,9 @@ FEDERATION_SERVLET_CLASSES = ( FederationVersionServlet, RoomComplexityServlet, FederationUserInfoServlet, + FederationSpaceSummaryServlet, + FederationV1SendKnockServlet, + FederationMakeKnockServlet, ) # type: Tuple[Type[BaseFederationServlet], ...] OPENID_SERVLET_CLASSES = ( @@ -1607,6 +1706,7 @@ GROUP_ATTESTATION_SERVLET_CLASSES = ( FederationGroupsRenewAttestaionServlet, ) # type: Tuple[Type[BaseFederationServlet], ...] + DEFAULT_SERVLET_GROUPS = ( "federation", "room_list", @@ -1643,23 +1743,16 @@ def register_servlets( if "federation" in servlet_groups: for servletclass in FEDERATION_SERVLET_CLASSES: servletclass( - handler=hs.get_federation_server(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, ).register(resource) - FederationSpaceSummaryServlet( - handler=hs.get_space_summary_handler(), - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - if "openid" in servlet_groups: for servletclass in OPENID_SERVLET_CLASSES: servletclass( - handler=hs.get_federation_server(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, @@ -1668,7 +1761,7 @@ def register_servlets( if "room_list" in servlet_groups: for servletclass in ROOM_LIST_CLASSES: servletclass( - handler=hs.get_room_list_handler(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, @@ -1678,7 +1771,7 @@ def register_servlets( if "group_server" in servlet_groups: for servletclass in GROUP_SERVER_SERVLET_CLASSES: servletclass( - handler=hs.get_groups_server_handler(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, @@ -1687,7 +1780,7 @@ def register_servlets( if "group_local" in servlet_groups: for servletclass in GROUP_LOCAL_SERVLET_CLASSES: servletclass( - handler=hs.get_groups_local_handler(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, @@ -1696,7 +1789,7 @@ def register_servlets( if "group_attestation" in servlet_groups: for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES: servletclass( - handler=hs.get_groups_attestation_renewer(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname,