diff --git a/changelog.d/10213.misc b/changelog.d/10213.misc
new file mode 100644
index 0000000000..9adb0fbd02
--- /dev/null
+++ b/changelog.d/10213.misc
@@ -0,0 +1 @@
+Add type hints to the federation servlets.
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 676fbd3750..d37d9565fc 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -15,7 +15,19 @@
import functools
import logging
import re
-from typing import Container, Mapping, Optional, Sequence, Tuple, Type
+from typing import (
+ Container,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+)
+
+from typing_extensions import Literal
import synapse
from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH
@@ -56,15 +68,15 @@ logger = logging.getLogger(__name__)
class TransportLayerServer(JsonResource):
"""Handles incoming federation HTTP requests"""
- def __init__(self, hs, servlet_groups=None):
+ def __init__(self, hs: HomeServer, servlet_groups: Optional[List[str]] = None):
"""Initialize the TransportLayerServer
Will by default register all servlets. For custom behaviour, pass in
a list of servlet_groups to register.
Args:
- hs (synapse.server.HomeServer): homeserver
- servlet_groups (list[str], optional): List of servlet groups to register.
+ hs: homeserver
+ servlet_groups: List of servlet groups to register.
Defaults to ``DEFAULT_SERVLET_GROUPS``.
"""
self.hs = hs
@@ -78,7 +90,7 @@ class TransportLayerServer(JsonResource):
self.register_servlets()
- def register_servlets(self):
+ def register_servlets(self) -> None:
register_servlets(
self.hs,
resource=self,
@@ -91,14 +103,10 @@ class TransportLayerServer(JsonResource):
class AuthenticationError(SynapseError):
"""There was a problem authenticating the request"""
- pass
-
class NoAuthenticationError(AuthenticationError):
"""The request had no authentication information"""
- pass
-
class Authenticator:
def __init__(self, hs: HomeServer):
@@ -410,13 +418,18 @@ class FederationSendServlet(BaseFederationServerServlet):
RATELIMIT = False
# This is when someone is trying to send us a bunch of data.
- async def on_PUT(self, origin, content, query, transaction_id):
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ transaction_id: str,
+ ) -> Tuple[int, JsonDict]:
"""Called on PUT /send/<transaction_id>/
Args:
- request (twisted.web.http.Request): The HTTP request.
- transaction_id (str): The transaction_id associated with this
- request. This is *not* None.
+ transaction_id: The transaction_id associated with this request. This
+ is *not* None.
Returns:
Tuple of `(code, response)`, where
@@ -461,7 +474,13 @@ class FederationEventServlet(BaseFederationServerServlet):
PATH = "/event/(?P<event_id>[^/]*)/?"
# This is when someone asks for a data item for a given server data_id pair.
- async def on_GET(self, origin, content, query, event_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ event_id: str,
+ ) -> Tuple[int, Union[JsonDict, str]]:
return await self.handler.on_pdu_request(origin, event_id)
@@ -469,7 +488,13 @@ class FederationStateV1Servlet(BaseFederationServerServlet):
PATH = "/state/(?P<room_id>[^/]*)/?"
# This is when someone asks for all data for a given room.
- async def on_GET(self, origin, content, query, room_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
return await self.handler.on_room_state_request(
origin,
room_id,
@@ -480,7 +505,13 @@ class FederationStateV1Servlet(BaseFederationServerServlet):
class FederationStateIdsServlet(BaseFederationServerServlet):
PATH = "/state_ids/(?P<room_id>[^/]*)/?"
- async def on_GET(self, origin, content, query, room_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
return await self.handler.on_state_ids_request(
origin,
room_id,
@@ -491,7 +522,13 @@ class FederationStateIdsServlet(BaseFederationServerServlet):
class FederationBackfillServlet(BaseFederationServerServlet):
PATH = "/backfill/(?P<room_id>[^/]*)/?"
- async def on_GET(self, origin, content, query, room_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
versions = [x.decode("ascii") for x in query[b"v"]]
limit = parse_integer_from_args(query, "limit", None)
@@ -505,7 +542,13 @@ class FederationQueryServlet(BaseFederationServerServlet):
PATH = "/query/(?P<query_type>[^/]*)"
# This is when we receive a server-server Query
- async def on_GET(self, origin, content, query, query_type):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ query_type: str,
+ ) -> Tuple[int, JsonDict]:
args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()}
args["origin"] = origin
return await self.handler.on_query_request(query_type, args)
@@ -514,47 +557,66 @@ class FederationQueryServlet(BaseFederationServerServlet):
class FederationMakeJoinServlet(BaseFederationServerServlet):
PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
- async def on_GET(self, origin, _content, query, room_id, user_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
"""
Args:
- origin (unicode): The authenticated server_name of the calling server
+ origin: The authenticated server_name of the calling server
- _content (None): (GETs don't have bodies)
+ content: (GETs don't have bodies)
- query (dict[bytes, list[bytes]]): Query params from the request.
+ query: Query params from the request.
- **kwargs (dict[unicode, unicode]): the dict mapping keys to path
- components as specified in the path match regexp.
+ **kwargs: the dict mapping keys to path components as specified in
+ the path match regexp.
Returns:
- Tuple[int, object]: (response code, response object)
+ Tuple of (response code, response object)
"""
- versions = query.get(b"ver")
- if versions is not None:
- supported_versions = [v.decode("utf-8") for v in versions]
- else:
+ supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8")
+ if supported_versions is None:
supported_versions = ["1"]
- content = await self.handler.on_make_join_request(
+ result = await self.handler.on_make_join_request(
origin, room_id, user_id, supported_versions=supported_versions
)
- return 200, content
+ return 200, result
class FederationMakeLeaveServlet(BaseFederationServerServlet):
PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
- async def on_GET(self, origin, content, query, room_id, user_id):
- content = await self.handler.on_make_leave_request(origin, room_id, user_id)
- return 200, content
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
+ result = await self.handler.on_make_leave_request(origin, room_id, user_id)
+ return 200, result
class FederationV1SendLeaveServlet(BaseFederationServerServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
- async def on_PUT(self, origin, content, query, room_id, event_id):
- content = await self.handler.on_send_leave_request(origin, content, room_id)
- return 200, (200, content)
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ event_id: str,
+ ) -> Tuple[int, Tuple[int, JsonDict]]:
+ result = await self.handler.on_send_leave_request(origin, content, room_id)
+ return 200, (200, result)
class FederationV2SendLeaveServlet(BaseFederationServerServlet):
@@ -562,50 +624,84 @@ class FederationV2SendLeaveServlet(BaseFederationServerServlet):
PREFIX = FEDERATION_V2_PREFIX
- async def on_PUT(self, origin, content, query, room_id, event_id):
- content = await self.handler.on_send_leave_request(origin, content, room_id)
- return 200, content
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ event_id: str,
+ ) -> Tuple[int, JsonDict]:
+ result = await self.handler.on_send_leave_request(origin, content, room_id)
+ return 200, result
class FederationMakeKnockServlet(BaseFederationServerServlet):
PATH = "/make_knock/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
- async def on_GET(self, origin, content, query, room_id, user_id):
- try:
- # Retrieve the room versions the remote homeserver claims to support
- supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8")
- except KeyError:
- raise SynapseError(400, "Missing required query parameter 'ver'")
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
+ # Retrieve the room versions the remote homeserver claims to support
+ supported_versions = parse_strings_from_args(
+ query, "ver", required=True, encoding="utf-8"
+ )
- content = await self.handler.on_make_knock_request(
+ result = await self.handler.on_make_knock_request(
origin, room_id, user_id, supported_versions=supported_versions
)
- return 200, content
+ return 200, result
class FederationV1SendKnockServlet(BaseFederationServerServlet):
PATH = "/send_knock/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
- async def on_PUT(self, origin, content, query, room_id, event_id):
- content = await self.handler.on_send_knock_request(origin, content, room_id)
- return 200, content
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ event_id: str,
+ ) -> Tuple[int, JsonDict]:
+ result = await self.handler.on_send_knock_request(origin, content, room_id)
+ return 200, result
class FederationEventAuthServlet(BaseFederationServerServlet):
PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
- async def on_GET(self, origin, content, query, room_id, event_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ event_id: str,
+ ) -> Tuple[int, JsonDict]:
return await self.handler.on_event_auth(origin, room_id, event_id)
class FederationV1SendJoinServlet(BaseFederationServerServlet):
PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
- async def on_PUT(self, origin, content, query, room_id, event_id):
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ event_id: str,
+ ) -> Tuple[int, Tuple[int, JsonDict]]:
# TODO(paul): assert that event_id parsed from path actually
# match those given in content
- content = await self.handler.on_send_join_request(origin, content, room_id)
- return 200, (200, content)
+ result = await self.handler.on_send_join_request(origin, content, room_id)
+ return 200, (200, result)
class FederationV2SendJoinServlet(BaseFederationServerServlet):
@@ -613,28 +709,42 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
PREFIX = FEDERATION_V2_PREFIX
- async def on_PUT(self, origin, content, query, room_id, event_id):
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ event_id: str,
+ ) -> Tuple[int, JsonDict]:
# TODO(paul): assert that event_id parsed from path actually
# match those given in content
- content = await self.handler.on_send_join_request(origin, content, room_id)
- return 200, content
+ result = await self.handler.on_send_join_request(origin, content, room_id)
+ return 200, result
class FederationV1InviteServlet(BaseFederationServerServlet):
PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
- async def on_PUT(self, origin, content, query, room_id, event_id):
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ event_id: str,
+ ) -> Tuple[int, Tuple[int, JsonDict]]:
# We don't get a room version, so we have to assume its EITHER v1 or
# v2. This is "fine" as the only difference between V1 and V2 is the
# state resolution algorithm, and we don't use that for processing
# invites
- content = await self.handler.on_invite_request(
+ result = await self.handler.on_invite_request(
origin, content, room_version_id=RoomVersions.V1.identifier
)
# V1 federation API is defined to return a content of `[200, {...}]`
# due to a historical bug.
- return 200, (200, content)
+ return 200, (200, result)
class FederationV2InviteServlet(BaseFederationServerServlet):
@@ -642,7 +752,14 @@ class FederationV2InviteServlet(BaseFederationServerServlet):
PREFIX = FEDERATION_V2_PREFIX
- async def on_PUT(self, origin, content, query, room_id, event_id):
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ event_id: str,
+ ) -> Tuple[int, JsonDict]:
# TODO(paul): assert that room_id/event_id parsed from path actually
# match those given in content
@@ -655,16 +772,22 @@ class FederationV2InviteServlet(BaseFederationServerServlet):
event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state
- content = await self.handler.on_invite_request(
+ result = await self.handler.on_invite_request(
origin, event, room_version_id=room_version
)
- return 200, content
+ return 200, result
class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
- async def on_PUT(self, origin, content, query, room_id):
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
await self.handler.on_exchange_third_party_invite_request(content)
return 200, {}
@@ -672,21 +795,31 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet):
class FederationClientKeysQueryServlet(BaseFederationServerServlet):
PATH = "/user/keys/query"
- async def on_POST(self, origin, content, query):
+ async def on_POST(
+ self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
+ ) -> Tuple[int, JsonDict]:
return await self.handler.on_query_client_keys(origin, content)
class FederationUserDevicesQueryServlet(BaseFederationServerServlet):
PATH = "/user/devices/(?P<user_id>[^/]*)"
- async def on_GET(self, origin, content, query, user_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
return await self.handler.on_query_user_devices(origin, user_id)
class FederationClientKeysClaimServlet(BaseFederationServerServlet):
PATH = "/user/keys/claim"
- async def on_POST(self, origin, content, query):
+ async def on_POST(
+ self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
+ ) -> Tuple[int, JsonDict]:
response = await self.handler.on_claim_client_keys(origin, content)
return 200, response
@@ -695,12 +828,18 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet):
# TODO(paul): Why does this path alone end with "/?" optional?
PATH = "/get_missing_events/(?P<room_id>[^/]*)/?"
- async def on_POST(self, origin, content, query, room_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
limit = int(content.get("limit", 10))
earliest_events = content.get("earliest_events", [])
latest_events = content.get("latest_events", [])
- content = await self.handler.on_get_missing_events(
+ result = await self.handler.on_get_missing_events(
origin,
room_id=room_id,
earliest_events=earliest_events,
@@ -708,7 +847,7 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet):
limit=limit,
)
- return 200, content
+ return 200, result
class On3pidBindServlet(BaseFederationServerServlet):
@@ -716,7 +855,9 @@ class On3pidBindServlet(BaseFederationServerServlet):
REQUIRE_AUTH = False
- async def on_POST(self, origin, content, query):
+ async def on_POST(
+ self, origin: Optional[str], content: JsonDict, query: Dict[bytes, List[bytes]]
+ ) -> Tuple[int, JsonDict]:
if "invites" in content:
last_exception = None
for invite in content["invites"]:
@@ -762,15 +903,20 @@ class OpenIdUserInfo(BaseFederationServerServlet):
REQUIRE_AUTH = False
- async def on_GET(self, origin, content, query):
- token = query.get(b"access_token", [None])[0]
+ async def on_GET(
+ self,
+ origin: Optional[str],
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ ) -> Tuple[int, JsonDict]:
+ token = parse_string_from_args(query, "access_token")
if token is None:
return (
401,
{"errcode": "M_MISSING_TOKEN", "error": "Access Token required"},
)
- user_id = await self.handler.on_openid_userinfo(token.decode("ascii"))
+ user_id = await self.handler.on_openid_userinfo(token)
if user_id is None:
return (
@@ -829,7 +975,9 @@ class PublicRoomList(BaseFederationServlet):
self.handler = hs.get_room_list_handler()
self.allow_access = allow_access
- async def on_GET(self, origin, content, query):
+ async def on_GET(
+ self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]]
+ ) -> Tuple[int, JsonDict]:
if not self.allow_access:
raise FederationDeniedError(origin)
@@ -858,7 +1006,9 @@ class PublicRoomList(BaseFederationServlet):
)
return 200, data
- async def on_POST(self, origin, content, query):
+ async def on_POST(
+ self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
+ ) -> Tuple[int, JsonDict]:
# This implements MSC2197 (Search Filtering over Federation)
if not self.allow_access:
raise FederationDeniedError(origin)
@@ -904,7 +1054,12 @@ class FederationVersionServlet(BaseFederationServlet):
REQUIRE_AUTH = False
- async def on_GET(self, origin, content, query):
+ async def on_GET(
+ self,
+ origin: Optional[str],
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ ) -> Tuple[int, JsonDict]:
return (
200,
{"server": {"name": "Synapse", "version": get_version_string(synapse)}},
@@ -933,7 +1088,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/profile"
- async def on_GET(self, origin, content, query, group_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -942,7 +1103,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet):
return 200, new_content
- async def on_POST(self, origin, content, query, group_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -957,7 +1124,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet):
class FederationGroupsSummaryServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/summary"
- async def on_GET(self, origin, content, query, group_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -972,7 +1145,13 @@ class FederationGroupsRoomsServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/rooms"
- async def on_GET(self, origin, content, query, group_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -987,7 +1166,14 @@ class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
- async def on_POST(self, origin, content, query, group_id, room_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -998,7 +1184,14 @@ class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet):
return 200, new_content
- async def on_DELETE(self, origin, content, query, group_id, room_id):
+ async def on_DELETE(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1018,7 +1211,15 @@ class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet):
"/config/(?P<config_key>[^/]*)"
)
- async def on_POST(self, origin, content, query, group_id, room_id, config_key):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ room_id: str,
+ config_key: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1035,7 +1236,13 @@ class FederationGroupsUsersServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/users"
- async def on_GET(self, origin, content, query, group_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1050,7 +1257,13 @@ class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
- async def on_GET(self, origin, content, query, group_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1067,7 +1280,14 @@ class FederationGroupsInviteServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
- async def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1084,7 +1304,14 @@ class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
- async def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
if get_domain_from_id(user_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
@@ -1098,7 +1325,14 @@ class FederationGroupsJoinServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
- async def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
if get_domain_from_id(user_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
@@ -1112,7 +1346,14 @@ class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
- async def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1146,7 +1387,14 @@ class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet):
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
- async def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "group_id doesn't match origin")
@@ -1164,7 +1412,14 @@ class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet):
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
- async def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ user_id: str,
+ ) -> Tuple[int, None]:
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
@@ -1172,11 +1427,9 @@ class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet):
self.handler, GroupsLocalHandler
), "Workers cannot handle group removals."
- new_content = await self.handler.user_removed_from_group(
- group_id, user_id, content
- )
+ await self.handler.user_removed_from_group(group_id, user_id, content)
- return 200, new_content
+ return 200, None
class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
@@ -1194,7 +1447,14 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
super().__init__(hs, authenticator, ratelimiter, server_name)
self.handler = hs.get_groups_attestation_renewer()
- async def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
# We don't need to check auth here as we check the attestation signatures
new_content = await self.handler.on_renew_attestation(
@@ -1218,7 +1478,15 @@ class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet):
"/rooms/(?P<room_id>[^/]*)"
)
- async def on_POST(self, origin, content, query, group_id, category_id, room_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ category_id: str,
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1246,7 +1514,15 @@ class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet):
return 200, resp
- async def on_DELETE(self, origin, content, query, group_id, category_id, room_id):
+ async def on_DELETE(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ category_id: str,
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1266,7 +1542,13 @@ class FederationGroupsCategoriesServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
- async def on_GET(self, origin, content, query, group_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1281,7 +1563,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
- async def on_GET(self, origin, content, query, group_id, category_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ category_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1292,7 +1581,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
return 200, resp
- async def on_POST(self, origin, content, query, group_id, category_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ category_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1314,7 +1610,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
return 200, resp
- async def on_DELETE(self, origin, content, query, group_id, category_id):
+ async def on_DELETE(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ category_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1334,7 +1637,13 @@ class FederationGroupsRolesServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
- async def on_GET(self, origin, content, query, group_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1349,7 +1658,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
- async def on_GET(self, origin, content, query, group_id, role_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ role_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1358,7 +1674,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet):
return 200, resp
- async def on_POST(self, origin, content, query, group_id, role_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ role_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1382,7 +1705,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet):
return 200, resp
- async def on_DELETE(self, origin, content, query, group_id, role_id):
+ async def on_DELETE(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ role_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1411,7 +1741,15 @@ class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet):
"/users/(?P<user_id>[^/]*)"
)
- async def on_POST(self, origin, content, query, group_id, role_id, user_id):
+ async def on_POST(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ role_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1437,7 +1775,15 @@ class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet):
return 200, resp
- async def on_DELETE(self, origin, content, query, group_id, role_id, user_id):
+ async def on_DELETE(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ role_id: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1457,7 +1803,9 @@ class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet):
PATH = "/get_groups_publicised"
- async def on_POST(self, origin, content, query):
+ async def on_POST(
+ self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
+ ) -> Tuple[int, JsonDict]:
resp = await self.handler.bulk_get_publicised_groups(
content["user_ids"], proxy=False
)
@@ -1470,7 +1818,13 @@ class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet):
PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
- async def on_PUT(self, origin, content, query, group_id):
+ async def on_PUT(
+ self,
+ origin: str,
+ content: JsonDict,
+ query: Dict[bytes, List[bytes]],
+ group_id: str,
+ ) -> Tuple[int, JsonDict]:
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1499,7 +1853,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
async def on_GET(
self,
origin: str,
- content: JsonDict,
+ content: Literal[None],
query: Mapping[bytes, Sequence[bytes]],
room_id: str,
) -> Tuple[int, JsonDict]:
@@ -1571,7 +1925,13 @@ class RoomComplexityServlet(BaseFederationServlet):
super().__init__(hs, authenticator, ratelimiter, server_name)
self._store = self.hs.get_datastore()
- async def on_GET(self, origin, content, query, room_id):
+ async def on_GET(
+ self,
+ origin: str,
+ content: Literal[None],
+ query: Dict[bytes, List[bytes]],
+ room_id: str,
+ ) -> Tuple[int, JsonDict]:
is_public = await self._store.is_room_world_readable_or_publicly_joinable(
room_id
)
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index fda8da21b7..6ba2ce1e53 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -113,8 +113,18 @@ def parse_boolean_from_args(args, name, default=None, required=False):
def parse_bytes_from_args(
args: Dict[bytes, List[bytes]],
name: str,
+ default: Optional[bytes] = None,
+) -> Optional[bytes]:
+ ...
+
+
+@overload
+def parse_bytes_from_args(
+ args: Dict[bytes, List[bytes]],
+ name: str,
default: Literal[None] = None,
- required: Literal[True] = True,
+ *,
+ required: Literal[True],
) -> bytes:
...
@@ -197,7 +207,12 @@ def parse_string(
"""
args = request.args # type: Dict[bytes, List[bytes]] # type: ignore
return parse_string_from_args(
- args, name, default, required, allowed_values, encoding
+ args,
+ name,
+ default,
+ required=required,
+ allowed_values=allowed_values,
+ encoding=encoding,
)
@@ -227,7 +242,20 @@ def parse_strings_from_args(
args: Dict[bytes, List[bytes]],
name: str,
default: Optional[List[str]] = None,
- required: Literal[True] = True,
+ *,
+ allowed_values: Optional[Iterable[str]] = None,
+ encoding: str = "ascii",
+) -> Optional[List[str]]:
+ ...
+
+
+@overload
+def parse_strings_from_args(
+ args: Dict[bytes, List[bytes]],
+ name: str,
+ default: Optional[List[str]] = None,
+ *,
+ required: Literal[True],
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
) -> List[str]:
@@ -239,6 +267,7 @@ def parse_strings_from_args(
args: Dict[bytes, List[bytes]],
name: str,
default: Optional[List[str]] = None,
+ *,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
@@ -299,7 +328,20 @@ def parse_string_from_args(
args: Dict[bytes, List[bytes]],
name: str,
default: Optional[str] = None,
- required: Literal[True] = True,
+ *,
+ allowed_values: Optional[Iterable[str]] = None,
+ encoding: str = "ascii",
+) -> Optional[str]:
+ ...
+
+
+@overload
+def parse_string_from_args(
+ args: Dict[bytes, List[bytes]],
+ name: str,
+ default: Optional[str] = None,
+ *,
+ required: Literal[True],
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
) -> str:
|