summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_client.py69
-rw-r--r--synapse/federation/federation_server.py105
-rw-r--r--synapse/federation/transport/client.py41
-rw-r--r--synapse/federation/transport/server.py266
4 files changed, 400 insertions, 81 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py

index 1076ebc036..ed09c6af1f 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py
@@ -1,4 +1,5 @@ -# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2015-2021 The Matrix.org Foundation C.I.C. +# Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -619,7 +620,8 @@ class FederationClient(FederationBase): SynapseError: if the chosen remote server returns a 300/400 code, or no servers successfully handle the request. """ - valid_memberships = {Membership.JOIN, Membership.LEAVE} + valid_memberships = {Membership.JOIN, Membership.LEAVE, Membership.KNOCK} + if membership not in valid_memberships: raise RuntimeError( "make_membership_event called with membership='%s', must be one of %s" @@ -638,6 +640,13 @@ class FederationClient(FederationBase): if not room_version: raise UnsupportedRoomVersionError() + if not room_version.msc2403_knocking and membership == Membership.KNOCK: + raise SynapseError( + 400, + "This room version does not support knocking", + errcode=Codes.FORBIDDEN, + ) + pdu_dict = ret.get("event", None) if not isinstance(pdu_dict, dict): raise InvalidResponseError("Bad 'event' field in response") @@ -946,6 +955,62 @@ class FederationClient(FederationBase): # content. return resp[1] + async def send_knock(self, destinations: List[str], pdu: EventBase) -> JsonDict: + """Attempts to send a knock event to given a list of servers. Iterates + through the list until one attempt succeeds. + + Doing so will cause the remote server to add the event to the graph, + and send the event out to the rest of the federation. + + Args: + destinations: A list of candidate homeservers which are likely to be + participating in the room. + pdu: The event to be sent. + + Returns: + The remote homeserver return some state from the room. The response + dictionary is in the form: + + {"knock_state_events": [<state event dict>, ...]} + + The list of state events may be empty. + + Raises: + SynapseError: If the chosen remote server returns a 3xx/4xx code. + RuntimeError: If no servers were reachable. + """ + + async def send_request(destination: str) -> JsonDict: + return await self._do_send_knock(destination, pdu) + + return await self._try_destination_list( + "send_knock", destinations, send_request + ) + + async def _do_send_knock(self, destination: str, pdu: EventBase) -> JsonDict: + """Send a knock event to a remote homeserver. + + Args: + destination: The homeserver to send to. + pdu: The event to send. + + Returns: + The remote homeserver can optionally return some state from the room. The response + dictionary is in the form: + + {"knock_state_events": [<state event dict>, ...]} + + The list of state events may be empty. + """ + time_now = self._clock.time_msec() + + return await self.transport_layer.send_knock_v1( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + async def get_public_rooms( self, remote_server: str, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index ace30aa450..2b07f18529 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -129,7 +129,7 @@ class FederationServer(FederationBase): # come in waves. self._state_resp_cache = ResponseCache( hs.get_clock(), "state_resp", timeout_ms=30000 - ) # type: ResponseCache[Tuple[str, str]] + ) # type: ResponseCache[Tuple[str, Optional[str]]] self._state_ids_resp_cache = ResponseCache( hs.get_clock(), "state_ids_resp", timeout_ms=30000 ) # type: ResponseCache[Tuple[str, str]] @@ -138,6 +138,8 @@ class FederationServer(FederationBase): hs.config.federation.federation_metrics_domains ) + self._room_prejoin_state_types = hs.config.api.room_prejoin_state + async def on_backfill_request( self, origin: str, room_id: str, versions: List[str], limit: int ) -> Tuple[int, Dict[str, Any]]: @@ -406,7 +408,7 @@ class FederationServer(FederationBase): ) async def on_room_state_request( - self, origin: str, room_id: str, event_id: str + self, origin: str, room_id: str, event_id: Optional[str] ) -> Tuple[int, Dict[str, Any]]: origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -463,7 +465,7 @@ class FederationServer(FederationBase): return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} async def _on_context_state_request_compute( - self, room_id: str, event_id: str + self, room_id: str, event_id: Optional[str] ) -> Dict[str, list]: if event_id: pdus = await self.handler.get_state_for_pdu( @@ -586,6 +588,103 @@ class FederationServer(FederationBase): await self.handler.on_send_leave_request(origin, pdu) return {} + async def on_make_knock_request( + self, origin: str, room_id: str, user_id: str, supported_versions: List[str] + ) -> Dict[str, Union[EventBase, str]]: + """We've received a /make_knock/ request, so we create a partial knock + event for the room and hand that back, along with the room version, to the knocking + homeserver. We do *not* persist or process this event until the other server has + signed it and sent it back. + + Args: + origin: The (verified) server name of the requesting server. + room_id: The room to create the knock event in. + user_id: The user to create the knock for. + supported_versions: The room versions supported by the requesting server. + + Returns: + The partial knock event. + """ + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, room_id) + + room_version = await self.store.get_room_version(room_id) + + # Check that this room version is supported by the remote homeserver + if room_version.identifier not in supported_versions: + logger.warning( + "Room version %s not in %s", room_version.identifier, supported_versions + ) + raise IncompatibleRoomVersionError(room_version=room_version.identifier) + + # Check that this room supports knocking as defined by its room version + if not room_version.msc2403_knocking: + raise SynapseError( + 403, + "This room version does not support knocking", + errcode=Codes.FORBIDDEN, + ) + + pdu = await self.handler.on_make_knock_request(origin, room_id, user_id) + time_now = self._clock.time_msec() + return { + "event": pdu.get_pdu_json(time_now), + "room_version": room_version.identifier, + } + + async def on_send_knock_request( + self, + origin: str, + content: JsonDict, + room_id: str, + ) -> Dict[str, List[JsonDict]]: + """ + We have received a knock event for a room. Verify and send the event into the room + on the knocking homeserver's behalf. Then reply with some stripped state from the + room for the knockee. + + Args: + origin: The remote homeserver of the knocking user. + content: The content of the request. + room_id: The ID of the room to knock on. + + Returns: + The stripped room state. + """ + logger.debug("on_send_knock_request: content: %s", content) + + room_version = await self.store.get_room_version(room_id) + + # Check that this room supports knocking as defined by its room version + if not room_version.msc2403_knocking: + raise SynapseError( + 403, + "This room version does not support knocking", + errcode=Codes.FORBIDDEN, + ) + + pdu = event_from_pdu_json(content, room_version) + + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, pdu.room_id) + + logger.debug("on_send_knock_request: pdu sigs: %s", pdu.signatures) + + pdu = await self._check_sigs_and_hash(room_version, pdu) + + # Handle the event, and retrieve the EventContext + event_context = await self.handler.on_send_knock_request(origin, pdu) + + # Retrieve stripped state events from the room and send them back to the remote + # server. This will allow the remote server's clients to display information + # related to the room while the knock request is pending. + stripped_room_state = ( + await self.store.get_stripped_room_state_from_event_context( + event_context, self._room_prejoin_state_types + ) + ) + return {"knock_state_events": stripped_room_state} + async def on_event_auth( self, origin: str, room_id: str, event_id: str ) -> Tuple[int, Dict[str, Any]]: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 5b4f5d17f7..c9e7c57461 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py
@@ -1,5 +1,5 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. +# Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -220,7 +220,8 @@ class TransportLayerClient: Fails with ``FederationDeniedError`` if the remote destination is not in our federation whitelist """ - valid_memberships = {Membership.JOIN, Membership.LEAVE} + valid_memberships = {Membership.JOIN, Membership.LEAVE, Membership.KNOCK} + if membership not in valid_memberships: raise RuntimeError( "make_membership_event called with membership='%s', must be one of %s" @@ -322,6 +323,40 @@ class TransportLayerClient: return response @log_function + async def send_knock_v1( + self, + destination: str, + room_id: str, + event_id: str, + content: JsonDict, + ) -> JsonDict: + """ + Sends a signed knock membership event to a remote server. This is the second + step for knocking after make_knock. + + Args: + destination: The remote homeserver. + room_id: The ID of the room to knock on. + event_id: The ID of the knock membership event that we're sending. + content: The knock membership event that we're sending. Note that this is not the + `content` field of the membership event, but the entire signed membership event + itself represented as a JSON dict. + + Returns: + The remote homeserver can optionally return some state from the room. The response + dictionary is in the form: + + {"knock_state_events": [<state event dict>, ...]} + + The list of state events may be empty. + """ + path = _create_v1_path("/send_knock/%s/%s", room_id, event_id) + + return await self.client.put_json( + destination=destination, path=path, data=content + ) + + @log_function async def send_invite_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/invite/%s/%s", room_id, event_id) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 5756fcb551..16d740cf58 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py
@@ -1,6 +1,5 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. +# Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import functools import logging import re @@ -28,12 +26,14 @@ from synapse.api.urls import ( FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX, ) +from synapse.handlers.groups_local import GroupsLocalHandler from synapse.http.server import HttpServer, JsonResource from synapse.http.servlet import ( parse_boolean_from_args, parse_integer_from_args, parse_json_object_from_request, parse_string_from_args, + parse_strings_from_args, ) from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( @@ -275,10 +275,17 @@ class BaseFederationServlet: RATELIMIT = True # Whether to rate limit requests or not - def __init__(self, handler, authenticator, ratelimiter, server_name): - self.handler = handler + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + self.hs = hs self.authenticator = authenticator self.ratelimiter = ratelimiter + self.server_name = server_name def _wrap(self, func): authenticator = self.authenticator @@ -375,17 +382,30 @@ class BaseFederationServlet: ) -class FederationSendServlet(BaseFederationServlet): +class BaseFederationServerServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a federation server handler. + + See BaseFederationServlet for more information. + """ + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_federation_server() + + +class FederationSendServlet(BaseFederationServerServlet): PATH = "/send/(?P<transaction_id>[^/]*)/?" # We ratelimit manually in the handler as we queue up the requests and we # don't want to fill up the ratelimiter with blocked requests. RATELIMIT = False - def __init__(self, handler, server_name, **kwargs): - super().__init__(handler, server_name=server_name, **kwargs) - self.server_name = server_name - # This is when someone is trying to send us a bunch of data. async def on_PUT(self, origin, content, query, transaction_id): """Called on PUT /send/<transaction_id>/ @@ -434,7 +454,7 @@ class FederationSendServlet(BaseFederationServlet): return code, response -class FederationEventServlet(BaseFederationServlet): +class FederationEventServlet(BaseFederationServerServlet): PATH = "/event/(?P<event_id>[^/]*)/?" # This is when someone asks for a data item for a given server data_id pair. @@ -442,7 +462,7 @@ class FederationEventServlet(BaseFederationServlet): return await self.handler.on_pdu_request(origin, event_id) -class FederationStateV1Servlet(BaseFederationServlet): +class FederationStateV1Servlet(BaseFederationServerServlet): PATH = "/state/(?P<room_id>[^/]*)/?" # This is when someone asks for all data for a given room. @@ -454,7 +474,7 @@ class FederationStateV1Servlet(BaseFederationServlet): ) -class FederationStateIdsServlet(BaseFederationServlet): +class FederationStateIdsServlet(BaseFederationServerServlet): PATH = "/state_ids/(?P<room_id>[^/]*)/?" async def on_GET(self, origin, content, query, room_id): @@ -465,7 +485,7 @@ class FederationStateIdsServlet(BaseFederationServlet): ) -class FederationBackfillServlet(BaseFederationServlet): +class FederationBackfillServlet(BaseFederationServerServlet): PATH = "/backfill/(?P<room_id>[^/]*)/?" async def on_GET(self, origin, content, query, room_id): @@ -478,7 +498,7 @@ class FederationBackfillServlet(BaseFederationServlet): return await self.handler.on_backfill_request(origin, room_id, versions, limit) -class FederationQueryServlet(BaseFederationServlet): +class FederationQueryServlet(BaseFederationServerServlet): PATH = "/query/(?P<query_type>[^/]*)" # This is when we receive a server-server Query @@ -488,7 +508,7 @@ class FederationQueryServlet(BaseFederationServlet): return await self.handler.on_query_request(query_type, args) -class FederationMakeJoinServlet(BaseFederationServlet): +class FederationMakeJoinServlet(BaseFederationServerServlet): PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)" async def on_GET(self, origin, _content, query, room_id, user_id): @@ -518,7 +538,7 @@ class FederationMakeJoinServlet(BaseFederationServlet): return 200, content -class FederationMakeLeaveServlet(BaseFederationServlet): +class FederationMakeLeaveServlet(BaseFederationServerServlet): PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)" async def on_GET(self, origin, content, query, room_id, user_id): @@ -526,7 +546,7 @@ class FederationMakeLeaveServlet(BaseFederationServlet): return 200, content -class FederationV1SendLeaveServlet(BaseFederationServlet): +class FederationV1SendLeaveServlet(BaseFederationServerServlet): PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): @@ -534,7 +554,7 @@ class FederationV1SendLeaveServlet(BaseFederationServlet): return 200, (200, content) -class FederationV2SendLeaveServlet(BaseFederationServlet): +class FederationV2SendLeaveServlet(BaseFederationServerServlet): PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PREFIX = FEDERATION_V2_PREFIX @@ -544,14 +564,38 @@ class FederationV2SendLeaveServlet(BaseFederationServlet): return 200, content -class FederationEventAuthServlet(BaseFederationServlet): +class FederationMakeKnockServlet(BaseFederationServerServlet): + PATH = "/make_knock/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)" + + async def on_GET(self, origin, content, query, room_id, user_id): + try: + # Retrieve the room versions the remote homeserver claims to support + supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8") + except KeyError: + raise SynapseError(400, "Missing required query parameter 'ver'") + + content = await self.handler.on_make_knock_request( + origin, room_id, user_id, supported_versions=supported_versions + ) + return 200, content + + +class FederationV1SendKnockServlet(BaseFederationServerServlet): + PATH = "/send_knock/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" + + async def on_PUT(self, origin, content, query, room_id, event_id): + content = await self.handler.on_send_knock_request(origin, content, room_id) + return 200, content + + +class FederationEventAuthServlet(BaseFederationServerServlet): PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" async def on_GET(self, origin, content, query, room_id, event_id): return await self.handler.on_event_auth(origin, room_id, event_id) -class FederationV1SendJoinServlet(BaseFederationServlet): +class FederationV1SendJoinServlet(BaseFederationServerServlet): PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): @@ -561,7 +605,7 @@ class FederationV1SendJoinServlet(BaseFederationServlet): return 200, (200, content) -class FederationV2SendJoinServlet(BaseFederationServlet): +class FederationV2SendJoinServlet(BaseFederationServerServlet): PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PREFIX = FEDERATION_V2_PREFIX @@ -573,7 +617,7 @@ class FederationV2SendJoinServlet(BaseFederationServlet): return 200, content -class FederationV1InviteServlet(BaseFederationServlet): +class FederationV1InviteServlet(BaseFederationServerServlet): PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): @@ -590,7 +634,7 @@ class FederationV1InviteServlet(BaseFederationServlet): return 200, (200, content) -class FederationV2InviteServlet(BaseFederationServlet): +class FederationV2InviteServlet(BaseFederationServerServlet): PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" PREFIX = FEDERATION_V2_PREFIX @@ -614,7 +658,7 @@ class FederationV2InviteServlet(BaseFederationServlet): return 200, content -class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet): +class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)" async def on_PUT(self, origin, content, query, room_id): @@ -622,21 +666,21 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet): return 200, {} -class FederationClientKeysQueryServlet(BaseFederationServlet): +class FederationClientKeysQueryServlet(BaseFederationServerServlet): PATH = "/user/keys/query" async def on_POST(self, origin, content, query): return await self.handler.on_query_client_keys(origin, content) -class FederationUserDevicesQueryServlet(BaseFederationServlet): +class FederationUserDevicesQueryServlet(BaseFederationServerServlet): PATH = "/user/devices/(?P<user_id>[^/]*)" async def on_GET(self, origin, content, query, user_id): return await self.handler.on_query_user_devices(origin, user_id) -class FederationClientKeysClaimServlet(BaseFederationServlet): +class FederationClientKeysClaimServlet(BaseFederationServerServlet): PATH = "/user/keys/claim" async def on_POST(self, origin, content, query): @@ -644,7 +688,7 @@ class FederationClientKeysClaimServlet(BaseFederationServlet): return 200, response -class FederationGetMissingEventsServlet(BaseFederationServlet): +class FederationGetMissingEventsServlet(BaseFederationServerServlet): # TODO(paul): Why does this path alone end with "/?" optional? PATH = "/get_missing_events/(?P<room_id>[^/]*)/?" @@ -664,7 +708,7 @@ class FederationGetMissingEventsServlet(BaseFederationServlet): return 200, content -class On3pidBindServlet(BaseFederationServlet): +class On3pidBindServlet(BaseFederationServerServlet): PATH = "/3pid/onbind" REQUIRE_AUTH = False @@ -694,7 +738,7 @@ class On3pidBindServlet(BaseFederationServlet): return 200, {} -class OpenIdUserInfo(BaseFederationServlet): +class OpenIdUserInfo(BaseFederationServerServlet): """ Exchange a bearer token for information about a user. @@ -770,8 +814,16 @@ class PublicRoomList(BaseFederationServlet): PATH = "/publicRooms" - def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access): - super().__init__(handler, authenticator, ratelimiter, server_name) + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + allow_access: bool, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_room_list_handler() self.allow_access = allow_access async def on_GET(self, origin, content, query): @@ -856,7 +908,24 @@ class FederationVersionServlet(BaseFederationServlet): ) -class FederationGroupsProfileServlet(BaseFederationServlet): +class BaseGroupsServerServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a groups server handler. + + See BaseFederationServlet for more information. + """ + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_server_handler() + + +class FederationGroupsProfileServlet(BaseGroupsServerServlet): """Get/set the basic profile of a group on behalf of a user""" PATH = "/groups/(?P<group_id>[^/]*)/profile" @@ -882,7 +951,7 @@ class FederationGroupsProfileServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsSummaryServlet(BaseFederationServlet): +class FederationGroupsSummaryServlet(BaseGroupsServerServlet): PATH = "/groups/(?P<group_id>[^/]*)/summary" async def on_GET(self, origin, content, query, group_id): @@ -895,7 +964,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsRoomsServlet(BaseFederationServlet): +class FederationGroupsRoomsServlet(BaseGroupsServerServlet): """Get the rooms in a group on behalf of a user""" PATH = "/groups/(?P<group_id>[^/]*)/rooms" @@ -910,7 +979,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsAddRoomsServlet(BaseFederationServlet): +class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): """Add/remove room from group""" PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" @@ -938,7 +1007,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): +class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet): """Update room config in group""" PATH = ( @@ -958,7 +1027,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): return 200, result -class FederationGroupsUsersServlet(BaseFederationServlet): +class FederationGroupsUsersServlet(BaseGroupsServerServlet): """Get the users in a group on behalf of a user""" PATH = "/groups/(?P<group_id>[^/]*)/users" @@ -973,7 +1042,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsInvitedUsersServlet(BaseFederationServlet): +class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet): """Get the users that have been invited to a group""" PATH = "/groups/(?P<group_id>[^/]*)/invited_users" @@ -990,7 +1059,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsInviteServlet(BaseFederationServlet): +class FederationGroupsInviteServlet(BaseGroupsServerServlet): """Ask a group server to invite someone to the group""" PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" @@ -1007,7 +1076,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsAcceptInviteServlet(BaseFederationServlet): +class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet): """Accept an invitation from the group server""" PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite" @@ -1021,7 +1090,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsJoinServlet(BaseFederationServlet): +class FederationGroupsJoinServlet(BaseGroupsServerServlet): """Attempt to join a group""" PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join" @@ -1035,7 +1104,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsRemoveUserServlet(BaseFederationServlet): +class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet): """Leave or kick a user from the group""" PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" @@ -1052,7 +1121,24 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsLocalInviteServlet(BaseFederationServlet): +class BaseGroupsLocalServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a groups local handler. + + See BaseFederationServlet for more information. + """ + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_local_handler() + + +class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet): """A group server has invited a local user""" PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" @@ -1061,12 +1147,16 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet): if get_domain_from_id(group_id) != origin: raise SynapseError(403, "group_id doesn't match origin") + assert isinstance( + self.handler, GroupsLocalHandler + ), "Workers cannot handle group invites." + new_content = await self.handler.on_invite(group_id, user_id, content) return 200, new_content -class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): +class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): """A group server has removed a local user""" PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" @@ -1075,6 +1165,10 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): if get_domain_from_id(group_id) != origin: raise SynapseError(403, "user_id doesn't match origin") + assert isinstance( + self.handler, GroupsLocalHandler + ), "Workers cannot handle group removals." + new_content = await self.handler.user_removed_from_group( group_id, user_id, content ) @@ -1087,6 +1181,16 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)" + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_attestation_renewer() + async def on_POST(self, origin, content, query, group_id, user_id): # We don't need to check auth here as we check the attestation signatures @@ -1097,7 +1201,7 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): return 200, new_content -class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): +class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): """Add/remove a room from the group summary, with optional category. Matches both: @@ -1154,7 +1258,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): return 200, resp -class FederationGroupsCategoriesServlet(BaseFederationServlet): +class FederationGroupsCategoriesServlet(BaseGroupsServerServlet): """Get all categories for a group""" PATH = "/groups/(?P<group_id>[^/]*)/categories/?" @@ -1169,7 +1273,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet): return 200, resp -class FederationGroupsCategoryServlet(BaseFederationServlet): +class FederationGroupsCategoryServlet(BaseGroupsServerServlet): """Add/remove/get a category in a group""" PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)" @@ -1222,7 +1326,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet): return 200, resp -class FederationGroupsRolesServlet(BaseFederationServlet): +class FederationGroupsRolesServlet(BaseGroupsServerServlet): """Get roles in a group""" PATH = "/groups/(?P<group_id>[^/]*)/roles/?" @@ -1237,7 +1341,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet): return 200, resp -class FederationGroupsRoleServlet(BaseFederationServlet): +class FederationGroupsRoleServlet(BaseGroupsServerServlet): """Add/remove/get a role in a group""" PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)" @@ -1290,7 +1394,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet): return 200, resp -class FederationGroupsSummaryUsersServlet(BaseFederationServlet): +class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): """Add/remove a user from the group summary, with optional role. Matches both: @@ -1345,7 +1449,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): return 200, resp -class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): +class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet): """Get roles in a group""" PATH = "/get_groups_publicised" @@ -1358,7 +1462,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): return 200, resp -class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): +class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet): """Sets whether a group is joinable without an invite or knock""" PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy" @@ -1379,6 +1483,16 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" PATH = "/spaces/(?P<room_id>[^/]*)" + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_space_summary_handler() + async def on_GET( self, origin: str, @@ -1444,16 +1558,25 @@ class RoomComplexityServlet(BaseFederationServlet): PATH = "/rooms/(?P<room_id>[^/]*)/complexity" PREFIX = FEDERATION_UNSTABLE_PREFIX - async def on_GET(self, origin, content, query, room_id): - - store = self.handler.hs.get_datastore() + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self._store = self.hs.get_datastore() - is_public = await store.is_room_world_readable_or_publicly_joinable(room_id) + async def on_GET(self, origin, content, query, room_id): + is_public = await self._store.is_room_world_readable_or_publicly_joinable( + room_id + ) if not is_public: raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM) - complexity = await store.get_room_complexity(room_id) + complexity = await self._store.get_room_complexity(room_id) return 200, complexity @@ -1482,6 +1605,9 @@ FEDERATION_SERVLET_CLASSES = ( On3pidBindServlet, FederationVersionServlet, RoomComplexityServlet, + FederationSpaceSummaryServlet, + FederationV1SendKnockServlet, + FederationMakeKnockServlet, ) # type: Tuple[Type[BaseFederationServlet], ...] OPENID_SERVLET_CLASSES = ( @@ -1523,6 +1649,7 @@ GROUP_ATTESTATION_SERVLET_CLASSES = ( FederationGroupsRenewAttestaionServlet, ) # type: Tuple[Type[BaseFederationServlet], ...] + DEFAULT_SERVLET_GROUPS = ( "federation", "room_list", @@ -1559,23 +1686,16 @@ def register_servlets( if "federation" in servlet_groups: for servletclass in FEDERATION_SERVLET_CLASSES: servletclass( - handler=hs.get_federation_server(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, ).register(resource) - FederationSpaceSummaryServlet( - handler=hs.get_space_summary_handler(), - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - if "openid" in servlet_groups: for servletclass in OPENID_SERVLET_CLASSES: servletclass( - handler=hs.get_federation_server(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, @@ -1584,7 +1704,7 @@ def register_servlets( if "room_list" in servlet_groups: for servletclass in ROOM_LIST_CLASSES: servletclass( - handler=hs.get_room_list_handler(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, @@ -1594,7 +1714,7 @@ def register_servlets( if "group_server" in servlet_groups: for servletclass in GROUP_SERVER_SERVLET_CLASSES: servletclass( - handler=hs.get_groups_server_handler(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, @@ -1603,7 +1723,7 @@ def register_servlets( if "group_local" in servlet_groups: for servletclass in GROUP_LOCAL_SERVLET_CLASSES: servletclass( - handler=hs.get_groups_local_handler(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, @@ -1612,7 +1732,7 @@ def register_servlets( if "group_attestation" in servlet_groups: for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES: servletclass( - handler=hs.get_groups_attestation_renewer(), + hs=hs, authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname,