summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
authorBen Banfield-Zanin <benbz@matrix.org>2021-02-16 13:33:20 +0000
committerBen Banfield-Zanin <benbz@matrix.org>2021-02-16 13:33:20 +0000
commitdcf1b9c276e22bb6f5200fc029301c4d40e87a1f (patch)
tree1f5badce24645d99534133a7a989069906088fff /synapse/rest
parentMerge remote-tracking branch 'origin/release-v1.24.0' into bbz/info-mainline-... (diff)
parentFixup CHANGES (diff)
downloadsynapse-dcf1b9c276e22bb6f5200fc029301c4d40e87a1f.tar.xz
Merge remote-tracking branch 'origin/release-v1.27.0' into bbz/info-mainline-1.27.0 github/bbz/info-mainline-1.27.0 bbz/info-mainline-1.27.0
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/admin/__init__.py10
-rw-r--r--synapse/rest/admin/media.py64
-rw-r--r--synapse/rest/admin/rooms.py292
-rw-r--r--synapse/rest/admin/users.py109
-rw-r--r--synapse/rest/client/v1/login.py169
-rw-r--r--synapse/rest/client/v1/pusher.py15
-rw-r--r--synapse/rest/client/v1/room.py37
-rw-r--r--synapse/rest/client/v2_alpha/account.py66
-rw-r--r--synapse/rest/client/v2_alpha/account_data.py22
-rw-r--r--synapse/rest/client/v2_alpha/auth.py53
-rw-r--r--synapse/rest/client/v2_alpha/devices.py12
-rw-r--r--synapse/rest/client/v2_alpha/groups.py48
-rw-r--r--synapse/rest/client/v2_alpha/keys.py6
-rw-r--r--synapse/rest/client/v2_alpha/register.py43
-rw-r--r--synapse/rest/client/v2_alpha/sendtodevice.py3
-rw-r--r--synapse/rest/client/v2_alpha/tags.py11
-rw-r--r--synapse/rest/consent/consent_resource.py1
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py9
-rw-r--r--synapse/rest/media/v1/_base.py84
-rw-r--r--synapse/rest/media/v1/config_resource.py14
-rw-r--r--synapse/rest/media/v1/download_resource.py18
-rw-r--r--synapse/rest/media/v1/filepath.py50
-rw-r--r--synapse/rest/media/v1/media_repository.py52
-rw-r--r--synapse/rest/media/v1/media_storage.py12
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py123
-rw-r--r--synapse/rest/media/v1/storage_provider.py51
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py313
-rw-r--r--synapse/rest/media/v1/thumbnailer.py18
-rw-r--r--synapse/rest/media/v1/upload_resource.py16
-rw-r--r--synapse/rest/synapse/client/__init__.py54
-rw-r--r--synapse/rest/synapse/client/new_user_consent.py97
-rw-r--r--synapse/rest/synapse/client/oidc/__init__.py (renamed from synapse/rest/oidc/__init__.py)6
-rw-r--r--synapse/rest/synapse/client/oidc/callback_resource.py (renamed from synapse/rest/oidc/callback_resource.py)0
-rw-r--r--synapse/rest/synapse/client/pick_idp.py84
-rw-r--r--synapse/rest/synapse/client/pick_username.py131
-rw-r--r--synapse/rest/synapse/client/saml2/__init__.py (renamed from synapse/rest/saml2/__init__.py)8
-rw-r--r--synapse/rest/synapse/client/saml2/metadata_resource.py (renamed from synapse/rest/saml2/metadata_resource.py)0
-rw-r--r--synapse/rest/synapse/client/saml2/response_resource.py (renamed from synapse/rest/saml2/response_resource.py)0
-rw-r--r--synapse/rest/synapse/client/sso_register.py50
39 files changed, 1575 insertions, 576 deletions
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 55ddebb4fe..f5c5d164f9 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -1,6 +1,8 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2018-2019 New Vector Ltd
+# Copyright 2020, 2021 The Matrix.org Foundation C.I.C.
+
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -36,10 +38,13 @@ from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_medi
 from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
 from synapse.rest.admin.rooms import (
     DeleteRoomRestServlet,
+    ForwardExtremitiesRestServlet,
     JoinRoomAliasServlet,
     ListRoomRestServlet,
+    MakeRoomAdminRestServlet,
     RoomMembersRestServlet,
     RoomRestServlet,
+    RoomStateRestServlet,
     ShutdownRoomRestServlet,
 )
 from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
@@ -50,6 +55,7 @@ from synapse.rest.admin.users import (
     PushersRestServlet,
     ResetPasswordRestServlet,
     SearchUsersRestServlet,
+    ShadowBanRestServlet,
     UserAdminServlet,
     UserMediaRestServlet,
     UserMembershipRestServlet,
@@ -208,6 +214,7 @@ def register_servlets(hs, http_server):
     """
     register_servlets_for_client_rest_resource(hs, http_server)
     ListRoomRestServlet(hs).register(http_server)
+    RoomStateRestServlet(hs).register(http_server)
     RoomRestServlet(hs).register(http_server)
     RoomMembersRestServlet(hs).register(http_server)
     DeleteRoomRestServlet(hs).register(http_server)
@@ -228,6 +235,9 @@ def register_servlets(hs, http_server):
     EventReportDetailRestServlet(hs).register(http_server)
     EventReportsRestServlet(hs).register(http_server)
     PushersRestServlet(hs).register(http_server)
+    MakeRoomAdminRestServlet(hs).register(http_server)
+    ShadowBanRestServlet(hs).register(http_server)
+    ForwardExtremitiesRestServlet(hs).register(http_server)
 
 
 def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index c82b4f87d6..8720b1401f 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -15,6 +15,9 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.http import Request
 
 from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
@@ -23,6 +26,10 @@ from synapse.rest.admin._base import (
     assert_requester_is_admin,
     assert_user_is_admin,
 )
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet):
         admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request, room_id: str):
+    async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet):
 
     PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request, user_id: str):
+    async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet):
         "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request, server_name: str, media_id: str):
+    async def on_POST(
+        self, request: Request, server_name: str, media_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet):
         return 200, {}
 
 
+class ProtectMediaByID(RestServlet):
+    """Protect local media from being quarantined.
+    """
+
+    PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
+
+    def __init__(self, hs: "HomeServer"):
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+
+    async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        logging.info("Protecting local media by ID: %s", media_id)
+
+        # Quarantine this media id
+        await self.store.mark_local_media_as_safe(media_id)
+
+        return 200, {}
+
+
 class ListMediaInRoom(RestServlet):
     """Lists all of the media in a given room.
     """
 
     PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         is_admin = await self.auth.is_server_admin(requester.user)
         if not is_admin:
@@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet):
 class PurgeMediaCacheRestServlet(RestServlet):
     PATTERNS = admin_patterns("/purge_media_cache")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.media_repository = hs.get_media_repository()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         before_ts = parse_integer(request, "before_ts", required=True)
@@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet):
 
     PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.server_name = hs.hostname
         self.media_repository = hs.get_media_repository()
 
-    async def on_DELETE(self, request, server_name: str, media_id: str):
+    async def on_DELETE(
+        self, request: Request, server_name: str, media_id: str
+    ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         if self.server_name != server_name:
@@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet):
 
     PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.server_name = hs.hostname
         self.media_repository = hs.get_media_repository()
 
-    async def on_POST(self, request, server_name: str):
+    async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         before_ts = parse_integer(request, "before_ts", required=True)
@@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet):
         return 200, {"deleted_media": deleted_media, "total": total}
 
 
-def register_servlets_for_media_repo(hs, http_server):
+def register_servlets_for_media_repo(hs: "HomeServer", http_server):
     """
     Media repo specific APIs.
     """
@@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server):
     QuarantineMediaInRoom(hs).register(http_server)
     QuarantineMediaByID(hs).register(http_server)
     QuarantineMediaByUser(hs).register(http_server)
+    ProtectMediaByID(hs).register(http_server)
     ListMediaInRoom(hs).register(http_server)
     DeleteMediaByID(hs).register(http_server)
     DeleteMediaByDateSize(hs).register(http_server)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 25f89e4685..3e57e6a4d0 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,10 +14,10 @@
 # limitations under the License.
 import logging
 from http import HTTPStatus
-from typing import List, Optional
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
-from synapse.api.constants import EventTypes, JoinRules
-from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
@@ -25,13 +25,18 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
     parse_string,
 )
+from synapse.http.site import SynapseRequest
 from synapse.rest.admin._base import (
     admin_patterns,
     assert_requester_is_admin,
     assert_user_is_admin,
 )
 from synapse.storage.databases.main.room import RoomSortOrder
-from synapse.types import RoomAlias, RoomID, UserID, create_requester
+from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 
 logger = logging.getLogger(__name__)
 
@@ -45,12 +50,14 @@ class ShutdownRoomRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.room_shutdown_handler = hs.get_room_shutdown_handler()
 
-    async def on_POST(self, request, room_id):
+    async def on_POST(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -86,13 +93,15 @@ class DeleteRoomRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.room_shutdown_handler = hs.get_room_shutdown_handler()
         self.pagination_handler = hs.get_pagination_handler()
 
-    async def on_POST(self, request, room_id):
+    async def on_POST(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -146,12 +155,12 @@ class ListRoomRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/rooms$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -236,19 +245,24 @@ class RoomRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         ret = await self.store.get_room_with_stats(room_id)
         if not ret:
             raise NotFoundError("Room not found")
 
-        return 200, ret
+        members = await self.store.get_users_in_room(room_id)
+        ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
+
+        return (200, ret)
 
 
 class RoomMembersRestServlet(RestServlet):
@@ -258,12 +272,14 @@ class RoomMembersRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         ret = await self.store.get_room(room_id)
@@ -276,18 +292,59 @@ class RoomMembersRestServlet(RestServlet):
         return 200, ret
 
 
+class RoomStateRestServlet(RestServlet):
+    """
+    Get full state within a room.
+    """
+
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+        self._event_serializer = hs.get_event_client_serializer()
+
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        ret = await self.store.get_room(room_id)
+        if not ret:
+            raise NotFoundError("Room not found")
+
+        event_ids = await self.store.get_current_state_ids(room_id)
+        events = await self.store.get_events(event_ids.values())
+        now = self.clock.time_msec()
+        room_state = await self._event_serializer.serialize_events(
+            events.values(),
+            now,
+            # We don't bother bundling aggregations in when asked for state
+            # events, as clients won't use them.
+            bundle_aggregations=False,
+        )
+        ret = {"state": room_state}
+
+        return 200, ret
+
+
 class JoinRoomAliasServlet(RestServlet):
 
     PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.room_member_handler = hs.get_room_member_handler()
         self.admin_handler = hs.get_admin_handler()
         self.state_handler = hs.get_state_handler()
 
-    async def on_POST(self, request, room_identifier):
+    async def on_POST(
+        self, request: SynapseRequest, room_identifier: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -314,7 +371,6 @@ class JoinRoomAliasServlet(RestServlet):
             handler = self.room_member_handler
             room_alias = RoomAlias.from_string(room_identifier)
             room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
-            room_id = room_id.to_string()
         else:
             raise SynapseError(
                 400, "%s was not legal room ID or room alias" % (room_identifier,)
@@ -351,3 +407,201 @@ class JoinRoomAliasServlet(RestServlet):
         )
 
         return 200, {"room_id": room_id}
+
+
+class MakeRoomAdminRestServlet(RestServlet):
+    """Allows a server admin to get power in a room if a local user has power in
+    a room. Will also invite the user if they're not in the room and it's a
+    private room. Can specify another user (rather than the admin user) to be
+    granted power, e.g.:
+
+        POST/_synapse/admin/v1/rooms/<room_id_or_alias>/make_room_admin
+        {
+            "user_id": "@foo:example.com"
+        }
+    """
+
+    PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.room_member_handler = hs.get_room_member_handler()
+        self.event_creation_handler = hs.get_event_creation_handler()
+        self.state_handler = hs.get_state_handler()
+        self.is_mine_id = hs.is_mine_id
+
+    async def on_POST(self, request, room_identifier):
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+        content = parse_json_object_from_request(request, allow_empty_body=True)
+
+        # Resolve to a room ID, if necessary.
+        if RoomID.is_valid(room_identifier):
+            room_id = room_identifier
+        elif RoomAlias.is_valid(room_identifier):
+            room_alias = RoomAlias.from_string(room_identifier)
+            room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
+            room_id = room_id.to_string()
+        else:
+            raise SynapseError(
+                400, "%s was not legal room ID or room alias" % (room_identifier,)
+            )
+
+        # Which user to grant room admin rights to.
+        user_to_add = content.get("user_id", requester.user.to_string())
+
+        # Figure out which local users currently have power in the room, if any.
+        room_state = await self.state_handler.get_current_state(room_id)
+        if not room_state:
+            raise SynapseError(400, "Server not in room")
+
+        create_event = room_state[(EventTypes.Create, "")]
+        power_levels = room_state.get((EventTypes.PowerLevels, ""))
+
+        if power_levels is not None:
+            # We pick the local user with the highest power.
+            user_power = power_levels.content.get("users", {})
+            admin_users = [
+                user_id for user_id in user_power if self.is_mine_id(user_id)
+            ]
+            admin_users.sort(key=lambda user: user_power[user])
+
+            if not admin_users:
+                raise SynapseError(400, "No local admin user in room")
+
+            admin_user_id = None
+
+            for admin_user in reversed(admin_users):
+                if room_state.get((EventTypes.Member, admin_user)):
+                    admin_user_id = admin_user
+                    break
+
+            if not admin_user_id:
+                raise SynapseError(
+                    400, "No local admin user in room",
+                )
+
+            pl_content = power_levels.content
+        else:
+            # If there is no power level events then the creator has rights.
+            pl_content = {}
+            admin_user_id = create_event.sender
+            if not self.is_mine_id(admin_user_id):
+                raise SynapseError(
+                    400, "No local admin user in room",
+                )
+
+        # Grant the user power equal to the room admin by attempting to send an
+        # updated power level event.
+        new_pl_content = dict(pl_content)
+        new_pl_content["users"] = dict(pl_content.get("users", {}))
+        new_pl_content["users"][user_to_add] = new_pl_content["users"][admin_user_id]
+
+        fake_requester = create_requester(
+            admin_user_id, authenticated_entity=requester.authenticated_entity,
+        )
+
+        try:
+            await self.event_creation_handler.create_and_send_nonmember_event(
+                fake_requester,
+                event_dict={
+                    "content": new_pl_content,
+                    "sender": admin_user_id,
+                    "type": EventTypes.PowerLevels,
+                    "state_key": "",
+                    "room_id": room_id,
+                },
+            )
+        except AuthError:
+            # The admin user we found turned out not to have enough power.
+            raise SynapseError(
+                400, "No local admin user in room with power to update power levels."
+            )
+
+        # Now we check if the user we're granting admin rights to is already in
+        # the room. If not and it's not a public room we invite them.
+        member_event = room_state.get((EventTypes.Member, user_to_add))
+        is_joined = False
+        if member_event:
+            is_joined = member_event.content["membership"] in (
+                Membership.JOIN,
+                Membership.INVITE,
+            )
+
+        if is_joined:
+            return 200, {}
+
+        join_rules = room_state.get((EventTypes.JoinRules, ""))
+        is_public = False
+        if join_rules:
+            is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
+
+        if is_public:
+            return 200, {}
+
+        await self.room_member_handler.update_membership(
+            fake_requester,
+            target=UserID.from_string(user_to_add),
+            room_id=room_id,
+            action=Membership.INVITE,
+        )
+
+        return 200, {}
+
+
+class ForwardExtremitiesRestServlet(RestServlet):
+    """Allows a server admin to get or clear forward extremities.
+
+    Clearing does not require restarting the server.
+
+        Clear forward extremities:
+        DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+
+        Get forward_extremities:
+        GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+    """
+
+    PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.room_member_handler = hs.get_room_member_handler()
+        self.store = hs.get_datastore()
+
+    async def resolve_room_id(self, room_identifier: str) -> str:
+        """Resolve to a room ID, if necessary."""
+        if RoomID.is_valid(room_identifier):
+            resolved_room_id = room_identifier
+        elif RoomAlias.is_valid(room_identifier):
+            room_alias = RoomAlias.from_string(room_identifier)
+            room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
+            resolved_room_id = room_id.to_string()
+        else:
+            raise SynapseError(
+                400, "%s was not legal room ID or room alias" % (room_identifier,)
+            )
+        if not resolved_room_id:
+            raise SynapseError(
+                400, "Unknown room ID or room alias %s" % room_identifier
+            )
+        return resolved_room_id
+
+    async def on_DELETE(self, request, room_identifier):
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        room_id = await self.resolve_room_id(room_identifier)
+
+        deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
+        return 200, {"deleted": deleted_count}
+
+    async def on_GET(self, request, room_identifier):
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        room_id = await self.resolve_room_id(room_identifier)
+
+        extremities = await self.store.get_forward_extremities_for_room(room_id)
+        return 200, {"count": len(extremities), "results": extremities}
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index b0ff5e1ead..68c3c64a0d 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -42,17 +42,6 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-_GET_PUSHERS_ALLOWED_KEYS = {
-    "app_display_name",
-    "app_id",
-    "data",
-    "device_display_name",
-    "kind",
-    "lang",
-    "profile_tag",
-    "pushkey",
-}
-
 
 class UsersRestServlet(RestServlet):
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
@@ -94,17 +83,32 @@ class UsersRestServletV2(RestServlet):
     The parameter `deactivated` can be used to include deactivated users.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         start = parse_integer(request, "from", default=0)
         limit = parse_integer(request, "limit", default=100)
+
+        if start < 0:
+            raise SynapseError(
+                400,
+                "Query parameter from must be a string representing a positive integer.",
+                errcode=Codes.INVALID_PARAM,
+            )
+
+        if limit < 0:
+            raise SynapseError(
+                400,
+                "Query parameter limit must be a string representing a positive integer.",
+                errcode=Codes.INVALID_PARAM,
+            )
+
         user_id = parse_string(request, "user_id", default=None)
         name = parse_string(request, "name", default=None)
         guests = parse_boolean(request, "guests", default=True)
@@ -114,7 +118,7 @@ class UsersRestServletV2(RestServlet):
             start, limit, user_id, name, guests, deactivated
         )
         ret = {"users": users, "total": total}
-        if len(users) >= limit:
+        if (start + limit) < total:
             ret["next_token"] = str(start + len(users))
 
         return 200, ret
@@ -255,7 +259,7 @@ class UserRestServletV2(RestServlet):
 
                 if deactivate and not user["deactivated"]:
                     await self.deactivate_account_handler.deactivate_account(
-                        target_user.to_string(), False
+                        target_user.to_string(), False, requester, by_admin=True
                     )
                 elif not deactivate and user["deactivated"]:
                     if "password" not in body:
@@ -320,9 +324,9 @@ class UserRestServletV2(RestServlet):
                             data={},
                         )
 
-            if "avatar_url" in body and type(body["avatar_url"]) == str:
+            if "avatar_url" in body and isinstance(body["avatar_url"], str):
                 await self.profile_handler.set_avatar_url(
-                    user_id, requester, body["avatar_url"], True
+                    target_user, requester, body["avatar_url"], True
                 )
 
             ret = await self.admin_handler.get_user(target_user)
@@ -420,6 +424,9 @@ class UserRegisterServlet(RestServlet):
         if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
             raise SynapseError(400, "Invalid user type")
 
+        if "mac" not in body:
+            raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON)
+
         got_mac = body["mac"]
 
         want_mac_builder = hmac.new(
@@ -494,12 +501,22 @@ class WhoisRestServlet(RestServlet):
 class DeactivateAccountRestServlet(RestServlet):
     PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self._deactivate_account_handler = hs.get_deactivate_account_handler()
         self.auth = hs.get_auth()
+        self.is_mine = hs.is_mine
+        self.store = hs.get_datastore()
+
+    async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        if not self.is_mine(UserID.from_string(target_user_id)):
+            raise SynapseError(400, "Can only deactivate local users")
+
+        if not await self.store.get_user_by_id(target_user_id):
+            raise NotFoundError("User not found")
 
-    async def on_POST(self, request, target_user_id):
-        await assert_requester_is_admin(self.auth, request)
         body = parse_json_object_from_request(request, allow_empty_body=True)
         erase = body.get("erase", False)
         if not isinstance(erase, bool):
@@ -509,10 +526,8 @@ class DeactivateAccountRestServlet(RestServlet):
                 Codes.BAD_JSON,
             )
 
-        UserID.from_string(target_user_id)
-
         result = await self._deactivate_account_handler.deactivate_account(
-            target_user_id, erase
+            target_user_id, erase, requester, by_admin=True
         )
         if result:
             id_server_unbind_result = "success"
@@ -722,13 +737,6 @@ class UserMembershipRestServlet(RestServlet):
     async def on_GET(self, request, user_id):
         await assert_requester_is_admin(self.auth, request)
 
-        if not self.is_mine(UserID.from_string(user_id)):
-            raise SynapseError(400, "Can only lookup local users")
-
-        user = await self.store.get_user_by_id(user_id)
-        if user is None:
-            raise NotFoundError("Unknown user")
-
         room_ids = await self.store.get_rooms_for_user(user_id)
         ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
         return 200, ret
@@ -767,10 +775,7 @@ class PushersRestServlet(RestServlet):
 
         pushers = await self.store.get_pushers_by_user_id(user_id)
 
-        filtered_pushers = [
-            {k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS}
-            for p in pushers
-        ]
+        filtered_pushers = [p.as_dict() for p in pushers]
 
         return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
 
@@ -885,3 +890,39 @@ class UserTokenRestServlet(RestServlet):
         )
 
         return 200, {"access_token": token}
+
+
+class ShadowBanRestServlet(RestServlet):
+    """An admin API for shadow-banning a user.
+
+    A shadow-banned users receives successful responses to their client-server
+    API requests, but the events are not propagated into rooms.
+
+    Shadow-banning a user should be used as a tool of last resort and may lead
+    to confusing or broken behaviour for the client.
+
+    Example:
+
+        POST /_synapse/admin/v1/users/@test:example.com/shadow_ban
+        {}
+
+        200 OK
+        {}
+    """
+
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+
+    async def on_POST(self, request, user_id):
+        await assert_requester_is_admin(self.auth, request)
+
+        if not self.hs.is_mine_id(user_id):
+            raise SynapseError(400, "Only local users can be shadow-banned")
+
+        await self.store.set_shadow_banned(UserID.from_string(user_id), True)
+
+        return 200, {}
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d7ae148214..0fb9419e58 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,12 +14,13 @@
 # limitations under the License.
 
 import logging
-from typing import Awaitable, Callable, Dict, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.appservice import ApplicationService
-from synapse.http.server import finish_request
+from synapse.handlers.sso import SsoIdentityProvider
+from synapse.http.server import HttpServer, finish_request
 from synapse.http.servlet import (
     RestServlet,
     parse_json_object_from_request,
@@ -30,6 +31,9 @@ from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.well_known import WellKnownBuilder
 from synapse.types import JsonDict, UserID
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -42,7 +46,7 @@ class LoginRestServlet(RestServlet):
     JWT_TYPE_DEPRECATED = "m.login.jwt"
     APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
 
@@ -57,11 +61,14 @@ class LoginRestServlet(RestServlet):
         self.saml2_enabled = hs.config.saml2_enabled
         self.cas_enabled = hs.config.cas_enabled
         self.oidc_enabled = hs.config.oidc_enabled
+        self._msc2858_enabled = hs.config.experimental.msc2858_enabled
 
         self.auth = hs.get_auth()
 
         self.auth_handler = self.hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
+        self._sso_handler = hs.get_sso_handler()
+
         self._well_known_builder = WellKnownBuilder(hs)
         self._address_ratelimiter = Ratelimiter(
             clock=hs.get_clock(),
@@ -86,8 +93,17 @@ class LoginRestServlet(RestServlet):
             flows.append({"type": LoginRestServlet.CAS_TYPE})
 
         if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
-            flows.append({"type": LoginRestServlet.SSO_TYPE})
-            # While its valid for us to advertise this login type generally,
+            sso_flow = {"type": LoginRestServlet.SSO_TYPE}  # type: JsonDict
+
+            if self._msc2858_enabled:
+                sso_flow["org.matrix.msc2858.identity_providers"] = [
+                    _get_auth_flow_dict_for_idp(idp)
+                    for idp in self._sso_handler.get_identity_providers().values()
+                ]
+
+            flows.append(sso_flow)
+
+            # While it's valid for us to advertise this login type generally,
             # synapse currently only gives out these tokens as part of the
             # SSO login flow.
             # Generally we don't want to advertise login flows that clients
@@ -105,22 +121,27 @@ class LoginRestServlet(RestServlet):
         return 200, {"flows": flows}
 
     async def on_POST(self, request: SynapseRequest):
-        self._address_ratelimiter.ratelimit(request.getClientIP())
-
         login_submission = parse_json_object_from_request(request)
 
         try:
             if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
                 appservice = self.auth.get_appservice_by_req(request)
+
+                if appservice.is_rate_limited():
+                    self._address_ratelimiter.ratelimit(request.getClientIP())
+
                 result = await self._do_appservice_login(login_submission, appservice)
             elif self.jwt_enabled and (
                 login_submission["type"] == LoginRestServlet.JWT_TYPE
                 or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
             ):
+                self._address_ratelimiter.ratelimit(request.getClientIP())
                 result = await self._do_jwt_login(login_submission)
             elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
+                self._address_ratelimiter.ratelimit(request.getClientIP())
                 result = await self._do_token_login(login_submission)
             else:
+                self._address_ratelimiter.ratelimit(request.getClientIP())
                 result = await self._do_other_login(login_submission)
         except KeyError:
             raise SynapseError(400, "Missing JSON keys.")
@@ -159,7 +180,9 @@ class LoginRestServlet(RestServlet):
         if not appservice.is_interested_in_user(qualified_user_id):
             raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
 
-        return await self._complete_login(qualified_user_id, login_submission)
+        return await self._complete_login(
+            qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
+        )
 
     async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
         """Handle non-token/saml/jwt logins
@@ -194,6 +217,7 @@ class LoginRestServlet(RestServlet):
         login_submission: JsonDict,
         callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
         create_non_existent_users: bool = False,
+        ratelimit: bool = True,
     ) -> Dict[str, str]:
         """Called when we've successfully authed the user and now need to
         actually login them in (e.g. create devices). This gets called on
@@ -208,6 +232,7 @@ class LoginRestServlet(RestServlet):
             callback: Callback function to run after login.
             create_non_existent_users: Whether to create the user if they don't
                 exist. Defaults to False.
+            ratelimit: Whether to ratelimit the login request.
 
         Returns:
             result: Dictionary of account information after successful login.
@@ -216,7 +241,8 @@ class LoginRestServlet(RestServlet):
         # Before we actually log them in we check if they've already logged in
         # too often. This happens here rather than before as we don't
         # necessarily know the user before now.
-        self._account_ratelimiter.ratelimit(user_id.lower())
+        if ratelimit:
+            self._account_ratelimiter.ratelimit(user_id.lower())
 
         if create_non_existent_users:
             canonical_uid = await self.auth_handler.check_user_exists(user_id)
@@ -298,48 +324,63 @@ class LoginRestServlet(RestServlet):
         return result
 
 
-class BaseSSORedirectServlet(RestServlet):
-    """Common base class for /login/sso/redirect impls"""
-
-    PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
+def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
+    """Return an entry for the login flow dict
+
+    Returns an entry suitable for inclusion in "identity_providers" in the
+    response to GET /_matrix/client/r0/login
+    """
+    e = {"id": idp.idp_id, "name": idp.idp_name}  # type: JsonDict
+    if idp.idp_icon:
+        e["icon"] = idp.idp_icon
+    if idp.idp_brand:
+        e["brand"] = idp.idp_brand
+    return e
+
+
+class SsoRedirectServlet(RestServlet):
+    PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
+
+    def __init__(self, hs: "HomeServer"):
+        # make sure that the relevant handlers are instantiated, so that they
+        # register themselves with the main SSOHandler.
+        if hs.config.cas_enabled:
+            hs.get_cas_handler()
+        if hs.config.saml2_enabled:
+            hs.get_saml_handler()
+        if hs.config.oidc_enabled:
+            hs.get_oidc_handler()
+        self._sso_handler = hs.get_sso_handler()
+        self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+
+    def register(self, http_server: HttpServer) -> None:
+        super().register(http_server)
+        if self._msc2858_enabled:
+            # expose additional endpoint for MSC2858 support
+            http_server.register_paths(
+                "GET",
+                client_patterns(
+                    "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
+                    releases=(),
+                    unstable=True,
+                ),
+                self.on_GET,
+                self.__class__.__name__,
+            )
 
-    async def on_GET(self, request: SynapseRequest):
-        args = request.args
-        if b"redirectUrl" not in args:
-            return 400, "Redirect URL not specified for SSO auth"
-        client_redirect_url = args[b"redirectUrl"][0]
-        sso_url = await self.get_sso_url(request, client_redirect_url)
+    async def on_GET(
+        self, request: SynapseRequest, idp_id: Optional[str] = None
+    ) -> None:
+        client_redirect_url = parse_string(
+            request, "redirectUrl", required=True, encoding=None
+        )
+        sso_url = await self._sso_handler.handle_redirect_request(
+            request, client_redirect_url, idp_id,
+        )
+        logger.info("Redirecting to %s", sso_url)
         request.redirect(sso_url)
         finish_request(request)
 
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        """Get the URL to redirect to, to perform SSO auth
-
-        Args:
-            request: The client request to redirect.
-            client_redirect_url: the URL that we should redirect the
-                client to when everything is done
-
-        Returns:
-            URL to redirect to
-        """
-        # to be implemented by subclasses
-        raise NotImplementedError()
-
-
-class CasRedirectServlet(BaseSSORedirectServlet):
-    def __init__(self, hs):
-        self._cas_handler = hs.get_cas_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return self._cas_handler.get_redirect_url(
-            {"redirectUrl": client_redirect_url}
-        ).encode("ascii")
-
 
 class CasTicketServlet(RestServlet):
     PATTERNS = client_patterns("/login/cas/ticket", v1=True)
@@ -366,40 +407,8 @@ class CasTicketServlet(RestServlet):
         )
 
 
-class SAMLRedirectServlet(BaseSSORedirectServlet):
-    PATTERNS = client_patterns("/login/sso/redirect", v1=True)
-
-    def __init__(self, hs):
-        self._saml_handler = hs.get_saml_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return self._saml_handler.handle_redirect_request(client_redirect_url)
-
-
-class OIDCRedirectServlet(BaseSSORedirectServlet):
-    """Implementation for /login/sso/redirect for the OIDC login flow."""
-
-    PATTERNS = client_patterns("/login/sso/redirect", v1=True)
-
-    def __init__(self, hs):
-        self._oidc_handler = hs.get_oidc_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return await self._oidc_handler.handle_redirect_request(
-            request, client_redirect_url
-        )
-
-
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
+    SsoRedirectServlet(hs).register(http_server)
     if hs.config.cas_enabled:
-        CasRedirectServlet(hs).register(http_server)
         CasTicketServlet(hs).register(http_server)
-    elif hs.config.saml2_enabled:
-        SAMLRedirectServlet(hs).register(http_server)
-    elif hs.config.oidc_enabled:
-        OIDCRedirectServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 8fe83f321a..89823fcc39 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -28,17 +28,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns
 
 logger = logging.getLogger(__name__)
 
-ALLOWED_KEYS = {
-    "app_display_name",
-    "app_id",
-    "data",
-    "device_display_name",
-    "kind",
-    "lang",
-    "profile_tag",
-    "pushkey",
-}
-
 
 class PushersRestServlet(RestServlet):
     PATTERNS = client_patterns("/pushers$", v1=True)
@@ -54,9 +43,7 @@ class PushersRestServlet(RestServlet):
 
         pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
 
-        filtered_pushers = [
-            {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
-        ]
+        filtered_pushers = [p.as_dict() for p in pushers]
 
         return 200, {"pushers": filtered_pushers}
 
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 93c06afe27..f95627ee61 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -46,7 +46,7 @@ from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
 from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
 from synapse.util import json_decoder
-from synapse.util.stringutils import random_string
+from synapse.util.stringutils import parse_and_validate_server_name, random_string
 
 if TYPE_CHECKING:
     import synapse.server
@@ -347,8 +347,6 @@ class PublicRoomListRestServlet(TransactionRestServlet):
             # provided.
             if server:
                 raise e
-            else:
-                pass
 
         limit = parse_integer(request, "limit", 0)
         since_token = parse_string(request, "since", None)
@@ -359,6 +357,14 @@ class PublicRoomListRestServlet(TransactionRestServlet):
 
         handler = self.hs.get_room_list_handler()
         if server and server != self.hs.config.server_name:
+            # Ensure the server is valid.
+            try:
+                parse_and_validate_server_name(server)
+            except ValueError:
+                raise SynapseError(
+                    400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
+                )
+
             try:
                 data = await handler.get_remote_public_room_list(
                     server, limit=limit, since_token=since_token
@@ -402,6 +408,14 @@ class PublicRoomListRestServlet(TransactionRestServlet):
 
         handler = self.hs.get_room_list_handler()
         if server and server != self.hs.config.server_name:
+            # Ensure the server is valid.
+            try:
+                parse_and_validate_server_name(server)
+            except ValueError:
+                raise SynapseError(
+                    400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
+                )
+
             try:
                 data = await handler.get_remote_public_room_list(
                     server,
@@ -963,25 +977,28 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
         )
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs, http_server, is_worker=False):
     RoomStateEventRestServlet(hs).register(http_server)
-    RoomCreateRestServlet(hs).register(http_server)
     RoomMemberListRestServlet(hs).register(http_server)
     JoinedRoomMemberListRestServlet(hs).register(http_server)
     RoomMessageListRestServlet(hs).register(http_server)
     JoinRoomAliasServlet(hs).register(http_server)
-    RoomForgetRestServlet(hs).register(http_server)
     RoomMembershipRestServlet(hs).register(http_server)
     RoomSendEventRestServlet(hs).register(http_server)
     PublicRoomListRestServlet(hs).register(http_server)
     RoomStateRestServlet(hs).register(http_server)
     RoomRedactEventRestServlet(hs).register(http_server)
     RoomTypingRestServlet(hs).register(http_server)
-    SearchRestServlet(hs).register(http_server)
-    JoinedRoomsRestServlet(hs).register(http_server)
-    RoomEventServlet(hs).register(http_server)
     RoomEventContextServlet(hs).register(http_server)
-    RoomAliasListServlet(hs).register(http_server)
+
+    # Some servlets only get registered for the main process.
+    if not is_worker:
+        RoomCreateRestServlet(hs).register(http_server)
+        RoomForgetRestServlet(hs).register(http_server)
+        SearchRestServlet(hs).register(http_server)
+        JoinedRoomsRestServlet(hs).register(http_server)
+        RoomEventServlet(hs).register(http_server)
+        RoomAliasListServlet(hs).register(http_server)
 
 
 def register_deprecated_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index e0feebea94..b67c1702ca 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -20,9 +20,6 @@ from http import HTTPStatus
 from typing import TYPE_CHECKING
 from urllib.parse import urlparse
 
-if TYPE_CHECKING:
-    from synapse.app.homeserver import HomeServer
-
 from synapse.api.constants import LoginType
 from synapse.api.errors import (
     Codes,
@@ -31,6 +28,7 @@ from synapse.api.errors import (
     ThreepidValidationError,
 )
 from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
@@ -46,13 +44,17 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed
 
 from ._base import client_patterns, interactive_auth_handler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 
 class EmailPasswordRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/password/email/requestToken$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.datastore = hs.get_datastore()
@@ -101,6 +103,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
 
+        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
         # The email will be sent to the stored address.
         # This avoids a potential account hijack by requesting a password reset to
         # an email address which is controlled by the attacker but which, after
@@ -189,11 +193,7 @@ class PasswordRestServlet(RestServlet):
             requester = await self.auth.get_user_by_req(request)
             try:
                 params, session_id = await self.auth_handler.validate_user_via_ui_auth(
-                    requester,
-                    request,
-                    body,
-                    self.hs.get_ip_from_request(request),
-                    "modify your account password",
+                    requester, request, body, "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
                 # The user needs to provide more steps to complete auth, but
@@ -204,7 +204,9 @@ class PasswordRestServlet(RestServlet):
                 if new_password:
                     password_hash = await self.auth_handler.hash(new_password)
                     await self.auth_handler.set_session_data(
-                        e.session_id, "password_hash", password_hash
+                        e.session_id,
+                        UIAuthSessionDataConstants.PASSWORD_HASH,
+                        password_hash,
                     )
                 raise
             user_id = requester.user.to_string()
@@ -215,7 +217,6 @@ class PasswordRestServlet(RestServlet):
                     [[LoginType.EMAIL_IDENTITY]],
                     request,
                     body,
-                    self.hs.get_ip_from_request(request),
                     "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
@@ -227,7 +228,9 @@ class PasswordRestServlet(RestServlet):
                 if new_password:
                     password_hash = await self.auth_handler.hash(new_password)
                     await self.auth_handler.set_session_data(
-                        e.session_id, "password_hash", password_hash
+                        e.session_id,
+                        UIAuthSessionDataConstants.PASSWORD_HASH,
+                        password_hash,
                     )
                 raise
 
@@ -254,14 +257,18 @@ class PasswordRestServlet(RestServlet):
                 logger.error("Auth succeeded but no known type! %r", result.keys())
                 raise SynapseError(500, "", Codes.UNKNOWN)
 
-        # If we have a password in this request, prefer it. Otherwise, there
-        # must be a password hash from an earlier request.
+        # If we have a password in this request, prefer it. Otherwise, use the
+        # password hash from an earlier request.
         if new_password:
             password_hash = await self.auth_handler.hash(new_password)
-        else:
+        elif session_id is not None:
             password_hash = await self.auth_handler.get_session_data(
-                session_id, "password_hash", None
+                session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
             )
+        else:
+            # UI validation was skipped, but the request did not include a new
+            # password.
+            password_hash = None
         if not password_hash:
             raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
 
@@ -300,19 +307,18 @@ class DeactivateAccountRestServlet(RestServlet):
         # allow ASes to deactivate their own users
         if requester.app_service:
             await self._deactivate_account_handler.deactivate_account(
-                requester.user.to_string(), erase
+                requester.user.to_string(), erase, requester
             )
             return 200, {}
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "deactivate your account",
+            requester, request, body, "deactivate your account",
         )
         result = await self._deactivate_account_handler.deactivate_account(
-            requester.user.to_string(), erase, id_server=body.get("id_server")
+            requester.user.to_string(),
+            erase,
+            requester,
+            id_server=body.get("id_server"),
         )
         if result:
             id_server_unbind_result = "success"
@@ -375,6 +381,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
         if next_link:
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
@@ -426,7 +434,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
 class MsisdnThreepidRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         super().__init__()
         self.store = self.hs.get_datastore()
@@ -454,6 +462,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(
+            request, "msisdn", msisdn
+        )
+
         if next_link:
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
@@ -691,11 +703,7 @@ class ThreepidAddRestServlet(RestServlet):
         assert_valid_client_secret(client_secret)
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "add a third-party identifier to your account",
+            requester, request, body, "add a third-party identifier to your account",
         )
 
         validation_session = await self.identity_handler.validate_threepid_session(
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 87a5b1b86b..3f28c0bc3e 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
-        self._is_worker = hs.config.worker_app is not None
+        self.handler = hs.get_account_data_handler()
 
     async def on_PUT(self, request, user_id, account_data_type):
-        if self._is_worker:
-            raise Exception("Cannot handle PUT /account_data on worker")
-
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
 
         body = parse_json_object_from_request(request)
 
-        max_id = await self.store.add_account_data_for_user(
-            user_id, account_data_type, body
-        )
-
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        await self.handler.add_account_data_for_user(user_id, account_data_type, body)
 
         return 200, {}
 
@@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
-        self._is_worker = hs.config.worker_app is not None
+        self.handler = hs.get_account_data_handler()
 
     async def on_PUT(self, request, user_id, room_id, account_data_type):
-        if self._is_worker:
-            raise Exception("Cannot handle PUT /account_data on worker")
-
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
@@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet):
                 " Use /rooms/!roomId:server.name/read_markers",
             )
 
-        max_id = await self.store.add_account_data_to_room(
+        await self.handler.add_account_data_to_room(
             user_id, room_id, account_data_type, body
         )
 
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
-
         return 200, {}
 
     async def on_GET(self, request, user_id, room_id, account_data_type):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index fab077747f..75ece1c911 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import SynapseError
@@ -23,6 +24,9 @@ from synapse.http.servlet import RestServlet, parse_string
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -35,28 +39,12 @@ class AuthRestServlet(RestServlet):
 
     PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
-
-        # SSO configuration.
-        self._cas_enabled = hs.config.cas_enabled
-        if self._cas_enabled:
-            self._cas_handler = hs.get_cas_handler()
-            self._cas_server_url = hs.config.cas_server_url
-            self._cas_service_url = hs.config.cas_service_url
-        self._saml_enabled = hs.config.saml2_enabled
-        if self._saml_enabled:
-            self._saml_handler = hs.get_saml_handler()
-        self._oidc_enabled = hs.config.oidc_enabled
-        if self._oidc_enabled:
-            self._oidc_handler = hs.get_oidc_handler()
-            self._cas_server_url = hs.config.cas_server_url
-            self._cas_service_url = hs.config.cas_service_url
-
         self.recaptcha_template = hs.config.recaptcha_template
         self.terms_template = hs.config.terms_template
         self.success_template = hs.config.fallback_success_template
@@ -85,32 +73,7 @@ class AuthRestServlet(RestServlet):
         elif stagetype == LoginType.SSO:
             # Display a confirmation page which prompts the user to
             # re-authenticate with their SSO provider.
-            if self._cas_enabled:
-                # Generate a request to CAS that redirects back to an endpoint
-                # to verify the successful authentication.
-                sso_redirect_url = self._cas_handler.get_redirect_url(
-                    {"session": session},
-                )
-
-            elif self._saml_enabled:
-                # Some SAML identity providers (e.g. Google) require a
-                # RelayState parameter on requests. It is not necessary here, so
-                # pass in a dummy redirect URL (which will never get used).
-                client_redirect_url = b"unused"
-                sso_redirect_url = self._saml_handler.handle_redirect_request(
-                    client_redirect_url, session
-                )
-
-            elif self._oidc_enabled:
-                client_redirect_url = b""
-                sso_redirect_url = await self._oidc_handler.handle_redirect_request(
-                    request, client_redirect_url, session
-                )
-
-            else:
-                raise SynapseError(400, "Homeserver not configured for SSO.")
-
-            html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
+            html = await self.auth_handler.start_sso_ui_auth(request, session)
 
         else:
             raise SynapseError(404, "Unknown auth stage type")
@@ -134,7 +97,7 @@ class AuthRestServlet(RestServlet):
             authdict = {"response": response, "session": session}
 
             success = await self.auth_handler.add_oob_auth(
-                LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
+                LoginType.RECAPTCHA, authdict, request.getClientIP()
             )
 
             if success:
@@ -150,7 +113,7 @@ class AuthRestServlet(RestServlet):
             authdict = {"session": session}
 
             success = await self.auth_handler.add_oob_auth(
-                LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
+                LoginType.TERMS, authdict, request.getClientIP()
             )
 
             if success:
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index af117cb27c..314e01dfe4 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -83,11 +83,7 @@ class DeleteDevicesRestServlet(RestServlet):
         assert_params_in_dict(body, ["devices"])
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "remove device(s) from your account",
+            requester, request, body, "remove device(s) from your account",
         )
 
         await self.device_handler.delete_devices(
@@ -133,11 +129,7 @@ class DeviceRestServlet(RestServlet):
                 raise
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "remove a device from your account",
+            requester, request, body, "remove a device from your account",
         )
 
         await self.device_handler.delete_device(requester.user.to_string(), device_id)
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 75215a3779..28b55f27ad 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+from functools import wraps
 
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -25,6 +26,22 @@ from ._base import client_patterns
 logger = logging.getLogger(__name__)
 
 
+def _validate_group_id(f):
+    """Wrapper to validate the form of the group ID.
+
+    Can be applied to any on_FOO methods that accepts a group ID as a URL parameter.
+    """
+
+    @wraps(f)
+    def wrapper(self, request, group_id, *args, **kwargs):
+        if not GroupID.is_valid(group_id):
+            raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
+
+        return f(self, request, group_id, *args, **kwargs)
+
+    return wrapper
+
+
 class GroupServlet(RestServlet):
     """Get the group profile
     """
@@ -37,6 +54,7 @@ class GroupServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -47,6 +65,7 @@ class GroupServlet(RestServlet):
 
         return 200, group_description
 
+    @_validate_group_id
     async def on_POST(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -71,6 +90,7 @@ class GroupSummaryServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -102,6 +122,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, category_id, room_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -117,6 +138,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
 
         return 200, resp
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, category_id, room_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -142,6 +164,7 @@ class GroupCategoryServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id, category_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -152,6 +175,7 @@ class GroupCategoryServlet(RestServlet):
 
         return 200, category
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, category_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -163,6 +187,7 @@ class GroupCategoryServlet(RestServlet):
 
         return 200, resp
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, category_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -186,6 +211,7 @@ class GroupCategoriesServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -209,6 +235,7 @@ class GroupRoleServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id, role_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -219,6 +246,7 @@ class GroupRoleServlet(RestServlet):
 
         return 200, category
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, role_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -230,6 +258,7 @@ class GroupRoleServlet(RestServlet):
 
         return 200, resp
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, role_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -253,6 +282,7 @@ class GroupRolesServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -284,6 +314,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, role_id, user_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -299,6 +330,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
 
         return 200, resp
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, role_id, user_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -322,13 +354,11 @@ class GroupRoomServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
-        if not GroupID.is_valid(group_id):
-            raise SynapseError(400, "%s was not legal group ID" % (group_id,))
-
         result = await self.groups_handler.get_rooms_in_group(
             group_id, requester_user_id
         )
@@ -348,6 +378,7 @@ class GroupUsersServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -371,6 +402,7 @@ class GroupInvitedUsersServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -393,6 +425,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
         self.auth = hs.get_auth()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -449,6 +482,7 @@ class GroupAdminRoomsServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, room_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -460,6 +494,7 @@ class GroupAdminRoomsServlet(RestServlet):
 
         return 200, result
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, room_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -486,6 +521,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, room_id, config_key):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -514,6 +550,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
         self.store = hs.get_datastore()
         self.is_mine_id = hs.is_mine_id
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, user_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -541,6 +578,7 @@ class GroupAdminUsersKickServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, user_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -590,6 +628,7 @@ class GroupSelfLeaveServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -614,6 +653,7 @@ class GroupSelfJoinServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -638,6 +678,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -662,6 +703,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index b91996c738..a6134ead8a 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -271,11 +271,7 @@ class SigningKeyUploadServlet(RestServlet):
         body = parse_json_object_from_request(request)
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "add a device signing key to your account",
+            requester, request, body, "add a device signing key to your account",
         )
 
         result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 5374d2c1b6..f0675abd32 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -38,6 +38,7 @@ from synapse.config.ratelimiting import FederationRateLimitConfig
 from synapse.config.registration import RegistrationConfig
 from synapse.config.server import is_threepid_reserved
 from synapse.handlers.auth import AuthHandler
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
@@ -125,6 +126,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
         existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
             "email", email
         )
@@ -204,6 +207,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(
+            request, "msisdn", msisdn
+        )
+
         existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
             "msisdn", msisdn
         )
@@ -353,7 +360,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
                 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
             )
 
-        ip = self.hs.get_ip_from_request(request)
+        ip = request.getClientIP()
         with self.ratelimiter.ratelimit(ip) as wait_deferred:
             await wait_deferred
 
@@ -451,7 +458,7 @@ class RegisterRestServlet(RestServlet):
 
         # == Normal User Registration == (everyone else)
         if not self._registration_enabled:
-            raise SynapseError(403, "Registration has been disabled")
+            raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
 
         # For regular registration, convert the provided username to lowercase
         # before attempting to register it. This should mean that people who try
@@ -494,11 +501,11 @@ class RegisterRestServlet(RestServlet):
             # user here. We carry on and go through the auth checks though,
             # for paranoia.
             registered_user_id = await self.auth_handler.get_session_data(
-                session_id, "registered_user_id", None
+                session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None
             )
             # Extract the previously-hashed password from the session.
             password_hash = await self.auth_handler.get_session_data(
-                session_id, "password_hash", None
+                session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
             )
 
         # Ensure that the username is valid.
@@ -513,11 +520,7 @@ class RegisterRestServlet(RestServlet):
         # not this will raise a user-interactive auth error.
         try:
             auth_result, params, session_id = await self.auth_handler.check_ui_auth(
-                self._registration_flows,
-                request,
-                body,
-                self.hs.get_ip_from_request(request),
-                "register a new account",
+                self._registration_flows, request, body, "register a new account",
             )
         except InteractiveAuthIncompleteError as e:
             # The user needs to provide more steps to complete auth.
@@ -532,7 +535,9 @@ class RegisterRestServlet(RestServlet):
             if not password_hash and password:
                 password_hash = await self.auth_handler.hash(password)
                 await self.auth_handler.set_session_data(
-                    e.session_id, "password_hash", password_hash
+                    e.session_id,
+                    UIAuthSessionDataConstants.PASSWORD_HASH,
+                    password_hash,
                 )
             raise
 
@@ -635,7 +640,9 @@ class RegisterRestServlet(RestServlet):
             # Remember that the user account has been registered (and the user
             # ID it was registered with, since it might not have been specified).
             await self.auth_handler.set_session_data(
-                session_id, "registered_user_id", registered_user_id
+                session_id,
+                UIAuthSessionDataConstants.REGISTERED_USER_ID,
+                registered_user_id,
             )
 
             registered = True
@@ -657,9 +664,13 @@ class RegisterRestServlet(RestServlet):
         user_id = await self.registration_handler.appservice_register(
             username, as_token
         )
-        return await self._create_registration_details(user_id, body)
+        return await self._create_registration_details(
+            user_id, body, is_appservice_ghost=True,
+        )
 
-    async def _create_registration_details(self, user_id, params):
+    async def _create_registration_details(
+        self, user_id, params, is_appservice_ghost=False
+    ):
         """Complete registration of newly-registered user
 
         Allocates device_id if one was not given; also creates access_token.
@@ -676,7 +687,11 @@ class RegisterRestServlet(RestServlet):
             device_id = params.get("device_id")
             initial_display_name = params.get("initial_device_display_name")
             device_id, access_token = await self.registration_handler.register_device(
-                user_id, device_id, initial_display_name, is_guest=False
+                user_id,
+                device_id,
+                initial_display_name,
+                is_guest=False,
+                is_appservice_ghost=is_appservice_ghost,
             )
 
             result.update({"access_token": access_token, "device_id": device_id})
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index bc4f43639a..a3dee14ed4 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -17,7 +17,7 @@ import logging
 from typing import Tuple
 
 from synapse.http import servlet
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
 from synapse.logging.opentracing import set_tag, trace
 from synapse.rest.client.transactions import HttpTransactionCache
 
@@ -54,6 +54,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         content = parse_json_object_from_request(request)
+        assert_params_in_dict(content, ("messages",))
 
         sender_user_id = requester.user.to_string()
 
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index bf3a79db44..a97cd66c52 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -58,8 +58,7 @@ class TagServlet(RestServlet):
     def __init__(self, hs):
         super().__init__()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
+        self.handler = hs.get_account_data_handler()
 
     async def on_PUT(self, request, user_id, room_id, tag):
         requester = await self.auth.get_user_by_req(request)
@@ -68,9 +67,7 @@ class TagServlet(RestServlet):
 
         body = parse_json_object_from_request(request)
 
-        max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
-
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        await self.handler.add_tag_to_room(user_id, room_id, tag, body)
 
         return 200, {}
 
@@ -79,9 +76,7 @@ class TagServlet(RestServlet):
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add tags for other users.")
 
-        max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
-
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        await self.handler.remove_tag_from_room(user_id, room_id, tag)
 
         return 200, {}
 
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index b3e4d5612e..8b9ef26cf2 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -100,6 +100,7 @@ class ConsentResource(DirectServeHtmlResource):
 
         consent_template_directory = hs.config.user_consent_template_dir
 
+        # TODO: switch to synapse.util.templates.build_jinja_env
         loader = jinja2.FileSystemLoader(consent_template_directory)
         self._jinja_env = jinja2.Environment(
             loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"])
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index f843f02454..c57ac22e58 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, Set
+from typing import Dict
 
 from signedjson.sign import sign_json
 
@@ -142,12 +142,13 @@ class RemoteKey(DirectServeJsonResource):
 
         time_now_ms = self.clock.time_msec()
 
-        cache_misses = {}  # type: Dict[str, Set[str]]
+        # Note that the value is unused.
+        cache_misses = {}  # type: Dict[str, Dict[str, int]]
         for (server_name, key_id, from_server), results in cached.items():
             results = [(result["ts_added_ms"], result) for result in results]
 
             if not results and key_id is not None:
-                cache_misses.setdefault(server_name, set()).add(key_id)
+                cache_misses.setdefault(server_name, {})[key_id] = 0
                 continue
 
             if key_id is not None:
@@ -201,7 +202,7 @@ class RemoteKey(DirectServeJsonResource):
                     )
 
                 if miss:
-                    cache_misses.setdefault(server_name, set()).add(key_id)
+                    cache_misses.setdefault(server_name, {})[key_id] = 0
                 # Cast to bytes since postgresql returns a memoryview.
                 json_results.add(bytes(most_recent_result["key_json"]))
             else:
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 67aa993f19..f71a03a12d 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2019 New Vector Ltd
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,10 +17,11 @@
 import logging
 import os
 import urllib
-from typing import Awaitable
+from typing import Awaitable, Dict, Generator, List, Optional, Tuple
 
 from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
+from twisted.web.http import Request
 
 from synapse.api.errors import Codes, SynapseError, cs_error
 from synapse.http.server import finish_request, respond_with_json
@@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [
 ]
 
 
-def parse_media_id(request):
+def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
     try:
         # This allows users to append e.g. /test.png to the URL. Useful for
         # clients that parse the URL to see content type.
@@ -69,7 +70,7 @@ def parse_media_id(request):
         )
 
 
-def respond_404(request):
+def respond_404(request: Request) -> None:
     respond_with_json(
         request,
         404,
@@ -79,8 +80,12 @@ def respond_404(request):
 
 
 async def respond_with_file(
-    request, media_type, file_path, file_size=None, upload_name=None
-):
+    request: Request,
+    media_type: str,
+    file_path: str,
+    file_size: Optional[int] = None,
+    upload_name: Optional[str] = None,
+) -> None:
     logger.debug("Responding with %r", file_path)
 
     if os.path.isfile(file_path):
@@ -98,15 +103,20 @@ async def respond_with_file(
         respond_404(request)
 
 
-def add_file_headers(request, media_type, file_size, upload_name):
+def add_file_headers(
+    request: Request,
+    media_type: str,
+    file_size: Optional[int],
+    upload_name: Optional[str],
+) -> None:
     """Adds the correct response headers in preparation for responding with the
     media.
 
     Args:
-        request (twisted.web.http.Request)
-        media_type (str): The media/content type.
-        file_size (int): Size in bytes of the media, if known.
-        upload_name (str): The name of the requested file, if any.
+        request
+        media_type: The media/content type.
+        file_size: Size in bytes of the media, if known.
+        upload_name: The name of the requested file, if any.
     """
 
     def _quote(x):
@@ -153,7 +163,13 @@ def add_file_headers(request, media_type, file_size, upload_name):
     # select private. don't bother setting Expires as all our
     # clients are smart enough to be happy with Cache-Control
     request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
-    request.setHeader(b"Content-Length", b"%d" % (file_size,))
+    if file_size is not None:
+        request.setHeader(b"Content-Length", b"%d" % (file_size,))
+
+    # Tell web crawlers to not index, archive, or follow links in media. This
+    # should help to prevent things in the media repo from showing up in web
+    # search results.
+    request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex")
 
 
 # separators as defined in RFC2616. SP and HT are handled separately.
@@ -179,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = {
 }
 
 
-def _can_encode_filename_as_token(x):
+def _can_encode_filename_as_token(x: str) -> bool:
     for c in x:
         # from RFC2616:
         #
@@ -201,17 +217,21 @@ def _can_encode_filename_as_token(x):
 
 
 async def respond_with_responder(
-    request, responder, media_type, file_size, upload_name=None
-):
+    request: Request,
+    responder: "Optional[Responder]",
+    media_type: str,
+    file_size: Optional[int],
+    upload_name: Optional[str] = None,
+) -> None:
     """Responds to the request with given responder. If responder is None then
     returns 404.
 
     Args:
-        request (twisted.web.http.Request)
-        responder (Responder|None)
-        media_type (str): The media/content type.
-        file_size (int|None): Size in bytes of the media. If not known it should be None
-        upload_name (str|None): The name of the requested file, if any.
+        request
+        responder
+        media_type: The media/content type.
+        file_size: Size in bytes of the media. If not known it should be None
+        upload_name: The name of the requested file, if any.
     """
     if request._disconnected:
         logger.warning(
@@ -280,6 +300,7 @@ class FileInfo:
         thumbnail_height (int)
         thumbnail_method (str)
         thumbnail_type (str): Content type of thumbnail, e.g. image/png
+        thumbnail_length (int): The size of the media file, in bytes.
     """
 
     def __init__(
@@ -292,6 +313,7 @@ class FileInfo:
         thumbnail_height=None,
         thumbnail_method=None,
         thumbnail_type=None,
+        thumbnail_length=None,
     ):
         self.server_name = server_name
         self.file_id = file_id
@@ -301,24 +323,25 @@ class FileInfo:
         self.thumbnail_height = thumbnail_height
         self.thumbnail_method = thumbnail_method
         self.thumbnail_type = thumbnail_type
+        self.thumbnail_length = thumbnail_length
 
 
-def get_filename_from_headers(headers):
+def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
     """
     Get the filename of the downloaded file by inspecting the
     Content-Disposition HTTP header.
 
     Args:
-        headers (dict[bytes, list[bytes]]): The HTTP request headers.
+        headers: The HTTP request headers.
 
     Returns:
-        A Unicode string of the filename, or None.
+        The filename, or None.
     """
     content_disposition = headers.get(b"Content-Disposition", [b""])
 
     # No header, bail out.
     if not content_disposition[0]:
-        return
+        return None
 
     _, params = _parse_header(content_disposition[0])
 
@@ -351,17 +374,16 @@ def get_filename_from_headers(headers):
     return upload_name
 
 
-def _parse_header(line):
+def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
     """Parse a Content-type like header.
 
     Cargo-culted from `cgi`, but works on bytes rather than strings.
 
     Args:
-        line (bytes): header to be parsed
+        line: header to be parsed
 
     Returns:
-        Tuple[bytes, dict[bytes, bytes]]:
-            the main content-type, followed by the parameter dictionary
+        The main content-type, followed by the parameter dictionary
     """
     parts = _parseparam(b";" + line)
     key = next(parts)
@@ -381,16 +403,16 @@ def _parse_header(line):
     return key, pdict
 
 
-def _parseparam(s):
+def _parseparam(s: bytes) -> Generator[bytes, None, None]:
     """Generator which splits the input on ;, respecting double-quoted sequences
 
     Cargo-culted from `cgi`, but works on bytes rather than strings.
 
     Args:
-        s (bytes): header to be parsed
+        s: header to be parsed
 
     Returns:
-        Iterable[bytes]: the split input
+        The split input
     """
     while s[:1] == b";":
         s = s[1:]
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 68dd2a1c8a..4e4c6971f7 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2018 Will Hunt <will@half-shot.uk>
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,22 +15,29 @@
 # limitations under the License.
 #
 
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
+
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 
 class MediaConfigResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         config = hs.get_config()
         self.clock = hs.get_clock()
         self.auth = hs.get_auth()
         self.limits_dict = {"m.upload.size": config.max_upload_size}
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: Request) -> None:
         await self.auth.get_user_by_req(request)
         respond_with_json(request, 200, self.limits_dict, send_cors=True)
 
-    async def _async_render_OPTIONS(self, request):
+    async def _async_render_OPTIONS(self, request: Request) -> None:
         respond_with_json(request, 200, {}, send_cors=True)
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index d3d8457303..3ed219ae43 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,24 +14,31 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
 
-import synapse.http.servlet
 from synapse.http.server import DirectServeJsonResource, set_cors_headers
+from synapse.http.servlet import parse_boolean
 
 from ._base import parse_media_id, respond_404
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+    from synapse.rest.media.v1.media_repository import MediaRepository
+
 logger = logging.getLogger(__name__)
 
 
 class DownloadResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo):
+    def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
         super().__init__()
         self.media_repo = media_repo
         self.server_name = hs.hostname
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: Request) -> None:
         set_cors_headers(request)
         request.setHeader(
             b"Content-Security-Policy",
@@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource):
         if server_name == self.server_name:
             await self.media_repo.get_local_media(request, media_id, name)
         else:
-            allow_remote = synapse.http.servlet.parse_boolean(
-                request, "allow_remote", default=True
-            )
+            allow_remote = parse_boolean(request, "allow_remote", default=True)
             if not allow_remote:
                 logger.info(
                     "Rejecting request for remote media %s/%s due to allow_remote",
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 9e079f672f..7792f26e78 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,11 +17,12 @@
 import functools
 import os
 import re
+from typing import Callable, List
 
 NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
 
 
-def _wrap_in_base_path(func):
+def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]":
     """Takes a function that returns a relative path and turns it into an
     absolute path based on the location of the primary media store
     """
@@ -41,12 +43,18 @@ class MediaFilePaths:
     to write to the backup media store (when one is configured)
     """
 
-    def __init__(self, primary_base_path):
+    def __init__(self, primary_base_path: str):
         self.base_path = primary_base_path
 
     def default_thumbnail_rel(
-        self, default_top_level, default_sub_type, width, height, content_type, method
-    ):
+        self,
+        default_top_level: str,
+        default_sub_type: str,
+        width: int,
+        height: int,
+        content_type: str,
+        method: str,
+    ) -> str:
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
         return os.path.join(
@@ -55,12 +63,14 @@ class MediaFilePaths:
 
     default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
 
-    def local_media_filepath_rel(self, media_id):
+    def local_media_filepath_rel(self, media_id: str) -> str:
         return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
 
     local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
 
-    def local_media_thumbnail_rel(self, media_id, width, height, content_type, method):
+    def local_media_thumbnail_rel(
+        self, media_id: str, width: int, height: int, content_type: str, method: str
+    ) -> str:
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
         return os.path.join(
@@ -86,7 +96,7 @@ class MediaFilePaths:
             media_id[4:],
         )
 
-    def remote_media_filepath_rel(self, server_name, file_id):
+    def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
         return os.path.join(
             "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
         )
@@ -94,8 +104,14 @@ class MediaFilePaths:
     remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
 
     def remote_media_thumbnail_rel(
-        self, server_name, file_id, width, height, content_type, method
-    ):
+        self,
+        server_name: str,
+        file_id: str,
+        width: int,
+        height: int,
+        content_type: str,
+        method: str,
+    ) -> str:
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
         return os.path.join(
@@ -113,7 +129,7 @@ class MediaFilePaths:
     # Should be removed after some time, when most of the thumbnails are stored
     # using the new path.
     def remote_media_thumbnail_rel_legacy(
-        self, server_name, file_id, width, height, content_type
+        self, server_name: str, file_id: str, width: int, height: int, content_type: str
     ):
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
@@ -126,7 +142,7 @@ class MediaFilePaths:
             file_name,
         )
 
-    def remote_media_thumbnail_dir(self, server_name, file_id):
+    def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
         return os.path.join(
             self.base_path,
             "remote_thumbnail",
@@ -136,7 +152,7 @@ class MediaFilePaths:
             file_id[4:],
         )
 
-    def url_cache_filepath_rel(self, media_id):
+    def url_cache_filepath_rel(self, media_id: str) -> str:
         if NEW_FORMAT_ID_RE.match(media_id):
             # Media id is of the form <DATE><RANDOM_STRING>
             # E.g.: 2017-09-28-fsdRDt24DS234dsf
@@ -146,7 +162,7 @@ class MediaFilePaths:
 
     url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
 
-    def url_cache_filepath_dirs_to_delete(self, media_id):
+    def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
         "The dirs to try and remove if we delete the media_id file"
         if NEW_FORMAT_ID_RE.match(media_id):
             return [os.path.join(self.base_path, "url_cache", media_id[:10])]
@@ -156,7 +172,9 @@ class MediaFilePaths:
                 os.path.join(self.base_path, "url_cache", media_id[0:2]),
             ]
 
-    def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method):
+    def url_cache_thumbnail_rel(
+        self, media_id: str, width: int, height: int, content_type: str, method: str
+    ) -> str:
         # Media id is of the form <DATE><RANDOM_STRING>
         # E.g.: 2017-09-28-fsdRDt24DS234dsf
 
@@ -178,7 +196,7 @@ class MediaFilePaths:
 
     url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
 
-    def url_cache_thumbnail_directory(self, media_id):
+    def url_cache_thumbnail_directory(self, media_id: str) -> str:
         # Media id is of the form <DATE><RANDOM_STRING>
         # E.g.: 2017-09-28-fsdRDt24DS234dsf
 
@@ -195,7 +213,7 @@ class MediaFilePaths:
                 media_id[4:],
             )
 
-    def url_cache_thumbnail_dirs_to_delete(self, media_id):
+    def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
         "The dirs to try and remove if we delete the media_id thumbnails"
         # Media id is of the form <DATE><RANDOM_STRING>
         # E.g.: 2017-09-28-fsdRDt24DS234dsf
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 9cac74ebd8..4c9946a616 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,12 +13,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import errno
 import logging
 import os
 import shutil
-from typing import IO, Dict, List, Optional, Tuple
+from io import BytesIO
+from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
 
 import twisted.internet.error
 import twisted.web.http
@@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource
 from .thumbnailer import Thumbnailer, ThumbnailError
 from .upload_resource import UploadResource
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -63,26 +66,26 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
 
 
 class MediaRepository:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
-        self.client = hs.get_http_client()
+        self.client = hs.get_federation_http_client()
         self.clock = hs.get_clock()
         self.server_name = hs.hostname
         self.store = hs.get_datastore()
         self.max_upload_size = hs.config.max_upload_size
         self.max_image_pixels = hs.config.max_image_pixels
 
-        self.primary_base_path = hs.config.media_store_path
-        self.filepaths = MediaFilePaths(self.primary_base_path)
+        self.primary_base_path = hs.config.media_store_path  # type: str
+        self.filepaths = MediaFilePaths(self.primary_base_path)  # type: MediaFilePaths
 
         self.dynamic_thumbnails = hs.config.dynamic_thumbnails
         self.thumbnail_requirements = hs.config.thumbnail_requirements
 
         self.remote_media_linearizer = Linearizer(name="media_remote")
 
-        self.recently_accessed_remotes = set()
-        self.recently_accessed_locals = set()
+        self.recently_accessed_remotes = set()  # type: Set[Tuple[str, str]]
+        self.recently_accessed_locals = set()  # type: Set[str]
 
         self.federation_domain_whitelist = hs.config.federation_domain_whitelist
 
@@ -113,7 +116,7 @@ class MediaRepository:
             "update_recently_accessed_media", self._update_recently_accessed
         )
 
-    async def _update_recently_accessed(self):
+    async def _update_recently_accessed(self) -> None:
         remote_media = self.recently_accessed_remotes
         self.recently_accessed_remotes = set()
 
@@ -124,12 +127,12 @@ class MediaRepository:
             local_media, remote_media, self.clock.time_msec()
         )
 
-    def mark_recently_accessed(self, server_name, media_id):
+    def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None:
         """Mark the given media as recently accessed.
 
         Args:
-            server_name (str|None): Origin server of media, or None if local
-            media_id (str): The media ID of the content
+            server_name: Origin server of media, or None if local
+            media_id: The media ID of the content
         """
         if server_name:
             self.recently_accessed_remotes.add((server_name, media_id))
@@ -459,7 +462,14 @@ class MediaRepository:
     def _get_thumbnail_requirements(self, media_type):
         return self.thumbnail_requirements.get(media_type, ())
 
-    def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):
+    def _generate_thumbnail(
+        self,
+        thumbnailer: Thumbnailer,
+        t_width: int,
+        t_height: int,
+        t_method: str,
+        t_type: str,
+    ) -> Optional[BytesIO]:
         m_width = thumbnailer.width
         m_height = thumbnailer.height
 
@@ -470,22 +480,20 @@ class MediaRepository:
                 m_height,
                 self.max_image_pixels,
             )
-            return
+            return None
 
         if thumbnailer.transpose_method is not None:
             m_width, m_height = thumbnailer.transpose()
 
         if t_method == "crop":
-            t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
+            return thumbnailer.crop(t_width, t_height, t_type)
         elif t_method == "scale":
             t_width, t_height = thumbnailer.aspect(t_width, t_height)
             t_width = min(m_width, t_width)
             t_height = min(m_height, t_height)
-            t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
-        else:
-            t_byte_source = None
+            return thumbnailer.scale(t_width, t_height, t_type)
 
-        return t_byte_source
+        return None
 
     async def generate_local_exact_thumbnail(
         self,
@@ -776,7 +784,7 @@ class MediaRepository:
 
         return {"width": m_width, "height": m_height}
 
-    async def delete_old_remote_media(self, before_ts):
+    async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
         old_media = await self.store.get_remote_media_before(before_ts)
 
         deleted = 0
@@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource):
     within a given rectangle.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         # If we're not configured to use it, raise if we somehow got here.
         if not hs.config.can_load_media_repo:
             raise ConfigError("Synapse is not configured to use a media repo.")
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 268e0c8f50..89cdd605aa 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2018 New Vecotr Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -18,6 +18,8 @@ import os
 import shutil
 from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
 
+from twisted.internet.defer import Deferred
+from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
 
 from synapse.logging.context import defer_to_thread, make_deferred_yieldable
@@ -270,7 +272,7 @@ class MediaStorage:
         return self.filepaths.local_media_filepath_rel(file_info.file_id)
 
 
-def _write_file_synchronously(source, dest):
+def _write_file_synchronously(source: IO, dest: IO) -> None:
     """Write `source` to the file like `dest` synchronously. Should be called
     from a thread.
 
@@ -286,14 +288,14 @@ class FileResponder(Responder):
     """Wraps an open file that can be sent to a request.
 
     Args:
-        open_file (file): A file like object to be streamed ot the client,
+        open_file: A file like object to be streamed ot the client,
             is closed when finished streaming.
     """
 
-    def __init__(self, open_file):
+    def __init__(self, open_file: IO):
         self.open_file = open_file
 
-    def write_to_consumer(self, consumer):
+    def write_to_consumer(self, consumer: IConsumer) -> Deferred:
         return make_deferred_yieldable(
             FileSender().beginFileTransfer(self.open_file, consumer)
         )
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index dce6c4d168..bf3be653aa 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import datetime
 import errno
 import fnmatch
@@ -23,12 +23,13 @@ import re
 import shutil
 import sys
 import traceback
-from typing import Dict, Optional
+from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
 from urllib import parse as urlparse
 
 import attr
 
 from twisted.internet.error import DNSLookupError
+from twisted.web.http import Request
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.client import SimpleHttpClient
@@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.media.v1._base import get_filename_from_headers
+from synapse.rest.media.v1.media_storage import MediaStorage
 from synapse.util import json_encoder
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string
 
 from ._base import FileInfo
 
+if TYPE_CHECKING:
+    from lxml import etree
+
+    from synapse.app.homeserver import HomeServer
+    from synapse.rest.media.v1.media_repository import MediaRepository
+
 logger = logging.getLogger(__name__)
 
 _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
@@ -119,7 +127,12 @@ class OEmbedError(Exception):
 class PreviewUrlResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo, media_storage):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        media_repo: "MediaRepository",
+        media_storage: MediaStorage,
+    ):
         super().__init__()
 
         self.auth = hs.get_auth()
@@ -166,11 +179,11 @@ class PreviewUrlResource(DirectServeJsonResource):
                 self._start_expire_url_cache_data, 10 * 1000
             )
 
-    async def _async_render_OPTIONS(self, request):
+    async def _async_render_OPTIONS(self, request: Request) -> None:
         request.setHeader(b"Allow", b"OPTIONS, GET")
         respond_with_json(request, 200, {}, send_cors=True)
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: Request) -> None:
 
         # XXX: if get_user_by_req fails, what should we do in an async render?
         requester = await self.auth.get_user_by_req(request)
@@ -373,7 +386,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         """
         Check whether the URL should be downloaded as oEmbed content instead.
 
-        Params:
+        Args:
             url: The URL to check.
 
         Returns:
@@ -390,7 +403,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         """
         Request content from an oEmbed endpoint.
 
-        Params:
+        Args:
             endpoint: The oEmbed API endpoint.
             url: The URL to pass to the API.
 
@@ -450,7 +463,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
             raise OEmbedError() from e
 
-    async def _download_url(self, url: str, user):
+    async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
         # TODO: we should probably honour robots.txt... except in practice
         # we're most likely being explicitly triggered by a human rather than a
         # bot, so are we really a robot?
@@ -580,7 +593,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             "expire_url_cache_data", self._expire_url_cache_data
         )
 
-    async def _expire_url_cache_data(self):
+    async def _expire_url_cache_data(self) -> None:
         """Clean up expired url cache content, media and thumbnails.
         """
         # TODO: Delete from backup media store
@@ -676,24 +689,54 @@ class PreviewUrlResource(DirectServeJsonResource):
             logger.debug("No media removed from url cache")
 
 
-def decode_and_calc_og(body, media_uri, request_encoding=None):
+def decode_and_calc_og(
+    body: bytes, media_uri: str, request_encoding: Optional[str] = None
+) -> Dict[str, Optional[str]]:
+    """
+    Calculate metadata for an HTML document.
+
+    This uses lxml to parse the HTML document into the OG response. If errors
+    occur during processing of the document, an empty response is returned.
+
+    Args:
+        body: The HTML document, as bytes.
+        media_url: The URI used to download the body.
+        request_encoding: The character encoding of the body, as a string.
+
+    Returns:
+        The OG response as a dictionary.
+    """
+    # If there's no body, nothing useful is going to be found.
+    if not body:
+        return {}
+
     from lxml import etree
 
+    # Create an HTML parser. If this fails, log and return no metadata.
     try:
         parser = etree.HTMLParser(recover=True, encoding=request_encoding)
-        tree = etree.fromstring(body, parser)
-        og = _calc_og(tree, media_uri)
+    except LookupError:
+        # blindly consider the encoding as utf-8.
+        parser = etree.HTMLParser(recover=True, encoding="utf-8")
+    except Exception as e:
+        logger.warning("Unable to create HTML parser: %s" % (e,))
+        return {}
+
+    def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
+        # Attempt to parse the body. If this fails, log and return no metadata.
+        tree = etree.fromstring(body_attempt, parser)
+        return _calc_og(tree, media_uri)
+
+    # Attempt to parse the body. If this fails, log and return no metadata.
+    try:
+        return _attempt_calc_og(body)
     except UnicodeDecodeError:
         # blindly try decoding the body as utf-8, which seems to fix
         # the charset mismatches on https://google.com
-        parser = etree.HTMLParser(recover=True, encoding=request_encoding)
-        tree = etree.fromstring(body.decode("utf-8", "ignore"), parser)
-        og = _calc_og(tree, media_uri)
-
-    return og
+        return _attempt_calc_og(body.decode("utf-8", "ignore"))
 
 
-def _calc_og(tree, media_uri):
+def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
     # suck our tree into lxml and define our OG response.
 
     # if we see any image URLs in the OG response, then spider them
@@ -797,7 +840,9 @@ def _calc_og(tree, media_uri):
                 for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
             )
             og["og:description"] = summarize_paragraphs(text_nodes)
-    else:
+    elif og["og:description"]:
+        # This must be a non-empty string at this point.
+        assert isinstance(og["og:description"], str)
         og["og:description"] = summarize_paragraphs([og["og:description"]])
 
     # TODO: delete the url downloads to stop diskfilling,
@@ -805,7 +850,9 @@ def _calc_og(tree, media_uri):
     return og
 
 
-def _iterate_over_text(tree, *tags_to_ignore):
+def _iterate_over_text(
+    tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
+) -> Generator[str, None, None]:
     """Iterate over the tree returning text nodes in a depth first fashion,
     skipping text nodes inside certain tags.
     """
@@ -839,32 +886,32 @@ def _iterate_over_text(tree, *tags_to_ignore):
             )
 
 
-def _rebase_url(url, base):
-    base = list(urlparse.urlparse(base))
-    url = list(urlparse.urlparse(url))
-    if not url[0]:  # fix up schema
-        url[0] = base[0] or "http"
-    if not url[1]:  # fix up hostname
-        url[1] = base[1]
-        if not url[2].startswith("/"):
-            url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2]
-    return urlparse.urlunparse(url)
+def _rebase_url(url: str, base: str) -> str:
+    base_parts = list(urlparse.urlparse(base))
+    url_parts = list(urlparse.urlparse(url))
+    if not url_parts[0]:  # fix up schema
+        url_parts[0] = base_parts[0] or "http"
+    if not url_parts[1]:  # fix up hostname
+        url_parts[1] = base_parts[1]
+        if not url_parts[2].startswith("/"):
+            url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
+    return urlparse.urlunparse(url_parts)
 
 
-def _is_media(content_type):
-    if content_type.lower().startswith("image/"):
-        return True
+def _is_media(content_type: str) -> bool:
+    return content_type.lower().startswith("image/")
 
 
-def _is_html(content_type):
+def _is_html(content_type: str) -> bool:
     content_type = content_type.lower()
-    if content_type.startswith("text/html") or content_type.startswith(
+    return content_type.startswith("text/html") or content_type.startswith(
         "application/xhtml"
-    ):
-        return True
+    )
 
 
-def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
+def summarize_paragraphs(
+    text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
+) -> Optional[str]:
     # Try to get a summary of between 200 and 500 words, respecting
     # first paragraph and then word boundaries.
     # TODO: Respect sentences?
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 18c9ed48d6..e92006faa9 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,27 +13,32 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import inspect
+import abc
 import logging
 import os
 import shutil
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 from synapse.config._base import Config
 from synapse.logging.context import defer_to_thread, run_in_background
+from synapse.util.async_helpers import maybe_awaitable
 
 from ._base import FileInfo, Responder
 from .media_storage import FileResponder
 
 logger = logging.getLogger(__name__)
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
-class StorageProvider:
+
+class StorageProvider(metaclass=abc.ABCMeta):
     """A storage provider is a service that can store uploaded media and
     retrieve them.
     """
 
-    async def store_file(self, path: str, file_info: FileInfo):
+    @abc.abstractmethod
+    async def store_file(self, path: str, file_info: FileInfo) -> None:
         """Store the file described by file_info. The actual contents can be
         retrieved by reading the file in file_info.upload_path.
 
@@ -42,6 +47,7 @@ class StorageProvider:
             file_info: The metadata of the file.
         """
 
+    @abc.abstractmethod
     async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
         """Attempt to fetch the file described by file_info and stream it
         into writer.
@@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider):
         self.store_synchronous = store_synchronous
         self.store_remote = store_remote
 
-    def __str__(self):
+    def __str__(self) -> str:
         return "StorageProviderWrapper[%s]" % (self.backend,)
 
-    async def store_file(self, path, file_info):
+    async def store_file(self, path: str, file_info: FileInfo) -> None:
         if not file_info.server_name and not self.store_local:
             return None
 
@@ -91,39 +97,34 @@ class StorageProviderWrapper(StorageProvider):
         if self.store_synchronous:
             # store_file is supposed to return an Awaitable, but guard
             # against improper implementations.
-            result = self.backend.store_file(path, file_info)
-            if inspect.isawaitable(result):
-                return await result
+            await maybe_awaitable(self.backend.store_file(path, file_info))  # type: ignore
         else:
             # TODO: Handle errors.
             async def store():
                 try:
-                    result = self.backend.store_file(path, file_info)
-                    if inspect.isawaitable(result):
-                        return await result
+                    return await maybe_awaitable(
+                        self.backend.store_file(path, file_info)
+                    )
                 except Exception:
                     logger.exception("Error storing file")
 
             run_in_background(store)
-            return None
 
-    async def fetch(self, path, file_info):
+    async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
         # store_file is supposed to return an Awaitable, but guard
         # against improper implementations.
-        result = self.backend.fetch(path, file_info)
-        if inspect.isawaitable(result):
-            return await result
+        return await maybe_awaitable(self.backend.fetch(path, file_info))
 
 
 class FileStorageProviderBackend(StorageProvider):
     """A storage provider that stores files in a directory on a filesystem.
 
     Args:
-        hs (HomeServer)
+        hs
         config: The config returned by `parse_config`.
     """
 
-    def __init__(self, hs, config):
+    def __init__(self, hs: "HomeServer", config: str):
         self.hs = hs
         self.cache_directory = hs.config.media_store_path
         self.base_directory = config
@@ -131,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider):
     def __str__(self):
         return "FileStorageProviderBackend[%s]" % (self.base_directory,)
 
-    async def store_file(self, path, file_info):
+    async def store_file(self, path: str, file_info: FileInfo) -> None:
         """See StorageProvider.store_file"""
 
         primary_fname = os.path.join(self.cache_directory, path)
@@ -141,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider):
         if not os.path.exists(dirname):
             os.makedirs(dirname)
 
-        return await defer_to_thread(
+        await defer_to_thread(
             self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
         )
 
-    async def fetch(self, path, file_info):
+    async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
         """See StorageProvider.fetch"""
 
         backup_fname = os.path.join(self.base_directory, path)
         if os.path.isfile(backup_fname):
             return FileResponder(open(backup_fname, "rb"))
 
+        return None
+
     @staticmethod
-    def parse_config(config):
+    def parse_config(config: dict) -> str:
         """Called on startup to parse config supplied. This should parse
         the config and raise if there is a problem.
 
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 30421b663a..d653a58be9 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,10 +16,14 @@
 
 
 import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from twisted.web.http import Request
 
 from synapse.api.errors import SynapseError
 from synapse.http.server import DirectServeJsonResource, set_cors_headers
 from synapse.http.servlet import parse_integer, parse_string
+from synapse.rest.media.v1.media_storage import MediaStorage
 
 from ._base import (
     FileInfo,
@@ -28,13 +33,22 @@ from ._base import (
     respond_with_responder,
 )
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+    from synapse.rest.media.v1.media_repository import MediaRepository
+
 logger = logging.getLogger(__name__)
 
 
 class ThumbnailResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo, media_storage):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        media_repo: "MediaRepository",
+        media_storage: MediaStorage,
+    ):
         super().__init__()
 
         self.store = hs.get_datastore()
@@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource):
         self.dynamic_thumbnails = hs.config.dynamic_thumbnails
         self.server_name = hs.hostname
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: Request) -> None:
         set_cors_headers(request)
         server_name, media_id, _ = parse_media_id(request)
         width = parse_integer(request, "width", required=True)
@@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource):
             self.media_repo.mark_recently_accessed(server_name, media_id)
 
     async def _respond_local_thumbnail(
-        self, request, media_id, width, height, method, m_type
-    ):
+        self,
+        request: Request,
+        media_id: str,
+        width: int,
+        height: int,
+        method: str,
+        m_type: str,
+    ) -> None:
         media_info = await self.store.get_local_media(media_id)
 
         if not media_info:
@@ -86,41 +106,27 @@ class ThumbnailResource(DirectServeJsonResource):
             return
 
         thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
-
-        if thumbnail_infos:
-            thumbnail_info = self._select_thumbnail(
-                width, height, method, m_type, thumbnail_infos
-            )
-
-            file_info = FileInfo(
-                server_name=None,
-                file_id=media_id,
-                url_cache=media_info["url_cache"],
-                thumbnail=True,
-                thumbnail_width=thumbnail_info["thumbnail_width"],
-                thumbnail_height=thumbnail_info["thumbnail_height"],
-                thumbnail_type=thumbnail_info["thumbnail_type"],
-                thumbnail_method=thumbnail_info["thumbnail_method"],
-            )
-
-            t_type = file_info.thumbnail_type
-            t_length = thumbnail_info["thumbnail_length"]
-
-            responder = await self.media_storage.fetch_media(file_info)
-            await respond_with_responder(request, responder, t_type, t_length)
-        else:
-            logger.info("Couldn't find any generated thumbnails")
-            respond_404(request)
+        await self._select_and_respond_with_thumbnail(
+            request,
+            width,
+            height,
+            method,
+            m_type,
+            thumbnail_infos,
+            media_id,
+            url_cache=media_info["url_cache"],
+            server_name=None,
+        )
 
     async def _select_or_generate_local_thumbnail(
         self,
-        request,
-        media_id,
-        desired_width,
-        desired_height,
-        desired_method,
-        desired_type,
-    ):
+        request: Request,
+        media_id: str,
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+    ) -> None:
         media_info = await self.store.get_local_media(media_id)
 
         if not media_info:
@@ -178,14 +184,14 @@ class ThumbnailResource(DirectServeJsonResource):
 
     async def _select_or_generate_remote_thumbnail(
         self,
-        request,
-        server_name,
-        media_id,
-        desired_width,
-        desired_height,
-        desired_method,
-        desired_type,
-    ):
+        request: Request,
+        server_name: str,
+        media_id: str,
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+    ) -> None:
         media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
 
         thumbnail_infos = await self.store.get_remote_media_thumbnails(
@@ -239,8 +245,15 @@ class ThumbnailResource(DirectServeJsonResource):
             raise SynapseError(400, "Failed to generate thumbnail.")
 
     async def _respond_remote_thumbnail(
-        self, request, server_name, media_id, width, height, method, m_type
-    ):
+        self,
+        request: Request,
+        server_name: str,
+        media_id: str,
+        width: int,
+        height: int,
+        method: str,
+        m_type: str,
+    ) -> None:
         # TODO: Don't download the whole remote file
         # We should proxy the thumbnail from the remote server instead of
         # downloading the remote file and generating our own thumbnails.
@@ -249,97 +262,185 @@ class ThumbnailResource(DirectServeJsonResource):
         thumbnail_infos = await self.store.get_remote_media_thumbnails(
             server_name, media_id
         )
+        await self._select_and_respond_with_thumbnail(
+            request,
+            width,
+            height,
+            method,
+            m_type,
+            thumbnail_infos,
+            media_info["filesystem_id"],
+            url_cache=None,
+            server_name=server_name,
+        )
 
+    async def _select_and_respond_with_thumbnail(
+        self,
+        request: Request,
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+        thumbnail_infos: List[Dict[str, Any]],
+        file_id: str,
+        url_cache: Optional[str] = None,
+        server_name: Optional[str] = None,
+    ) -> None:
+        """
+        Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
+
+        Args:
+            request: The incoming request.
+            desired_width: The desired width, the returned thumbnail may be larger than this.
+            desired_height: The desired height, the returned thumbnail may be larger than this.
+            desired_method: The desired method used to generate the thumbnail.
+            desired_type: The desired content-type of the thumbnail.
+            thumbnail_infos: A list of dictionaries of candidate thumbnails.
+            file_id: The ID of the media that a thumbnail is being requested for.
+            url_cache: The URL cache value.
+            server_name: The server name, if this is a remote thumbnail.
+        """
         if thumbnail_infos:
-            thumbnail_info = self._select_thumbnail(
-                width, height, method, m_type, thumbnail_infos
+            file_info = self._select_thumbnail(
+                desired_width,
+                desired_height,
+                desired_method,
+                desired_type,
+                thumbnail_infos,
+                file_id,
+                url_cache,
+                server_name,
             )
-            file_info = FileInfo(
-                server_name=server_name,
-                file_id=media_info["filesystem_id"],
-                thumbnail=True,
-                thumbnail_width=thumbnail_info["thumbnail_width"],
-                thumbnail_height=thumbnail_info["thumbnail_height"],
-                thumbnail_type=thumbnail_info["thumbnail_type"],
-                thumbnail_method=thumbnail_info["thumbnail_method"],
-            )
-
-            t_type = file_info.thumbnail_type
-            t_length = thumbnail_info["thumbnail_length"]
+            if not file_info:
+                logger.info("Couldn't find a thumbnail matching the desired inputs")
+                respond_404(request)
+                return
 
             responder = await self.media_storage.fetch_media(file_info)
-            await respond_with_responder(request, responder, t_type, t_length)
+            await respond_with_responder(
+                request, responder, file_info.thumbnail_type, file_info.thumbnail_length
+            )
         else:
             logger.info("Failed to find any generated thumbnails")
             respond_404(request)
 
     def _select_thumbnail(
         self,
-        desired_width,
-        desired_height,
-        desired_method,
-        desired_type,
-        thumbnail_infos,
-    ):
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+        thumbnail_infos: List[Dict[str, Any]],
+        file_id: str,
+        url_cache: Optional[str],
+        server_name: Optional[str],
+    ) -> Optional[FileInfo]:
+        """
+        Choose an appropriate thumbnail from the previously generated thumbnails.
+
+        Args:
+            desired_width: The desired width, the returned thumbnail may be larger than this.
+            desired_height: The desired height, the returned thumbnail may be larger than this.
+            desired_method: The desired method used to generate the thumbnail.
+            desired_type: The desired content-type of the thumbnail.
+            thumbnail_infos: A list of dictionaries of candidate thumbnails.
+            file_id: The ID of the media that a thumbnail is being requested for.
+            url_cache: The URL cache value.
+            server_name: The server name, if this is a remote thumbnail.
+
+        Returns:
+             The thumbnail which best matches the desired parameters.
+        """
+        desired_method = desired_method.lower()
+
+        # The chosen thumbnail.
+        thumbnail_info = None
+
         d_w = desired_width
         d_h = desired_height
 
-        if desired_method.lower() == "crop":
+        if desired_method == "crop":
+            # Thumbnails that match equal or larger sizes of desired width/height.
             crop_info_list = []
+            # Other thumbnails.
             crop_info_list2 = []
             for info in thumbnail_infos:
+                # Skip thumbnails generated with different methods.
+                if info["thumbnail_method"] != "crop":
+                    continue
+
                 t_w = info["thumbnail_width"]
                 t_h = info["thumbnail_height"]
-                t_method = info["thumbnail_method"]
-                if t_method == "crop":
-                    aspect_quality = abs(d_w * t_h - d_h * t_w)
-                    min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
-                    size_quality = abs((d_w - t_w) * (d_h - t_h))
-                    type_quality = desired_type != info["thumbnail_type"]
-                    length_quality = info["thumbnail_length"]
-                    if t_w >= d_w or t_h >= d_h:
-                        crop_info_list.append(
-                            (
-                                aspect_quality,
-                                min_quality,
-                                size_quality,
-                                type_quality,
-                                length_quality,
-                                info,
-                            )
+                aspect_quality = abs(d_w * t_h - d_h * t_w)
+                min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
+                size_quality = abs((d_w - t_w) * (d_h - t_h))
+                type_quality = desired_type != info["thumbnail_type"]
+                length_quality = info["thumbnail_length"]
+                if t_w >= d_w or t_h >= d_h:
+                    crop_info_list.append(
+                        (
+                            aspect_quality,
+                            min_quality,
+                            size_quality,
+                            type_quality,
+                            length_quality,
+                            info,
                         )
-                    else:
-                        crop_info_list2.append(
-                            (
-                                aspect_quality,
-                                min_quality,
-                                size_quality,
-                                type_quality,
-                                length_quality,
-                                info,
-                            )
+                    )
+                else:
+                    crop_info_list2.append(
+                        (
+                            aspect_quality,
+                            min_quality,
+                            size_quality,
+                            type_quality,
+                            length_quality,
+                            info,
                         )
+                    )
             if crop_info_list:
-                return min(crop_info_list)[-1]
-            else:
-                return min(crop_info_list2)[-1]
-        else:
+                thumbnail_info = min(crop_info_list)[-1]
+            elif crop_info_list2:
+                thumbnail_info = min(crop_info_list2)[-1]
+        elif desired_method == "scale":
+            # Thumbnails that match equal or larger sizes of desired width/height.
             info_list = []
+            # Other thumbnails.
             info_list2 = []
+
             for info in thumbnail_infos:
+                # Skip thumbnails generated with different methods.
+                if info["thumbnail_method"] != "scale":
+                    continue
+
                 t_w = info["thumbnail_width"]
                 t_h = info["thumbnail_height"]
-                t_method = info["thumbnail_method"]
                 size_quality = abs((d_w - t_w) * (d_h - t_h))
                 type_quality = desired_type != info["thumbnail_type"]
                 length_quality = info["thumbnail_length"]
-                if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
+                if t_w >= d_w or t_h >= d_h:
                     info_list.append((size_quality, type_quality, length_quality, info))
-                elif t_method == "scale":
+                else:
                     info_list2.append(
                         (size_quality, type_quality, length_quality, info)
                     )
             if info_list:
-                return min(info_list)[-1]
-            else:
-                return min(info_list2)[-1]
+                thumbnail_info = min(info_list)[-1]
+            elif info_list2:
+                thumbnail_info = min(info_list2)[-1]
+
+        if thumbnail_info:
+            return FileInfo(
+                file_id=file_id,
+                url_cache=url_cache,
+                server_name=server_name,
+                thumbnail=True,
+                thumbnail_width=thumbnail_info["thumbnail_width"],
+                thumbnail_height=thumbnail_info["thumbnail_height"],
+                thumbnail_type=thumbnail_info["thumbnail_type"],
+                thumbnail_method=thumbnail_info["thumbnail_method"],
+                thumbnail_length=thumbnail_info["thumbnail_length"],
+            )
+
+        # No matching thumbnail was found.
+        return None
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 32a8e4f960..07903e4017 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,6 +15,7 @@
 # limitations under the License.
 import logging
 from io import BytesIO
+from typing import Tuple
 
 from PIL import Image
 
@@ -39,7 +41,7 @@ class Thumbnailer:
 
     FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
 
-    def __init__(self, input_path):
+    def __init__(self, input_path: str):
         try:
             self.image = Image.open(input_path)
         except OSError as e:
@@ -59,11 +61,11 @@ class Thumbnailer:
             # A lot of parsing errors can happen when parsing EXIF
             logger.info("Error parsing image EXIF information: %s", e)
 
-    def transpose(self):
+    def transpose(self) -> Tuple[int, int]:
         """Transpose the image using its EXIF Orientation tag
 
         Returns:
-            Tuple[int, int]: (width, height) containing the new image size in pixels.
+            A tuple containing the new image size in pixels as (width, height).
         """
         if self.transpose_method is not None:
             self.image = self.image.transpose(self.transpose_method)
@@ -73,7 +75,7 @@ class Thumbnailer:
             self.image.info["exif"] = None
         return self.image.size
 
-    def aspect(self, max_width, max_height):
+    def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]:
         """Calculate the largest size that preserves aspect ratio which
         fits within the given rectangle::
 
@@ -91,7 +93,7 @@ class Thumbnailer:
         else:
             return (max_height * self.width) // self.height, max_height
 
-    def _resize(self, width, height):
+    def _resize(self, width: int, height: int) -> Image:
         # 1-bit or 8-bit color palette images need converting to RGB
         # otherwise they will be scaled using nearest neighbour which
         # looks awful
@@ -99,7 +101,7 @@ class Thumbnailer:
             self.image = self.image.convert("RGB")
         return self.image.resize((width, height), Image.ANTIALIAS)
 
-    def scale(self, width, height, output_type):
+    def scale(self, width: int, height: int, output_type: str) -> BytesIO:
         """Rescales the image to the given dimensions.
 
         Returns:
@@ -108,7 +110,7 @@ class Thumbnailer:
         scaled = self._resize(width, height)
         return self._encode_image(scaled, output_type)
 
-    def crop(self, width, height, output_type):
+    def crop(self, width: int, height: int, output_type: str) -> BytesIO:
         """Rescales and crops the image to the given dimensions preserving
         aspect::
             (w_in / h_in) = (w_scaled / h_scaled)
@@ -136,7 +138,7 @@ class Thumbnailer:
             cropped = scaled_image.crop((crop_left, 0, crop_right, height))
         return self._encode_image(cropped, output_type)
 
-    def _encode_image(self, output_image, output_type):
+    def _encode_image(self, output_image: Image, output_type: str) -> BytesIO:
         output_bytes_io = BytesIO()
         fmt = self.FORMATS[output_type]
         if fmt == "JPEG":
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index d76f7389e1..6da76ae994 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,18 +15,25 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_string
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+    from synapse.rest.media.v1.media_repository import MediaRepository
+
 logger = logging.getLogger(__name__)
 
 
 class UploadResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo):
+    def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
         super().__init__()
 
         self.media_repo = media_repo
@@ -37,14 +45,14 @@ class UploadResource(DirectServeJsonResource):
         self.max_upload_size = hs.config.max_upload_size
         self.clock = hs.get_clock()
 
-    async def _async_render_OPTIONS(self, request):
+    async def _async_render_OPTIONS(self, request: Request) -> None:
         respond_with_json(request, 200, {}, send_cors=True)
 
-    async def _async_render_POST(self, request):
+    async def _async_render_POST(self, request: Request) -> None:
         requester = await self.auth.get_user_by_req(request)
         # TODO: The checks here are a bit late. The content will have
         # already been uploaded to a tmp file at this point
-        content_length = request.getHeader(b"Content-Length").decode("ascii")
+        content_length = request.getHeader("Content-Length")
         if content_length is None:
             raise SynapseError(msg="Request must specify a Content-Length", code=400)
         if int(content_length) > self.max_upload_size:
diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py
index c0b733488b..e5ef515090 100644
--- a/synapse/rest/synapse/client/__init__.py
+++ b/synapse/rest/synapse/client/__init__.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,3 +12,55 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+from typing import TYPE_CHECKING, Mapping
+
+from twisted.web.resource import Resource
+
+from synapse.rest.synapse.client.new_user_consent import NewUserConsentResource
+from synapse.rest.synapse.client.pick_idp import PickIdpResource
+from synapse.rest.synapse.client.pick_username import pick_username_resource
+from synapse.rest.synapse.client.sso_register import SsoRegisterResource
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resource]:
+    """Builds a resource tree to include synapse-specific client resources
+
+    These are resources which should be loaded on all workers which expose a C-S API:
+    ie, the main process, and any generic workers so configured.
+
+    Returns:
+         map from path to Resource.
+    """
+    resources = {
+        # SSO bits. These are always loaded, whether or not SSO login is actually
+        # enabled (they just won't work very well if it's not)
+        "/_synapse/client/pick_idp": PickIdpResource(hs),
+        "/_synapse/client/pick_username": pick_username_resource(hs),
+        "/_synapse/client/new_user_consent": NewUserConsentResource(hs),
+        "/_synapse/client/sso_register": SsoRegisterResource(hs),
+    }
+
+    # provider-specific SSO bits. Only load these if they are enabled, since they
+    # rely on optional dependencies.
+    if hs.config.oidc_enabled:
+        from synapse.rest.synapse.client.oidc import OIDCResource
+
+        resources["/_synapse/client/oidc"] = OIDCResource(hs)
+
+    if hs.config.saml2_enabled:
+        from synapse.rest.synapse.client.saml2 import SAML2Resource
+
+        res = SAML2Resource(hs)
+        resources["/_synapse/client/saml2"] = res
+
+        # This is also mounted under '/_matrix' for backwards-compatibility.
+        resources["/_matrix/saml2"] = res
+
+    return resources
+
+
+__all__ = ["build_synapse_client_resource_tree"]
diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py
new file mode 100644
index 0000000000..b2e0f93810
--- /dev/null
+++ b/synapse/rest/synapse/client/new_user_consent.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import DirectServeHtmlResource, respond_with_html
+from synapse.http.servlet import parse_string
+from synapse.types import UserID
+from synapse.util.templates import build_jinja_env
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class NewUserConsentResource(DirectServeHtmlResource):
+    """A resource which collects consent to the server's terms from a new user
+
+    This resource gets mounted at /_synapse/client/new_user_consent, and is shown
+    when we are automatically creating a new user due to an SSO login.
+
+    It shows a template which prompts the user to go and read the Ts and Cs, and click
+    a clickybox if they have done so.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._sso_handler = hs.get_sso_handler()
+        self._server_name = hs.hostname
+        self._consent_version = hs.config.consent.user_consent_version
+
+        def template_search_dirs():
+            if hs.config.sso.sso_template_dir:
+                yield hs.config.sso.sso_template_dir
+            yield hs.config.sso.default_template_dir
+
+        self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
+
+    async def _async_render_GET(self, request: Request) -> None:
+        try:
+            session_id = get_username_mapping_session_cookie_from_request(request)
+            session = self._sso_handler.get_mapping_session(session_id)
+        except SynapseError as e:
+            logger.warning("Error fetching session: %s", e)
+            self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
+        user_id = UserID(session.chosen_localpart, self._server_name)
+        user_profile = {
+            "display_name": session.display_name,
+        }
+
+        template_params = {
+            "user_id": user_id.to_string(),
+            "user_profile": user_profile,
+            "consent_version": self._consent_version,
+            "terms_url": "/_matrix/consent?v=%s" % (self._consent_version,),
+        }
+
+        template = self._jinja_env.get_template("sso_new_user_consent.html")
+        html = template.render(template_params)
+        respond_with_html(request, 200, html)
+
+    async def _async_render_POST(self, request: Request):
+        try:
+            session_id = get_username_mapping_session_cookie_from_request(request)
+        except SynapseError as e:
+            logger.warning("Error fetching session cookie: %s", e)
+            self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
+        try:
+            accepted_version = parse_string(request, "accepted_version", required=True)
+        except SynapseError as e:
+            self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code)
+            return
+
+        await self._sso_handler.handle_terms_accepted(
+            request, session_id, accepted_version
+        )
diff --git a/synapse/rest/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py
index d958dd65bb..64c0deb75d 100644
--- a/synapse/rest/oidc/__init__.py
+++ b/synapse/rest/synapse/client/oidc/__init__.py
@@ -12,11 +12,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import logging
 
 from twisted.web.resource import Resource
 
-from synapse.rest.oidc.callback_resource import OIDCCallbackResource
+from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource
 
 logger = logging.getLogger(__name__)
 
@@ -25,3 +26,6 @@ class OIDCResource(Resource):
     def __init__(self, hs):
         Resource.__init__(self)
         self.putChild(b"callback", OIDCCallbackResource(hs))
+
+
+__all__ = ["OIDCResource"]
diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/synapse/client/oidc/callback_resource.py
index f7a0bc4bdb..f7a0bc4bdb 100644
--- a/synapse/rest/oidc/callback_resource.py
+++ b/synapse/rest/synapse/client/oidc/callback_resource.py
diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py
new file mode 100644
index 0000000000..9550b82998
--- /dev/null
+++ b/synapse/rest/synapse/client/pick_idp.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING
+
+from synapse.http.server import (
+    DirectServeHtmlResource,
+    finish_request,
+    respond_with_html,
+)
+from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class PickIdpResource(DirectServeHtmlResource):
+    """IdP picker resource.
+
+    This resource gets mounted under /_synapse/client/pick_idp. It serves an HTML page
+    which prompts the user to choose an Identity Provider from the list.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._sso_handler = hs.get_sso_handler()
+        self._sso_login_idp_picker_template = (
+            hs.config.sso.sso_login_idp_picker_template
+        )
+        self._server_name = hs.hostname
+
+    async def _async_render_GET(self, request: SynapseRequest) -> None:
+        client_redirect_url = parse_string(
+            request, "redirectUrl", required=True, encoding="utf-8"
+        )
+        idp = parse_string(request, "idp", required=False)
+
+        # if we need to pick an IdP, do so
+        if not idp:
+            return await self._serve_id_picker(request, client_redirect_url)
+
+        # otherwise, redirect to the IdP's redirect URI
+        providers = self._sso_handler.get_identity_providers()
+        auth_provider = providers.get(idp)
+        if not auth_provider:
+            logger.info("Unknown idp %r", idp)
+            self._sso_handler.render_error(
+                request, "unknown_idp", "Unknown identity provider ID"
+            )
+            return
+
+        sso_url = await auth_provider.handle_redirect_request(
+            request, client_redirect_url.encode("utf8")
+        )
+        logger.info("Redirecting to %s", sso_url)
+        request.redirect(sso_url)
+        finish_request(request)
+
+    async def _serve_id_picker(
+        self, request: SynapseRequest, client_redirect_url: str
+    ) -> None:
+        # otherwise, serve up the IdP picker
+        providers = self._sso_handler.get_identity_providers()
+        html = self._sso_login_idp_picker_template.render(
+            redirect_url=client_redirect_url,
+            server_name=self._server_name,
+            providers=providers.values(),
+        )
+        respond_with_html(request, 200, html)
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
new file mode 100644
index 0000000000..96077cfcd1
--- /dev/null
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -0,0 +1,131 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, List
+
+from twisted.web.http import Request
+from twisted.web.resource import Resource
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import (
+    DirectServeHtmlResource,
+    DirectServeJsonResource,
+    respond_with_html,
+)
+from synapse.http.servlet import parse_boolean, parse_string
+from synapse.http.site import SynapseRequest
+from synapse.util.templates import build_jinja_env
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+def pick_username_resource(hs: "HomeServer") -> Resource:
+    """Factory method to generate the username picker resource.
+
+    This resource gets mounted under /_synapse/client/pick_username and has two
+       children:
+
+      * "account_details": renders the form and handles the POSTed response
+      * "check": a JSON endpoint which checks if a userid is free.
+    """
+
+    res = Resource()
+    res.putChild(b"account_details", AccountDetailsResource(hs))
+    res.putChild(b"check", AvailabilityCheckResource(hs))
+
+    return res
+
+
+class AvailabilityCheckResource(DirectServeJsonResource):
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._sso_handler = hs.get_sso_handler()
+
+    async def _async_render_GET(self, request: Request):
+        localpart = parse_string(request, "username", required=True)
+
+        session_id = get_username_mapping_session_cookie_from_request(request)
+
+        is_available = await self._sso_handler.check_username_availability(
+            localpart, session_id
+        )
+        return 200, {"available": is_available}
+
+
+class AccountDetailsResource(DirectServeHtmlResource):
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._sso_handler = hs.get_sso_handler()
+
+        def template_search_dirs():
+            if hs.config.sso.sso_template_dir:
+                yield hs.config.sso.sso_template_dir
+            yield hs.config.sso.default_template_dir
+
+        self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
+
+    async def _async_render_GET(self, request: Request) -> None:
+        try:
+            session_id = get_username_mapping_session_cookie_from_request(request)
+            session = self._sso_handler.get_mapping_session(session_id)
+        except SynapseError as e:
+            logger.warning("Error fetching session: %s", e)
+            self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
+        idp_id = session.auth_provider_id
+        template_params = {
+            "idp": self._sso_handler.get_identity_providers()[idp_id],
+            "user_attributes": {
+                "display_name": session.display_name,
+                "emails": session.emails,
+            },
+        }
+
+        template = self._jinja_env.get_template("sso_auth_account_details.html")
+        html = template.render(template_params)
+        respond_with_html(request, 200, html)
+
+    async def _async_render_POST(self, request: SynapseRequest):
+        try:
+            session_id = get_username_mapping_session_cookie_from_request(request)
+        except SynapseError as e:
+            logger.warning("Error fetching session cookie: %s", e)
+            self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
+        try:
+            localpart = parse_string(request, "username", required=True)
+            use_display_name = parse_boolean(request, "use_display_name", default=False)
+
+            try:
+                emails_to_use = [
+                    val.decode("utf-8") for val in request.args.get(b"use_email", [])
+                ]  # type: List[str]
+            except ValueError:
+                raise SynapseError(400, "Query parameter use_email must be utf-8")
+        except SynapseError as e:
+            logger.warning("[session %s] bad param: %s", session_id, e)
+            self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code)
+            return
+
+        await self._sso_handler.handle_submit_username_request(
+            request, session_id, localpart, use_display_name, emails_to_use
+        )
diff --git a/synapse/rest/saml2/__init__.py b/synapse/rest/synapse/client/saml2/__init__.py
index 68da37ca6a..3e8235ee1e 100644
--- a/synapse/rest/saml2/__init__.py
+++ b/synapse/rest/synapse/client/saml2/__init__.py
@@ -12,12 +12,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import logging
 
 from twisted.web.resource import Resource
 
-from synapse.rest.saml2.metadata_resource import SAML2MetadataResource
-from synapse.rest.saml2.response_resource import SAML2ResponseResource
+from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource
+from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource
 
 logger = logging.getLogger(__name__)
 
@@ -27,3 +28,6 @@ class SAML2Resource(Resource):
         Resource.__init__(self)
         self.putChild(b"metadata.xml", SAML2MetadataResource(hs))
         self.putChild(b"authn_response", SAML2ResponseResource(hs))
+
+
+__all__ = ["SAML2Resource"]
diff --git a/synapse/rest/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py
index 1e8526e22e..1e8526e22e 100644
--- a/synapse/rest/saml2/metadata_resource.py
+++ b/synapse/rest/synapse/client/saml2/metadata_resource.py
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py
index f6668fb5e3..f6668fb5e3 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/synapse/client/saml2/response_resource.py
diff --git a/synapse/rest/synapse/client/sso_register.py b/synapse/rest/synapse/client/sso_register.py
new file mode 100644
index 0000000000..dfefeb7796
--- /dev/null
+++ b/synapse/rest/synapse/client/sso_register.py
@@ -0,0 +1,50 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import DirectServeHtmlResource
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class SsoRegisterResource(DirectServeHtmlResource):
+    """A resource which completes SSO registration
+
+    This resource gets mounted at /_synapse/client/sso_register, and is shown
+    after we collect username and/or consent for a new SSO user. It (finally) registers
+    the user, and confirms redirect to the client
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._sso_handler = hs.get_sso_handler()
+
+    async def _async_render_GET(self, request: Request) -> None:
+        try:
+            session_id = get_username_mapping_session_cookie_from_request(request)
+        except SynapseError as e:
+            logger.warning("Error fetching session cookie: %s", e)
+            self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+        await self._sso_handler.register_sso_user(request, session_id)