summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/admin/__init__.py2
-rw-r--r--synapse/rest/admin/rooms.py65
-rw-r--r--synapse/rest/client/v1/room.py2
-rw-r--r--synapse/rest/client/v2_alpha/groups.py179
-rw-r--r--synapse/rest/client/v2_alpha/register.py2
-rw-r--r--synapse/rest/media/v1/media_storage.py59
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py71
-rw-r--r--synapse/rest/media/v1/upload_resource.py12
8 files changed, 306 insertions, 86 deletions
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index f5c5d164f9..8457db1e22 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -42,6 +42,7 @@ from synapse.rest.admin.rooms import (
     JoinRoomAliasServlet,
     ListRoomRestServlet,
     MakeRoomAdminRestServlet,
+    RoomEventContextServlet,
     RoomMembersRestServlet,
     RoomRestServlet,
     RoomStateRestServlet,
@@ -238,6 +239,7 @@ def register_servlets(hs, http_server):
     MakeRoomAdminRestServlet(hs).register(http_server)
     ShadowBanRestServlet(hs).register(http_server)
     ForwardExtremitiesRestServlet(hs).register(http_server)
+    RoomEventContextServlet(hs).register(http_server)
 
 
 def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 3e57e6a4d0..491f9ca095 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -15,9 +15,11 @@
 import logging
 from http import HTTPStatus
 from typing import TYPE_CHECKING, List, Optional, Tuple
+from urllib import parse as urlparse
 
 from synapse.api.constants import EventTypes, JoinRules, Membership
 from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.api.filtering import Filter
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
@@ -33,6 +35,7 @@ from synapse.rest.admin._base import (
 )
 from synapse.storage.databases.main.room import RoomSortOrder
 from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
+from synapse.util import json_decoder
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -605,3 +608,65 @@ class ForwardExtremitiesRestServlet(RestServlet):
 
         extremities = await self.store.get_forward_extremities_for_room(room_id)
         return 200, {"count": len(extremities), "results": extremities}
+
+
+class RoomEventContextServlet(RestServlet):
+    """
+    Provide the context for an event.
+    This API is designed to be used when system administrators wish to look at
+    an abuse report and understand what happened during and immediately prior
+    to this event.
+    """
+
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$")
+
+    def __init__(self, hs):
+        super().__init__()
+        self.clock = hs.get_clock()
+        self.room_context_handler = hs.get_room_context_handler()
+        self._event_serializer = hs.get_event_client_serializer()
+        self.auth = hs.get_auth()
+
+    async def on_GET(self, request, room_id, event_id):
+        requester = await self.auth.get_user_by_req(request, allow_guest=False)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        limit = parse_integer(request, "limit", default=10)
+
+        # picking the API shape for symmetry with /messages
+        filter_str = parse_string(request, b"filter", encoding="utf-8")
+        if filter_str:
+            filter_json = urlparse.unquote(filter_str)
+            event_filter = Filter(
+                json_decoder.decode(filter_json)
+            )  # type: Optional[Filter]
+        else:
+            event_filter = None
+
+        results = await self.room_context_handler.get_event_context(
+            requester,
+            room_id,
+            event_id,
+            limit,
+            event_filter,
+            use_admin_priviledge=True,
+        )
+
+        if not results:
+            raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
+
+        time_now = self.clock.time_msec()
+        results["events_before"] = await self._event_serializer.serialize_events(
+            results["events_before"], time_now
+        )
+        results["event"] = await self._event_serializer.serialize_event(
+            results["event"], time_now
+        )
+        results["events_after"] = await self._event_serializer.serialize_events(
+            results["events_after"], time_now
+        )
+        results["state"] = await self._event_serializer.serialize_events(
+            results["state"], time_now
+        )
+
+        return 200, results
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index f95627ee61..90fd98c53e 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -650,7 +650,7 @@ class RoomEventContextServlet(RestServlet):
             event_filter = None
 
         results = await self.room_context_handler.get_event_context(
-            requester.user, room_id, event_id, limit, event_filter
+            requester, room_id, event_id, limit, event_filter
         )
 
         if not results:
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 5b5da71815..4fe712b30c 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -16,13 +16,24 @@
 
 import logging
 from functools import wraps
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.http import Request
 
 from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import GroupID
+from synapse.handlers.groups_local import GroupsLocalHandler
+from synapse.http.servlet import (
+    RestServlet,
+    assert_params_in_dict,
+    parse_json_object_from_request,
+)
+from synapse.types import GroupID, JsonDict
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -33,7 +44,7 @@ def _validate_group_id(f):
     """
 
     @wraps(f)
-    def wrapper(self, request, group_id, *args, **kwargs):
+    def wrapper(self, request: Request, group_id: str, *args, **kwargs):
         if not GroupID.is_valid(group_id):
             raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
 
@@ -48,14 +59,14 @@ class GroupServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -66,11 +77,15 @@ class GroupServlet(RestServlet):
         return 200, group_description
 
     @_validate_group_id
-    async def on_POST(self, request, group_id):
+    async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert_params_in_dict(
+            content, ("name", "avatar_url", "short_description", "long_description")
+        )
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         await self.groups_handler.update_group_profile(
             group_id, requester_user_id, content
         )
@@ -84,14 +99,14 @@ class GroupSummaryServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -116,18 +131,21 @@ class GroupSummaryRoomsCatServlet(RestServlet):
         "/rooms/(?P<room_id>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, category_id, room_id):
+    async def on_PUT(
+        self, request: Request, group_id: str, category_id: str, room_id: str
+    ):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         resp = await self.groups_handler.update_group_summary_room(
             group_id,
             requester_user_id,
@@ -139,10 +157,13 @@ class GroupSummaryRoomsCatServlet(RestServlet):
         return 200, resp
 
     @_validate_group_id
-    async def on_DELETE(self, request, group_id, category_id, room_id):
+    async def on_DELETE(
+        self, request: Request, group_id: str, category_id: str, room_id: str
+    ):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         resp = await self.groups_handler.delete_group_summary_room(
             group_id, requester_user_id, room_id=room_id, category_id=category_id
         )
@@ -158,14 +179,16 @@ class GroupCategoryServlet(RestServlet):
         "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id, category_id):
+    async def on_GET(
+        self, request: Request, group_id: str, category_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -176,11 +199,14 @@ class GroupCategoryServlet(RestServlet):
         return 200, category
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, category_id):
+    async def on_PUT(
+        self, request: Request, group_id: str, category_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         resp = await self.groups_handler.update_group_category(
             group_id, requester_user_id, category_id=category_id, content=content
         )
@@ -188,10 +214,13 @@ class GroupCategoryServlet(RestServlet):
         return 200, resp
 
     @_validate_group_id
-    async def on_DELETE(self, request, group_id, category_id):
+    async def on_DELETE(
+        self, request: Request, group_id: str, category_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         resp = await self.groups_handler.delete_group_category(
             group_id, requester_user_id, category_id=category_id
         )
@@ -205,14 +234,14 @@ class GroupCategoriesServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -229,14 +258,16 @@ class GroupRoleServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id, role_id):
+    async def on_GET(
+        self, request: Request, group_id: str, role_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -247,11 +278,14 @@ class GroupRoleServlet(RestServlet):
         return 200, category
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, role_id):
+    async def on_PUT(
+        self, request: Request, group_id: str, role_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         resp = await self.groups_handler.update_group_role(
             group_id, requester_user_id, role_id=role_id, content=content
         )
@@ -259,10 +293,13 @@ class GroupRoleServlet(RestServlet):
         return 200, resp
 
     @_validate_group_id
-    async def on_DELETE(self, request, group_id, role_id):
+    async def on_DELETE(
+        self, request: Request, group_id: str, role_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         resp = await self.groups_handler.delete_group_role(
             group_id, requester_user_id, role_id=role_id
         )
@@ -276,14 +313,14 @@ class GroupRolesServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -308,18 +345,21 @@ class GroupSummaryUsersRoleServlet(RestServlet):
         "/users/(?P<user_id>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, role_id, user_id):
+    async def on_PUT(
+        self, request: Request, group_id: str, role_id: str, user_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         resp = await self.groups_handler.update_group_summary_user(
             group_id,
             requester_user_id,
@@ -331,10 +371,13 @@ class GroupSummaryUsersRoleServlet(RestServlet):
         return 200, resp
 
     @_validate_group_id
-    async def on_DELETE(self, request, group_id, role_id, user_id):
+    async def on_DELETE(
+        self, request: Request, group_id: str, role_id: str, user_id: str
+    ):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         resp = await self.groups_handler.delete_group_summary_user(
             group_id, requester_user_id, user_id=user_id, role_id=role_id
         )
@@ -348,14 +391,14 @@ class GroupRoomServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -372,14 +415,14 @@ class GroupUsersServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -396,14 +439,14 @@ class GroupInvitedUsersServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request, group_id):
+    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -420,18 +463,19 @@ class GroupSettingJoinPolicyServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id):
+    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
 
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.set_group_join_policy(
             group_id, requester_user_id, content
         )
@@ -445,14 +489,14 @@ class GroupCreateServlet(RestServlet):
 
     PATTERNS = client_patterns("/create_group$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
         self.server_name = hs.hostname
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -461,6 +505,7 @@ class GroupCreateServlet(RestServlet):
         localpart = content.pop("localpart")
         group_id = GroupID(localpart, self.server_name).to_string()
 
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.create_group(
             group_id, requester_user_id, content
         )
@@ -476,18 +521,21 @@ class GroupAdminRoomsServlet(RestServlet):
         "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, room_id):
+    async def on_PUT(
+        self, request: Request, group_id: str, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.add_room_to_group(
             group_id, requester_user_id, room_id, content
         )
@@ -495,10 +543,13 @@ class GroupAdminRoomsServlet(RestServlet):
         return 200, result
 
     @_validate_group_id
-    async def on_DELETE(self, request, group_id, room_id):
+    async def on_DELETE(
+        self, request: Request, group_id: str, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.remove_room_from_group(
             group_id, requester_user_id, room_id
         )
@@ -515,18 +566,21 @@ class GroupAdminRoomsConfigServlet(RestServlet):
         "/config/(?P<config_key>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, room_id, config_key):
+    async def on_PUT(
+        self, request: Request, group_id: str, room_id: str, config_key: str
+    ):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.update_room_in_group(
             group_id, requester_user_id, room_id, config_key, content
         )
@@ -542,7 +596,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
         "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
@@ -551,12 +605,13 @@ class GroupAdminUsersInviteServlet(RestServlet):
         self.is_mine_id = hs.is_mine_id
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, user_id):
+    async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
         config = content.get("config", {})
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.invite(
             group_id, user_id, requester_user_id, config
         )
@@ -572,18 +627,19 @@ class GroupAdminUsersKickServlet(RestServlet):
         "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id, user_id):
+    async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.remove_user_from_group(
             group_id, user_id, requester_user_id, content
         )
@@ -597,18 +653,19 @@ class GroupSelfLeaveServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id):
+    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.remove_user_from_group(
             group_id, requester_user_id, requester_user_id, content
         )
@@ -622,18 +679,19 @@ class GroupSelfJoinServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id):
+    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.join_group(
             group_id, requester_user_id, content
         )
@@ -647,18 +705,19 @@ class GroupSelfAcceptInviteServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id):
+    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
+        assert isinstance(self.groups_handler, GroupsLocalHandler)
         result = await self.groups_handler.accept_invite(
             group_id, requester_user_id, content
         )
@@ -672,14 +731,14 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
 
     PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
 
     @_validate_group_id
-    async def on_PUT(self, request, group_id):
+    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -696,14 +755,14 @@ class PublicisedGroupsForUserServlet(RestServlet):
 
     PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.groups_handler = hs.get_groups_local_handler()
 
-    async def on_GET(self, request, user_id):
+    async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
         await self.auth.get_user_by_req(request, allow_guest=True)
 
         result = await self.groups_handler.get_publicised_groups_for_user(user_id)
@@ -717,14 +776,14 @@ class PublicisedGroupsForUsersServlet(RestServlet):
 
     PATTERNS = client_patterns("/publicised_groups$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.groups_handler = hs.get_groups_local_handler()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
         await self.auth.get_user_by_req(request, allow_guest=True)
 
         content = parse_json_object_from_request(request)
@@ -741,13 +800,13 @@ class GroupsForUserServlet(RestServlet):
 
     PATTERNS = client_patterns("/joined_groups$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -756,7 +815,7 @@ class GroupsForUserServlet(RestServlet):
         return 200, result
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
     GroupServlet(hs).register(http_server)
     GroupSummaryServlet(hs).register(http_server)
     GroupInvitedUsersServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 10e1891174..e3d322f2ac 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -193,6 +193,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
             body, ["client_secret", "country", "phone_number", "send_attempt"]
         )
         client_secret = body["client_secret"]
+        assert_valid_client_secret(client_secret)
         country = body["country"]
         phone_number = body["phone_number"]
         send_attempt = body["send_attempt"]
@@ -293,6 +294,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
 
         sid = parse_string(request, "sid", required=True)
         client_secret = parse_string(request, "client_secret", required=True)
+        assert_valid_client_secret(client_secret)
         token = parse_string(request, "token", required=True)
 
         # Attempt to validate a 3PID session
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 89cdd605aa..aba6d689a8 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -16,13 +16,17 @@ import contextlib
 import logging
 import os
 import shutil
-from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
+from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence
+
+import attr
 
 from twisted.internet.defer import Deferred
 from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
 
+from synapse.api.errors import NotFoundError
 from synapse.logging.context import defer_to_thread, make_deferred_yieldable
+from synapse.util import Clock
 from synapse.util.file_consumer import BackgroundFileConsumer
 
 from ._base import FileInfo, Responder
@@ -58,6 +62,8 @@ class MediaStorage:
         self.local_media_directory = local_media_directory
         self.filepaths = filepaths
         self.storage_providers = storage_providers
+        self.spam_checker = hs.get_spam_checker()
+        self.clock = hs.get_clock()
 
     async def store_file(self, source: IO, file_info: FileInfo) -> str:
         """Write `source` to the on disk media store, and also any other
@@ -127,18 +133,29 @@ class MediaStorage:
                     f.flush()
                     f.close()
 
+                    spam = await self.spam_checker.check_media_file_for_spam(
+                        ReadableFileWrapper(self.clock, fname), file_info
+                    )
+                    if spam:
+                        logger.info("Blocking media due to spam checker")
+                        # Note that we'll delete the stored media, due to the
+                        # try/except below. The media also won't be stored in
+                        # the DB.
+                        raise SpamMediaException()
+
                     for provider in self.storage_providers:
                         await provider.store_file(path, file_info)
 
                     finished_called[0] = True
 
                 yield f, fname, finish
-        except Exception:
+        except Exception as e:
             try:
                 os.remove(fname)
             except Exception:
                 pass
-            raise
+
+            raise e from None
 
         if not finished_called:
             raise Exception("Finished callback not called")
@@ -302,3 +319,39 @@ class FileResponder(Responder):
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         self.open_file.close()
+
+
+class SpamMediaException(NotFoundError):
+    """The media was blocked by a spam checker, so we simply 404 the request (in
+    the same way as if it was quarantined).
+    """
+
+
+@attr.s(slots=True)
+class ReadableFileWrapper:
+    """Wrapper that allows reading a file in chunks, yielding to the reactor,
+    and writing to a callback.
+
+    This is simplified `FileSender` that takes an IO object rather than an
+    `IConsumer`.
+    """
+
+    CHUNK_SIZE = 2 ** 14
+
+    clock = attr.ib(type=Clock)
+    path = attr.ib(type=str)
+
+    async def write_chunks_to(self, callback: Callable[[bytes], None]):
+        """Reads the file in chunks and calls the callback with each chunk.
+        """
+
+        with open(self.path, "rb") as file:
+            while True:
+                chunk = file.read(self.CHUNK_SIZE)
+                if not chunk:
+                    break
+
+                callback(chunk)
+
+                # We yield to the reactor by sleeping for 0 seconds.
+                await self.clock.sleep(0)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index bf3be653aa..ae53b1d23f 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -58,7 +58,10 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
+_charset_match = re.compile(br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9-]+)"?', flags=re.I)
+_xml_encoding_match = re.compile(
+    br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9-]+)"', flags=re.I
+)
 _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
 
 OG_TAG_NAME_MAXLEN = 50
@@ -300,24 +303,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             with open(media_info["filename"], "rb") as file:
                 body = file.read()
 
-            encoding = None
-
-            # Let's try and figure out if it has an encoding set in a meta tag.
-            # Limit it to the first 1kb, since it ought to be in the meta tags
-            # at the top.
-            match = _charset_match.search(body[:1000])
-
-            # If we find a match, it should take precedence over the
-            # Content-Type header, so set it here.
-            if match:
-                encoding = match.group(1).decode("ascii")
-
-            # If we don't find a match, we'll look at the HTTP Content-Type, and
-            # if that doesn't exist, we'll fall back to UTF-8.
-            if not encoding:
-                content_match = _content_type_match.match(media_info["media_type"])
-                encoding = content_match.group(1) if content_match else "utf-8"
-
+            encoding = get_html_media_encoding(body, media_info["media_type"])
             og = decode_and_calc_og(body, media_info["uri"], encoding)
 
             # pre-cache the image for posterity
@@ -689,6 +675,48 @@ class PreviewUrlResource(DirectServeJsonResource):
             logger.debug("No media removed from url cache")
 
 
+def get_html_media_encoding(body: bytes, content_type: str) -> str:
+    """
+    Get the encoding of the body based on the (presumably) HTML body or media_type.
+
+    The precedence used for finding a character encoding is:
+
+    1. meta tag with a charset declared.
+    2. The XML document's character encoding attribute.
+    3. The Content-Type header.
+    4. Fallback to UTF-8.
+
+    Args:
+        body: The HTML document, as bytes.
+        content_type: The Content-Type header.
+
+    Returns:
+        The character encoding of the body, as a string.
+    """
+    # Limit searches to the first 1kb, since it ought to be at the top.
+    body_start = body[:1024]
+
+    # Let's try and figure out if it has an encoding set in a meta tag.
+    match = _charset_match.search(body_start)
+    if match:
+        return match.group(1).decode("ascii")
+
+    # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
+
+    # If we didn't find a match, see if it an XML document with an encoding.
+    match = _xml_encoding_match.match(body_start)
+    if match:
+        return match.group(1).decode("ascii")
+
+    # If we don't find a match, we'll look at the HTTP Content-Type, and
+    # if that doesn't exist, we'll fall back to UTF-8.
+    content_match = _content_type_match.match(content_type)
+    if content_match:
+        return content_match.group(1)
+
+    return "utf-8"
+
+
 def decode_and_calc_og(
     body: bytes, media_uri: str, request_encoding: Optional[str] = None
 ) -> Dict[str, Optional[str]]:
@@ -725,6 +753,11 @@ def decode_and_calc_og(
     def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
         # Attempt to parse the body. If this fails, log and return no metadata.
         tree = etree.fromstring(body_attempt, parser)
+
+        # The data was successfully parsed, but no tree was found.
+        if tree is None:
+            return {}
+
         return _calc_og(tree, media_uri)
 
     # Attempt to parse the body. If this fails, log and return no metadata.
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 6da76ae994..1136277794 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -22,6 +22,7 @@ from twisted.web.http import Request
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_string
+from synapse.rest.media.v1.media_storage import SpamMediaException
 
 if TYPE_CHECKING:
     from synapse.app.homeserver import HomeServer
@@ -86,9 +87,14 @@ class UploadResource(DirectServeJsonResource):
         #     disposition = headers.getRawHeaders(b"Content-Disposition")[0]
         # TODO(markjh): parse content-dispostion
 
-        content_uri = await self.media_repo.create_content(
-            media_type, upload_name, request.content, content_length, requester.user
-        )
+        try:
+            content_uri = await self.media_repo.create_content(
+                media_type, upload_name, request.content, content_length, requester.user
+            )
+        except SpamMediaException:
+            # For uploading of media we want to respond with a 400, instead of
+            # the default 404, as that would just be confusing.
+            raise SynapseError(400, "Bad content")
 
         logger.info("Uploaded content with URI %r", content_uri)