summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-02-08 13:59:54 -0500
committerGitHub <noreply@github.com>2021-02-08 13:59:54 -0500
commit3f58fc848d0002de4605bed91603a1f9f245d128 (patch)
treec34cffdce8e7b037f0c1f7114c53c51f24bb113f /synapse
parentHandle additional errors when previewing URLs. (#9333) (diff)
downloadsynapse-3f58fc848d0002de4605bed91603a1f9f245d128.tar.xz
Type hints and validation improvements. (#9321)
* Adds type hints to the groups servlet and stringutils code.
* Assert the maximum length of some input values for spec compliance.
Diffstat (limited to '')
-rw-r--r--synapse/groups/groups_server.py25
-rw-r--r--synapse/rest/client/v2_alpha/groups.py179
-rw-r--r--synapse/rest/client/v2_alpha/register.py2
-rw-r--r--synapse/server.py16
-rw-r--r--synapse/util/stringutils.py33
5 files changed, 176 insertions, 79 deletions
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 0d042cbfac..76bf52ea23 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -18,6 +18,7 @@
 import logging
 
 from synapse.api.errors import Codes, SynapseError
+from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
 from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
 from synapse.util.async_helpers import concurrently_execute
 
@@ -32,6 +33,11 @@ logger = logging.getLogger(__name__)
 # TODO: Flairs
 
 
+# Note that the maximum lengths are somewhat arbitrary.
+MAX_SHORT_DESC_LEN = 1000
+MAX_LONG_DESC_LEN = 10000
+
+
 class GroupsServerWorkerHandler:
     def __init__(self, hs):
         self.hs = hs
@@ -508,11 +514,26 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         )
 
         profile = {}
-        for keyname in ("name", "avatar_url", "short_description", "long_description"):
+        for keyname, max_length in (
+            ("name", MAX_DISPLAYNAME_LEN),
+            ("avatar_url", MAX_AVATAR_URL_LEN),
+            ("short_description", MAX_SHORT_DESC_LEN),
+            ("long_description", MAX_LONG_DESC_LEN),
+        ):
             if keyname in content:
                 value = content[keyname]
                 if not isinstance(value, str):
-                    raise SynapseError(400, "%r value is not a string" % (keyname,))
+                    raise SynapseError(
+                        400,
+                        "%r value is not a string" % (keyname,),
+                        errcode=Codes.INVALID_PARAM,
+                    )
+                if len(value) > max_length:
+                    raise SynapseError(
+                        400,
+                        "Invalid %s parameter" % (keyname,),
+                        errcode=Codes.INVALID_PARAM,
+                    )
                 profile[keyname] = value
 
         await self.store.update_group_profile(group_id, profile)
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 5b5da71815..4fe712b30c 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -16,13 +16,24 @@
 
 import logging
 from functools import wraps
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.http import Request
 
 from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import GroupID
+from synapse.handlers.groups_local import GroupsLocalHandler
+from synapse.http.servlet import (
+    RestServlet,
+    assert_params_in_dict,
+    parse_json_object_from_request,
+)
+from synapse.types import GroupID, JsonDict
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -33,7 +44,7 @@ def _validate_group_id(f):
     """
 
     @wraps(f)
-    def wrapper(self, request, group_id, *args, **kwargs):
+    def wrapper(self, request: Request, group_id: str, *args, **kwargs):
         if not GroupID.is_valid(group_id):
             raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
 
@@ -48,14 +59,14 @@ class GroupServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -66,11 +77,15 @@ class GroupServlet(RestServlet):
         return 200, group_description
 
     @_validate_group_id
-    async def on_POST(self, request, group_id):
+    async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert_params_in_dict(
+            content, ("name", "avatar_url", "short_description", "long_description")
+        )
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         await self.groups_handler.update_group_profile(
             group_id, requester_user_id, content
         )
@@ -84,14 +99,14 @@ class GroupSummaryServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -116,18 +131,21 @@ class GroupSummaryRoomsCatServlet(RestServlet):
         "/rooms/(?P<room_id>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, category_id, room_id):
+    async def on_PUT(
+        self, request: Request, group_id: str, category_id: str, room_id: str
+    ):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         resp = await self.groups_handler.update_group_summary_room(
             group_id,
             requester_user_id,
@@ -139,10 +157,13 @@ class GroupSummaryRoomsCatServlet(RestServlet):
         return 200, resp
 
     @_validate_group_id
-    async def on_DELETE(self, request, group_id, category_id, room_id):
+    async def on_DELETE(
+        self, request: Request, group_id: str, category_id: str, room_id: str
+    ):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         resp = await self.groups_handler.delete_group_summary_room(
             group_id, requester_user_id, room_id=room_id, category_id=category_id
         )
@@ -158,14 +179,16 @@ class GroupCategoryServlet(RestServlet):
         "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id, category_id):
+    async def on_GET(
+        self, request: Request, group_id: str, category_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -176,11 +199,14 @@ class GroupCategoryServlet(RestServlet):
         return 200, category
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, category_id):
+    async def on_PUT(
+        self, request: Request, group_id: str, category_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         resp = await self.groups_handler.update_group_category(
             group_id, requester_user_id, category_id=category_id, content=content
         )
@@ -188,10 +214,13 @@ class GroupCategoryServlet(RestServlet):
         return 200, resp
 
     @_validate_group_id
-    async def on_DELETE(self, request, group_id, category_id):
+    async def on_DELETE(
+        self, request: Request, group_id: str, category_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         resp = await self.groups_handler.delete_group_category(
             group_id, requester_user_id, category_id=category_id
         )
@@ -205,14 +234,14 @@ class GroupCategoriesServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -229,14 +258,16 @@ class GroupRoleServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id, role_id):
+    async def on_GET(
+        self, request: Request, group_id: str, role_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -247,11 +278,14 @@ class GroupRoleServlet(RestServlet):
         return 200, category
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, role_id):
+    async def on_PUT(
+        self, request: Request, group_id: str, role_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         resp = await self.groups_handler.update_group_role(
             group_id, requester_user_id, role_id=role_id, content=content
         )
@@ -259,10 +293,13 @@ class GroupRoleServlet(RestServlet):
         return 200, resp
 
     @_validate_group_id
-    async def on_DELETE(self, request, group_id, role_id):
+    async def on_DELETE(
+        self, request: Request, group_id: str, role_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         resp = await self.groups_handler.delete_group_role(
             group_id, requester_user_id, role_id=role_id
         )
@@ -276,14 +313,14 @@ class GroupRolesServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -308,18 +345,21 @@ class GroupSummaryUsersRoleServlet(RestServlet):
         "/users/(?P<user_id>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, role_id, user_id):
+    async def on_PUT(
+        self, request: Request, group_id: str, role_id: str, user_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         resp = await self.groups_handler.update_group_summary_user(
             group_id,
             requester_user_id,
@@ -331,10 +371,13 @@ class GroupSummaryUsersRoleServlet(RestServlet):
         return 200, resp
 
     @_validate_group_id
-    async def on_DELETE(self, request, group_id, role_id, user_id):
+    async def on_DELETE(
+        self, request: Request, group_id: str, role_id: str, user_id: str
+    ):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         resp = await self.groups_handler.delete_group_summary_user(
             group_id, requester_user_id, user_id=user_id, role_id=role_id
         )
@@ -348,14 +391,14 @@ class GroupRoomServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -372,14 +415,14 @@ class GroupUsersServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -396,14 +439,14 @@ class GroupInvitedUsersServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -420,18 +463,19 @@ class GroupSettingJoinPolicyServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id):
+    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
 
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.set_group_join_policy(
             group_id, requester_user_id, content
         )
@@ -445,14 +489,14 @@ class GroupCreateServlet(RestServlet):
 
     PATTERNS = client_patterns("/create_group$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
         self.server_name = hs.hostname
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -461,6 +505,7 @@ class GroupCreateServlet(RestServlet):
         localpart = content.pop("localpart")
         group_id = GroupID(localpart, self.server_name).to_string()
 
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.create_group(
             group_id, requester_user_id, content
         )
@@ -476,18 +521,21 @@ class GroupAdminRoomsServlet(RestServlet):
         "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, room_id):
+    async def on_PUT(
+        self, request: Request, group_id: str, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.add_room_to_group(
             group_id, requester_user_id, room_id, content
         )
@@ -495,10 +543,13 @@ class GroupAdminRoomsServlet(RestServlet):
         return 200, result
 
     @_validate_group_id
-    async def on_DELETE(self, request, group_id, room_id):
+    async def on_DELETE(
+        self, request: Request, group_id: str, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.remove_room_from_group(
             group_id, requester_user_id, room_id
         )
@@ -515,18 +566,21 @@ class GroupAdminRoomsConfigServlet(RestServlet):
         "/config/(?P<config_key>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, room_id, config_key):
+    async def on_PUT(
+        self, request: Request, group_id: str, room_id: str, config_key: str
+    ):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.update_room_in_group(
             group_id, requester_user_id, room_id, config_key, content
         )
@@ -542,7 +596,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
         "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
@@ -551,12 +605,13 @@ class GroupAdminUsersInviteServlet(RestServlet):
         self.is_mine_id = hs.is_mine_id
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, user_id):
+    async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
         config = content.get("config", {})
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.invite(
             group_id, user_id, requester_user_id, config
         )
@@ -572,18 +627,19 @@ class GroupAdminUsersKickServlet(RestServlet):
         "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, user_id):
+    async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.remove_user_from_group(
             group_id, user_id, requester_user_id, content
         )
@@ -597,18 +653,19 @@ class GroupSelfLeaveServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id):
+    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.remove_user_from_group(
             group_id, requester_user_id, requester_user_id, content
         )
@@ -622,18 +679,19 @@ class GroupSelfJoinServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id):
+    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.join_group(
             group_id, requester_user_id, content
         )
@@ -647,18 +705,19 @@ class GroupSelfAcceptInviteServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id):
+    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.accept_invite(
             group_id, requester_user_id, content
         )
@@ -672,14 +731,14 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id):
+    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -696,14 +755,14 @@ class PublicisedGroupsForUserServlet(RestServlet):
 
     PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.groups_handler = hs.get_groups_local_handler()
 
-    async def on_GET(self, request, user_id):
+    async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
         await self.auth.get_user_by_req(request, allow_guest=True)
 
         result = await self.groups_handler.get_publicised_groups_for_user(user_id)
@@ -717,14 +776,14 @@ class PublicisedGroupsForUsersServlet(RestServlet):
 
     PATTERNS = client_patterns("/publicised_groups$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.groups_handler = hs.get_groups_local_handler()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
         await self.auth.get_user_by_req(request, allow_guest=True)
 
         content = parse_json_object_from_request(request)
@@ -741,13 +800,13 @@ class GroupsForUserServlet(RestServlet):
 
     PATTERNS = client_patterns("/joined_groups$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -756,7 +815,7 @@ class GroupsForUserServlet(RestServlet):
         return 200, result
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
     GroupServlet(hs).register(http_server)
     GroupSummaryServlet(hs).register(http_server)
     GroupInvitedUsersServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 10e1891174..e3d322f2ac 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -193,6 +193,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
             body, ["client_secret", "country", "phone_number", "send_attempt"]
         )
         client_secret = body["client_secret"]
+        assert_valid_client_secret(client_secret)
         country = body["country"]
         phone_number = body["phone_number"]
         send_attempt = body["send_attempt"]
@@ -293,6 +294,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
 
         sid = parse_string(request, "sid", required=True)
         client_secret = parse_string(request, "client_secret", required=True)
+        assert_valid_client_secret(client_secret)
         token = parse_string(request, "token", required=True)
 
         # Attempt to validate a 3PID session
diff --git a/synapse/server.py b/synapse/server.py
index 9bdd3177d7..6b3892e3cd 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -25,7 +25,17 @@ import abc
 import functools
 import logging
 import os
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    TypeVar,
+    Union,
+    cast,
+)
 
 import twisted.internet.base
 import twisted.internet.tcp
@@ -588,7 +598,9 @@ class HomeServer(metaclass=abc.ABCMeta):
         return UserDirectoryHandler(self)
 
     @cache_in_self
-    def get_groups_local_handler(self):
+    def get_groups_local_handler(
+        self,
+    ) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]:
         if self.config.worker_app:
             return GroupsLocalWorkerHandler(self)
         else:
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index f8038bf861..9ce7873ab5 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -25,7 +25,7 @@ from synapse.api.errors import Codes, SynapseError
 _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
 
 # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
-client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
+CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
 
 # https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
 # together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
@@ -42,28 +42,31 @@ MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
 rand = random.SystemRandom()
 
 
-def random_string(length):
+def random_string(length: int) -> str:
     return "".join(rand.choice(string.ascii_letters) for _ in range(length))
 
 
-def random_string_with_symbols(length):
+def random_string_with_symbols(length: int) -> str:
     return "".join(rand.choice(_string_with_symbols) for _ in range(length))
 
 
-def is_ascii(s):
-    if isinstance(s, bytes):
-        try:
-            s.decode("ascii").encode("ascii")
-        except UnicodeDecodeError:
-            return False
-        except UnicodeEncodeError:
-            return False
-        return True
+def is_ascii(s: bytes) -> bool:
+    try:
+        s.decode("ascii").encode("ascii")
+    except UnicodeDecodeError:
+        return False
+    except UnicodeEncodeError:
+        return False
+    return True
 
 
-def assert_valid_client_secret(client_secret):
-    """Validate that a given string matches the client_secret regex defined by the spec"""
-    if client_secret_regex.match(client_secret) is None:
+def assert_valid_client_secret(client_secret: str) -> None:
+    """Validate that a given string matches the client_secret defined by the spec"""
+    if (
+        len(client_secret) <= 0
+        or len(client_secret) > 255
+        or CLIENT_SECRET_REGEX.match(client_secret) is None
+    ):
         raise SynapseError(
             400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
         )