summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py64
-rw-r--r--tests/appservice/test_appservice.py11
-rw-r--r--tests/federation/test_federation_sender.py5
-rw-r--r--tests/handlers/test_e2e_keys.py30
-rw-r--r--tests/handlers/test_message.py103
-rw-r--r--tests/rest/admin/test_background_updates.py3
-rw-r--r--tests/rest/admin/test_federation.py37
-rw-r--r--tests/rest/admin/test_media.py6
-rw-r--r--tests/rest/admin/test_registration_tokens.py25
-rw-r--r--tests/rest/admin/test_room.py62
-rw-r--r--tests/rest/admin/test_user.py15
-rw-r--r--tests/rest/client/test_auth.py134
-rw-r--r--tests/rest/client/test_relations.py115
-rw-r--r--tests/rest/client/test_room_batch.py180
-rw-r--r--tests/rest/media/v1/test_url_preview.py1
-rw-r--r--tests/storage/test_background_update.py17
-rw-r--r--tests/storage/test_e2e_room_keys.py4
-rw-r--r--tests/storage/test_event_federation.py2
-rw-r--r--tests/test_preview.py46
-rw-r--r--tests/unittest.py11
-rw-r--r--tests/util/test_logcontext.py35
21 files changed, 701 insertions, 205 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 3aa9ba3c43..a2dfa1ed05 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -31,6 +31,7 @@ from synapse.types import Requester
 
 from tests import unittest
 from tests.test_utils import simple_async_mock
+from tests.unittest import override_config
 from tests.utils import mock_getRawHeaders
 
 
@@ -210,6 +211,69 @@ class AuthTestCase(unittest.HomeserverTestCase):
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         self.get_failure(self.auth.get_user_by_req(request), AuthError)
 
+    @override_config({"experimental_features": {"msc3202_device_masquerading": True}})
+    def test_get_user_by_req_appservice_valid_token_valid_device_id(self):
+        """
+        Tests that when an application service passes the device_id URL parameter
+        with the ID of a valid device for the user in question,
+        the requester instance tracks that device ID.
+        """
+        masquerading_user_id = b"@doppelganger:matrix.org"
+        masquerading_device_id = b"DOPPELDEVICE"
+        app_service = Mock(
+            token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
+        )
+        app_service.is_interested_in_user = Mock(return_value=True)
+        self.store.get_app_service_by_token = Mock(return_value=app_service)
+        # This just needs to return a truth-y value.
+        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
+        self.store.get_user_by_access_token = simple_async_mock(None)
+        # This also needs to just return a truth-y value
+        self.store.get_device = simple_async_mock({"hidden": False})
+
+        request = Mock(args={})
+        request.getClientIP.return_value = "127.0.0.1"
+        request.args[b"access_token"] = [self.test_token]
+        request.args[b"user_id"] = [masquerading_user_id]
+        request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        requester = self.get_success(self.auth.get_user_by_req(request))
+        self.assertEquals(
+            requester.user.to_string(), masquerading_user_id.decode("utf8")
+        )
+        self.assertEquals(requester.device_id, masquerading_device_id.decode("utf8"))
+
+    @override_config({"experimental_features": {"msc3202_device_masquerading": True}})
+    def test_get_user_by_req_appservice_valid_token_invalid_device_id(self):
+        """
+        Tests that when an application service passes the device_id URL parameter
+        with an ID that is not a valid device ID for the user in question,
+        the request fails with the appropriate error code.
+        """
+        masquerading_user_id = b"@doppelganger:matrix.org"
+        masquerading_device_id = b"NOT_A_REAL_DEVICE_ID"
+        app_service = Mock(
+            token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
+        )
+        app_service.is_interested_in_user = Mock(return_value=True)
+        self.store.get_app_service_by_token = Mock(return_value=app_service)
+        # This just needs to return a truth-y value.
+        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
+        self.store.get_user_by_access_token = simple_async_mock(None)
+        # This also needs to just return a falsey value
+        self.store.get_device = simple_async_mock(None)
+
+        request = Mock(args={})
+        request.getClientIP.return_value = "127.0.0.1"
+        request.args[b"access_token"] = [self.test_token]
+        request.args[b"user_id"] = [masquerading_user_id]
+        request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+
+        failure = self.get_failure(self.auth.get_user_by_req(request), AuthError)
+        self.assertEquals(failure.value.code, 400)
+        self.assertEquals(failure.value.errcode, Codes.EXCLUSIVE)
+
     def test_get_user_from_macaroon(self):
         self.store.get_user_by_access_token = simple_async_mock(
             TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index f386b5e128..ba2a2bfd64 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -16,13 +16,13 @@ from unittest.mock import Mock
 
 from twisted.internet import defer
 
-from synapse.appservice import ApplicationService
+from synapse.appservice import ApplicationService, Namespace
 
 from tests import unittest
 
 
-def _regex(regex, exclusive=True):
-    return {"regex": re.compile(regex), "exclusive": exclusive}
+def _regex(regex: str, exclusive: bool = True) -> Namespace:
+    return Namespace(exclusive, None, re.compile(regex))
 
 
 class ApplicationServiceTestCase(unittest.TestCase):
@@ -33,11 +33,6 @@ class ApplicationServiceTestCase(unittest.TestCase):
             url="some_url",
             token="some_token",
             hostname="matrix.org",  # only used by get_groups_for_user
-            namespaces={
-                ApplicationService.NS_USERS: [],
-                ApplicationService.NS_ROOMS: [],
-                ApplicationService.NS_ALIASES: [],
-            },
         )
         self.event = Mock(
             type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index b457dad6d2..b2376e2db9 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -266,7 +266,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         )
 
         # expect signing key update edu
-        self.assertEqual(len(self.edus), 1)
+        self.assertEqual(len(self.edus), 2)
+        self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update")
         self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update")
 
         # sign the devices
@@ -491,7 +492,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
     ) -> None:
         """Check that the txn has an EDU with a signing key update."""
         edus = txn["edus"]
-        self.assertEqual(len(edus), 1)
+        self.assertEqual(len(edus), 2)
 
     def generate_and_upload_device_signing_key(
         self, user_id: str, device_id: str
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index f0723892e4..ddcf3ee348 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -161,8 +161,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
     def test_fallback_key(self):
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
-        fallback_key = {"alg1:k1": "key1"}
-        fallback_key2 = {"alg1:k2": "key2"}
+        fallback_key = {"alg1:k1": "fallback_key1"}
+        fallback_key2 = {"alg1:k2": "fallback_key2"}
+        fallback_key3 = {"alg1:k2": "fallback_key3"}
         otk = {"alg1:k2": "key2"}
 
         # we shouldn't have any unused fallback keys yet
@@ -175,7 +176,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             self.handler.upload_keys_for_user(
                 local_user,
                 device_id,
-                {"org.matrix.msc2732.fallback_keys": fallback_key},
+                {"fallback_keys": fallback_key},
             )
         )
 
@@ -220,7 +221,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             self.handler.upload_keys_for_user(
                 local_user,
                 device_id,
-                {"org.matrix.msc2732.fallback_keys": fallback_key},
+                {"fallback_keys": fallback_key},
             )
         )
 
@@ -234,7 +235,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             self.handler.upload_keys_for_user(
                 local_user,
                 device_id,
-                {"org.matrix.msc2732.fallback_keys": fallback_key2},
+                {"fallback_keys": fallback_key2},
             )
         )
 
@@ -271,6 +272,25 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
         )
 
+        # using the unstable prefix should also set the fallback key
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id,
+                {"org.matrix.msc2732.fallback_keys": fallback_key3},
+            )
+        )
+
+        res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+            )
+        )
+        self.assertEqual(
+            res,
+            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
+        )
+
     def test_replace_master_key(self):
         """uploading a new signing key should make the old signing key unavailable"""
         local_user = "@boris:" + self.hs.hostname
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 8a8d369fac..5816295d8b 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -23,6 +23,7 @@ from synapse.types import create_requester
 from synapse.util.stringutils import random_string
 
 from tests import unittest
+from tests.test_utils.event_injection import create_event
 
 logger = logging.getLogger(__name__)
 
@@ -51,6 +52,24 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
         self.requester = create_requester(self.user_id, access_token_id=self.token_id)
 
+    def _create_and_persist_member_event(self) -> Tuple[EventBase, EventContext]:
+        # Create a member event we can use as an auth_event
+        memberEvent, memberEventContext = self.get_success(
+            create_event(
+                self.hs,
+                room_id=self.room_id,
+                type="m.room.member",
+                sender=self.requester.user.to_string(),
+                state_key=self.requester.user.to_string(),
+                content={"membership": "join"},
+            )
+        )
+        self.get_success(
+            self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+        )
+
+        return memberEvent, memberEventContext
+
     def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]:
         """Create a new event with the given transaction ID. All events produced
         by this method will be considered duplicates.
@@ -156,6 +175,90 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(events), 2)
         self.assertEqual(events[0].event_id, events[1].event_id)
 
+    def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(self):
+        """When we set allow_no_prev_events=True, should be able to create a
+        event without any prev_events (only auth_events).
+        """
+        # Create a member event we can use as an auth_event
+        memberEvent, _ = self._create_and_persist_member_event()
+
+        # Try to create the event with empty prev_events bit with some auth_events
+        event, _ = self.get_success(
+            self.handler.create_event(
+                self.requester,
+                {
+                    "type": EventTypes.Message,
+                    "room_id": self.room_id,
+                    "sender": self.requester.user.to_string(),
+                    "content": {"msgtype": "m.text", "body": random_string(5)},
+                },
+                # Empty prev_events is the key thing we're testing here
+                prev_event_ids=[],
+                # But with some auth_events
+                auth_event_ids=[memberEvent.event_id],
+                # Allow no prev_events!
+                allow_no_prev_events=True,
+            )
+        )
+        self.assertIsNotNone(event)
+
+    def test_when_empty_prev_events_not_allowed_reject_event_with_empty_prev_events(
+        self,
+    ):
+        """When we set allow_no_prev_events=False, shouldn't be able to create a
+        event without any prev_events even if it has auth_events. Expect an
+        exception to be raised.
+        """
+        # Create a member event we can use as an auth_event
+        memberEvent, _ = self._create_and_persist_member_event()
+
+        # Try to create the event with empty prev_events but with some auth_events
+        self.get_failure(
+            self.handler.create_event(
+                self.requester,
+                {
+                    "type": EventTypes.Message,
+                    "room_id": self.room_id,
+                    "sender": self.requester.user.to_string(),
+                    "content": {"msgtype": "m.text", "body": random_string(5)},
+                },
+                # Empty prev_events is the key thing we're testing here
+                prev_event_ids=[],
+                # But with some auth_events
+                auth_event_ids=[memberEvent.event_id],
+                # We expect the test to fail because empty prev_events are not
+                # allowed here!
+                allow_no_prev_events=False,
+            ),
+            AssertionError,
+        )
+
+    def test_when_empty_prev_events_allowed_reject_event_with_empty_prev_events_and_auth_events(
+        self,
+    ):
+        """When we set allow_no_prev_events=True, should be able to create a
+        event without any prev_events or auth_events. Expect an exception to be
+        raised.
+        """
+        # Try to create the event with empty prev_events and empty auth_events
+        self.get_failure(
+            self.handler.create_event(
+                self.requester,
+                {
+                    "type": EventTypes.Message,
+                    "room_id": self.room_id,
+                    "sender": self.requester.user.to_string(),
+                    "content": {"msgtype": "m.text", "body": random_string(5)},
+                },
+                prev_event_ids=[],
+                # The event should be rejected when there are no auth_events
+                auth_event_ids=[],
+                # Allow no prev_events!
+                allow_no_prev_events=True,
+            ),
+            AssertionError,
+        )
+
 
 class ServerAclValidationTestCase(unittest.HomeserverTestCase):
     servlets = [
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index 4d152c0d66..1e3fe9c62c 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -23,6 +23,7 @@ from synapse.api.errors import Codes
 from synapse.rest.client import login
 from synapse.server import HomeServer
 from synapse.storage.background_updates import BackgroundUpdater
+from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
@@ -96,7 +97,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
     def _register_bg_update(self) -> None:
         "Adds a bg update but doesn't start it"
 
-        async def _fake_update(progress, batch_size) -> int:
+        async def _fake_update(progress: JsonDict, batch_size: int) -> int:
             await self.clock.sleep(0.2)
             return batch_size
 
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index 5188499ef2..742f194257 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -16,11 +16,14 @@ from typing import List, Optional
 
 from parameterized import parameterized
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.api.errors import Codes
 from synapse.rest.client import login
 from synapse.server import HomeServer
 from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests import unittest
 
@@ -31,7 +34,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs: HomeServer):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastore()
         self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -44,7 +47,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
             ("/_synapse/admin/v1/federation/destinations/dummy",),
         ]
     )
-    def test_requester_is_no_admin(self, url: str):
+    def test_requester_is_no_admin(self, url: str) -> None:
         """
         If the user is not a server admin, an error 403 is returned.
         """
@@ -62,7 +65,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-    def test_invalid_parameter(self):
+    def test_invalid_parameter(self) -> None:
         """
         If parameters are invalid, an error is returned.
         """
@@ -95,7 +98,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # invalid search order
         channel = self.make_request(
@@ -105,7 +108,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # invalid destination
         channel = self.make_request(
@@ -117,7 +120,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
-    def test_limit(self):
+    def test_limit(self) -> None:
         """
         Testing list of destinations with limit
         """
@@ -137,7 +140,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["next_token"], "5")
         self._check_fields(channel.json_body["destinations"])
 
-    def test_from(self):
+    def test_from(self) -> None:
         """
         Testing list of destinations with a defined starting point (from)
         """
@@ -157,7 +160,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertNotIn("next_token", channel.json_body)
         self._check_fields(channel.json_body["destinations"])
 
-    def test_limit_and_from(self):
+    def test_limit_and_from(self) -> None:
         """
         Testing list of destinations with a defined starting point and limit
         """
@@ -177,7 +180,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(channel.json_body["destinations"]), 10)
         self._check_fields(channel.json_body["destinations"])
 
-    def test_next_token(self):
+    def test_next_token(self) -> None:
         """
         Testing that `next_token` appears at the right place
         """
@@ -238,7 +241,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(channel.json_body["destinations"]), 1)
         self.assertNotIn("next_token", channel.json_body)
 
-    def test_list_all_destinations(self):
+    def test_list_all_destinations(self) -> None:
         """
         List all destinations.
         """
@@ -259,7 +262,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         # Check that all fields are available
         self._check_fields(channel.json_body["destinations"])
 
-    def test_order_by(self):
+    def test_order_by(self) -> None:
         """
         Testing order list with parameter `order_by`
         """
@@ -268,7 +271,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
             expected_destination_list: List[str],
             order_by: Optional[str],
             dir: Optional[str] = None,
-        ):
+        ) -> None:
             """Request the list of destinations in a certain order.
             Assert that order is what we expect
 
@@ -358,13 +361,13 @@ class FederationTestCase(unittest.HomeserverTestCase):
             [dest[0][0], dest[2][0], dest[1][0]], "last_successful_stream_ordering", "b"
         )
 
-    def test_search_term(self):
+    def test_search_term(self) -> None:
         """Test that searching for a destination works correctly"""
 
         def _search_test(
             expected_destination: Optional[str],
             search_term: str,
-        ):
+        ) -> None:
             """Search for a destination and check that the returned destinationis a match
 
             Args:
@@ -410,7 +413,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         _search_test(None, "foo")
         _search_test(None, "bar")
 
-    def test_get_single_destination(self):
+    def test_get_single_destination(self) -> None:
         """
         Get one specific destinations.
         """
@@ -429,7 +432,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         # convert channel.json_body into a List
         self._check_fields([channel.json_body])
 
-    def _create_destinations(self, number_destinations: int):
+    def _create_destinations(self, number_destinations: int) -> None:
         """Create a number of destinations
 
         Args:
@@ -442,7 +445,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
                 self.store.set_destination_last_successful_stream_ordering(dest, 100)
             )
 
-    def _check_fields(self, content: List[JsonDict]):
+    def _check_fields(self, content: List[JsonDict]) -> None:
         """Checks that the expected destination attributes are present in content
 
         Args:
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 81e578fd26..86aff7575c 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -360,7 +360,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
             channel.code,
             msg=channel.json_body,
         )
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
         self.assertEqual(
             "Boolean query parameter 'keep_profiles' must be one of ['true', 'false']",
             channel.json_body["error"],
@@ -580,7 +580,9 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
 
         return server_and_media_id
 
-    def _access_media(self, server_and_media_id, expect_success=True) -> None:
+    def _access_media(
+        self, server_and_media_id: str, expect_success: bool = True
+    ) -> None:
         """
         Try to access a media and check the result
         """
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
index 350a62dda6..81f3ac7f04 100644
--- a/tests/rest/admin/test_registration_tokens.py
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -14,6 +14,7 @@
 import random
 import string
 from http import HTTPStatus
+from typing import Optional
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -42,21 +43,27 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
 
         self.url = "/_synapse/admin/v1/registration_tokens"
 
-    def _new_token(self, **kwargs) -> str:
+    def _new_token(
+        self,
+        token: Optional[str] = None,
+        uses_allowed: Optional[int] = None,
+        pending: int = 0,
+        completed: int = 0,
+        expiry_time: Optional[int] = None,
+    ) -> str:
         """Helper function to create a token."""
-        token = kwargs.get(
-            "token",
-            "".join(random.choices(string.ascii_letters, k=8)),
-        )
+        if token is None:
+            token = "".join(random.choices(string.ascii_letters, k=8))
+
         self.get_success(
             self.store.db_pool.simple_insert(
                 "registration_tokens",
                 {
                     "token": token,
-                    "uses_allowed": kwargs.get("uses_allowed", None),
-                    "pending": kwargs.get("pending", 0),
-                    "completed": kwargs.get("completed", 0),
-                    "expiry_time": kwargs.get("expiry_time", None),
+                    "uses_allowed": uses_allowed,
+                    "pending": pending,
+                    "completed": completed,
+                    "expiry_time": expiry_time,
                 },
             )
         )
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 22f9aa6234..d2c8781cd4 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -66,7 +66,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         )
         self.url = "/_synapse/admin/v1/rooms/%s" % self.room_id
 
-    def test_requester_is_no_admin(self):
+    def test_requester_is_no_admin(self) -> None:
         """
         If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
         """
@@ -81,7 +81,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-    def test_room_does_not_exist(self):
+    def test_room_does_not_exist(self) -> None:
         """
         Check that unknown rooms/server return 200
         """
@@ -96,7 +96,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
 
-    def test_room_is_not_valid(self):
+    def test_room_is_not_valid(self) -> None:
         """
         Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
         """
@@ -115,7 +115,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
             channel.json_body["error"],
         )
 
-    def test_new_room_user_does_not_exist(self):
+    def test_new_room_user_does_not_exist(self) -> None:
         """
         Tests that the user ID must be from local server but it does not have to exist.
         """
@@ -133,7 +133,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self.assertIn("failed_to_kick_users", channel.json_body)
         self.assertIn("local_aliases", channel.json_body)
 
-    def test_new_room_user_is_not_local(self):
+    def test_new_room_user_is_not_local(self) -> None:
         """
         Check that only local users can create new room to move members.
         """
@@ -151,7 +151,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
             channel.json_body["error"],
         )
 
-    def test_block_is_not_bool(self):
+    def test_block_is_not_bool(self) -> None:
         """
         If parameter `block` is not boolean, return an error
         """
@@ -166,7 +166,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
 
-    def test_purge_is_not_bool(self):
+    def test_purge_is_not_bool(self) -> None:
         """
         If parameter `purge` is not boolean, return an error
         """
@@ -181,7 +181,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
 
-    def test_purge_room_and_block(self):
+    def test_purge_room_and_block(self) -> None:
         """Test to purge a room and block it.
         Members will not be moved to a new room and will not receive a message.
         """
@@ -212,7 +212,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self._is_blocked(self.room_id, expect=True)
         self._has_no_members(self.room_id)
 
-    def test_purge_room_and_not_block(self):
+    def test_purge_room_and_not_block(self) -> None:
         """Test to purge a room and do not block it.
         Members will not be moved to a new room and will not receive a message.
         """
@@ -243,7 +243,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self._is_blocked(self.room_id, expect=False)
         self._has_no_members(self.room_id)
 
-    def test_block_room_and_not_purge(self):
+    def test_block_room_and_not_purge(self) -> None:
         """Test to block a room without purging it.
         Members will not be moved to a new room and will not receive a message.
         The room will not be purged.
@@ -299,7 +299,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
         self._is_blocked(room_id)
 
-    def test_shutdown_room_consent(self):
+    def test_shutdown_room_consent(self) -> None:
         """Test that we can shutdown rooms with local users who have not
         yet accepted the privacy policy. This used to fail when we tried to
         force part the user from the old room.
@@ -351,7 +351,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self._is_purged(self.room_id)
         self._has_no_members(self.room_id)
 
-    def test_shutdown_room_block_peek(self):
+    def test_shutdown_room_block_peek(self) -> None:
         """Test that a world_readable room can no longer be peeked into after
         it has been shut down.
         Members will be moved to a new room and will receive a message.
@@ -400,7 +400,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         # Assert we can no longer peek into the room
         self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN)
 
-    def _is_blocked(self, room_id, expect=True):
+    def _is_blocked(self, room_id: str, expect: bool = True) -> None:
         """Assert that the room is blocked or not"""
         d = self.store.is_room_blocked(room_id)
         if expect:
@@ -408,17 +408,17 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         else:
             self.assertIsNone(self.get_success(d))
 
-    def _has_no_members(self, room_id):
+    def _has_no_members(self, room_id: str) -> None:
         """Assert there is now no longer anyone in the room"""
         users_in_room = self.get_success(self.store.get_users_in_room(room_id))
         self.assertEqual([], users_in_room)
 
-    def _is_member(self, room_id, user_id):
+    def _is_member(self, room_id: str, user_id: str) -> None:
         """Test that user is member of the room"""
         users_in_room = self.get_success(self.store.get_users_in_room(room_id))
         self.assertIn(user_id, users_in_room)
 
-    def _is_purged(self, room_id):
+    def _is_purged(self, room_id: str) -> None:
         """Test that the following tables have been purged of all rows related to the room."""
         for table in PURGE_TABLES:
             count = self.get_success(
@@ -432,7 +432,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
             self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
 
-    def _assert_peek(self, room_id, expect_code):
+    def _assert_peek(self, room_id: str, expect_code: int) -> None:
         """Assert that the admin user can (or cannot) peek into the room."""
 
         url = "rooms/%s/initialSync" % (room_id,)
@@ -492,7 +492,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
             ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"),
         ]
     )
-    def test_requester_is_no_admin(self, method: str, url: str):
+    def test_requester_is_no_admin(self, method: str, url: str) -> None:
         """
         If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
         """
@@ -507,7 +507,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-    def test_room_does_not_exist(self):
+    def test_room_does_not_exist(self) -> None:
         """
         Check that unknown rooms/server return 200
 
@@ -544,7 +544,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
             ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
         ]
     )
-    def test_room_is_not_valid(self, method: str, url: str):
+    def test_room_is_not_valid(self, method: str, url: str) -> None:
         """
         Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
         """
@@ -562,7 +562,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
             channel.json_body["error"],
         )
 
-    def test_new_room_user_does_not_exist(self):
+    def test_new_room_user_does_not_exist(self) -> None:
         """
         Tests that the user ID must be from local server but it does not have to exist.
         """
@@ -580,7 +580,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
 
         self._test_result(delete_id, self.other_user, expect_new_room=True)
 
-    def test_new_room_user_is_not_local(self):
+    def test_new_room_user_is_not_local(self) -> None:
         """
         Check that only local users can create new room to move members.
         """
@@ -598,7 +598,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
             channel.json_body["error"],
         )
 
-    def test_block_is_not_bool(self):
+    def test_block_is_not_bool(self) -> None:
         """
         If parameter `block` is not boolean, return an error
         """
@@ -613,7 +613,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
 
-    def test_purge_is_not_bool(self):
+    def test_purge_is_not_bool(self) -> None:
         """
         If parameter `purge` is not boolean, return an error
         """
@@ -628,7 +628,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
 
-    def test_delete_expired_status(self):
+    def test_delete_expired_status(self) -> None:
         """Test that the task status is removed after expiration."""
 
         # first task, do not purge, that we can create a second task
@@ -699,7 +699,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
-    def test_delete_same_room_twice(self):
+    def test_delete_same_room_twice(self) -> None:
         """Test that the call for delete a room at second time gives an exception."""
 
         body = {"new_room_user_id": self.admin_user}
@@ -743,7 +743,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
             expect_new_room=True,
         )
 
-    def test_purge_room_and_block(self):
+    def test_purge_room_and_block(self) -> None:
         """Test to purge a room and block it.
         Members will not be moved to a new room and will not receive a message.
         """
@@ -774,7 +774,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self._is_blocked(self.room_id, expect=True)
         self._has_no_members(self.room_id)
 
-    def test_purge_room_and_not_block(self):
+    def test_purge_room_and_not_block(self) -> None:
         """Test to purge a room and do not block it.
         Members will not be moved to a new room and will not receive a message.
         """
@@ -805,7 +805,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self._is_blocked(self.room_id, expect=False)
         self._has_no_members(self.room_id)
 
-    def test_block_room_and_not_purge(self):
+    def test_block_room_and_not_purge(self) -> None:
         """Test to block a room without purging it.
         Members will not be moved to a new room and will not receive a message.
         The room will not be purged.
@@ -838,7 +838,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self._is_blocked(self.room_id, expect=True)
         self._has_no_members(self.room_id)
 
-    def test_shutdown_room_consent(self):
+    def test_shutdown_room_consent(self) -> None:
         """Test that we can shutdown rooms with local users who have not
         yet accepted the privacy policy. This used to fail when we tried to
         force part the user from the old room.
@@ -899,7 +899,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self._is_purged(self.room_id)
         self._has_no_members(self.room_id)
 
-    def test_shutdown_room_block_peek(self):
+    def test_shutdown_room_block_peek(self) -> None:
         """Test that a world_readable room can no longer be peeked into after
         it has been shut down.
         Members will be moved to a new room and will receive a message.
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 4fedd5fd08..eea675991c 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -608,7 +608,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # invalid deactivated
         channel = self.make_request(
@@ -618,7 +618,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # unkown order_by
         channel = self.make_request(
@@ -628,7 +628,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # invalid search order
         channel = self.make_request(
@@ -638,7 +638,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
     def test_limit(self):
         """
@@ -1550,7 +1550,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Create user
         body = {
             "password": "abc123",
-            "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+            # Note that the given email is not in canonical form.
+            "threepids": [{"medium": "email", "address": "Bob@bob.bob"}],
         }
 
         channel = self.make_request(
@@ -2896,7 +2897,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # invalid search order
         channel = self.make_request(
@@ -2906,7 +2907,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # negative limit
         channel = self.make_request(
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 72bbc87b4a..27cb856b0a 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -85,7 +85,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         channel = self.make_request(
             "GET", "auth/m.login.recaptcha/fallback/web?session=" + session
         )
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, HTTPStatus.OK)
 
         channel = self.make_request(
             "POST",
@@ -104,7 +104,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         """Ensure that fallback auth via a captcha works."""
         # Returns a 401 as per the spec
         channel = self.register(
-            401,
+            HTTPStatus.UNAUTHORIZED,
             {"username": "user", "type": "m.login.password", "password": "bar"},
         )
 
@@ -116,15 +116,17 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         )
 
         # Complete the recaptcha step.
-        self.recaptcha(session, 200)
+        self.recaptcha(session, HTTPStatus.OK)
 
         # also complete the dummy auth
-        self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}})
+        self.register(
+            HTTPStatus.OK, {"auth": {"session": session, "type": "m.login.dummy"}}
+        )
 
         # Now we should have fulfilled a complete auth flow, including
         # the recaptcha fallback step, we can then send a
         # request to the register API with the session in the authdict.
-        channel = self.register(200, {"auth": {"session": session}})
+        channel = self.register(HTTPStatus.OK, {"auth": {"session": session}})
 
         # We're given a registered user.
         self.assertEqual(channel.json_body["user_id"], "@user:test")
@@ -137,7 +139,8 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         # will be used.)
         # Returns a 401 as per the spec
         channel = self.register(
-            401, {"username": "user", "type": "m.login.password", "password": "bar"}
+            HTTPStatus.UNAUTHORIZED,
+            {"username": "user", "type": "m.login.password", "password": "bar"},
         )
 
         # Grab the session
@@ -231,7 +234,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
         """
         # Attempt to delete this device.
         # Returns a 401 as per the spec
-        channel = self.delete_device(self.user_tok, self.device_id, 401)
+        channel = self.delete_device(
+            self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+        )
 
         # Grab the session
         session = channel.json_body["session"]
@@ -242,7 +247,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         self.delete_device(
             self.user_tok,
             self.device_id,
-            200,
+            HTTPStatus.OK,
             {
                 "auth": {
                     "type": "m.login.password",
@@ -260,14 +265,16 @@ class UIAuthTests(unittest.HomeserverTestCase):
         UIA - check that still works.
         """
 
-        channel = self.delete_device(self.user_tok, self.device_id, 401)
+        channel = self.delete_device(
+            self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+        )
         session = channel.json_body["session"]
 
         # Make another request providing the UI auth flow.
         self.delete_device(
             self.user_tok,
             self.device_id,
-            200,
+            HTTPStatus.OK,
             {
                 "auth": {
                     "type": "m.login.password",
@@ -293,7 +300,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # Attempt to delete the first device.
         # Returns a 401 as per the spec
-        channel = self.delete_devices(401, {"devices": [self.device_id]})
+        channel = self.delete_devices(
+            HTTPStatus.UNAUTHORIZED, {"devices": [self.device_id]}
+        )
 
         # Grab the session
         session = channel.json_body["session"]
@@ -303,7 +312,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         # Make another request providing the UI auth flow, but try to delete the
         # second device.
         self.delete_devices(
-            200,
+            HTTPStatus.OK,
             {
                 "devices": ["dev2"],
                 "auth": {
@@ -324,7 +333,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # Attempt to delete the first device.
         # Returns a 401 as per the spec
-        channel = self.delete_device(self.user_tok, self.device_id, 401)
+        channel = self.delete_device(
+            self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+        )
 
         # Grab the session
         session = channel.json_body["session"]
@@ -338,7 +349,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         self.delete_device(
             self.user_tok,
             "dev2",
-            403,
+            HTTPStatus.FORBIDDEN,
             {
                 "auth": {
                     "type": "m.login.password",
@@ -361,13 +372,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
         self.login("test", self.user_pass, "dev3")
 
         # Attempt to delete a device. This works since the user just logged in.
-        self.delete_device(self.user_tok, "dev2", 200)
+        self.delete_device(self.user_tok, "dev2", HTTPStatus.OK)
 
         # Move the clock forward past the validation timeout.
         self.reactor.advance(6)
 
         # Deleting another devices throws the user into UI auth.
-        channel = self.delete_device(self.user_tok, "dev3", 401)
+        channel = self.delete_device(self.user_tok, "dev3", HTTPStatus.UNAUTHORIZED)
 
         # Grab the session
         session = channel.json_body["session"]
@@ -378,7 +389,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         self.delete_device(
             self.user_tok,
             "dev3",
-            200,
+            HTTPStatus.OK,
             {
                 "auth": {
                     "type": "m.login.password",
@@ -393,7 +404,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         # due to re-using the previous session.
         #
         # Note that *no auth* information is provided, not even a session iD!
-        self.delete_device(self.user_tok, self.device_id, 200)
+        self.delete_device(self.user_tok, self.device_id, HTTPStatus.OK)
 
     @skip_unless(HAS_OIDC, "requires OIDC")
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
@@ -413,7 +424,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
         self.assertEqual(login_resp["user_id"], self.user)
 
         # initiate a UI Auth process by attempting to delete the device
-        channel = self.delete_device(self.user_tok, self.device_id, 401)
+        channel = self.delete_device(
+            self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+        )
 
         # check that SSO is offered
         flows = channel.json_body["flows"]
@@ -426,13 +439,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
         )
 
         # that should serve a confirmation page
-        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
 
         # and now the delete request should succeed.
         self.delete_device(
             self.user_tok,
             self.device_id,
-            200,
+            HTTPStatus.OK,
             body={"auth": {"session": session_id}},
         )
 
@@ -445,13 +458,15 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # now call the device deletion API: we should get the option to auth with SSO
         # and not password.
-        channel = self.delete_device(user_tok, device_id, 401)
+        channel = self.delete_device(user_tok, device_id, HTTPStatus.UNAUTHORIZED)
 
         flows = channel.json_body["flows"]
         self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
 
     def test_does_not_offer_sso_for_password_user(self):
-        channel = self.delete_device(self.user_tok, self.device_id, 401)
+        channel = self.delete_device(
+            self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+        )
 
         flows = channel.json_body["flows"]
         self.assertEqual(flows, [{"stages": ["m.login.password"]}])
@@ -463,7 +478,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
         login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
         self.assertEqual(login_resp["user_id"], self.user)
 
-        channel = self.delete_device(self.user_tok, self.device_id, 401)
+        channel = self.delete_device(
+            self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+        )
 
         flows = channel.json_body["flows"]
         # we have no particular expectations of ordering here
@@ -480,7 +497,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
         self.assertEqual(login_resp["user_id"], self.user)
 
         # start a UI Auth flow by attempting to delete a device
-        channel = self.delete_device(self.user_tok, self.device_id, 401)
+        channel = self.delete_device(
+            self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+        )
 
         flows = channel.json_body["flows"]
         self.assertIn({"stages": ["m.login.sso"]}, flows)
@@ -496,7 +515,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # ... and the delete op should now fail with a 403
         self.delete_device(
-            self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}}
+            self.user_tok,
+            self.device_id,
+            HTTPStatus.FORBIDDEN,
+            body={"auth": {"session": session_id}},
         )
 
 
@@ -551,7 +573,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         login_without_refresh = self.make_request(
             "POST", "/_matrix/client/r0/login", body
         )
-        self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result)
+        self.assertEqual(
+            login_without_refresh.code, HTTPStatus.OK, login_without_refresh.result
+        )
         self.assertNotIn("refresh_token", login_without_refresh.json_body)
 
         login_with_refresh = self.make_request(
@@ -559,7 +583,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "/_matrix/client/r0/login",
             {"refresh_token": True, **body},
         )
-        self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result)
+        self.assertEqual(
+            login_with_refresh.code, HTTPStatus.OK, login_with_refresh.result
+        )
         self.assertIn("refresh_token", login_with_refresh.json_body)
         self.assertIn("expires_in_ms", login_with_refresh.json_body)
 
@@ -577,7 +603,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             },
         )
         self.assertEqual(
-            register_without_refresh.code, 200, register_without_refresh.result
+            register_without_refresh.code,
+            HTTPStatus.OK,
+            register_without_refresh.result,
         )
         self.assertNotIn("refresh_token", register_without_refresh.json_body)
 
@@ -591,7 +619,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
                 "refresh_token": True,
             },
         )
-        self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result)
+        self.assertEqual(
+            register_with_refresh.code, HTTPStatus.OK, register_with_refresh.result
+        )
         self.assertIn("refresh_token", register_with_refresh.json_body)
         self.assertIn("expires_in_ms", register_with_refresh.json_body)
 
@@ -610,14 +640,14 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "/_matrix/client/r0/login",
             body,
         )
-        self.assertEqual(login_response.code, 200, login_response.result)
+        self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
 
         refresh_response = self.make_request(
             "POST",
             "/_matrix/client/v1/refresh",
             {"refresh_token": login_response.json_body["refresh_token"]},
         )
-        self.assertEqual(refresh_response.code, 200, refresh_response.result)
+        self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
         self.assertIn("access_token", refresh_response.json_body)
         self.assertIn("refresh_token", refresh_response.json_body)
         self.assertIn("expires_in_ms", refresh_response.json_body)
@@ -648,7 +678,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "/_matrix/client/r0/login",
             body,
         )
-        self.assertEqual(login_response.code, 200, login_response.result)
+        self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
         self.assertApproximates(
             login_response.json_body["expires_in_ms"], 60 * 1000, 100
         )
@@ -658,7 +688,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "/_matrix/client/v1/refresh",
             {"refresh_token": login_response.json_body["refresh_token"]},
         )
-        self.assertEqual(refresh_response.code, 200, refresh_response.result)
+        self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
         self.assertApproximates(
             refresh_response.json_body["expires_in_ms"], 60 * 1000, 100
         )
@@ -705,7 +735,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "/_matrix/client/r0/login",
             {"refresh_token": True, **body},
         )
-        self.assertEqual(login_response1.code, 200, login_response1.result)
+        self.assertEqual(login_response1.code, HTTPStatus.OK, login_response1.result)
         self.assertApproximates(
             login_response1.json_body["expires_in_ms"], 60 * 1000, 100
         )
@@ -716,7 +746,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "/_matrix/client/r0/login",
             body,
         )
-        self.assertEqual(login_response2.code, 200, login_response2.result)
+        self.assertEqual(login_response2.code, HTTPStatus.OK, login_response2.result)
         nonrefreshable_access_token = login_response2.json_body["access_token"]
 
         # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry)
@@ -818,7 +848,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "/_matrix/client/r0/login",
             body,
         )
-        self.assertEqual(login_response.code, 200, login_response.result)
+        self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
         refresh_token = login_response.json_body["refresh_token"]
 
         # Advance shy of 2 minutes into the future
@@ -826,7 +856,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
 
         # Refresh our session. The refresh token should still be valid right now.
         refresh_response = self.use_refresh_token(refresh_token)
-        self.assertEqual(refresh_response.code, 200, refresh_response.result)
+        self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
         self.assertIn(
             "refresh_token",
             refresh_response.json_body,
@@ -846,7 +876,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         # This should fail because the refresh token's lifetime has also been
         # diminished as our session expired.
         refresh_response = self.use_refresh_token(refresh_token)
-        self.assertEqual(refresh_response.code, 403, refresh_response.result)
+        self.assertEqual(
+            refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result
+        )
 
     def test_refresh_token_invalidation(self):
         """Refresh tokens are invalidated after first use of the next token.
@@ -875,7 +907,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "/_matrix/client/r0/login",
             body,
         )
-        self.assertEqual(login_response.code, 200, login_response.result)
+        self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
 
         # This first refresh should work properly
         first_refresh_response = self.make_request(
@@ -884,7 +916,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             {"refresh_token": login_response.json_body["refresh_token"]},
         )
         self.assertEqual(
-            first_refresh_response.code, 200, first_refresh_response.result
+            first_refresh_response.code, HTTPStatus.OK, first_refresh_response.result
         )
 
         # This one as well, since the token in the first one was never used
@@ -894,7 +926,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             {"refresh_token": login_response.json_body["refresh_token"]},
         )
         self.assertEqual(
-            second_refresh_response.code, 200, second_refresh_response.result
+            second_refresh_response.code, HTTPStatus.OK, second_refresh_response.result
         )
 
         # This one should not, since the token from the first refresh is not valid anymore
@@ -904,7 +936,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             {"refresh_token": first_refresh_response.json_body["refresh_token"]},
         )
         self.assertEqual(
-            third_refresh_response.code, 401, third_refresh_response.result
+            third_refresh_response.code,
+            HTTPStatus.UNAUTHORIZED,
+            third_refresh_response.result,
         )
 
         # The associated access token should also be invalid
@@ -913,7 +947,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "/_matrix/client/r0/account/whoami",
             access_token=first_refresh_response.json_body["access_token"],
         )
-        self.assertEqual(whoami_response.code, 401, whoami_response.result)
+        self.assertEqual(
+            whoami_response.code, HTTPStatus.UNAUTHORIZED, whoami_response.result
+        )
 
         # But all other tokens should work (they will expire after some time)
         for access_token in [
@@ -923,7 +959,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             whoami_response = self.make_request(
                 "GET", "/_matrix/client/r0/account/whoami", access_token=access_token
             )
-            self.assertEqual(whoami_response.code, 200, whoami_response.result)
+            self.assertEqual(
+                whoami_response.code, HTTPStatus.OK, whoami_response.result
+            )
 
         # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail
         fourth_refresh_response = self.make_request(
@@ -932,7 +970,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             {"refresh_token": login_response.json_body["refresh_token"]},
         )
         self.assertEqual(
-            fourth_refresh_response.code, 403, fourth_refresh_response.result
+            fourth_refresh_response.code,
+            HTTPStatus.FORBIDDEN,
+            fourth_refresh_response.result,
         )
 
         # But refreshing from the last valid refresh token still works
@@ -942,5 +982,5 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             {"refresh_token": second_refresh_response.json_body["refresh_token"]},
         )
         self.assertEqual(
-            fifth_refresh_response.code, 200, fifth_refresh_response.result
+            fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
         )
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 1b58b73136..c026d526ef 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -16,6 +16,7 @@
 import itertools
 import urllib.parse
 from typing import Dict, List, Optional, Tuple
+from unittest.mock import patch
 
 from synapse.api.constants import EventTypes, RelationTypes
 from synapse.rest import admin
@@ -23,6 +24,8 @@ from synapse.rest.client import login, register, relations, room, sync
 
 from tests import unittest
 from tests.server import FakeChannel
+from tests.test_utils import make_awaitable
+from tests.test_utils.event_injection import inject_event
 
 
 class RelationsTestCase(unittest.HomeserverTestCase):
@@ -651,6 +654,118 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             },
         )
 
+    @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+    def test_ignore_invalid_room(self):
+        """Test that we ignore invalid relations over federation."""
+        # Create another room and send a message in it.
+        room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
+        res = self.helper.send(room2, body="Hi!", tok=self.user_token)
+        parent_id = res["event_id"]
+
+        # Disable the validation to pretend this came over federation.
+        with patch(
+            "synapse.handlers.message.EventCreationHandler._validate_event_relation",
+            new=lambda self, event: make_awaitable(None),
+        ):
+            # Generate a various relations from a different room.
+            self.get_success(
+                inject_event(
+                    self.hs,
+                    room_id=self.room,
+                    type="m.reaction",
+                    sender=self.user_id,
+                    content={
+                        "m.relates_to": {
+                            "rel_type": RelationTypes.ANNOTATION,
+                            "event_id": parent_id,
+                            "key": "A",
+                        }
+                    },
+                )
+            )
+
+            self.get_success(
+                inject_event(
+                    self.hs,
+                    room_id=self.room,
+                    type="m.room.message",
+                    sender=self.user_id,
+                    content={
+                        "body": "foo",
+                        "msgtype": "m.text",
+                        "m.relates_to": {
+                            "rel_type": RelationTypes.REFERENCE,
+                            "event_id": parent_id,
+                        },
+                    },
+                )
+            )
+
+            self.get_success(
+                inject_event(
+                    self.hs,
+                    room_id=self.room,
+                    type="m.room.message",
+                    sender=self.user_id,
+                    content={
+                        "body": "foo",
+                        "msgtype": "m.text",
+                        "m.relates_to": {
+                            "rel_type": RelationTypes.THREAD,
+                            "event_id": parent_id,
+                        },
+                    },
+                )
+            )
+
+            self.get_success(
+                inject_event(
+                    self.hs,
+                    room_id=self.room,
+                    type="m.room.message",
+                    sender=self.user_id,
+                    content={
+                        "body": "foo",
+                        "msgtype": "m.text",
+                        "new_content": {
+                            "body": "new content",
+                            "msgtype": "m.text",
+                        },
+                        "m.relates_to": {
+                            "rel_type": RelationTypes.REPLACE,
+                            "event_id": parent_id,
+                        },
+                    },
+                )
+            )
+
+        # They should be ignored when fetching relations.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{room2}/relations/{parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(channel.json_body["chunk"], [])
+
+        # And when fetching aggregations.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(channel.json_body["chunk"], [])
+
+        # And for bundled aggregations.
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{room2}/event/{parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertNotIn("m.relations", channel.json_body["unsigned"])
+
     def test_edit(self):
         """Test that a simple edit works."""
 
diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py
new file mode 100644
index 0000000000..721454c187
--- /dev/null
+++ b/tests/rest/client/test_room_batch.py
@@ -0,0 +1,180 @@
+import logging
+from typing import List, Tuple
+from unittest.mock import Mock, patch
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import EventContentFields, EventTypes
+from synapse.appservice import ApplicationService
+from synapse.rest import admin
+from synapse.rest.client import login, register, room, room_batch
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests import unittest
+
+logger = logging.getLogger(__name__)
+
+
+def _create_join_state_events_for_batch_send_request(
+    virtual_user_ids: List[str],
+    insert_time: int,
+) -> List[JsonDict]:
+    return [
+        {
+            "type": EventTypes.Member,
+            "sender": virtual_user_id,
+            "origin_server_ts": insert_time,
+            "content": {
+                "membership": "join",
+                "displayname": "display-name-for-%s" % (virtual_user_id,),
+            },
+            "state_key": virtual_user_id,
+        }
+        for virtual_user_id in virtual_user_ids
+    ]
+
+
+def _create_message_events_for_batch_send_request(
+    virtual_user_id: str, insert_time: int, count: int
+) -> List[JsonDict]:
+    return [
+        {
+            "type": EventTypes.Message,
+            "sender": virtual_user_id,
+            "origin_server_ts": insert_time,
+            "content": {
+                "msgtype": "m.text",
+                "body": "Historical %d" % (i),
+                EventContentFields.MSC2716_HISTORICAL: True,
+            },
+        }
+        for i in range(count)
+    ]
+
+
+class RoomBatchTestCase(unittest.HomeserverTestCase):
+    """Test importing batches of historical messages."""
+
+    servlets = [
+        admin.register_servlets_for_client_rest_resource,
+        room_batch.register_servlets,
+        room.register_servlets,
+        register.register_servlets,
+        login.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        config = self.default_config()
+
+        self.appservice = ApplicationService(
+            token="i_am_an_app_service",
+            hostname="test",
+            id="1234",
+            namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+            # Note: this user does not have to match the regex above
+            sender="@as_main:test",
+        )
+
+        mock_load_appservices = Mock(return_value=[self.appservice])
+        with patch(
+            "synapse.storage.databases.main.appservice.load_appservices",
+            mock_load_appservices,
+        ):
+            hs = self.setup_test_homeserver(config=config)
+        return hs
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.clock = clock
+        self.storage = hs.get_storage()
+
+        self.virtual_user_id = self.register_appservice_user(
+            "as_user_potato", self.appservice.token
+        )
+
+    def _create_test_room(self) -> Tuple[str, str, str, str]:
+        room_id = self.helper.create_room_as(
+            self.appservice.sender, tok=self.appservice.token
+        )
+
+        res_a = self.helper.send_event(
+            room_id=room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "A",
+            },
+            tok=self.appservice.token,
+        )
+        event_id_a = res_a["event_id"]
+
+        res_b = self.helper.send_event(
+            room_id=room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "B",
+            },
+            tok=self.appservice.token,
+        )
+        event_id_b = res_b["event_id"]
+
+        res_c = self.helper.send_event(
+            room_id=room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "C",
+            },
+            tok=self.appservice.token,
+        )
+        event_id_c = res_c["event_id"]
+
+        return room_id, event_id_a, event_id_b, event_id_c
+
+    @unittest.override_config({"experimental_features": {"msc2716_enabled": True}})
+    def test_same_state_groups_for_whole_historical_batch(self):
+        """Make sure that when using the `/batch_send` endpoint to import a
+        bunch of historical messages, it re-uses the same `state_group` across
+        the whole batch. This is an easy optimization to make sure we're getting
+        right because the state for the whole batch is contained in
+        `state_events_at_start` and can be shared across everything.
+        """
+
+        time_before_room = int(self.clock.time_msec())
+        room_id, event_id_a, _, _ = self._create_test_room()
+
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/unstable/org.matrix.msc2716/rooms/%s/batch_send?prev_event_id=%s"
+            % (room_id, event_id_a),
+            content={
+                "events": _create_message_events_for_batch_send_request(
+                    self.virtual_user_id, time_before_room, 3
+                ),
+                "state_events_at_start": _create_join_state_events_for_batch_send_request(
+                    [self.virtual_user_id], time_before_room
+                ),
+            },
+            access_token=self.appservice.token,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+        # Get the historical event IDs that we just imported
+        historical_event_ids = channel.json_body["event_ids"]
+        self.assertEqual(len(historical_event_ids), 3)
+
+        # Fetch the state_groups
+        state_group_map = self.get_success(
+            self.storage.state.get_state_groups_ids(room_id, historical_event_ids)
+        )
+
+        # We expect all of the historical events to be using the same state_group
+        # so there should only be a single state_group here!
+        self.assertEqual(
+            len(state_group_map.keys()),
+            1,
+            "Expected a single state_group to be returned by saw state_groups=%s"
+            % (state_group_map.keys(),),
+        )
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 8698135a76..16e904f15b 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -1,4 +1,5 @@
 # Copyright 2018 New Vector Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index d77c001506..6156dfac4e 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -12,15 +12,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-# Use backported mock for AsyncMock support on Python 3.6.
-from mock import Mock
+from unittest.mock import Mock
 
 from twisted.internet.defer import Deferred, ensureDeferred
 
 from synapse.storage.background_updates import BackgroundUpdater
 
 from tests import unittest
-from tests.test_utils import make_awaitable
+from tests.test_utils import make_awaitable, simple_async_mock
 
 
 class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
@@ -116,14 +115,14 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
         )
 
         # Mock out the AsyncContextManager
-        self._update_ctx_manager = Mock(spec=["__aenter__", "__aexit__"])
-        self._update_ctx_manager.__aenter__ = Mock(
-            return_value=make_awaitable(None),
-        )
-        self._update_ctx_manager.__aexit__ = Mock(return_value=make_awaitable(None))
+        class MockCM:
+            __aenter__ = simple_async_mock(return_value=None)
+            __aexit__ = simple_async_mock(return_value=None)
+
+        self._update_ctx_manager = MockCM
 
         # Mock out the `update_handler` callback
-        self._on_update = Mock(return_value=self._update_ctx_manager)
+        self._on_update = Mock(return_value=self._update_ctx_manager())
 
         # Define a default batch size value that's not the same as the internal default
         # value (100).
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index 9b6b425425..7556171d8a 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -12,10 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.storage.databases.main.e2e_room_keys import RoomKey
+
 from tests import unittest
 
 # sample room_key data for use in the tests
-room_key = {
+room_key: RoomKey = {
     "first_message_index": 1,
     "forwarded_count": 1,
     "is_verified": False,
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index c3fcf7e7b4..ecfda7677e 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -550,7 +550,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             self.store.db_pool.simple_select_one_onecol(
                 table="federation_inbound_events_staging",
                 keyvalues={"room_id": room_id},
-                retcol="COALESCE(COUNT(*), 0)",
+                retcol="COUNT(*)",
                 desc="test_prune_inbound_federation_queue",
             )
         )
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 40b89fb2ef..46e02f483f 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -12,10 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.rest.media.v1.preview_url_resource import (
-    _calc_og,
+from synapse.rest.media.v1.preview_html import (
+    _get_html_media_encodings,
     decode_body,
-    get_html_media_encodings,
+    parse_html_to_open_graph,
     summarize_paragraphs,
 )
 
@@ -160,7 +160,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
@@ -176,7 +176,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
@@ -195,7 +195,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
 
         self.assertEqual(
             og,
@@ -217,7 +217,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
@@ -231,7 +231,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
 
         self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
 
@@ -246,7 +246,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
 
         self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
 
@@ -261,7 +261,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
 
         self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
 
@@ -289,7 +289,7 @@ class CalcOgTestCase(unittest.TestCase):
         <head><title>Foo</title></head><body>Some text.</body></html>
         """.strip()
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_invalid_encoding(self):
@@ -303,7 +303,7 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_invalid_encoding2(self):
@@ -318,7 +318,7 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
         self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
 
     def test_windows_1252(self):
@@ -332,14 +332,14 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
         self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
 
 
 class MediaEncodingTestCase(unittest.TestCase):
     def test_meta_charset(self):
         """A character encoding is found via the meta tag."""
-        encodings = get_html_media_encodings(
+        encodings = _get_html_media_encodings(
             b"""
         <html>
         <head><meta charset="ascii">
@@ -351,7 +351,7 @@ class MediaEncodingTestCase(unittest.TestCase):
         self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
 
         # A less well-formed version.
-        encodings = get_html_media_encodings(
+        encodings = _get_html_media_encodings(
             b"""
         <html>
         <head>< meta charset = ascii>
@@ -364,7 +364,7 @@ class MediaEncodingTestCase(unittest.TestCase):
 
     def test_meta_charset_underscores(self):
         """A character encoding contains underscore."""
-        encodings = get_html_media_encodings(
+        encodings = _get_html_media_encodings(
             b"""
         <html>
         <head><meta charset="Shift_JIS">
@@ -377,7 +377,7 @@ class MediaEncodingTestCase(unittest.TestCase):
 
     def test_xml_encoding(self):
         """A character encoding is found via the meta tag."""
-        encodings = get_html_media_encodings(
+        encodings = _get_html_media_encodings(
             b"""
         <?xml version="1.0" encoding="ascii"?>
         <html>
@@ -389,7 +389,7 @@ class MediaEncodingTestCase(unittest.TestCase):
 
     def test_meta_xml_encoding(self):
         """Meta tags take precedence over XML encoding."""
-        encodings = get_html_media_encodings(
+        encodings = _get_html_media_encodings(
             b"""
         <?xml version="1.0" encoding="ascii"?>
         <html>
@@ -413,17 +413,17 @@ class MediaEncodingTestCase(unittest.TestCase):
             'text/html; charset=ascii";',
         )
         for header in headers:
-            encodings = get_html_media_encodings(b"", header)
+            encodings = _get_html_media_encodings(b"", header)
             self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
 
     def test_fallback(self):
         """A character encoding cannot be found in the body or header."""
-        encodings = get_html_media_encodings(b"", "text/html")
+        encodings = _get_html_media_encodings(b"", "text/html")
         self.assertEqual(list(encodings), ["utf-8", "cp1252"])
 
     def test_duplicates(self):
         """Ensure each encoding is only attempted once."""
-        encodings = get_html_media_encodings(
+        encodings = _get_html_media_encodings(
             b"""
         <?xml version="1.0" encoding="utf8"?>
         <html>
@@ -437,7 +437,7 @@ class MediaEncodingTestCase(unittest.TestCase):
 
     def test_unknown_invalid(self):
         """A character encoding should be ignored if it is unknown or invalid."""
-        encodings = get_html_media_encodings(
+        encodings = _get_html_media_encodings(
             b"""
         <html>
         <head><meta charset="invalid">
diff --git a/tests/unittest.py b/tests/unittest.py
index eea0903f05..1431848367 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -331,16 +331,13 @@ class HomeserverTestCase(TestCase):
             time.sleep(0.01)
 
     def wait_for_background_updates(self) -> None:
-        """Block until all background database updates have completed.
-
-        Note that callers must ensure there's a store property created on the
-        testcase.
-        """
+        """Block until all background database updates have completed."""
+        store = self.hs.get_datastore()
         while not self.get_success(
-            self.store.db_pool.updates.has_completed_background_updates()
+            store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db_pool.updates.do_next_background_update(False), by=0.1
+                store.db_pool.updates.do_next_background_update(False), by=0.1
             )
 
     def make_homeserver(self, reactor, clock):
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 5d9c4665aa..621b0f9fcd 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -152,46 +152,11 @@ class LoggingContextTestCase(unittest.TestCase):
             # now it should be restored
             self._check_test_key("one")
 
-    @defer.inlineCallbacks
-    def test_make_deferred_yieldable_on_non_deferred(self):
-        """Check that make_deferred_yieldable does the right thing when its
-        argument isn't actually a deferred"""
-
-        with LoggingContext("one"):
-            d1 = make_deferred_yieldable("bum")
-            self._check_test_key("one")
-
-            r = yield d1
-            self.assertEqual(r, "bum")
-            self._check_test_key("one")
-
     def test_nested_logging_context(self):
         with LoggingContext("foo"):
             nested_context = nested_logging_context(suffix="bar")
             self.assertEqual(nested_context.name, "foo-bar")
 
-    @defer.inlineCallbacks
-    def test_make_deferred_yieldable_with_await(self):
-        # an async function which returns an incomplete coroutine, but doesn't
-        # follow the synapse rules.
-
-        async def blocking_function():
-            d = defer.Deferred()
-            reactor.callLater(0, d.callback, None)
-            await d
-
-        sentinel_context = current_context()
-
-        with LoggingContext("one"):
-            d1 = make_deferred_yieldable(blocking_function())
-            # make sure that the context was reset by make_deferred_yieldable
-            self.assertIs(current_context(), sentinel_context)
-
-            yield d1
-
-            # now it should be restored
-            self._check_test_key("one")
-
 
 # a function which returns a deferred which has been "called", but
 # which had a function which returned another incomplete deferred on