diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 1076ebc036..ed09c6af1f 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -1,4 +1,5 @@
-# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -619,7 +620,8 @@ class FederationClient(FederationBase):
SynapseError: if the chosen remote server returns a 300/400 code, or
no servers successfully handle the request.
"""
- valid_memberships = {Membership.JOIN, Membership.LEAVE}
+ valid_memberships = {Membership.JOIN, Membership.LEAVE, Membership.KNOCK}
+
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s"
@@ -638,6 +640,13 @@ class FederationClient(FederationBase):
if not room_version:
raise UnsupportedRoomVersionError()
+ if not room_version.msc2403_knocking and membership == Membership.KNOCK:
+ raise SynapseError(
+ 400,
+ "This room version does not support knocking",
+ errcode=Codes.FORBIDDEN,
+ )
+
pdu_dict = ret.get("event", None)
if not isinstance(pdu_dict, dict):
raise InvalidResponseError("Bad 'event' field in response")
@@ -946,6 +955,62 @@ class FederationClient(FederationBase):
# content.
return resp[1]
+ async def send_knock(self, destinations: List[str], pdu: EventBase) -> JsonDict:
+ """Attempts to send a knock event to given a list of servers. Iterates
+ through the list until one attempt succeeds.
+
+ Doing so will cause the remote server to add the event to the graph,
+ and send the event out to the rest of the federation.
+
+ Args:
+ destinations: A list of candidate homeservers which are likely to be
+ participating in the room.
+ pdu: The event to be sent.
+
+ Returns:
+ The remote homeserver return some state from the room. The response
+ dictionary is in the form:
+
+ {"knock_state_events": [<state event dict>, ...]}
+
+ The list of state events may be empty.
+
+ Raises:
+ SynapseError: If the chosen remote server returns a 3xx/4xx code.
+ RuntimeError: If no servers were reachable.
+ """
+
+ async def send_request(destination: str) -> JsonDict:
+ return await self._do_send_knock(destination, pdu)
+
+ return await self._try_destination_list(
+ "send_knock", destinations, send_request
+ )
+
+ async def _do_send_knock(self, destination: str, pdu: EventBase) -> JsonDict:
+ """Send a knock event to a remote homeserver.
+
+ Args:
+ destination: The homeserver to send to.
+ pdu: The event to send.
+
+ Returns:
+ The remote homeserver can optionally return some state from the room. The response
+ dictionary is in the form:
+
+ {"knock_state_events": [<state event dict>, ...]}
+
+ The list of state events may be empty.
+ """
+ time_now = self._clock.time_msec()
+
+ return await self.transport_layer.send_knock_v1(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
async def get_public_rooms(
self,
remote_server: str,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index ace30aa450..2b07f18529 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -129,7 +129,7 @@ class FederationServer(FederationBase):
# come in waves.
self._state_resp_cache = ResponseCache(
hs.get_clock(), "state_resp", timeout_ms=30000
- ) # type: ResponseCache[Tuple[str, str]]
+ ) # type: ResponseCache[Tuple[str, Optional[str]]]
self._state_ids_resp_cache = ResponseCache(
hs.get_clock(), "state_ids_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
@@ -138,6 +138,8 @@ class FederationServer(FederationBase):
hs.config.federation.federation_metrics_domains
)
+ self._room_prejoin_state_types = hs.config.api.room_prejoin_state
+
async def on_backfill_request(
self, origin: str, room_id: str, versions: List[str], limit: int
) -> Tuple[int, Dict[str, Any]]:
@@ -406,7 +408,7 @@ class FederationServer(FederationBase):
)
async def on_room_state_request(
- self, origin: str, room_id: str, event_id: str
+ self, origin: str, room_id: str, event_id: Optional[str]
) -> Tuple[int, Dict[str, Any]]:
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@@ -463,7 +465,7 @@ class FederationServer(FederationBase):
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
async def _on_context_state_request_compute(
- self, room_id: str, event_id: str
+ self, room_id: str, event_id: Optional[str]
) -> Dict[str, list]:
if event_id:
pdus = await self.handler.get_state_for_pdu(
@@ -586,6 +588,103 @@ class FederationServer(FederationBase):
await self.handler.on_send_leave_request(origin, pdu)
return {}
+ async def on_make_knock_request(
+ self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
+ ) -> Dict[str, Union[EventBase, str]]:
+ """We've received a /make_knock/ request, so we create a partial knock
+ event for the room and hand that back, along with the room version, to the knocking
+ homeserver. We do *not* persist or process this event until the other server has
+ signed it and sent it back.
+
+ Args:
+ origin: The (verified) server name of the requesting server.
+ room_id: The room to create the knock event in.
+ user_id: The user to create the knock for.
+ supported_versions: The room versions supported by the requesting server.
+
+ Returns:
+ The partial knock event.
+ """
+ origin_host, _ = parse_server_name(origin)
+ await self.check_server_matches_acl(origin_host, room_id)
+
+ room_version = await self.store.get_room_version(room_id)
+
+ # Check that this room version is supported by the remote homeserver
+ if room_version.identifier not in supported_versions:
+ logger.warning(
+ "Room version %s not in %s", room_version.identifier, supported_versions
+ )
+ raise IncompatibleRoomVersionError(room_version=room_version.identifier)
+
+ # Check that this room supports knocking as defined by its room version
+ if not room_version.msc2403_knocking:
+ raise SynapseError(
+ 403,
+ "This room version does not support knocking",
+ errcode=Codes.FORBIDDEN,
+ )
+
+ pdu = await self.handler.on_make_knock_request(origin, room_id, user_id)
+ time_now = self._clock.time_msec()
+ return {
+ "event": pdu.get_pdu_json(time_now),
+ "room_version": room_version.identifier,
+ }
+
+ async def on_send_knock_request(
+ self,
+ origin: str,
+ content: JsonDict,
+ room_id: str,
+ ) -> Dict[str, List[JsonDict]]:
+ """
+ We have received a knock event for a room. Verify and send the event into the room
+ on the knocking homeserver's behalf. Then reply with some stripped state from the
+ room for the knockee.
+
+ Args:
+ origin: The remote homeserver of the knocking user.
+ content: The content of the request.
+ room_id: The ID of the room to knock on.
+
+ Returns:
+ The stripped room state.
+ """
+ logger.debug("on_send_knock_request: content: %s", content)
+
+ room_version = await self.store.get_room_version(room_id)
+
+ # Check that this room supports knocking as defined by its room version
+ if not room_version.msc2403_knocking:
+ raise SynapseError(
+ 403,
+ "This room version does not support knocking",
+ errcode=Codes.FORBIDDEN,
+ )
+
+ pdu = event_from_pdu_json(content, room_version)
+
+ origin_host, _ = parse_server_name(origin)
+ await self.check_server_matches_acl(origin_host, pdu.room_id)
+
+ logger.debug("on_send_knock_request: pdu sigs: %s", pdu.signatures)
+
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
+
+ # Handle the event, and retrieve the EventContext
+ event_context = await self.handler.on_send_knock_request(origin, pdu)
+
+ # Retrieve stripped state events from the room and send them back to the remote
+ # server. This will allow the remote server's clients to display information
+ # related to the room while the knock request is pending.
+ stripped_room_state = (
+ await self.store.get_stripped_room_state_from_event_context(
+ event_context, self._room_prejoin_state_types
+ )
+ )
+ return {"knock_state_events": stripped_room_state}
+
async def on_event_auth(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, Dict[str, Any]]:
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 5b4f5d17f7..c9e7c57461 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1,5 +1,5 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -220,7 +220,8 @@ class TransportLayerClient:
Fails with ``FederationDeniedError`` if the remote destination
is not in our federation whitelist
"""
- valid_memberships = {Membership.JOIN, Membership.LEAVE}
+ valid_memberships = {Membership.JOIN, Membership.LEAVE, Membership.KNOCK}
+
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s"
@@ -322,6 +323,40 @@ class TransportLayerClient:
return response
@log_function
+ async def send_knock_v1(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
+ """
+ Sends a signed knock membership event to a remote server. This is the second
+ step for knocking after make_knock.
+
+ Args:
+ destination: The remote homeserver.
+ room_id: The ID of the room to knock on.
+ event_id: The ID of the knock membership event that we're sending.
+ content: The knock membership event that we're sending. Note that this is not the
+ `content` field of the membership event, but the entire signed membership event
+ itself represented as a JSON dict.
+
+ Returns:
+ The remote homeserver can optionally return some state from the room. The response
+ dictionary is in the form:
+
+ {"knock_state_events": [<state event dict>, ...]}
+
+ The list of state events may be empty.
+ """
+ path = _create_v1_path("/send_knock/%s/%s", room_id, event_id)
+
+ return await self.client.put_json(
+ destination=destination, path=path, data=content
+ )
+
+ @log_function
async def send_invite_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 5756fcb551..16d740cf58 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1,6 +1,5 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2020 Sorunome
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import functools
import logging
import re
@@ -28,12 +26,14 @@ from synapse.api.urls import (
FEDERATION_V1_PREFIX,
FEDERATION_V2_PREFIX,
)
+from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import (
parse_boolean_from_args,
parse_integer_from_args,
parse_json_object_from_request,
parse_string_from_args,
+ parse_strings_from_args,
)
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
@@ -275,10 +275,17 @@ class BaseFederationServlet:
RATELIMIT = True # Whether to rate limit requests or not
- def __init__(self, handler, authenticator, ratelimiter, server_name):
- self.handler = handler
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ self.hs = hs
self.authenticator = authenticator
self.ratelimiter = ratelimiter
+ self.server_name = server_name
def _wrap(self, func):
authenticator = self.authenticator
@@ -375,17 +382,30 @@ class BaseFederationServlet:
)
-class FederationSendServlet(BaseFederationServlet):
+class BaseFederationServerServlet(BaseFederationServlet):
+ """Abstract base class for federation servlet classes which provides a federation server handler.
+
+ See BaseFederationServlet for more information.
+ """
+
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.handler = hs.get_federation_server()
+
+
+class FederationSendServlet(BaseFederationServerServlet):
PATH = "/send/(?P<transaction_id>[^/]*)/?"
# We ratelimit manually in the handler as we queue up the requests and we
# don't want to fill up the ratelimiter with blocked requests.
RATELIMIT = False
- def __init__(self, handler, server_name, **kwargs):
- super().__init__(handler, server_name=server_name, **kwargs)
- self.server_name = server_name
-
# This is when someone is trying to send us a bunch of data.
async def on_PUT(self, origin, content, query, transaction_id):
"""Called on PUT /send/<transaction_id>/
@@ -434,7 +454,7 @@ class FederationSendServlet(BaseFederationServlet):
return code, response
-class FederationEventServlet(BaseFederationServlet):
+class FederationEventServlet(BaseFederationServerServlet):
PATH = "/event/(?P<event_id>[^/]*)/?"
# This is when someone asks for a data item for a given server data_id pair.
@@ -442,7 +462,7 @@ class FederationEventServlet(BaseFederationServlet):
return await self.handler.on_pdu_request(origin, event_id)
-class FederationStateV1Servlet(BaseFederationServlet):
+class FederationStateV1Servlet(BaseFederationServerServlet):
PATH = "/state/(?P<room_id>[^/]*)/?"
# This is when someone asks for all data for a given room.
@@ -454,7 +474,7 @@ class FederationStateV1Servlet(BaseFederationServlet):
)
-class FederationStateIdsServlet(BaseFederationServlet):
+class FederationStateIdsServlet(BaseFederationServerServlet):
PATH = "/state_ids/(?P<room_id>[^/]*)/?"
async def on_GET(self, origin, content, query, room_id):
@@ -465,7 +485,7 @@ class FederationStateIdsServlet(BaseFederationServlet):
)
-class FederationBackfillServlet(BaseFederationServlet):
+class FederationBackfillServlet(BaseFederationServerServlet):
PATH = "/backfill/(?P<room_id>[^/]*)/?"
async def on_GET(self, origin, content, query, room_id):
@@ -478,7 +498,7 @@ class FederationBackfillServlet(BaseFederationServlet):
return await self.handler.on_backfill_request(origin, room_id, versions, limit)
-class FederationQueryServlet(BaseFederationServlet):
+class FederationQueryServlet(BaseFederationServerServlet):
PATH = "/query/(?P<query_type>[^/]*)"
# This is when we receive a server-server Query
@@ -488,7 +508,7 @@ class FederationQueryServlet(BaseFederationServlet):
return await self.handler.on_query_request(query_type, args)
-class FederationMakeJoinServlet(BaseFederationServlet):
+class FederationMakeJoinServlet(BaseFederationServerServlet):
PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
async def on_GET(self, origin, _content, query, room_id, user_id):
@@ -518,7 +538,7 @@ class FederationMakeJoinServlet(BaseFederationServlet):
return 200, content
-class FederationMakeLeaveServlet(BaseFederationServlet):
+class FederationMakeLeaveServlet(BaseFederationServerServlet):
PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
async def on_GET(self, origin, content, query, room_id, user_id):
@@ -526,7 +546,7 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
return 200, content
-class FederationV1SendLeaveServlet(BaseFederationServlet):
+class FederationV1SendLeaveServlet(BaseFederationServerServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
@@ -534,7 +554,7 @@ class FederationV1SendLeaveServlet(BaseFederationServlet):
return 200, (200, content)
-class FederationV2SendLeaveServlet(BaseFederationServlet):
+class FederationV2SendLeaveServlet(BaseFederationServerServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
@@ -544,14 +564,38 @@ class FederationV2SendLeaveServlet(BaseFederationServlet):
return 200, content
-class FederationEventAuthServlet(BaseFederationServlet):
+class FederationMakeKnockServlet(BaseFederationServerServlet):
+ PATH = "/make_knock/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
+
+ async def on_GET(self, origin, content, query, room_id, user_id):
+ try:
+ # Retrieve the room versions the remote homeserver claims to support
+ supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8")
+ except KeyError:
+ raise SynapseError(400, "Missing required query parameter 'ver'")
+
+ content = await self.handler.on_make_knock_request(
+ origin, room_id, user_id, supported_versions=supported_versions
+ )
+ return 200, content
+
+
+class FederationV1SendKnockServlet(BaseFederationServerServlet):
+ PATH = "/send_knock/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
+
+ async def on_PUT(self, origin, content, query, room_id, event_id):
+ content = await self.handler.on_send_knock_request(origin, content, room_id)
+ return 200, content
+
+
+class FederationEventAuthServlet(BaseFederationServerServlet):
PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_GET(self, origin, content, query, room_id, event_id):
return await self.handler.on_event_auth(origin, room_id, event_id)
-class FederationV1SendJoinServlet(BaseFederationServlet):
+class FederationV1SendJoinServlet(BaseFederationServerServlet):
PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
@@ -561,7 +605,7 @@ class FederationV1SendJoinServlet(BaseFederationServlet):
return 200, (200, content)
-class FederationV2SendJoinServlet(BaseFederationServlet):
+class FederationV2SendJoinServlet(BaseFederationServerServlet):
PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
@@ -573,7 +617,7 @@ class FederationV2SendJoinServlet(BaseFederationServlet):
return 200, content
-class FederationV1InviteServlet(BaseFederationServlet):
+class FederationV1InviteServlet(BaseFederationServerServlet):
PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
@@ -590,7 +634,7 @@ class FederationV1InviteServlet(BaseFederationServlet):
return 200, (200, content)
-class FederationV2InviteServlet(BaseFederationServlet):
+class FederationV2InviteServlet(BaseFederationServerServlet):
PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
@@ -614,7 +658,7 @@ class FederationV2InviteServlet(BaseFederationServlet):
return 200, content
-class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
+class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id):
@@ -622,21 +666,21 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
return 200, {}
-class FederationClientKeysQueryServlet(BaseFederationServlet):
+class FederationClientKeysQueryServlet(BaseFederationServerServlet):
PATH = "/user/keys/query"
async def on_POST(self, origin, content, query):
return await self.handler.on_query_client_keys(origin, content)
-class FederationUserDevicesQueryServlet(BaseFederationServlet):
+class FederationUserDevicesQueryServlet(BaseFederationServerServlet):
PATH = "/user/devices/(?P<user_id>[^/]*)"
async def on_GET(self, origin, content, query, user_id):
return await self.handler.on_query_user_devices(origin, user_id)
-class FederationClientKeysClaimServlet(BaseFederationServlet):
+class FederationClientKeysClaimServlet(BaseFederationServerServlet):
PATH = "/user/keys/claim"
async def on_POST(self, origin, content, query):
@@ -644,7 +688,7 @@ class FederationClientKeysClaimServlet(BaseFederationServlet):
return 200, response
-class FederationGetMissingEventsServlet(BaseFederationServlet):
+class FederationGetMissingEventsServlet(BaseFederationServerServlet):
# TODO(paul): Why does this path alone end with "/?" optional?
PATH = "/get_missing_events/(?P<room_id>[^/]*)/?"
@@ -664,7 +708,7 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
return 200, content
-class On3pidBindServlet(BaseFederationServlet):
+class On3pidBindServlet(BaseFederationServerServlet):
PATH = "/3pid/onbind"
REQUIRE_AUTH = False
@@ -694,7 +738,7 @@ class On3pidBindServlet(BaseFederationServlet):
return 200, {}
-class OpenIdUserInfo(BaseFederationServlet):
+class OpenIdUserInfo(BaseFederationServerServlet):
"""
Exchange a bearer token for information about a user.
@@ -770,8 +814,16 @@ class PublicRoomList(BaseFederationServlet):
PATH = "/publicRooms"
- def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access):
- super().__init__(handler, authenticator, ratelimiter, server_name)
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ allow_access: bool,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.handler = hs.get_room_list_handler()
self.allow_access = allow_access
async def on_GET(self, origin, content, query):
@@ -856,7 +908,24 @@ class FederationVersionServlet(BaseFederationServlet):
)
-class FederationGroupsProfileServlet(BaseFederationServlet):
+class BaseGroupsServerServlet(BaseFederationServlet):
+ """Abstract base class for federation servlet classes which provides a groups server handler.
+
+ See BaseFederationServlet for more information.
+ """
+
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.handler = hs.get_groups_server_handler()
+
+
+class FederationGroupsProfileServlet(BaseGroupsServerServlet):
"""Get/set the basic profile of a group on behalf of a user"""
PATH = "/groups/(?P<group_id>[^/]*)/profile"
@@ -882,7 +951,7 @@ class FederationGroupsProfileServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsSummaryServlet(BaseFederationServlet):
+class FederationGroupsSummaryServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/summary"
async def on_GET(self, origin, content, query, group_id):
@@ -895,7 +964,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsRoomsServlet(BaseFederationServlet):
+class FederationGroupsRoomsServlet(BaseGroupsServerServlet):
"""Get the rooms in a group on behalf of a user"""
PATH = "/groups/(?P<group_id>[^/]*)/rooms"
@@ -910,7 +979,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsAddRoomsServlet(BaseFederationServlet):
+class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet):
"""Add/remove room from group"""
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
@@ -938,7 +1007,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
+class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet):
"""Update room config in group"""
PATH = (
@@ -958,7 +1027,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
return 200, result
-class FederationGroupsUsersServlet(BaseFederationServlet):
+class FederationGroupsUsersServlet(BaseGroupsServerServlet):
"""Get the users in a group on behalf of a user"""
PATH = "/groups/(?P<group_id>[^/]*)/users"
@@ -973,7 +1042,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
+class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet):
"""Get the users that have been invited to a group"""
PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
@@ -990,7 +1059,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsInviteServlet(BaseFederationServlet):
+class FederationGroupsInviteServlet(BaseGroupsServerServlet):
"""Ask a group server to invite someone to the group"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
@@ -1007,7 +1076,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
+class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet):
"""Accept an invitation from the group server"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
@@ -1021,7 +1090,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsJoinServlet(BaseFederationServlet):
+class FederationGroupsJoinServlet(BaseGroupsServerServlet):
"""Attempt to join a group"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
@@ -1035,7 +1104,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsRemoveUserServlet(BaseFederationServlet):
+class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet):
"""Leave or kick a user from the group"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
@@ -1052,7 +1121,24 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsLocalInviteServlet(BaseFederationServlet):
+class BaseGroupsLocalServlet(BaseFederationServlet):
+ """Abstract base class for federation servlet classes which provides a groups local handler.
+
+ See BaseFederationServlet for more information.
+ """
+
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.handler = hs.get_groups_local_handler()
+
+
+class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet):
"""A group server has invited a local user"""
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
@@ -1061,12 +1147,16 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "group_id doesn't match origin")
+ assert isinstance(
+ self.handler, GroupsLocalHandler
+ ), "Workers cannot handle group invites."
+
new_content = await self.handler.on_invite(group_id, user_id, content)
return 200, new_content
-class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
+class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet):
"""A group server has removed a local user"""
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
@@ -1075,6 +1165,10 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
+ assert isinstance(
+ self.handler, GroupsLocalHandler
+ ), "Workers cannot handle group removals."
+
new_content = await self.handler.user_removed_from_group(
group_id, user_id, content
)
@@ -1087,6 +1181,16 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.handler = hs.get_groups_attestation_renewer()
+
async def on_POST(self, origin, content, query, group_id, user_id):
# We don't need to check auth here as we check the attestation signatures
@@ -1097,7 +1201,7 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
return 200, new_content
-class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
+class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet):
"""Add/remove a room from the group summary, with optional category.
Matches both:
@@ -1154,7 +1258,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsCategoriesServlet(BaseFederationServlet):
+class FederationGroupsCategoriesServlet(BaseGroupsServerServlet):
"""Get all categories for a group"""
PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
@@ -1169,7 +1273,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsCategoryServlet(BaseFederationServlet):
+class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
"""Add/remove/get a category in a group"""
PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
@@ -1222,7 +1326,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsRolesServlet(BaseFederationServlet):
+class FederationGroupsRolesServlet(BaseGroupsServerServlet):
"""Get roles in a group"""
PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
@@ -1237,7 +1341,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsRoleServlet(BaseFederationServlet):
+class FederationGroupsRoleServlet(BaseGroupsServerServlet):
"""Add/remove/get a role in a group"""
PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
@@ -1290,7 +1394,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
+class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet):
"""Add/remove a user from the group summary, with optional role.
Matches both:
@@ -1345,7 +1449,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
+class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet):
"""Get roles in a group"""
PATH = "/get_groups_publicised"
@@ -1358,7 +1462,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
return 200, resp
-class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
+class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet):
"""Sets whether a group is joinable without an invite or knock"""
PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
@@ -1379,6 +1483,16 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946"
PATH = "/spaces/(?P<room_id>[^/]*)"
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.handler = hs.get_space_summary_handler()
+
async def on_GET(
self,
origin: str,
@@ -1444,16 +1558,25 @@ class RoomComplexityServlet(BaseFederationServlet):
PATH = "/rooms/(?P<room_id>[^/]*)/complexity"
PREFIX = FEDERATION_UNSTABLE_PREFIX
- async def on_GET(self, origin, content, query, room_id):
-
- store = self.handler.hs.get_datastore()
+ def __init__(
+ self,
+ hs: HomeServer,
+ authenticator: Authenticator,
+ ratelimiter: FederationRateLimiter,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self._store = self.hs.get_datastore()
- is_public = await store.is_room_world_readable_or_publicly_joinable(room_id)
+ async def on_GET(self, origin, content, query, room_id):
+ is_public = await self._store.is_room_world_readable_or_publicly_joinable(
+ room_id
+ )
if not is_public:
raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM)
- complexity = await store.get_room_complexity(room_id)
+ complexity = await self._store.get_room_complexity(room_id)
return 200, complexity
@@ -1482,6 +1605,9 @@ FEDERATION_SERVLET_CLASSES = (
On3pidBindServlet,
FederationVersionServlet,
RoomComplexityServlet,
+ FederationSpaceSummaryServlet,
+ FederationV1SendKnockServlet,
+ FederationMakeKnockServlet,
) # type: Tuple[Type[BaseFederationServlet], ...]
OPENID_SERVLET_CLASSES = (
@@ -1523,6 +1649,7 @@ GROUP_ATTESTATION_SERVLET_CLASSES = (
FederationGroupsRenewAttestaionServlet,
) # type: Tuple[Type[BaseFederationServlet], ...]
+
DEFAULT_SERVLET_GROUPS = (
"federation",
"room_list",
@@ -1559,23 +1686,16 @@ def register_servlets(
if "federation" in servlet_groups:
for servletclass in FEDERATION_SERVLET_CLASSES:
servletclass(
- handler=hs.get_federation_server(),
+ hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
- FederationSpaceSummaryServlet(
- handler=hs.get_space_summary_handler(),
- authenticator=authenticator,
- ratelimiter=ratelimiter,
- server_name=hs.hostname,
- ).register(resource)
-
if "openid" in servlet_groups:
for servletclass in OPENID_SERVLET_CLASSES:
servletclass(
- handler=hs.get_federation_server(),
+ hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
@@ -1584,7 +1704,7 @@ def register_servlets(
if "room_list" in servlet_groups:
for servletclass in ROOM_LIST_CLASSES:
servletclass(
- handler=hs.get_room_list_handler(),
+ hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
@@ -1594,7 +1714,7 @@ def register_servlets(
if "group_server" in servlet_groups:
for servletclass in GROUP_SERVER_SERVLET_CLASSES:
servletclass(
- handler=hs.get_groups_server_handler(),
+ hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
@@ -1603,7 +1723,7 @@ def register_servlets(
if "group_local" in servlet_groups:
for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
servletclass(
- handler=hs.get_groups_local_handler(),
+ hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
@@ -1612,7 +1732,7 @@ def register_servlets(
if "group_attestation" in servlet_groups:
for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
servletclass(
- handler=hs.get_groups_attestation_renewer(),
+ hs=hs,
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
|