summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/federation/test_complexity.py30
-rw-r--r--tests/federation/test_federation_catch_up.py158
-rw-r--r--tests/federation/test_federation_sender.py2
-rw-r--r--tests/handlers/test_auth.py20
-rw-r--r--tests/handlers/test_register.py10
-rw-r--r--tests/handlers/test_stats.py15
-rw-r--r--tests/handlers/test_typing.py3
-rw-r--r--tests/http/__init__.py2
-rw-r--r--tests/replication/test_federation_sender_shard.py10
-rw-r--r--tests/rest/admin/test_user.py6
-rw-r--r--tests/rest/client/v1/test_push_rule_attrs.py448
-rw-r--r--tests/rest/client/v2_alpha/test_account.py132
-rw-r--r--tests/rest/media/v1/test_media_storage.py39
-rw-r--r--tests/server.py15
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py14
-rw-r--r--tests/storage/test_client_ips.py2
-rw-r--r--tests/storage/test_id_generators.py50
-rw-r--r--tests/storage/test_monthly_active_users.py16
-rw-r--r--tests/test_utils/__init__.py13
-rw-r--r--tests/test_utils/event_injection.py5
-rw-r--r--tests/unittest.py4
21 files changed, 893 insertions, 101 deletions
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 3d880c499d..1471cc1a28 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -77,11 +77,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(
-            side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
-        )
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
         handler.federation_handler.do_invite_join = Mock(
-            side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+            return_value=make_awaitable(("", 1))
         )
 
         d = handler._remote_join(
@@ -110,11 +108,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(
-            side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
-        )
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
         handler.federation_handler.do_invite_join = Mock(
-            side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+            return_value=make_awaitable(("", 1))
         )
 
         d = handler._remote_join(
@@ -150,11 +146,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(
-            side_effect=lambda *args, **kwargs: make_awaitable(None)
-        )
+        fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
         handler.federation_handler.do_invite_join = Mock(
-            side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+            return_value=make_awaitable(("", 1))
         )
 
         # Artificially raise the complexity
@@ -208,11 +202,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(
-            side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
-        )
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
         handler.federation_handler.do_invite_join = Mock(
-            side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+            return_value=make_awaitable(("", 1))
         )
 
         d = handler._remote_join(
@@ -240,11 +232,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(
-            side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
-        )
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
         handler.federation_handler.do_invite_join = Mock(
-            side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+            return_value=make_awaitable(("", 1))
         )
 
         d = handler._remote_join(
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
new file mode 100644
index 0000000000..6cdcc378f0
--- /dev/null
+++ b/tests/federation/test_federation_catch_up.py
@@ -0,0 +1,158 @@
+from mock import Mock
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.test_utils import event_injection, make_awaitable
+from tests.unittest import FederatingHomeserverTestCase, override_config
+
+
+class FederationCatchUpTestCases(FederatingHomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        return self.setup_test_homeserver(
+            federation_transport_client=Mock(spec=["send_transaction"]),
+        )
+
+    def prepare(self, reactor, clock, hs):
+        # stub out get_current_hosts_in_room
+        state_handler = hs.get_state_handler()
+
+        # This mock is crucial for destination_rooms to be populated.
+        state_handler.get_current_hosts_in_room = Mock(
+            return_value=make_awaitable(["test", "host2"])
+        )
+
+        # whenever send_transaction is called, record the pdu data
+        self.pdus = []
+        self.failed_pdus = []
+        self.is_online = True
+        self.hs.get_federation_transport_client().send_transaction.side_effect = (
+            self.record_transaction
+        )
+
+    async def record_transaction(self, txn, json_cb):
+        if self.is_online:
+            data = json_cb()
+            self.pdus.extend(data["pdus"])
+            return {}
+        else:
+            data = json_cb()
+            self.failed_pdus.extend(data["pdus"])
+            raise IOError("Failed to connect because this is a test!")
+
+    def get_destination_room(self, room: str, destination: str = "host2") -> dict:
+        """
+        Gets the destination_rooms entry for a (destination, room_id) pair.
+
+        Args:
+            room: room ID
+            destination: what destination, default is "host2"
+
+        Returns:
+            Dictionary of { event_id: str, stream_ordering: int }
+        """
+        event_id, stream_ordering = self.get_success(
+            self.hs.get_datastore().db_pool.execute(
+                "test:get_destination_rooms",
+                None,
+                """
+                SELECT event_id, stream_ordering
+                    FROM destination_rooms dr
+                    JOIN events USING (stream_ordering)
+                    WHERE dr.destination = ? AND dr.room_id = ?
+                """,
+                destination,
+                room,
+            )
+        )[0]
+        return {"event_id": event_id, "stream_ordering": stream_ordering}
+
+    @override_config({"send_federation": True})
+    def test_catch_up_destination_rooms_tracking(self):
+        """
+        Tests that we populate the `destination_rooms` table as needed.
+        """
+        self.register_user("u1", "you the one")
+        u1_token = self.login("u1", "you the one")
+        room = self.helper.create_room_as("u1", tok=u1_token)
+
+        self.get_success(
+            event_injection.inject_member_event(self.hs, room, "@user:host2", "join")
+        )
+
+        event_id_1 = self.helper.send(room, "wombats!", tok=u1_token)["event_id"]
+
+        row_1 = self.get_destination_room(room)
+
+        event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"]
+
+        row_2 = self.get_destination_room(room)
+
+        # check: events correctly registered in order
+        self.assertEqual(row_1["event_id"], event_id_1)
+        self.assertEqual(row_2["event_id"], event_id_2)
+        self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1)
+
+    @override_config({"send_federation": True})
+    def test_catch_up_last_successful_stream_ordering_tracking(self):
+        """
+        Tests that we populate the `destination_rooms` table as needed.
+        """
+        self.register_user("u1", "you the one")
+        u1_token = self.login("u1", "you the one")
+        room = self.helper.create_room_as("u1", tok=u1_token)
+
+        # take the remote offline
+        self.is_online = False
+
+        self.get_success(
+            event_injection.inject_member_event(self.hs, room, "@user:host2", "join")
+        )
+
+        self.helper.send(room, "wombats!", tok=u1_token)
+        self.pump()
+
+        lsso_1 = self.get_success(
+            self.hs.get_datastore().get_destination_last_successful_stream_ordering(
+                "host2"
+            )
+        )
+
+        self.assertIsNone(
+            lsso_1,
+            "There should be no last successful stream ordering for an always-offline destination",
+        )
+
+        # bring the remote online
+        self.is_online = True
+
+        event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"]
+
+        lsso_2 = self.get_success(
+            self.hs.get_datastore().get_destination_last_successful_stream_ordering(
+                "host2"
+            )
+        )
+        row_2 = self.get_destination_room(room)
+
+        self.assertEqual(
+            self.pdus[0]["content"]["body"],
+            "rabbits!",
+            "Test fault: didn't receive the right PDU",
+        )
+        self.assertEqual(
+            row_2["event_id"],
+            event_id_2,
+            "Test fault: destination_rooms not updated correctly",
+        )
+        self.assertEqual(
+            lsso_2,
+            row_2["stream_ordering"],
+            "Send succeeded but not marked as last_successful_stream_ordering",
+        )
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 5f512ff8bf..917762e6b6 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -34,7 +34,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
         # Ensure a new Awaitable is created for each call.
-        mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable(
+        mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
             ["test", "host2"]
         )
         return self.setup_test_homeserver(
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index c7efd3822d..97877c2e42 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -143,7 +143,7 @@ class AuthTestCase(unittest.TestCase):
     def test_mau_limits_exceeded_large(self):
         self.auth_blocking._limit_usage_by_mau = True
         self.hs.get_datastore().get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(self.large_number_of_users)
+            return_value=make_awaitable(self.large_number_of_users)
         )
 
         with self.assertRaises(ResourceLimitError):
@@ -154,7 +154,7 @@ class AuthTestCase(unittest.TestCase):
             )
 
         self.hs.get_datastore().get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(self.large_number_of_users)
+            return_value=make_awaitable(self.large_number_of_users)
         )
         with self.assertRaises(ResourceLimitError):
             yield defer.ensureDeferred(
@@ -169,7 +169,7 @@ class AuthTestCase(unittest.TestCase):
 
         # If not in monthly active cohort
         self.hs.get_datastore().get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
+            return_value=make_awaitable(self.auth_blocking._max_mau_value)
         )
         with self.assertRaises(ResourceLimitError):
             yield defer.ensureDeferred(
@@ -179,7 +179,7 @@ class AuthTestCase(unittest.TestCase):
             )
 
         self.hs.get_datastore().get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
+            return_value=make_awaitable(self.auth_blocking._max_mau_value)
         )
         with self.assertRaises(ResourceLimitError):
             yield defer.ensureDeferred(
@@ -189,10 +189,10 @@ class AuthTestCase(unittest.TestCase):
             )
         # If in monthly active cohort
         self.hs.get_datastore().user_last_seen_monthly_active = Mock(
-            side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec())
+            return_value=make_awaitable(self.hs.get_clock().time_msec())
         )
         self.hs.get_datastore().get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
+            return_value=make_awaitable(self.auth_blocking._max_mau_value)
         )
         yield defer.ensureDeferred(
             self.auth_handler.get_access_token_for_user_id(
@@ -200,10 +200,10 @@ class AuthTestCase(unittest.TestCase):
             )
         )
         self.hs.get_datastore().user_last_seen_monthly_active = Mock(
-            side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec())
+            return_value=make_awaitable(self.hs.get_clock().time_msec())
         )
         self.hs.get_datastore().get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
+            return_value=make_awaitable(self.auth_blocking._max_mau_value)
         )
         yield defer.ensureDeferred(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
@@ -216,7 +216,7 @@ class AuthTestCase(unittest.TestCase):
         self.auth_blocking._limit_usage_by_mau = True
 
         self.hs.get_datastore().get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(self.small_number_of_users)
+            return_value=make_awaitable(self.small_number_of_users)
         )
         # Ensure does not raise exception
         yield defer.ensureDeferred(
@@ -226,7 +226,7 @@ class AuthTestCase(unittest.TestCase):
         )
 
         self.hs.get_datastore().get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(self.small_number_of_users)
+            return_value=make_awaitable(self.small_number_of_users)
         )
         yield defer.ensureDeferred(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index eddf5e2498..cb7c0ed51a 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -100,7 +100,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_get_or_create_user_mau_not_blocked(self):
         self.hs.config.limit_usage_by_mau = True
         self.store.count_monthly_users = Mock(
-            side_effect=lambda: make_awaitable(self.hs.config.max_mau_value - 1)
+            return_value=make_awaitable(self.hs.config.max_mau_value - 1)
         )
         # Ensure does not throw exception
         self.get_success(self.get_or_create_user(self.requester, "c", "User"))
@@ -108,7 +108,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_get_or_create_user_mau_blocked(self):
         self.hs.config.limit_usage_by_mau = True
         self.store.get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(self.lots_of_users)
+            return_value=make_awaitable(self.lots_of_users)
         )
         self.get_failure(
             self.get_or_create_user(self.requester, "b", "display_name"),
@@ -116,7 +116,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         )
 
         self.store.get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+            return_value=make_awaitable(self.hs.config.max_mau_value)
         )
         self.get_failure(
             self.get_or_create_user(self.requester, "b", "display_name"),
@@ -126,14 +126,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_register_mau_blocked(self):
         self.hs.config.limit_usage_by_mau = True
         self.store.get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(self.lots_of_users)
+            return_value=make_awaitable(self.lots_of_users)
         )
         self.get_failure(
             self.handler.register_user(localpart="local_part"), ResourceLimitError
         )
 
         self.store.get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+            return_value=make_awaitable(self.hs.config.max_mau_value)
         )
         self.get_failure(
             self.handler.register_user(localpart="local_part"), ResourceLimitError
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index a609f148c0..312c0a0d41 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -54,7 +54,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             self.store.db_pool.simple_insert(
                 "background_updates",
                 {
-                    "update_name": "populate_stats_process_rooms_2",
+                    "update_name": "populate_stats_process_rooms",
                     "progress_json": "{}",
                     "depends_on": "populate_stats_prepare",
                 },
@@ -66,7 +66,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
                 {
                     "update_name": "populate_stats_process_users",
                     "progress_json": "{}",
-                    "depends_on": "populate_stats_process_rooms_2",
+                    "depends_on": "populate_stats_process_rooms",
                 },
             )
         )
@@ -219,10 +219,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         self.get_success(
             self.store.db_pool.simple_insert(
                 "background_updates",
-                {
-                    "update_name": "populate_stats_process_rooms_2",
-                    "progress_json": "{}",
-                },
+                {"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
             )
         )
         self.get_success(
@@ -231,7 +228,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
                 {
                     "update_name": "populate_stats_cleanup",
                     "progress_json": "{}",
-                    "depends_on": "populate_stats_process_rooms_2",
+                    "depends_on": "populate_stats_process_rooms",
                 },
             )
         )
@@ -728,7 +725,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             self.store.db_pool.simple_insert(
                 "background_updates",
                 {
-                    "update_name": "populate_stats_process_rooms_2",
+                    "update_name": "populate_stats_process_rooms",
                     "progress_json": "{}",
                     "depends_on": "populate_stats_prepare",
                 },
@@ -740,7 +737,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
                 {
                     "update_name": "populate_stats_process_users",
                     "progress_json": "{}",
-                    "depends_on": "populate_stats_process_rooms_2",
+                    "depends_on": "populate_stats_process_rooms",
                 },
             )
         )
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 7bf15c4ba9..f306a09bfa 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -80,6 +80,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
                 "get_user_directory_stream_pos",
                 "get_current_state_deltas",
                 "get_device_updates_by_remote",
+                "get_room_max_stream_ordering",
             ]
         )
 
@@ -116,7 +117,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             retry_timings_res
         )
 
-        self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable(
+        self.datastore.get_device_updates_by_remote.return_value = make_awaitable(
             (0, [])
         )
 
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 5d41443293..3e5a856584 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -145,7 +145,7 @@ class TestServerTLSConnectionFactory:
         self._cert_file = create_test_cert_file(sanlist)
 
     def serverConnectionForTLS(self, tlsProtocol):
-        ctx = SSL.Context(SSL.TLSv1_METHOD)
+        ctx = SSL.Context(SSL.SSLv23_METHOD)
         ctx.use_certificate_file(self._cert_file)
         ctx.use_privatekey_file(get_test_key_file())
         return Connection(ctx, None)
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 8b4982ecb1..1d7edee5ba 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -45,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         new event.
         """
         mock_client = Mock(spec=["put_json"])
-        mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({})
+        mock_client.put_json.return_value = make_awaitable({})
 
         self.make_worker_hs(
             "synapse.app.federation_sender",
@@ -73,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         new events.
         """
         mock_client1 = Mock(spec=["put_json"])
-        mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
+        mock_client1.put_json.return_value = make_awaitable({})
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {
@@ -85,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         mock_client2 = Mock(spec=["put_json"])
-        mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
+        mock_client2.put_json.return_value = make_awaitable({})
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {
@@ -136,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         new typing EDUs.
         """
         mock_client1 = Mock(spec=["put_json"])
-        mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
+        mock_client1.put_json.return_value = make_awaitable({})
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {
@@ -148,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         mock_client2 = Mock(spec=["put_json"])
-        mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
+        mock_client2.put_json.return_value = make_awaitable({})
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 160c630235..b8b7758d24 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -337,7 +337,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
 
         # Set monthly active users to the limit
         store.get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+            return_value=make_awaitable(self.hs.config.max_mau_value)
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
@@ -591,7 +591,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # Set monthly active users to the limit
         self.store.get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+            return_value=make_awaitable(self.hs.config.max_mau_value)
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
@@ -631,7 +631,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # Set monthly active users to the limit
         self.store.get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+            return_value=make_awaitable(self.hs.config.max_mau_value)
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py
new file mode 100644
index 0000000000..081052f6a6
--- /dev/null
+++ b/tests/rest/client/v1/test_push_rule_attrs.py
@@ -0,0 +1,448 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import synapse
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login, push_rule, room
+
+from tests.unittest import HomeserverTestCase
+
+
+class PushRuleAttributesTestCase(HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+        push_rule.register_servlets,
+    ]
+    hijack_auth = False
+
+    def test_enabled_on_creation(self):
+        """
+        Tests the GET and PUT of push rules' `enabled` endpoints.
+        Tests that a rule is enabled upon creation, even though a rule with that
+            ruleId existed previously and was disabled.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        body = {
+            "conditions": [
+                {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+            ],
+            "actions": ["notify", {"set_tweak": "highlight"}],
+        }
+
+        # PUT a new rule
+        request, channel = self.make_request(
+            "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # GET enabled for that new rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["enabled"], True)
+
+    def test_enabled_on_recreation(self):
+        """
+        Tests the GET and PUT of push rules' `enabled` endpoints.
+        Tests that a rule is enabled upon creation, even if a rule with that
+            ruleId existed previously and was disabled.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        body = {
+            "conditions": [
+                {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+            ],
+            "actions": ["notify", {"set_tweak": "highlight"}],
+        }
+
+        # PUT a new rule
+        request, channel = self.make_request(
+            "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # disable the rule
+        request, channel = self.make_request(
+            "PUT",
+            "/pushrules/global/override/best.friend/enabled",
+            {"enabled": False},
+            access_token=token,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # check rule disabled
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["enabled"], False)
+
+        # DELETE the rule
+        request, channel = self.make_request(
+            "DELETE", "/pushrules/global/override/best.friend", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # PUT a new rule
+        request, channel = self.make_request(
+            "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # GET enabled for that new rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["enabled"], True)
+
+    def test_enabled_disable(self):
+        """
+        Tests the GET and PUT of push rules' `enabled` endpoints.
+        Tests that a rule is disabled and enabled when we ask for it.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        body = {
+            "conditions": [
+                {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+            ],
+            "actions": ["notify", {"set_tweak": "highlight"}],
+        }
+
+        # PUT a new rule
+        request, channel = self.make_request(
+            "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # disable the rule
+        request, channel = self.make_request(
+            "PUT",
+            "/pushrules/global/override/best.friend/enabled",
+            {"enabled": False},
+            access_token=token,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # check rule disabled
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["enabled"], False)
+
+        # re-enable the rule
+        request, channel = self.make_request(
+            "PUT",
+            "/pushrules/global/override/best.friend/enabled",
+            {"enabled": True},
+            access_token=token,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # check rule enabled
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["enabled"], True)
+
+    def test_enabled_404_when_get_non_existent(self):
+        """
+        Tests that `enabled` gives 404 when the rule doesn't exist.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        body = {
+            "conditions": [
+                {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+            ],
+            "actions": ["notify", {"set_tweak": "highlight"}],
+        }
+
+        # check 404 for never-heard-of rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+        # PUT a new rule
+        request, channel = self.make_request(
+            "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # GET enabled for that new rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # DELETE the rule
+        request, channel = self.make_request(
+            "DELETE", "/pushrules/global/override/best.friend", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # check 404 for deleted rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_enabled_404_when_get_non_existent_server_rule(self):
+        """
+        Tests that `enabled` gives 404 when the server-default rule doesn't exist.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        # check 404 for never-heard-of rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_enabled_404_when_put_non_existent_rule(self):
+        """
+        Tests that `enabled` gives 404 when we put to a rule that doesn't exist.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        # enable & check 404 for never-heard-of rule
+        request, channel = self.make_request(
+            "PUT",
+            "/pushrules/global/override/best.friend/enabled",
+            {"enabled": True},
+            access_token=token,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_enabled_404_when_put_non_existent_server_rule(self):
+        """
+        Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        # enable & check 404 for never-heard-of rule
+        request, channel = self.make_request(
+            "PUT",
+            "/pushrules/global/override/.m.muahahah/enabled",
+            {"enabled": True},
+            access_token=token,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_actions_get(self):
+        """
+        Tests that `actions` gives you what you expect on a fresh rule.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        body = {
+            "conditions": [
+                {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+            ],
+            "actions": ["notify", {"set_tweak": "highlight"}],
+        }
+
+        # PUT a new rule
+        request, channel = self.make_request(
+            "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # GET actions for that new rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/actions", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}]
+        )
+
+    def test_actions_put(self):
+        """
+        Tests that PUT on actions updates the value you'd get from GET.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        body = {
+            "conditions": [
+                {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+            ],
+            "actions": ["notify", {"set_tweak": "highlight"}],
+        }
+
+        # PUT a new rule
+        request, channel = self.make_request(
+            "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # change the rule actions
+        request, channel = self.make_request(
+            "PUT",
+            "/pushrules/global/override/best.friend/actions",
+            {"actions": ["dont_notify"]},
+            access_token=token,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # GET actions for that new rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/actions", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["actions"], ["dont_notify"])
+
+    def test_actions_404_when_get_non_existent(self):
+        """
+        Tests that `actions` gives 404 when the rule doesn't exist.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        body = {
+            "conditions": [
+                {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"}
+            ],
+            "actions": ["notify", {"set_tweak": "highlight"}],
+        }
+
+        # check 404 for never-heard-of rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+        # PUT a new rule
+        request, channel = self.make_request(
+            "PUT", "/pushrules/global/override/best.friend", body, access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # DELETE the rule
+        request, channel = self.make_request(
+            "DELETE", "/pushrules/global/override/best.friend", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200)
+
+        # check 404 for deleted rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_actions_404_when_get_non_existent_server_rule(self):
+        """
+        Tests that `actions` gives 404 when the server-default rule doesn't exist.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        # check 404 for never-heard-of rule
+        request, channel = self.make_request(
+            "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_actions_404_when_put_non_existent_rule(self):
+        """
+        Tests that `actions` gives 404 when putting to a rule that doesn't exist.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        # enable & check 404 for never-heard-of rule
+        request, channel = self.make_request(
+            "PUT",
+            "/pushrules/global/override/best.friend/actions",
+            {"actions": ["dont_notify"]},
+            access_token=token,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_actions_404_when_put_non_existent_server_rule(self):
+        """
+        Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist.
+        """
+        self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        # enable & check 404 for never-heard-of rule
+        request, channel = self.make_request(
+            "PUT",
+            "/pushrules/global/override/.m.muahahah/actions",
+            {"actions": ["dont_notify"]},
+            access_token=token,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 152a5182fa..93f899d861 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -14,11 +14,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import json
 import os
 import re
 from email.parser import Parser
+from typing import Optional
+from urllib.parse import urlencode
 
 import pkg_resources
 
@@ -27,8 +28,10 @@ from synapse.api.constants import LoginType, Membership
 from synapse.api.errors import Codes
 from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import account, register
+from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
 
 from tests import unittest
+from tests.unittest import override_config
 
 
 class PasswordResetTestCase(unittest.HomeserverTestCase):
@@ -69,6 +72,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
+        self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
 
     def test_basic_password_reset(self):
         """Test basic password reset flow
@@ -250,8 +254,32 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         # Remove the host
         path = link.replace("https://example.com", "")
 
+        # Load the password reset confirmation page
         request, channel = self.make_request("GET", path, shorthand=False)
-        self.render(request)
+        request.render(self.submit_token_resource)
+        self.pump()
+        self.assertEquals(200, channel.code, channel.result)
+
+        # Now POST to the same endpoint, mimicking the same behaviour as clicking the
+        # password reset confirm button
+
+        # Send arguments as url-encoded form data, matching the template's behaviour
+        form_args = []
+        for key, value_list in request.args.items():
+            for value in value_list:
+                arg = (key, value)
+                form_args.append(arg)
+
+        # Confirm the password reset
+        request, channel = self.make_request(
+            "POST",
+            path,
+            content=urlencode(form_args).encode("utf8"),
+            shorthand=False,
+            content_is_form=True,
+        )
+        request.render(self.submit_token_resource)
+        self.pump()
         self.assertEquals(200, channel.code, channel.result)
 
     def _get_link_from_email(self):
@@ -668,16 +696,104 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertFalse(channel.json_body["threepids"])
 
-    def _request_token(self, email, client_secret):
+    @override_config({"next_link_domain_whitelist": None})
+    def test_next_link(self):
+        """Tests a valid next_link parameter value with no whitelist (good case)"""
+        self._request_token(
+            "something@example.com",
+            "some_secret",
+            next_link="https://example.com/a/good/site",
+            expect_code=200,
+        )
+
+    @override_config({"next_link_domain_whitelist": None})
+    def test_next_link_exotic_protocol(self):
+        """Tests using a esoteric protocol as a next_link parameter value.
+        Someone may be hosting a client on IPFS etc.
+        """
+        self._request_token(
+            "something@example.com",
+            "some_secret",
+            next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
+            expect_code=200,
+        )
+
+    @override_config({"next_link_domain_whitelist": None})
+    def test_next_link_file_uri(self):
+        """Tests next_link parameters cannot be file URI"""
+        # Attempt to use a next_link value that points to the local disk
+        self._request_token(
+            "something@example.com",
+            "some_secret",
+            next_link="file:///host/path",
+            expect_code=400,
+        )
+
+    @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
+    def test_next_link_domain_whitelist(self):
+        """Tests next_link parameters must fit the whitelist if provided"""
+        self._request_token(
+            "something@example.com",
+            "some_secret",
+            next_link="https://example.com/some/good/page",
+            expect_code=200,
+        )
+
+        self._request_token(
+            "something@example.com",
+            "some_secret",
+            next_link="https://example.org/some/also/good/page",
+            expect_code=200,
+        )
+
+        self._request_token(
+            "something@example.com",
+            "some_secret",
+            next_link="https://bad.example.org/some/bad/page",
+            expect_code=400,
+        )
+
+    @override_config({"next_link_domain_whitelist": []})
+    def test_empty_next_link_domain_whitelist(self):
+        """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially
+        disallowed
+        """
+        self._request_token(
+            "something@example.com",
+            "some_secret",
+            next_link="https://example.com/a/page",
+            expect_code=400,
+        )
+
+    def _request_token(
+        self,
+        email: str,
+        client_secret: str,
+        next_link: Optional[str] = None,
+        expect_code: int = 200,
+    ) -> str:
+        """Request a validation token to add an email address to a user's account
+
+        Args:
+            email: The email address to validate
+            client_secret: A secret string
+            next_link: A link to redirect the user to after validation
+            expect_code: Expected return code of the call
+
+        Returns:
+            The ID of the new threepid validation session
+        """
+        body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
+        if next_link:
+            body["next_link"] = next_link
+
         request, channel = self.make_request(
-            "POST",
-            b"account/3pid/email/requestToken",
-            {"client_secret": client_secret, "email": email, "send_attempt": 1},
+            "POST", b"account/3pid/email/requestToken", body,
         )
         self.render(request)
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEquals(expect_code, channel.code, channel.result)
 
-        return channel.json_body["sid"]
+        return channel.json_body.get("sid")
 
     def _request_token_invalid_email(
         self, email, expected_errcode, expected_error, client_secret="foobar",
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index f4f3e56777..5f897d49cf 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -120,12 +120,13 @@ class _TestImage:
     extension = attr.ib(type=bytes)
     expected_cropped = attr.ib(type=Optional[bytes])
     expected_scaled = attr.ib(type=Optional[bytes])
+    expected_found = attr.ib(default=True, type=bool)
 
 
 @parameterized_class(
     ("test_image",),
     [
-        # smol png
+        # smoll png
         (
             _TestImage(
                 unhexlify(
@@ -161,6 +162,8 @@ class _TestImage:
                 None,
             ),
         ),
+        # an empty file
+        (_TestImage(b"", b"image/gif", b".gif", None, None, False,),),
     ],
 )
 class MediaRepoTests(unittest.HomeserverTestCase):
@@ -303,12 +306,16 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
 
     def test_thumbnail_crop(self):
-        self._test_thumbnail("crop", self.test_image.expected_cropped)
+        self._test_thumbnail(
+            "crop", self.test_image.expected_cropped, self.test_image.expected_found
+        )
 
     def test_thumbnail_scale(self):
-        self._test_thumbnail("scale", self.test_image.expected_scaled)
+        self._test_thumbnail(
+            "scale", self.test_image.expected_scaled, self.test_image.expected_found
+        )
 
-    def _test_thumbnail(self, method, expected_body):
+    def _test_thumbnail(self, method, expected_body, expected_found):
         params = "?width=32&height=32&method=" + method
         request, channel = self.make_request(
             "GET", self.media_id + params, shorthand=False
@@ -325,11 +332,23 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         )
         self.pump()
 
-        self.assertEqual(channel.code, 200)
-        if expected_body is not None:
+        if expected_found:
+            self.assertEqual(channel.code, 200)
+            if expected_body is not None:
+                self.assertEqual(
+                    channel.result["body"], expected_body, channel.result["body"]
+                )
+            else:
+                # ensure that the result is at least some valid image
+                Image.open(BytesIO(channel.result["body"]))
+        else:
+            # A 404 with a JSON body.
+            self.assertEqual(channel.code, 404)
             self.assertEqual(
-                channel.result["body"], expected_body, channel.result["body"]
+                channel.json_body,
+                {
+                    "errcode": "M_NOT_FOUND",
+                    "error": "Not found [b'example.com', b'12345?width=32&height=32&method=%s']"
+                    % method,
+                },
             )
-        else:
-            # ensure that the result is at least some valid image
-            Image.open(BytesIO(channel.result["body"]))
diff --git a/tests/server.py b/tests/server.py
index 48e45c6c8b..61ec670155 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,6 +1,6 @@
 import json
 import logging
-from io import BytesIO
+from io import SEEK_END, BytesIO
 
 import attr
 from zope.interface import implementer
@@ -135,6 +135,7 @@ def make_request(
     request=SynapseRequest,
     shorthand=True,
     federation_auth_origin=None,
+    content_is_form=False,
 ):
     """
     Make a web request using the given method and path, feed it the
@@ -150,6 +151,8 @@ def make_request(
         with the usual REST API path, if it doesn't contain it.
         federation_auth_origin (bytes|None): if set to not-None, we will add a fake
             Authorization header pretenting to be the given server name.
+        content_is_form: Whether the content is URL encoded form data. Adds the
+            'Content-Type': 'application/x-www-form-urlencoded' header.
 
     Returns:
         Tuple[synapse.http.site.SynapseRequest, channel]
@@ -181,6 +184,8 @@ def make_request(
     req = request(channel)
     req.process = lambda: b""
     req.content = BytesIO(content)
+    # Twisted expects to be at the end of the content when parsing the request.
+    req.content.seek(SEEK_END)
     req.postpath = list(map(unquote, path[1:].split(b"/")))
 
     if access_token:
@@ -195,7 +200,13 @@ def make_request(
         )
 
     if content:
-        req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
+        if content_is_form:
+            req.requestHeaders.addRawHeader(
+                b"Content-Type", b"application/x-www-form-urlencoded"
+            )
+        else:
+            # Assume the body is JSON
+            req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
 
     req.requestReceived(method, path, b"1.1")
 
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 973338ea71..6382b19dc3 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -67,7 +67,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
             raise Exception("Failed to find reference to ResourceLimitsServerNotices")
 
         self._rlsn._store.user_last_seen_monthly_active = Mock(
-            side_effect=lambda user_id: make_awaitable(1000)
+            return_value=make_awaitable(1000)
         )
         self._rlsn._server_notices_manager.send_notice = Mock(
             return_value=defer.succeed(Mock())
@@ -80,9 +80,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
             return_value=defer.succeed("!something:localhost")
         )
         self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
-        self._rlsn._store.get_tags_for_room = Mock(
-            side_effect=lambda user_id, room_id: make_awaitable({})
-        )
+        self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))
 
     @override_config({"hs_disabled": True})
     def test_maybe_send_server_notice_disabled_hs(self):
@@ -158,7 +156,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
         self._rlsn._store.user_last_seen_monthly_active = Mock(
-            side_effect=lambda user_id: make_awaitable(None)
+            return_value=make_awaitable(None)
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
@@ -261,12 +259,10 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
         self.user_id = "@user_id:test"
 
     def test_server_notice_only_sent_once(self):
-        self.store.get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(1000)
-        )
+        self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000))
 
         self.store.user_last_seen_monthly_active = Mock(
-            side_effect=lambda user_id: make_awaitable(1000)
+            return_value=make_awaitable(1000)
         )
 
         # Call the function multiple times to ensure we only send the notice once
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 370c247e16..755c70db31 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -154,7 +154,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         user_id = "@user:server"
 
         self.store.get_monthly_active_count = Mock(
-            side_effect=lambda: make_awaitable(lots_of_users)
+            return_value=make_awaitable(lots_of_users)
         )
         self.get_success(
             self.store.insert_client_ip(
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index f0a8e32f1e..20636fc400 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -122,6 +122,56 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen.get_positions(), {"master": 8})
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
 
+    def test_out_of_order_finish(self):
+        """Test that IDs persisted out of order are correctly handled
+        """
+
+        # Prefill table with 7 rows written by 'master'
+        self._insert_rows("master", 7)
+
+        id_gen = self._create_id_generator()
+
+        self.assertEqual(id_gen.get_positions(), {"master": 7})
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+        ctx1 = self.get_success(id_gen.get_next())
+        ctx2 = self.get_success(id_gen.get_next())
+        ctx3 = self.get_success(id_gen.get_next())
+        ctx4 = self.get_success(id_gen.get_next())
+
+        s1 = ctx1.__enter__()
+        s2 = ctx2.__enter__()
+        s3 = ctx3.__enter__()
+        s4 = ctx4.__enter__()
+
+        self.assertEqual(s1, 8)
+        self.assertEqual(s2, 9)
+        self.assertEqual(s3, 10)
+        self.assertEqual(s4, 11)
+
+        self.assertEqual(id_gen.get_positions(), {"master": 7})
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+        ctx2.__exit__(None, None, None)
+
+        self.assertEqual(id_gen.get_positions(), {"master": 7})
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+        ctx1.__exit__(None, None, None)
+
+        self.assertEqual(id_gen.get_positions(), {"master": 9})
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
+
+        ctx4.__exit__(None, None, None)
+
+        self.assertEqual(id_gen.get_positions(), {"master": 9})
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
+
+        ctx3.__exit__(None, None, None)
+
+        self.assertEqual(id_gen.get_positions(), {"master": 11})
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
+
     def test_multi_instance(self):
         """Test that reads and writes from multiple processes are handled
         correctly.
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 9870c74883..643072bbaf 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -231,9 +231,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         )
         self.get_success(d)
 
-        self.store.upsert_monthly_active_user = Mock(
-            side_effect=lambda user_id: make_awaitable(None)
-        )
+        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
 
         d = self.store.populate_monthly_active_users(user_id)
         self.get_success(d)
@@ -241,9 +239,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.upsert_monthly_active_user.assert_not_called()
 
     def test_populate_monthly_users_should_update(self):
-        self.store.upsert_monthly_active_user = Mock(
-            side_effect=lambda user_id: make_awaitable(None)
-        )
+        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
 
         self.store.is_trial_user = Mock(return_value=defer.succeed(False))
 
@@ -256,9 +252,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.upsert_monthly_active_user.assert_called_once()
 
     def test_populate_monthly_users_should_not_update(self):
-        self.store.upsert_monthly_active_user = Mock(
-            side_effect=lambda user_id: make_awaitable(None)
-        )
+        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
 
         self.store.is_trial_user = Mock(return_value=defer.succeed(False))
         self.store.user_last_seen_monthly_active = Mock(
@@ -344,9 +338,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
     def test_no_users_when_not_tracking(self):
-        self.store.upsert_monthly_active_user = Mock(
-            side_effect=lambda user_id: make_awaitable(None)
-        )
+        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))
 
         self.get_success(self.store.populate_monthly_active_users("@user:sever"))
 
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 508aeba078..a298cc0fd3 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,6 +17,7 @@
 """
 Utilities for running the unit tests
 """
+from asyncio import Future
 from typing import Any, Awaitable, TypeVar
 
 TV = TypeVar("TV")
@@ -38,6 +39,12 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
     raise Exception("awaitable has not yet completed")
 
 
-async def make_awaitable(result: Any):
-    """Create an awaitable that just returns a result."""
-    return result
+def make_awaitable(result: Any) -> Awaitable[Any]:
+    """
+    Makes an awaitable, suitable for mocking an `async` function.
+    This uses Futures as they can be awaited multiple times so can be returned
+    to multiple callers.
+    """
+    future = Future()  # type: ignore
+    future.set_result(result)
+    return future
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index fb1ca90336..e93aa84405 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -71,7 +71,10 @@ async def inject_event(
     """
     event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
 
-    await hs.get_storage().persistence.persist_event(event, context)
+    persistence = hs.get_storage().persistence
+    assert persistence is not None
+
+    await persistence.persist_event(event, context)
 
     return event
 
diff --git a/tests/unittest.py b/tests/unittest.py
index 3cb55a7e96..128dd4e19c 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -353,6 +353,7 @@ class HomeserverTestCase(TestCase):
         request: Type[T] = SynapseRequest,
         shorthand: bool = True,
         federation_auth_origin: str = None,
+        content_is_form: bool = False,
     ) -> Tuple[T, FakeChannel]:
         """
         Create a SynapseRequest at the path using the method and containing the
@@ -368,6 +369,8 @@ class HomeserverTestCase(TestCase):
             with the usual REST API path, if it doesn't contain it.
             federation_auth_origin (bytes|None): if set to not-None, we will add a fake
                 Authorization header pretenting to be the given server name.
+            content_is_form: Whether the content is URL encoded form data. Adds the
+                'Content-Type': 'application/x-www-form-urlencoded' header.
 
         Returns:
             Tuple[synapse.http.site.SynapseRequest, channel]
@@ -384,6 +387,7 @@ class HomeserverTestCase(TestCase):
             request,
             shorthand,
             federation_auth_origin,
+            content_is_form,
         )
 
     def render(self, request):