diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index be938df962..6e2fbedd99 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService
-from synapse.http.server import finish_request
+from synapse.handlers.sso import SsoIdentityProvider
+from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
@@ -60,11 +61,14 @@ class LoginRestServlet(RestServlet):
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
+ self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self.auth = hs.get_auth()
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
+ self._sso_handler = hs.get_sso_handler()
+
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
@@ -89,8 +93,17 @@ class LoginRestServlet(RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
- flows.append({"type": LoginRestServlet.SSO_TYPE})
- # While its valid for us to advertise this login type generally,
+ sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict
+
+ if self._msc2858_enabled:
+ sso_flow["org.matrix.msc2858.identity_providers"] = [
+ _get_auth_flow_dict_for_idp(idp)
+ for idp in self._sso_handler.get_identity_providers().values()
+ ]
+
+ flows.append(sso_flow)
+
+ # While it's valid for us to advertise this login type generally,
# synapse currently only gives out these tokens as part of the
# SSO login flow.
# Generally we don't want to advertise login flows that clients
@@ -297,7 +310,9 @@ class LoginRestServlet(RestServlet):
except jwt.PyJWTError as e:
# A JWT error occurred, return some info back to the client.
raise LoginError(
- 403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN,
+ 403,
+ "JWT validation failed: %s" % (str(e),),
+ errcode=Codes.FORBIDDEN,
)
user = payload.get("sub", None)
@@ -311,8 +326,22 @@ class LoginRestServlet(RestServlet):
return result
+def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
+ """Return an entry for the login flow dict
+
+ Returns an entry suitable for inclusion in "identity_providers" in the
+ response to GET /_matrix/client/r0/login
+ """
+ e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
+ if idp.idp_icon:
+ e["icon"] = idp.idp_icon
+ if idp.idp_brand:
+ e["brand"] = idp.idp_brand
+ return e
+
+
class SsoRedirectServlet(RestServlet):
- PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
+ PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they
@@ -324,13 +353,33 @@ class SsoRedirectServlet(RestServlet):
if hs.config.oidc_enabled:
hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler()
+ self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+
+ def register(self, http_server: HttpServer) -> None:
+ super().register(http_server)
+ if self._msc2858_enabled:
+ # expose additional endpoint for MSC2858 support
+ http_server.register_paths(
+ "GET",
+ client_patterns(
+ "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
+ releases=(),
+ unstable=True,
+ ),
+ self.on_GET,
+ self.__class__.__name__,
+ )
- async def on_GET(self, request: SynapseRequest):
+ async def on_GET(
+ self, request: SynapseRequest, idp_id: Optional[str] = None
+ ) -> None:
client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding=None
)
sso_url = await self._sso_handler.handle_redirect_request(
- request, client_redirect_url
+ request,
+ client_redirect_url,
+ idp_id,
)
logger.info("Redirecting to %s", sso_url)
request.redirect(sso_url)
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 85a66458c5..717c5f2b10 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -60,7 +60,9 @@ class ProfileDisplaynameRestServlet(RestServlet):
new_name = content["displayname"]
except Exception:
raise SynapseError(
- code=400, msg="Unable to parse name", errcode=Codes.BAD_JSON,
+ code=400,
+ msg="Unable to parse name",
+ errcode=Codes.BAD_JSON,
)
await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 89823fcc39..0c148a213d 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -159,7 +159,9 @@ class PushersRemoveRestServlet(RestServlet):
self.notifier.on_new_replication_data()
respond_with_html_bytes(
- request, 200, PushersRemoveRestServlet.SUCCESS_HTML,
+ request,
+ 200,
+ PushersRemoveRestServlet.SUCCESS_HTML,
)
return None
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index f95627ee61..9a1df30c29 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -362,7 +362,9 @@ class PublicRoomListRestServlet(TransactionRestServlet):
parse_and_validate_server_name(server)
except ValueError:
raise SynapseError(
- 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
+ 400,
+ "Invalid server name: %s" % (server,),
+ Codes.INVALID_PARAM,
)
try:
@@ -413,7 +415,9 @@ class PublicRoomListRestServlet(TransactionRestServlet):
parse_and_validate_server_name(server)
except ValueError:
raise SynapseError(
- 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
+ 400,
+ "Invalid server name: %s" % (server,),
+ Codes.INVALID_PARAM,
)
try:
@@ -650,7 +654,7 @@ class RoomEventContextServlet(RestServlet):
event_filter = None
results = await self.room_context_handler.get_event_context(
- requester.user, room_id, event_id, limit, event_filter
+ requester, room_id, event_id, limit, event_filter
)
if not results:
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 65e68d641b..adf1d39728 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -54,7 +54,7 @@ logger = logging.getLogger(__name__)
class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password/email/requestToken$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.datastore = hs.get_datastore()
@@ -103,6 +103,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
# The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to
# an email address which is controlled by the attacker but which, after
@@ -191,7 +193,10 @@ class PasswordRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
try:
params, session_id = await self.auth_handler.validate_user_via_ui_auth(
- requester, request, body, "modify your account password",
+ requester,
+ request,
+ body,
+ "modify your account password",
)
except InteractiveAuthIncompleteError as e:
# The user needs to provide more steps to complete auth, but
@@ -310,7 +315,10 @@ class DeactivateAccountRestServlet(RestServlet):
return 200, {}
await self.auth_handler.validate_user_via_ui_auth(
- requester, request, body, "deactivate your account",
+ requester,
+ request,
+ body,
+ "deactivate your account",
)
result = await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(),
@@ -379,6 +387,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
@@ -430,7 +440,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
super().__init__()
self.store = self.hs.get_datastore()
@@ -458,6 +468,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(
+ request, "msisdn", msisdn
+ )
+
if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
@@ -695,7 +709,10 @@ class ThreepidAddRestServlet(RestServlet):
assert_valid_client_secret(client_secret)
await self.auth_handler.validate_user_via_ui_auth(
- requester, request, body, "add a third-party identifier to your account",
+ requester,
+ request,
+ body,
+ "add a third-party identifier to your account",
)
validation_session = await self.identity_handler.validate_threepid_session(
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 314e01dfe4..3d07aadd39 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -83,7 +83,10 @@ class DeleteDevicesRestServlet(RestServlet):
assert_params_in_dict(body, ["devices"])
await self.auth_handler.validate_user_via_ui_auth(
- requester, request, body, "remove device(s) from your account",
+ requester,
+ request,
+ body,
+ "remove device(s) from your account",
)
await self.device_handler.delete_devices(
@@ -129,7 +132,10 @@ class DeviceRestServlet(RestServlet):
raise
await self.auth_handler.validate_user_via_ui_auth(
- requester, request, body, "remove a device from your account",
+ requester,
+ request,
+ body,
+ "remove a device from your account",
)
await self.device_handler.delete_device(requester.user.to_string(), device_id)
@@ -206,7 +212,9 @@ class DehydratedDeviceServlet(RestServlet):
if "device_data" not in submission:
raise errors.SynapseError(
- 400, "device_data missing", errcode=errors.Codes.MISSING_PARAM,
+ 400,
+ "device_data missing",
+ errcode=errors.Codes.MISSING_PARAM,
)
elif not isinstance(submission["device_data"], dict):
raise errors.SynapseError(
@@ -259,11 +267,15 @@ class ClaimDehydratedDeviceServlet(RestServlet):
if "device_id" not in submission:
raise errors.SynapseError(
- 400, "device_id missing", errcode=errors.Codes.MISSING_PARAM,
+ 400,
+ "device_id missing",
+ errcode=errors.Codes.MISSING_PARAM,
)
elif not isinstance(submission["device_id"], str):
raise errors.SynapseError(
- 400, "device_id must be a string", errcode=errors.Codes.INVALID_PARAM,
+ 400,
+ "device_id must be a string",
+ errcode=errors.Codes.INVALID_PARAM,
)
result = await self.device_handler.rehydrate_device(
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 5b5da71815..d3434225cb 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -16,13 +16,29 @@
import logging
from functools import wraps
-
-from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import GroupID
+from typing import TYPE_CHECKING, Optional, Tuple
+
+from twisted.web.http import Request
+
+from synapse.api.constants import (
+ MAX_GROUP_CATEGORYID_LENGTH,
+ MAX_GROUP_ROLEID_LENGTH,
+ MAX_GROUPID_LENGTH,
+)
+from synapse.api.errors import Codes, SynapseError
+from synapse.handlers.groups_local import GroupsLocalHandler
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from synapse.types import GroupID, JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -33,7 +49,7 @@ def _validate_group_id(f):
"""
@wraps(f)
- def wrapper(self, request, group_id, *args, **kwargs):
+ def wrapper(self, request: Request, group_id: str, *args, **kwargs):
if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
@@ -43,19 +59,18 @@ def _validate_group_id(f):
class GroupServlet(RestServlet):
- """Get the group profile
- """
+ """Get the group profile"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -66,11 +81,17 @@ class GroupServlet(RestServlet):
return 200, group_description
@_validate_group_id
- async def on_POST(self, request, group_id):
+ async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert_params_in_dict(
+ content, ("name", "avatar_url", "short_description", "long_description")
+ )
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot create group profiles."
await self.groups_handler.update_group_profile(
group_id, requester_user_id, content
)
@@ -79,19 +100,18 @@ class GroupServlet(RestServlet):
class GroupSummaryServlet(RestServlet):
- """Get the full group summary
- """
+ """Get the full group summary"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -116,18 +136,34 @@ class GroupSummaryRoomsCatServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, category_id, room_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, category_id: Optional[str], room_id: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM)
+
+ if category_id and len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
+ raise SynapseError(
+ 400,
+ "category_id may not be longer than %s characters"
+ % (MAX_GROUP_CATEGORYID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
+
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group summaries."
resp = await self.groups_handler.update_group_summary_room(
group_id,
requester_user_id,
@@ -139,10 +175,15 @@ class GroupSummaryRoomsCatServlet(RestServlet):
return 200, resp
@_validate_group_id
- async def on_DELETE(self, request, group_id, category_id, room_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, category_id: str, room_id: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group profiles."
resp = await self.groups_handler.delete_group_summary_room(
group_id, requester_user_id, room_id=room_id, category_id=category_id
)
@@ -151,21 +192,22 @@ class GroupSummaryRoomsCatServlet(RestServlet):
class GroupCategoryServlet(RestServlet):
- """Get/add/update/delete a group category
- """
+ """Get/add/update/delete a group category"""
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id, category_id):
+ async def on_GET(
+ self, request: Request, group_id: str, category_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -176,11 +218,27 @@ class GroupCategoryServlet(RestServlet):
return 200, category
@_validate_group_id
- async def on_PUT(self, request, group_id, category_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, category_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ if not category_id:
+ raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM)
+
+ if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
+ raise SynapseError(
+ 400,
+ "category_id may not be longer than %s characters"
+ % (MAX_GROUP_CATEGORYID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
+
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group categories."
resp = await self.groups_handler.update_group_category(
group_id, requester_user_id, category_id=category_id, content=content
)
@@ -188,10 +246,15 @@ class GroupCategoryServlet(RestServlet):
return 200, resp
@_validate_group_id
- async def on_DELETE(self, request, group_id, category_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, category_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group categories."
resp = await self.groups_handler.delete_group_category(
group_id, requester_user_id, category_id=category_id
)
@@ -200,19 +263,18 @@ class GroupCategoryServlet(RestServlet):
class GroupCategoriesServlet(RestServlet):
- """Get all group categories
- """
+ """Get all group categories"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -224,19 +286,20 @@ class GroupCategoriesServlet(RestServlet):
class GroupRoleServlet(RestServlet):
- """Get/add/update/delete a group role
- """
+ """Get/add/update/delete a group role"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id, role_id):
+ async def on_GET(
+ self, request: Request, group_id: str, role_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -247,11 +310,27 @@ class GroupRoleServlet(RestServlet):
return 200, category
@_validate_group_id
- async def on_PUT(self, request, group_id, role_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, role_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ if not role_id:
+ raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM)
+
+ if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
+ raise SynapseError(
+ 400,
+ "role_id may not be longer than %s characters"
+ % (MAX_GROUP_ROLEID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
+
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group roles."
resp = await self.groups_handler.update_group_role(
group_id, requester_user_id, role_id=role_id, content=content
)
@@ -259,10 +338,15 @@ class GroupRoleServlet(RestServlet):
return 200, resp
@_validate_group_id
- async def on_DELETE(self, request, group_id, role_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, role_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group roles."
resp = await self.groups_handler.delete_group_role(
group_id, requester_user_id, role_id=role_id
)
@@ -271,19 +355,18 @@ class GroupRoleServlet(RestServlet):
class GroupRolesServlet(RestServlet):
- """Get all group roles
- """
+ """Get all group roles"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -308,18 +391,34 @@ class GroupSummaryUsersRoleServlet(RestServlet):
"/users/(?P<user_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, role_id, user_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, role_id: Optional[str], user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM)
+
+ if role_id and len(role_id) > MAX_GROUP_ROLEID_LENGTH:
+ raise SynapseError(
+ 400,
+ "role_id may not be longer than %s characters"
+ % (MAX_GROUP_ROLEID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
+
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group summaries."
resp = await self.groups_handler.update_group_summary_user(
group_id,
requester_user_id,
@@ -331,10 +430,15 @@ class GroupSummaryUsersRoleServlet(RestServlet):
return 200, resp
@_validate_group_id
- async def on_DELETE(self, request, group_id, role_id, user_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, role_id: str, user_id: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group summaries."
resp = await self.groups_handler.delete_group_summary_user(
group_id, requester_user_id, user_id=user_id, role_id=role_id
)
@@ -343,19 +447,18 @@ class GroupSummaryUsersRoleServlet(RestServlet):
class GroupRoomServlet(RestServlet):
- """Get all rooms in a group
- """
+ """Get all rooms in a group"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -367,19 +470,18 @@ class GroupRoomServlet(RestServlet):
class GroupUsersServlet(RestServlet):
- """Get all users in a group
- """
+ """Get all users in a group"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -391,19 +493,18 @@ class GroupUsersServlet(RestServlet):
class GroupInvitedUsersServlet(RestServlet):
- """Get users invited to a group
- """
+ """Get users invited to a group"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -415,23 +516,25 @@ class GroupInvitedUsersServlet(RestServlet):
class GroupSettingJoinPolicyServlet(RestServlet):
- """Set group join policy
- """
+ """Set group join policy"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group join policy."
result = await self.groups_handler.set_group_join_policy(
group_id, requester_user_id, content
)
@@ -440,19 +543,18 @@ class GroupSettingJoinPolicyServlet(RestServlet):
class GroupCreateServlet(RestServlet):
- """Create a group
- """
+ """Create a group"""
PATTERNS = client_patterns("/create_group$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
self.server_name = hs.hostname
- async def on_POST(self, request):
+ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -461,6 +563,19 @@ class GroupCreateServlet(RestServlet):
localpart = content.pop("localpart")
group_id = GroupID(localpart, self.server_name).to_string()
+ if not localpart:
+ raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
+
+ if len(group_id) > MAX_GROUPID_LENGTH:
+ raise SynapseError(
+ 400,
+ "Group ID may not be longer than %s characters" % (MAX_GROUPID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
+
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot create groups."
result = await self.groups_handler.create_group(
group_id, requester_user_id, content
)
@@ -469,25 +584,29 @@ class GroupCreateServlet(RestServlet):
class GroupAdminRoomsServlet(RestServlet):
- """Add a room to the group
- """
+ """Add a room to the group"""
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, room_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify rooms in a group."
result = await self.groups_handler.add_room_to_group(
group_id, requester_user_id, room_id, content
)
@@ -495,10 +614,15 @@ class GroupAdminRoomsServlet(RestServlet):
return 200, result
@_validate_group_id
- async def on_DELETE(self, request, group_id, room_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group categories."
result = await self.groups_handler.remove_room_from_group(
group_id, requester_user_id, room_id
)
@@ -507,26 +631,30 @@ class GroupAdminRoomsServlet(RestServlet):
class GroupAdminRoomsConfigServlet(RestServlet):
- """Update the config of a room in a group
- """
+ """Update the config of a room in a group"""
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
"/config/(?P<config_key>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, room_id, config_key):
+ async def on_PUT(
+ self, request: Request, group_id: str, room_id: str, config_key: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group categories."
result = await self.groups_handler.update_room_in_group(
group_id, requester_user_id, room_id, config_key, content
)
@@ -535,14 +663,13 @@ class GroupAdminRoomsConfigServlet(RestServlet):
class GroupAdminUsersInviteServlet(RestServlet):
- """Invite a user to the group
- """
+ """Invite a user to the group"""
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
@@ -551,12 +678,15 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.is_mine_id = hs.is_mine_id
@_validate_group_id
- async def on_PUT(self, request, group_id, user_id):
+ async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
config = content.get("config", {})
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot invite users to a group."
result = await self.groups_handler.invite(
group_id, user_id, requester_user_id, config
)
@@ -565,25 +695,27 @@ class GroupAdminUsersInviteServlet(RestServlet):
class GroupAdminUsersKickServlet(RestServlet):
- """Kick a user from the group
- """
+ """Kick a user from the group"""
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, user_id):
+ async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot kick users from a group."
result = await self.groups_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content
)
@@ -592,23 +724,25 @@ class GroupAdminUsersKickServlet(RestServlet):
class GroupSelfLeaveServlet(RestServlet):
- """Leave a joined group
- """
+ """Leave a joined group"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot leave a group for a users."
result = await self.groups_handler.remove_user_from_group(
group_id, requester_user_id, requester_user_id, content
)
@@ -617,23 +751,25 @@ class GroupSelfLeaveServlet(RestServlet):
class GroupSelfJoinServlet(RestServlet):
- """Attempt to join a group, or knock
- """
+ """Attempt to join a group, or knock"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot join a user to a group."
result = await self.groups_handler.join_group(
group_id, requester_user_id, content
)
@@ -642,23 +778,25 @@ class GroupSelfJoinServlet(RestServlet):
class GroupSelfAcceptInviteServlet(RestServlet):
- """Accept a group invite
- """
+ """Accept a group invite"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot accept an invite to a group."
result = await self.groups_handler.accept_invite(
group_id, requester_user_id, content
)
@@ -667,19 +805,18 @@ class GroupSelfAcceptInviteServlet(RestServlet):
class GroupSelfUpdatePublicityServlet(RestServlet):
- """Update whether we publicise a users membership of a group
- """
+ """Update whether we publicise a users membership of a group"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -691,19 +828,18 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
class PublicisedGroupsForUserServlet(RestServlet):
- """Get the list of groups a user is advertising
- """
+ """Get the list of groups a user is advertising"""
PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request, user_id):
+ async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
result = await self.groups_handler.get_publicised_groups_for_user(user_id)
@@ -712,19 +848,18 @@ class PublicisedGroupsForUserServlet(RestServlet):
class PublicisedGroupsForUsersServlet(RestServlet):
- """Get the list of groups a user is advertising
- """
+ """Get the list of groups a user is advertising"""
PATTERNS = client_patterns("/publicised_groups$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
@@ -736,18 +871,17 @@ class PublicisedGroupsForUsersServlet(RestServlet):
class GroupsForUserServlet(RestServlet):
- """Get all groups the logged in user is joined to
- """
+ """Get all groups the logged in user is joined to"""
PATTERNS = client_patterns("/joined_groups$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -756,7 +890,7 @@ class GroupsForUserServlet(RestServlet):
return 200, result
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
GroupServlet(hs).register(http_server)
GroupSummaryServlet(hs).register(http_server)
GroupInvitedUsersServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index a6134ead8a..f092e5b3a2 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -271,7 +271,10 @@ class SigningKeyUploadServlet(RestServlet):
body = parse_json_object_from_request(request)
await self.auth_handler.validate_user_via_ui_auth(
- requester, request, body, "add a device signing key to your account",
+ requester,
+ request,
+ body,
+ "add a device signing key to your account",
)
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index b093183e79..8f68d8dfc8 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -126,6 +126,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email
)
@@ -191,6 +193,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
body, ["client_secret", "country", "phone_number", "send_attempt"]
)
client_secret = body["client_secret"]
+ assert_valid_client_secret(client_secret)
country = body["country"]
phone_number = body["phone_number"]
send_attempt = body["send_attempt"]
@@ -205,6 +208,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(
+ request, "msisdn", msisdn
+ )
+
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"msisdn", msisdn
)
@@ -287,6 +294,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
sid = parse_string(request, "sid", required=True)
client_secret = parse_string(request, "client_secret", required=True)
+ assert_valid_client_secret(client_secret)
token = parse_string(request, "token", required=True)
# Attempt to validate a 3PID session
@@ -514,7 +522,10 @@ class RegisterRestServlet(RestServlet):
# not this will raise a user-interactive auth error.
try:
auth_result, params, session_id = await self.auth_handler.check_ui_auth(
- self._registration_flows, request, body, "register a new account",
+ self._registration_flows,
+ request,
+ body,
+ "register a new account",
)
except InteractiveAuthIncompleteError as e:
# The user needs to provide more steps to complete auth.
@@ -657,7 +668,9 @@ class RegisterRestServlet(RestServlet):
username, as_token
)
return await self._create_registration_details(
- user_id, body, is_appservice_ghost=True,
+ user_id,
+ body,
+ is_appservice_ghost=True,
)
async def _create_registration_details(
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 18c75738f8..fe765da23c 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -244,7 +244,9 @@ class RelationAggregationPaginationServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
- room_id, requester.user.to_string(), allow_departed_users=True,
+ room_id,
+ requester.user.to_string(),
+ allow_departed_users=True,
)
# This checks that a) the event exists and b) the user is allowed to
@@ -322,7 +324,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
- room_id, requester.user.to_string(), allow_departed_users=True,
+ room_id,
+ requester.user.to_string(),
+ allow_departed_users=True,
)
# This checks that a) the event exists and b) the user is allowed to
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index bf030e0ff4..147920767f 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class RoomUpgradeRestServlet(RestServlet):
- """Handler for room uprade requests.
+ """Handler for room upgrade requests.
Handles requests of the form:
|