summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py97
-rw-r--r--tests/api/test_errors.py43
-rw-r--r--tests/api/test_ratelimiting.py67
-rw-r--r--tests/appservice/test_api.py6
-rw-r--r--tests/appservice/test_appservice.py31
-rw-r--r--tests/appservice/test_scheduler.py43
-rw-r--r--tests/config/test_ratelimiting.py31
-rw-r--r--tests/crypto/test_keyring.py5
-rw-r--r--tests/events/test_presence_router.py5
-rw-r--r--tests/federation/test_complexity.py37
-rw-r--r--tests/federation/test_federation_catch_up.py10
-rw-r--r--tests/federation/test_federation_sender.py71
-rw-r--r--tests/federation/transport/test_knocking.py4
-rw-r--r--tests/handlers/test_appservice.py129
-rw-r--r--tests/handlers/test_auth.py27
-rw-r--r--tests/handlers/test_cas.py11
-rw-r--r--tests/handlers/test_device.py37
-rw-r--r--tests/handlers/test_directory.py12
-rw-r--r--tests/handlers/test_e2e_keys.py130
-rw-r--r--tests/handlers/test_federation.py56
-rw-r--r--tests/handlers/test_federation_event.py109
-rw-r--r--tests/handlers/test_message.py15
-rw-r--r--tests/handlers/test_oauth_delegation.py40
-rw-r--r--tests/handlers/test_oidc.py10
-rw-r--r--tests/handlers/test_password_providers.py51
-rw-r--r--tests/handlers/test_presence.py133
-rw-r--r--tests/handlers/test_profile.py9
-rw-r--r--tests/handlers/test_register.py33
-rw-r--r--tests/handlers/test_room_member.py33
-rw-r--r--tests/handlers/test_saml.py17
-rw-r--r--tests/handlers/test_send_email.py69
-rw-r--r--tests/handlers/test_sync.py7
-rw-r--r--tests/handlers/test_typing.py65
-rw-r--r--tests/handlers/test_user_directory.py6
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py60
-rw-r--r--tests/logging/test_terse_json.py2
-rw-r--r--tests/module_api/test_api.py13
-rw-r--r--tests/push/test_bulk_push_rule_evaluator.py6
-rw-r--r--tests/replication/storage/test_events.py2
-rw-r--r--tests/replication/test_federation_sender_shard.py13
-rw-r--r--tests/rest/admin/test_user.py80
-rw-r--r--tests/rest/admin/test_username_available.py2
-rw-r--r--tests/rest/client/test_account.py2
-rw-r--r--tests/rest/client/test_account_data.py5
-rw-r--r--tests/rest/client/test_events.py2
-rw-r--r--tests/rest/client/test_filter.py4
-rw-r--r--tests/rest/client/test_login.py29
-rw-r--r--tests/rest/client/test_notifications.py5
-rw-r--r--tests/rest/client/test_presence.py5
-rw-r--r--tests/rest/client/test_register.py8
-rw-r--r--tests/rest/client/test_relations.py49
-rw-r--r--tests/rest/client/test_rooms.py45
-rw-r--r--tests/rest/client/test_shadow_banned.py2
-rw-r--r--tests/rest/client/test_third_party_rules.py37
-rw-r--r--tests/rest/client/test_transactions.py9
-rw-r--r--tests/server.py56
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py83
-rw-r--r--tests/storage/databases/main/test_lock.py2
-rw-r--r--tests/storage/test_appservice.py5
-rw-r--r--tests/storage/test_background_update.py37
-rw-r--r--tests/storage/test_client_ips.py37
-rw-r--r--tests/storage/test_devices.py18
-rw-r--r--tests/storage/test_end_to_end_keys.py10
-rw-r--r--tests/storage/test_monthly_active_users.py23
-rw-r--r--tests/storage/test_registration.py1
-rw-r--r--tests/storage/test_room.py12
-rw-r--r--tests/storage/util/test_partial_state_events_tracker.py8
-rw-r--r--tests/test_federation.py44
-rw-r--r--tests/test_state.py4
-rw-r--r--tests/test_terms_auth.py4
-rw-r--r--tests/test_utils/__init__.py31
-rw-r--r--tests/unittest.py8
-rw-r--r--tests/util/test_async_helpers.py14
-rw-r--r--tests/util/test_task_scheduler.py58
74 files changed, 1279 insertions, 1005 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index ce96574915..dcd01d5688 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 import pymacaroons
 
@@ -35,7 +35,6 @@ from synapse.types import Requester, UserID
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import simple_async_mock
 from tests.unittest import override_config
 from tests.utils import mock_getRawHeaders
 
@@ -60,16 +59,16 @@ class AuthTestCase(unittest.HomeserverTestCase):
         # this is overridden for the appservice tests
         self.store.get_app_service_by_token = Mock(return_value=None)
 
-        self.store.insert_client_ip = simple_async_mock(None)
-        self.store.is_support_user = simple_async_mock(False)
+        self.store.insert_client_ip = AsyncMock(return_value=None)
+        self.store.is_support_user = AsyncMock(return_value=False)
 
     def test_get_user_by_req_user_valid_token(self) -> None:
         user_info = TokenLookupResult(
             user_id=self.test_user, token_id=5, device_id="device"
         )
-        self.store.get_user_by_access_token = simple_async_mock(user_info)
-        self.store.mark_access_token_as_used = simple_async_mock(None)
-        self.store.get_user_locked_status = simple_async_mock(False)
+        self.store.get_user_by_access_token = AsyncMock(return_value=user_info)
+        self.store.mark_access_token_as_used = AsyncMock(return_value=None)
+        self.store.get_user_locked_status = AsyncMock(return_value=False)
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
@@ -78,7 +77,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.assertEqual(requester.user.to_string(), self.test_user)
 
     def test_get_user_by_req_user_bad_token(self) -> None:
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
@@ -91,7 +90,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
     def test_get_user_by_req_user_missing_token(self) -> None:
         user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
-        self.store.get_user_by_access_token = simple_async_mock(user_info)
+        self.store.get_user_by_access_token = AsyncMock(return_value=user_info)
 
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@@ -106,7 +105,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
@@ -125,7 +124,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             ip_range_whitelist=IPSet(["192.168/16"]),
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "192.168.10.10"
@@ -144,7 +143,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             ip_range_whitelist=IPSet(["192.168/16"]),
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "131.111.8.42"
@@ -158,7 +157,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
     def test_get_user_by_req_appservice_bad_token(self) -> None:
         self.store.get_app_service_by_token = Mock(return_value=None)
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
@@ -172,7 +171,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
     def test_get_user_by_req_appservice_missing_token(self) -> None:
         app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@@ -190,8 +189,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
         app_service.is_interested_in_user = Mock(return_value=True)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
         # This just needs to return a truth-y value.
-        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
@@ -210,7 +209,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
         app_service.is_interested_in_user = Mock(return_value=False)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
@@ -234,10 +233,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
         app_service.is_interested_in_user = Mock(return_value=True)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
         # This just needs to return a truth-y value.
-        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
         # This also needs to just return a truth-y value
-        self.store.get_device = simple_async_mock({"hidden": False})
+        self.store.get_device = AsyncMock(return_value={"hidden": False})
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
@@ -266,10 +265,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
         app_service.is_interested_in_user = Mock(return_value=True)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
         # This just needs to return a truth-y value.
-        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
         # This also needs to just return a falsey value
-        self.store.get_device = simple_async_mock(None)
+        self.store.get_device = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
@@ -283,8 +282,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE)
 
     def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None:
-        self.store.get_user_by_access_token = simple_async_mock(
-            TokenLookupResult(
+        self.store.get_user_by_access_token = AsyncMock(
+            return_value=TokenLookupResult(
                 user_id="@baldrick:matrix.org",
                 device_id="device",
                 token_id=5,
@@ -292,9 +291,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
                 token_used=True,
             )
         )
-        self.store.insert_client_ip = simple_async_mock(None)
-        self.store.mark_access_token_as_used = simple_async_mock(None)
-        self.store.get_user_locked_status = simple_async_mock(False)
+        self.store.insert_client_ip = AsyncMock(return_value=None)
+        self.store.mark_access_token_as_used = AsyncMock(return_value=None)
+        self.store.get_user_locked_status = AsyncMock(return_value=False)
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
@@ -304,8 +303,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
     def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None:
         self.auth._track_puppeted_user_ips = True
-        self.store.get_user_by_access_token = simple_async_mock(
-            TokenLookupResult(
+        self.store.get_user_by_access_token = AsyncMock(
+            return_value=TokenLookupResult(
                 user_id="@baldrick:matrix.org",
                 device_id="device",
                 token_id=5,
@@ -313,9 +312,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
                 token_used=True,
             )
         )
-        self.store.get_user_locked_status = simple_async_mock(False)
-        self.store.insert_client_ip = simple_async_mock(None)
-        self.store.mark_access_token_as_used = simple_async_mock(None)
+        self.store.get_user_locked_status = AsyncMock(return_value=False)
+        self.store.insert_client_ip = AsyncMock(return_value=None)
+        self.store.mark_access_token_as_used = AsyncMock(return_value=None)
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
@@ -324,7 +323,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.assertEqual(self.store.insert_client_ip.call_count, 2)
 
     def test_get_user_from_macaroon(self) -> None:
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         user_id = "@baldrick:matrix.org"
         macaroon = pymacaroons.Macaroon(
@@ -342,8 +341,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
 
     def test_get_guest_user_from_macaroon(self) -> None:
-        self.store.get_user_by_id = simple_async_mock({"is_guest": True})
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True})
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         user_id = "@baldrick:matrix.org"
         macaroon = pymacaroons.Macaroon(
@@ -373,7 +372,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
         self.auth_blocking._limit_usage_by_mau = True
 
-        self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
+        self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users)
 
         e = self.get_failure(
             self.auth_blocking.check_auth_blocking(), ResourceLimitError
@@ -383,25 +382,27 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.assertEqual(e.value.code, 403)
 
         # Ensure does not throw an error
-        self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
+        self.store.get_monthly_active_count = AsyncMock(
+            return_value=small_number_of_users
+        )
         self.get_success(self.auth_blocking.check_auth_blocking())
 
     def test_blocking_mau__depending_on_user_type(self) -> None:
         self.auth_blocking._max_mau_value = 50
         self.auth_blocking._limit_usage_by_mau = True
 
-        self.store.get_monthly_active_count = simple_async_mock(100)
+        self.store.get_monthly_active_count = AsyncMock(return_value=100)
         # Support users allowed
         self.get_success(
             self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT)
         )
-        self.store.get_monthly_active_count = simple_async_mock(100)
+        self.store.get_monthly_active_count = AsyncMock(return_value=100)
         # Bots not allowed
         self.get_failure(
             self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT),
             ResourceLimitError,
         )
-        self.store.get_monthly_active_count = simple_async_mock(100)
+        self.store.get_monthly_active_count = AsyncMock(return_value=100)
         # Real users not allowed
         self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
 
@@ -412,9 +413,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.auth_blocking._limit_usage_by_mau = True
         self.auth_blocking._track_appservice_user_ips = False
 
-        self.store.get_monthly_active_count = simple_async_mock(100)
-        self.store.user_last_seen_monthly_active = simple_async_mock()
-        self.store.is_trial_user = simple_async_mock()
+        self.store.get_monthly_active_count = AsyncMock(return_value=100)
+        self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
+        self.store.is_trial_user = AsyncMock(return_value=False)
 
         appservice = ApplicationService(
             "abcd",
@@ -443,9 +444,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.auth_blocking._limit_usage_by_mau = True
         self.auth_blocking._track_appservice_user_ips = True
 
-        self.store.get_monthly_active_count = simple_async_mock(100)
-        self.store.user_last_seen_monthly_active = simple_async_mock()
-        self.store.is_trial_user = simple_async_mock()
+        self.store.get_monthly_active_count = AsyncMock(return_value=100)
+        self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
+        self.store.is_trial_user = AsyncMock(return_value=False)
 
         appservice = ApplicationService(
             "abcd",
@@ -473,7 +474,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
     def test_reserved_threepid(self) -> None:
         self.auth_blocking._limit_usage_by_mau = True
         self.auth_blocking._max_mau_value = 1
-        self.store.get_monthly_active_count = simple_async_mock(2)
+        self.store.get_monthly_active_count = AsyncMock(return_value=2)
         threepid = {"medium": "email", "address": "reserved@server.com"}
         unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
         self.auth_blocking._mau_limits_reserved_threepids = [threepid]
diff --git a/tests/api/test_errors.py b/tests/api/test_errors.py
new file mode 100644
index 0000000000..8e159029d9
--- /dev/null
+++ b/tests/api/test_errors.py
@@ -0,0 +1,43 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from synapse.api.errors import LimitExceededError
+
+from tests import unittest
+
+
+class LimitExceededErrorTestCase(unittest.TestCase):
+    def test_key_appears_in_context_but_not_error_dict(self) -> None:
+        err = LimitExceededError("needle")
+        serialised = json.dumps(err.error_dict(None))
+        self.assertIn("needle", err.debug_context)
+        self.assertNotIn("needle", serialised)
+
+    # Create a sub-class to avoid mutating the class-level property.
+    class LimitExceededErrorHeaders(LimitExceededError):
+        include_retry_after_header = True
+
+    def test_limit_exceeded_header(self) -> None:
+        err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=100)
+        self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100)
+        assert err.headers is not None
+        self.assertEqual(err.headers.get("Retry-After"), "1")
+
+    def test_limit_exceeded_rounding(self) -> None:
+        err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=3001)
+        self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001)
+        assert err.headers is not None
+        self.assertEqual(err.headers.get("Retry-After"), "4")
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index fa6c1c02ce..a24638c9ef 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -1,5 +1,6 @@
 from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
 from synapse.appservice import ApplicationService
+from synapse.config.ratelimiting import RatelimitSettings
 from synapse.types import create_requester
 
 from tests import unittest
@@ -10,8 +11,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
         )
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(None, key="test_id", _time_now_s=0)
@@ -43,8 +43,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(
+                key="",
+                per_second=0.1,
+                burst_count=1,
+            ),
         )
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(as_requester, _time_now_s=0)
@@ -76,8 +79,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(
+                key="",
+                per_second=0.1,
+                burst_count=1,
+            ),
         )
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(as_requester, _time_now_s=0)
@@ -101,8 +107,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
         )
 
         # Shouldn't raise
@@ -128,8 +133,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
         )
 
         # First attempt should be allowed
@@ -177,8 +181,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
         )
 
         # First attempt should be allowed
@@ -208,8 +211,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=1,
+            cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
         )
         self.get_success_or_raise(
             limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
@@ -244,7 +246,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
             )
         )
 
-        limiter = Ratelimiter(store=store, clock=self.clock, rate_hz=0.1, burst_count=1)
+        limiter = Ratelimiter(
+            store=store,
+            clock=self.clock,
+            cfg=RatelimitSettings("", per_second=0.1, burst_count=1),
+        )
 
         # Shouldn't raise
         for _ in range(20):
@@ -254,8 +260,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=3,
+            cfg=RatelimitSettings(
+                key="",
+                per_second=0.1,
+                burst_count=3,
+            ),
         )
         # Test that 4 actions aren't allowed with a maximum burst of 3.
         allowed, time_allowed = self.get_success_or_raise(
@@ -321,8 +330,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=3,
+            cfg=RatelimitSettings("", per_second=0.1, burst_count=3),
         )
 
         def consume_at(time: float) -> bool:
@@ -346,8 +354,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=3,
+            cfg=RatelimitSettings(
+                "",
+                per_second=0.1,
+                burst_count=3,
+            ),
         )
 
         # Observe two actions, leaving room in the bucket for one more.
@@ -369,8 +380,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=3,
+            cfg=RatelimitSettings(
+                "",
+                per_second=0.1,
+                burst_count=3,
+            ),
         )
 
         # Observe three actions, filling up the bucket.
@@ -398,8 +412,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         limiter = Ratelimiter(
             store=self.hs.get_datastores().main,
             clock=self.clock,
-            rate_hz=0.1,
-            burst_count=3,
+            cfg=RatelimitSettings(
+                "",
+                per_second=0.1,
+                burst_count=3,
+            ),
         )
 
         # Observe four actions, exceeding the bucket.
diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py
index 3c635e3dcb..75fb5fae6b 100644
--- a/tests/appservice/test_api.py
+++ b/tests/appservice/test_api.py
@@ -96,7 +96,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
                 )
 
         # We assign to a method, which mypy doesn't like.
-        self.api.get_json = Mock(side_effect=get_json)  # type: ignore[assignment]
+        self.api.get_json = Mock(side_effect=get_json)  # type: ignore[method-assign]
 
         result = self.get_success(
             self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]})
@@ -168,7 +168,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
                 )
 
         # We assign to a method, which mypy doesn't like.
-        self.api.get_json = Mock(side_effect=get_json)  # type: ignore[assignment]
+        self.api.get_json = Mock(side_effect=get_json)  # type: ignore[method-assign]
 
         result = self.get_success(
             self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]})
@@ -215,7 +215,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
             return RESPONSE
 
         # We assign to a method, which mypy doesn't like.
-        self.api.post_json_get_json = Mock(side_effect=post_json_get_json)  # type: ignore[assignment]
+        self.api.post_json_get_json = Mock(side_effect=post_json_get_json)  # type: ignore[method-assign]
 
         MISSING_KEYS = [
             # Known user, known device, missing algorithm.
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 66753c60c4..6ac5fc1ae7 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -13,14 +13,13 @@
 # limitations under the License.
 import re
 from typing import Any, Generator
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.internet import defer
 
 from synapse.appservice import ApplicationService, Namespace
 
 from tests import unittest
-from tests.test_utils import simple_async_mock
 
 
 def _regex(regex: str, exclusive: bool = True) -> Namespace:
@@ -43,8 +42,8 @@ class ApplicationServiceTestCase(unittest.TestCase):
         )
 
         self.store = Mock()
-        self.store.get_aliases_for_room = simple_async_mock([])
-        self.store.get_local_users_in_room = simple_async_mock([])
+        self.store.get_aliases_for_room = AsyncMock(return_value=[])
+        self.store.get_local_users_in_room = AsyncMock(return_value=[])
 
     @defer.inlineCallbacks
     def test_regex_user_id_prefix_match(
@@ -127,10 +126,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.service.namespaces[ApplicationService.NS_ALIASES].append(
             _regex("#irc_.*:matrix.org")
         )
-        self.store.get_aliases_for_room = simple_async_mock(
-            ["#irc_foobar:matrix.org", "#athing:matrix.org"]
+        self.store.get_aliases_for_room = AsyncMock(
+            return_value=["#irc_foobar:matrix.org", "#athing:matrix.org"]
         )
-        self.store.get_local_users_in_room = simple_async_mock([])
+        self.store.get_local_users_in_room = AsyncMock(return_value=[])
         self.assertTrue(
             (
                 yield self.service.is_interested_in_event(
@@ -182,10 +181,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.service.namespaces[ApplicationService.NS_ALIASES].append(
             _regex("#irc_.*:matrix.org")
         )
-        self.store.get_aliases_for_room = simple_async_mock(
-            ["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
+        self.store.get_aliases_for_room = AsyncMock(
+            return_value=["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
         )
-        self.store.get_local_users_in_room = simple_async_mock([])
+        self.store.get_local_users_in_room = AsyncMock(return_value=[])
         self.assertFalse(
             (
                 yield defer.ensureDeferred(
@@ -205,8 +204,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
         )
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         self.event.sender = "@irc_foobar:matrix.org"
-        self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"])
-        self.store.get_local_users_in_room = simple_async_mock([])
+        self.store.get_aliases_for_room = AsyncMock(
+            return_value=["#irc_barfoo:matrix.org"]
+        )
+        self.store.get_local_users_in_room = AsyncMock(return_value=[])
         self.assertTrue(
             (
                 yield self.service.is_interested_in_event(
@@ -235,10 +236,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
     def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         # Note that @irc_fo:here is the AS user.
-        self.store.get_local_users_in_room = simple_async_mock(
-            ["@alice:here", "@irc_fo:here", "@bob:here"]
+        self.store.get_local_users_in_room = AsyncMock(
+            return_value=["@alice:here", "@irc_fo:here", "@bob:here"]
         )
-        self.store.get_aliases_for_room = simple_async_mock([])
+        self.store.get_aliases_for_room = AsyncMock(return_value=[])
 
         self.event.sender = "@xmpp_foobar:matrix.org"
         self.assertTrue(
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index e2a3bad065..445919417e 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import List, Optional, Sequence, Tuple, cast
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from typing_extensions import TypeAlias
 
@@ -37,7 +37,6 @@ from synapse.types import DeviceListUpdates, JsonDict
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import simple_async_mock
 
 from ..utils import MockClock
 
@@ -62,10 +61,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
         txn = Mock(id=txn_id, service=service, events=events)
 
         # mock methods
-        self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP)
-        txn.send = simple_async_mock(True)
-        txn.complete = simple_async_mock(True)
-        self.store.create_appservice_txn = simple_async_mock(txn)
+        self.store.get_appservice_state = AsyncMock(
+            return_value=ApplicationServiceState.UP
+        )
+        txn.send = AsyncMock(return_value=True)
+        txn.complete = AsyncMock(return_value=True)
+        self.store.create_appservice_txn = AsyncMock(return_value=txn)
 
         # actual call
         self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@@ -89,10 +90,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
         events = [Mock(), Mock()]
 
         txn = Mock(id="idhere", service=service, events=events)
-        self.store.get_appservice_state = simple_async_mock(
-            ApplicationServiceState.DOWN
+        self.store.get_appservice_state = AsyncMock(
+            return_value=ApplicationServiceState.DOWN
         )
-        self.store.create_appservice_txn = simple_async_mock(txn)
+        self.store.create_appservice_txn = AsyncMock(return_value=txn)
 
         # actual call
         self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@@ -118,10 +119,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
         txn = Mock(id=txn_id, service=service, events=events)
 
         # mock methods
-        self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP)
-        self.store.set_appservice_state = simple_async_mock(True)
-        txn.send = simple_async_mock(False)  # fails to send
-        self.store.create_appservice_txn = simple_async_mock(txn)
+        self.store.get_appservice_state = AsyncMock(
+            return_value=ApplicationServiceState.UP
+        )
+        self.store.set_appservice_state = AsyncMock(return_value=True)
+        txn.send = AsyncMock(return_value=False)  # fails to send
+        self.store.create_appservice_txn = AsyncMock(return_value=txn)
 
         # actual call
         self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@@ -150,7 +153,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
         self.as_api = Mock()
         self.store = Mock()
         self.service = Mock()
-        self.callback = simple_async_mock()
+        self.callback = AsyncMock()
         self.recoverer = _Recoverer(
             clock=cast(Clock, self.clock),
             as_api=self.as_api,
@@ -174,8 +177,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
         self.recoverer.recover()
         # shouldn't have called anything prior to waiting for exp backoff
         self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
-        txn.send = simple_async_mock(True)
-        txn.complete = simple_async_mock(None)
+        txn.send = AsyncMock(return_value=True)
+        txn.complete = AsyncMock(return_value=None)
         # wait for exp backoff
         self.clock.advance_time(2)
         self.assertEqual(1, txn.send.call_count)
@@ -202,8 +205,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
 
         self.recoverer.recover()
         self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
-        txn.send = simple_async_mock(False)
-        txn.complete = simple_async_mock(None)
+        txn.send = AsyncMock(return_value=False)
+        txn.complete = AsyncMock(return_value=None)
         self.clock.advance_time(2)
         self.assertEqual(1, txn.send.call_count)
         self.assertEqual(0, txn.complete.call_count)
@@ -216,7 +219,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
         self.assertEqual(3, txn.send.call_count)
         self.assertEqual(0, txn.complete.call_count)
         self.assertEqual(0, self.callback.call_count)
-        txn.send = simple_async_mock(True)  # successfully send the txn
+        txn.send = AsyncMock(return_value=True)  # successfully send the txn
         pop_txn = True  # returns the txn the first time, then no more.
         self.clock.advance_time(16)
         self.assertEqual(1, txn.send.call_count)  # new mock reset call count
@@ -244,7 +247,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer) -> None:
         self.scheduler = ApplicationServiceScheduler(hs)
         self.txn_ctrl = Mock()
-        self.txn_ctrl.send = simple_async_mock()
+        self.txn_ctrl.send = AsyncMock()
 
         # Replace instantiated _TransactionController instances with our Mock
         self.scheduler.txn_ctrl = self.txn_ctrl
diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py
index f12147eaa0..0c27dd21e2 100644
--- a/tests/config/test_ratelimiting.py
+++ b/tests/config/test_ratelimiting.py
@@ -12,11 +12,42 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from synapse.config.homeserver import HomeServerConfig
+from synapse.config.ratelimiting import RatelimitSettings
 
 from tests.unittest import TestCase
 from tests.utils import default_config
 
 
+class ParseRatelimitSettingsTestcase(TestCase):
+    def test_depth_1(self) -> None:
+        cfg = {
+            "a": {
+                "per_second": 5,
+                "burst_count": 10,
+            }
+        }
+        parsed = RatelimitSettings.parse(cfg, "a")
+        self.assertEqual(parsed, RatelimitSettings("a", 5, 10))
+
+    def test_depth_2(self) -> None:
+        cfg = {
+            "a": {
+                "b": {
+                    "per_second": 5,
+                    "burst_count": 10,
+                },
+            }
+        }
+        parsed = RatelimitSettings.parse(cfg, "a.b")
+        self.assertEqual(parsed, RatelimitSettings("a.b", 5, 10))
+
+    def test_missing(self) -> None:
+        parsed = RatelimitSettings.parse(
+            {}, "a", defaults={"per_second": 5, "burst_count": 10}
+        )
+        self.assertEqual(parsed, RatelimitSettings("a", 5, 10))
+
+
 class RatelimitConfigTestCase(TestCase):
     def test_parse_rc_federation(self) -> None:
         config_dict = default_config("test")
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 2be341ac7b..f93ba5d4cf 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 import time
 from typing import Any, Dict, List, Optional, cast
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 import attr
 import canonicaljson
@@ -45,7 +45,6 @@ from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 from tests.unittest import logcontext_clean, override_config
 
 
@@ -291,7 +290,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         with a null `ts_valid_until_ms`
         """
         mock_fetcher = Mock()
-        mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
+        mock_fetcher.get_keys = AsyncMock(return_value={})
 
         key1 = signedjson.key.generate_signing_key("1")
         r = self.hs.get_datastores().main.store_server_signature_keys(
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index 6fb1f1bd6e..0fcfe38efa 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 import attr
 
@@ -30,7 +30,6 @@ from synapse.types import JsonDict, StreamToken, create_requester
 from synapse.util import Clock
 
 from tests.handlers.test_sync import generate_sync_config
-from tests.test_utils import simple_async_mock
 from tests.unittest import (
     FederatingHomeserverTestCase,
     HomeserverTestCase,
@@ -157,7 +156,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # Mock out the calls over federation.
         self.fed_transport_client = Mock(spec=["send_transaction"])
-        self.fed_transport_client.send_transaction = simple_async_mock({})
+        self.fed_transport_client.send_transaction = AsyncMock(return_value={})
 
         hs = self.setup_test_homeserver(
             federation_transport_client=self.fed_transport_client,
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 129d7cfd93..73a2766baf 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.rest import admin
@@ -20,7 +20,6 @@ from synapse.rest.client import login, room
 from synapse.types import JsonDict, UserID, create_requester
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
@@ -58,7 +57,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         async def get_current_state_event_counts(room_id: str) -> int:
             return int(500 * 1.23)
 
-        store.get_current_state_event_counts = get_current_state_event_counts  # type: ignore[assignment]
+        store.get_current_state_event_counts = get_current_state_event_counts  # type: ignore[method-assign]
 
         # Get the room complexity again -- make sure it's our artificial value
         channel = self.make_signed_federation_request(
@@ -75,9 +74,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))  # type: ignore[assignment]
-        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(("", 1))
+        fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999})  # type: ignore[method-assign]
+        handler.federation_handler.do_invite_join = AsyncMock(  # type: ignore[method-assign]
+            return_value=("", 1)
         )
 
         d = handler._remote_join(
@@ -106,9 +105,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))  # type: ignore[assignment]
-        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(("", 1))
+        fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999})  # type: ignore[method-assign]
+        handler.federation_handler.do_invite_join = AsyncMock(  # type: ignore[method-assign]
+            return_value=("", 1)
         )
 
         d = handler._remote_join(
@@ -143,16 +142,16 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
-        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(("", 1))
+        fed_transport.client.get_json = AsyncMock(return_value=None)  # type: ignore[method-assign]
+        handler.federation_handler.do_invite_join = AsyncMock(  # type: ignore[method-assign]
+            return_value=("", 1)
         )
 
         # Artificially raise the complexity
         async def get_current_state_event_counts(room_id: str) -> int:
             return 600
 
-        self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts  # type: ignore[assignment]
+        self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts  # type: ignore[method-assign]
 
         d = handler._remote_join(
             create_requester(u1),
@@ -200,9 +199,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))  # type: ignore[assignment]
-        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(("", 1))
+        fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999})  # type: ignore[method-assign]
+        handler.federation_handler.do_invite_join = AsyncMock(  # type: ignore[method-assign]
+            return_value=("", 1)
         )
 
         d = handler._remote_join(
@@ -230,9 +229,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))  # type: ignore[assignment]
-        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(("", 1))
+        fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999})  # type: ignore[method-assign]
+        handler.federation_handler.do_invite_join = AsyncMock(  # type: ignore[method-assign]
+            return_value=("", 1)
         )
 
         d = handler._remote_join(
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index b290b020a2..75ae740b43 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -1,6 +1,6 @@
 from typing import Callable, Collection, List, Optional, Tuple
 from unittest import mock
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -19,7 +19,7 @@ from synapse.types import JsonDict
 from synapse.util import Clock
 from synapse.util.retryutils import NotRetryingDestination
 
-from tests.test_utils import event_injection, make_awaitable
+from tests.test_utils import event_injection
 from tests.unittest import FederatingHomeserverTestCase
 
 
@@ -50,8 +50,8 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         # This mock is crucial for destination_rooms to be populated.
         # TODO: this seems to no longer be the case---tests pass with this mock
         # commented out.
-        state_storage_controller.get_current_hosts_in_room = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable({"test", "host2"})
+        state_storage_controller.get_current_hosts_in_room = AsyncMock(  # type: ignore[method-assign]
+            return_value={"test", "host2"}
         )
 
         # whenever send_transaction is called, record the pdu data
@@ -436,7 +436,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         def wake_destination_track(destination: str) -> None:
             woken.add(destination)
 
-        self.federation_sender.wake_destination = wake_destination_track  # type: ignore[assignment]
+        self.federation_sender.wake_destination = wake_destination_track  # type: ignore[method-assign]
 
         # We wait quite long so that all dests can be woken up, since there is a delay
         # between them.
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 9e104fd96a..caf04b54cb 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Callable, FrozenSet, List, Optional, Set
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from signedjson import key, sign
 from signedjson.types import BaseKey, SigningKey
@@ -29,7 +29,6 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict, ReadReceipt
 from synapse.util import Clock
 
-from tests.test_utils import make_awaitable
 from tests.unittest import HomeserverTestCase
 
 
@@ -43,15 +42,16 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.federation_transport_client = Mock(spec=["send_transaction"])
+        self.federation_transport_client.send_transaction = AsyncMock()
         hs = self.setup_test_homeserver(
             federation_transport_client=self.federation_transport_client,
         )
 
-        hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable({"test", "host2"})
+        hs.get_storage_controllers().state.get_current_hosts_in_room = AsyncMock(  # type: ignore[method-assign]
+            return_value={"test", "host2"}
         )
 
-        hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (  # type: ignore[assignment]
+        hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (  # type: ignore[method-assign]
             hs.get_storage_controllers().state.get_current_hosts_in_room
         )
 
@@ -64,7 +64,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
 
     def test_send_receipts(self) -> None:
         mock_send_transaction = self.federation_transport_client.send_transaction
-        mock_send_transaction.return_value = make_awaitable({})
+        mock_send_transaction.return_value = {}
 
         sender = self.hs.get_federation_sender()
         receipt = ReadReceipt(
@@ -75,7 +75,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
             thread_id=None,
             data={"ts": 1234},
         )
-        self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
+        self.get_success(sender.send_read_receipt(receipt))
 
         self.pump()
 
@@ -104,13 +104,16 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
 
     def test_send_receipts_thread(self) -> None:
         mock_send_transaction = self.federation_transport_client.send_transaction
-        mock_send_transaction.return_value = make_awaitable({})
+        mock_send_transaction.return_value = {}
 
         # Create receipts for:
         #
         # * The same room / user on multiple threads.
         # * A different user in the same room.
         sender = self.hs.get_federation_sender()
+        # Hack so that we have a txn in-flight so we batch up read receipts
+        # below
+        sender.wake_destination("host2")
         for user, thread in (
             ("alice", None),
             ("alice", "thread"),
@@ -125,9 +128,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
                 thread_id=thread,
                 data={"ts": 1234},
             )
-            self.successResultOf(
-                defer.ensureDeferred(sender.send_read_receipt(receipt))
-            )
+            defer.ensureDeferred(sender.send_read_receipt(receipt))
 
         self.pump()
 
@@ -180,7 +181,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         """Send two receipts in quick succession; the second should be flushed, but
         only after 20ms"""
         mock_send_transaction = self.federation_transport_client.send_transaction
-        mock_send_transaction.return_value = make_awaitable({})
+        mock_send_transaction.return_value = {}
 
         sender = self.hs.get_federation_sender()
         receipt = ReadReceipt(
@@ -191,7 +192,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
             thread_id=None,
             data={"ts": 1234},
         )
-        self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
+        self.get_success(sender.send_read_receipt(receipt))
 
         self.pump()
 
@@ -276,6 +277,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         self.federation_transport_client = Mock(
             spec=["send_transaction", "query_user_devices"]
         )
+        self.federation_transport_client.send_transaction = AsyncMock()
+        self.federation_transport_client.query_user_devices = AsyncMock()
         return self.setup_test_homeserver(
             federation_transport_client=self.federation_transport_client,
         )
@@ -317,13 +320,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
             self.record_transaction
         )
 
-    def record_transaction(
+    async def record_transaction(
         self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None
-    ) -> "defer.Deferred[JsonDict]":
+    ) -> JsonDict:
         assert json_cb is not None
         data = json_cb()
         self.edus.extend(data["edus"])
-        return defer.succeed({})
+        return {}
 
     def test_send_device_updates(self) -> None:
         """Basic case: each device update should result in an EDU"""
@@ -340,7 +343,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         self.reactor.advance(1)
 
         # a second call should produce no new device EDUs
-        self.hs.get_federation_sender().send_device_messages("host2")
+        self.get_success(
+            self.hs.get_federation_sender().send_device_messages(["host2"])
+        )
         self.assertEqual(self.edus, [])
 
         # a second device
@@ -354,15 +359,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         # Send the server a device list EDU for the other user, this will cause
         # it to try and resync the device lists.
-        self.federation_transport_client.query_user_devices.return_value = (
-            make_awaitable(
-                {
-                    "stream_id": "1",
-                    "user_id": "@user2:host2",
-                    "devices": [{"device_id": "D1"}],
-                }
-            )
-        )
+        self.federation_transport_client.query_user_devices.return_value = {
+            "stream_id": "1",
+            "user_id": "@user2:host2",
+            "devices": [{"device_id": "D1"}],
+        }
 
         self.get_success(
             self.device_handler.device_list_updater.incoming_device_list_update(
@@ -533,7 +534,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         recovery
         """
         mock_send_txn = self.federation_transport_client.send_transaction
-        mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
+        mock_send_txn.side_effect = AssertionError("fail")
 
         # create devices
         u1 = self.register_user("user", "pass")
@@ -552,7 +553,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         # recover the server
         mock_send_txn.side_effect = self.record_transaction
-        self.hs.get_federation_sender().send_device_messages("host2")
+        self.get_success(
+            self.hs.get_federation_sender().send_device_messages(["host2"])
+        )
 
         # We queue up device list updates to be sent over federation, so we
         # advance to clear the queue.
@@ -578,7 +581,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         This case tests the behaviour when the server has never been reachable.
         """
         mock_send_txn = self.federation_transport_client.send_transaction
-        mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
+        mock_send_txn.side_effect = AssertionError("fail")
 
         # create devices
         u1 = self.register_user("user", "pass")
@@ -603,7 +606,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         # recover the server
         mock_send_txn.side_effect = self.record_transaction
-        self.hs.get_federation_sender().send_device_messages("host2")
+        self.get_success(
+            self.hs.get_federation_sender().send_device_messages(["host2"])
+        )
 
         # We queue up device list updates to be sent over federation, so we
         # advance to clear the queue.
@@ -636,7 +641,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         # now the server goes offline
         mock_send_txn = self.federation_transport_client.send_transaction
-        mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
+        mock_send_txn.side_effect = AssertionError("fail")
 
         self.login("user", "pass", device_id="D2")
         self.login("user", "pass", device_id="D3")
@@ -658,7 +663,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         # recover the server
         mock_send_txn.side_effect = self.record_transaction
-        self.hs.get_federation_sender().send_device_messages("host2")
+        self.get_success(
+            self.hs.get_federation_sender().send_device_messages(["host2"])
+        )
 
         # We queue up device list updates to be sent over federation, so we
         # advance to clear the queue.
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index 70209ab090..3f42f79f26 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -218,7 +218,7 @@ class FederationKnockingTestCase(
         ) -> EventBase:
             return pdu
 
-        homeserver.get_federation_server()._check_sigs_and_hash = (  # type: ignore[assignment]
+        homeserver.get_federation_server()._check_sigs_and_hash = (  # type: ignore[method-assign]
             approve_all_signature_checking
         )
 
@@ -229,7 +229,7 @@ class FederationKnockingTestCase(
         ) -> None:
             pass
 
-        homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth  # type: ignore[assignment]
+        homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth  # type: ignore[method-assign]
 
         return super().prepare(reactor, clock, homeserver)
 
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 9014e60577..46d022092e 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from typing import Dict, Iterable, List, Optional
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from parameterized import parameterized
 
@@ -36,7 +36,7 @@ from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
 from tests import unittest
-from tests.test_utils import event_injection, make_awaitable, simple_async_mock
+from tests.test_utils import event_injection
 from tests.unittest import override_config
 from tests.utils import MockClock
 
@@ -46,15 +46,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
 
     def setUp(self) -> None:
         self.mock_store = Mock()
-        self.mock_as_api = Mock()
+        self.mock_as_api = AsyncMock()
         self.mock_scheduler = Mock()
         hs = Mock()
         hs.get_datastores.return_value = Mock(main=self.mock_store)
-        self.mock_store.get_appservice_last_pos.return_value = make_awaitable(None)
-        self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
-        self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable(
-            None
-        )
+        self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None)
+        self.mock_store.set_appservice_last_pos = AsyncMock(return_value=None)
+        self.mock_store.set_appservice_stream_type_pos = AsyncMock(return_value=None)
         hs.get_application_service_api.return_value = self.mock_as_api
         hs.get_application_service_scheduler.return_value = self.mock_scheduler
         hs.get_clock.return_value = MockClock()
@@ -69,21 +67,25 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             self._mkservice(is_interested_in_event=False),
         ]
 
-        self.mock_as_api.query_user.return_value = make_awaitable(True)
+        self.mock_as_api.query_user.return_value = True
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_user_by_id.return_value = make_awaitable([])
+        self.mock_store.get_user_by_id = AsyncMock(return_value=[])
 
         event = Mock(
             sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
         )
-        self.mock_store.get_all_new_event_ids_stream.side_effect = [
-            make_awaitable((0, {})),
-            make_awaitable((1, {event.event_id: 0})),
-        ]
-        self.mock_store.get_events_as_list.side_effect = [
-            make_awaitable([]),
-            make_awaitable([event]),
-        ]
+        self.mock_store.get_all_new_event_ids_stream = AsyncMock(
+            side_effect=[
+                (0, {}),
+                (1, {event.event_id: 0}),
+            ]
+        )
+        self.mock_store.get_events_as_list = AsyncMock(
+            side_effect=[
+                [],
+                [event],
+            ]
+        )
         self.handler.notify_interested_services(RoomStreamToken(None, 1))
 
         self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
@@ -95,14 +97,16 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         services = [self._mkservice(is_interested_in_event=True)]
         services[0].is_interested_in_user.return_value = True
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_user_by_id.return_value = make_awaitable(None)
+        self.mock_store.get_user_by_id = AsyncMock(return_value=None)
 
         event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
-        self.mock_as_api.query_user.return_value = make_awaitable(True)
-        self.mock_store.get_all_new_event_ids_stream.side_effect = [
-            make_awaitable((0, {event.event_id: 0})),
-        ]
-        self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])]
+        self.mock_as_api.query_user.return_value = True
+        self.mock_store.get_all_new_event_ids_stream = AsyncMock(
+            side_effect=[
+                (0, {event.event_id: 0}),
+            ]
+        )
+        self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]])
         self.handler.notify_interested_services(RoomStreamToken(None, 0))
 
         self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
@@ -112,13 +116,15 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         services = [self._mkservice(is_interested_in_event=True)]
         services[0].is_interested_in_user.return_value = True
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
+        self.mock_store.get_user_by_id = AsyncMock(return_value={"name": user_id})
 
         event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
-        self.mock_as_api.query_user.return_value = make_awaitable(True)
-        self.mock_store.get_all_new_event_ids_stream.side_effect = [
-            make_awaitable((0, [event], {event.event_id: 0})),
-        ]
+        self.mock_as_api.query_user.return_value = True
+        self.mock_store.get_all_new_event_ids_stream = AsyncMock(
+            side_effect=[
+                (0, [event], {event.event_id: 0}),
+            ]
+        )
 
         self.handler.notify_interested_services(RoomStreamToken(None, 0))
 
@@ -141,10 +147,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             self._mkservice_alias(is_room_alias_in_namespace=False),
         ]
 
-        self.mock_as_api.query_alias.return_value = make_awaitable(True)
+        self.mock_as_api.query_alias = AsyncMock(return_value=True)
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_association_from_room_alias.return_value = make_awaitable(
-            Mock(room_id=room_id, servers=servers)
+        self.mock_store.get_association_from_room_alias = AsyncMock(
+            return_value=Mock(room_id=room_id, servers=servers)
         )
 
         result = self.successResultOf(
@@ -177,7 +183,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
     def test_get_3pe_protocols_protocol_no_response(self) -> None:
         service = self._mkservice(False, ["my-protocol"])
         self.mock_store.get_app_services.return_value = [service]
-        self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None)
+        self.mock_as_api.get_3pe_protocol.return_value = None
         response = self.successResultOf(
             defer.ensureDeferred(self.handler.get_3pe_protocols())
         )
@@ -189,9 +195,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
     def test_get_3pe_protocols_select_one_protocol(self) -> None:
         service = self._mkservice(False, ["my-protocol"])
         self.mock_store.get_app_services.return_value = [service]
-        self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
-            {"x-protocol-data": 42, "instances": []}
-        )
+        self.mock_as_api.get_3pe_protocol.return_value = {
+            "x-protocol-data": 42,
+            "instances": [],
+        }
         response = self.successResultOf(
             defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
         )
@@ -205,9 +212,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
     def test_get_3pe_protocols_one_protocol(self) -> None:
         service = self._mkservice(False, ["my-protocol"])
         self.mock_store.get_app_services.return_value = [service]
-        self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
-            {"x-protocol-data": 42, "instances": []}
-        )
+        self.mock_as_api.get_3pe_protocol.return_value = {
+            "x-protocol-data": 42,
+            "instances": [],
+        }
         response = self.successResultOf(
             defer.ensureDeferred(self.handler.get_3pe_protocols())
         )
@@ -222,9 +230,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         service_one = self._mkservice(False, ["my-protocol"])
         service_two = self._mkservice(False, ["other-protocol"])
         self.mock_store.get_app_services.return_value = [service_one, service_two]
-        self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
-            {"x-protocol-data": 42, "instances": []}
-        )
+        self.mock_as_api.get_3pe_protocol.return_value = {
+            "x-protocol-data": 42,
+            "instances": [],
+        }
         response = self.successResultOf(
             defer.ensureDeferred(self.handler.get_3pe_protocols())
         )
@@ -287,13 +296,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         interested_service = self._mkservice(is_interested_in_event=True)
         services = [interested_service]
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
-            579
-        )
+        self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=579)
 
         event = Mock(event_id="event_1")
-        self.event_source.sources.receipt.get_new_events_as.return_value = (
-            make_awaitable(([event], None))
+        self.event_source.sources.receipt.get_new_events_as = AsyncMock(
+            return_value=([event], None)
         )
 
         self.handler.notify_interested_services_ephemeral(
@@ -317,13 +324,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         services = [interested_service]
 
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
-            580
-        )
+        self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=580)
 
         event = Mock(event_id="event_1")
-        self.event_source.sources.receipt.get_new_events_as.return_value = (
-            make_awaitable(([event], None))
+        self.event_source.sources.receipt.get_new_events_as = AsyncMock(
+            return_value=([event], None)
         )
 
         self.handler.notify_interested_services_ephemeral(
@@ -350,9 +355,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             A mock representing the ApplicationService.
         """
         service = Mock()
-        service.is_interested_in_event.return_value = make_awaitable(
-            is_interested_in_event
-        )
+        service.is_interested_in_event = AsyncMock(return_value=is_interested_in_event)
         service.token = "mock_service_token"
         service.url = "mock_service_url"
         service.protocols = protocols
@@ -396,12 +399,12 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
         self.hs = hs
         # Mock the ApplicationServiceScheduler's _TransactionController's send method so that
         # we can track any outgoing ephemeral events
-        self.send_mock = simple_async_mock()
-        hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock  # type: ignore[assignment]
+        self.send_mock = AsyncMock()
+        hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock  # type: ignore[method-assign]
 
         # Mock out application services, and allow defining our own in tests
         self._services: List[ApplicationService] = []
-        self.hs.get_datastores().main.get_app_services = Mock(  # type: ignore[assignment]
+        self.hs.get_datastores().main.get_app_services = Mock(  # type: ignore[method-assign]
             return_value=self._services
         )
 
@@ -894,12 +897,12 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
 
         # Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that
         # will be sent over the wire
-        self.put_json = simple_async_mock()
-        hs.get_application_service_api().put_json = self.put_json  # type: ignore[assignment]
+        self.put_json = AsyncMock()
+        hs.get_application_service_api().put_json = self.put_json  # type: ignore[method-assign]
 
         # Mock out application services, and allow defining our own in tests
         self._services: List[ApplicationService] = []
-        self.hs.get_datastores().main.get_app_services = Mock(  # type: ignore[assignment]
+        self.hs.get_datastores().main.get_app_services = Mock(  # type: ignore[method-assign]
             return_value=self._services
         )
 
@@ -1000,8 +1003,8 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # Mock the ApplicationServiceScheduler's _TransactionController's send method so that
         # we can track what's going out
-        self.send_mock = simple_async_mock()
-        hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock  # type: ignore[assignment]  # We assign to a method.
+        self.send_mock = AsyncMock()
+        hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock  # type: ignore[method-assign]  # We assign to a method.
 
         # Define an application service for the tests
         self._service_token = "VERYSECRET"
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 036dbbc45b..413ff8795b 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Optional
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
 
 import pymacaroons
 
@@ -25,7 +25,6 @@ from synapse.server import HomeServer
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class AuthTestCase(unittest.HomeserverTestCase):
@@ -166,8 +165,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
     def test_mau_limits_exceeded_large(self) -> None:
         self.auth_blocking._limit_usage_by_mau = True
-        self.hs.get_datastores().main.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.large_number_of_users)
+        self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
+            return_value=self.large_number_of_users
         )
 
         self.get_failure(
@@ -177,8 +176,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
             ResourceLimitError,
         )
 
-        self.hs.get_datastores().main.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.large_number_of_users)
+        self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
+            return_value=self.large_number_of_users
         )
         token = self.get_success(
             self.auth_handler.create_login_token_for_user_id(self.user1)
@@ -191,8 +190,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.auth_blocking._limit_usage_by_mau = True
 
         # Set the server to be at the edge of too many users.
-        self.hs.get_datastores().main.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.auth_blocking._max_mau_value)
+        self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
+            return_value=self.auth_blocking._max_mau_value
         )
 
         # If not in monthly active cohort
@@ -208,8 +207,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.assertIsNone(self.token_login(token))
 
         # If in monthly active cohort
-        self.hs.get_datastores().main.user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(self.clock.time_msec())
+        self.hs.get_datastores().main.user_last_seen_monthly_active = AsyncMock(
+            return_value=self.clock.time_msec()
         )
         self.get_success(
             self.auth_handler.create_access_token_for_user_id(
@@ -224,8 +223,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
     def test_mau_limits_not_exceeded(self) -> None:
         self.auth_blocking._limit_usage_by_mau = True
 
-        self.hs.get_datastores().main.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.small_number_of_users)
+        self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
+            return_value=self.small_number_of_users
         )
         # Ensure does not raise exception
         self.get_success(
@@ -234,8 +233,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.hs.get_datastores().main.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.small_number_of_users)
+        self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
+            return_value=self.small_number_of_users
         )
         token = self.get_success(
             self.auth_handler.create_login_token_for_user_id(self.user1)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 63aad0d10c..8582b1cd1e 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -12,7 +12,7 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 from typing import Any, Dict
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -20,7 +20,6 @@ from synapse.handlers.cas import CasResponse
 from synapse.server import HomeServer
 from synapse.util import Clock
 
-from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
 # These are a few constants that are used as config parameters in the tests.
@@ -61,7 +60,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
 
         cas_response = CasResponse("test_user", {})
         request = _mock_request()
@@ -89,7 +88,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
 
         # Map a user via SSO.
         cas_response = CasResponse("test_user", {})
@@ -129,7 +128,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
 
         cas_response = CasResponse("föö", {})
         request = _mock_request()
@@ -160,7 +159,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
 
         # The response doesn't have the proper userGroup or department.
         cas_response = CasResponse("test_user", {})
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index e1e58fa6e6..55a4f95ef3 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -32,7 +32,6 @@ from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
 user1 = "@boris:aaa"
@@ -41,7 +40,7 @@ user2 = "@theresa:bbb"
 
 class DeviceTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        self.appservice_api = mock.Mock()
+        self.appservice_api = mock.AsyncMock()
         hs = self.setup_test_homeserver(
             "server",
             application_service_api=self.appservice_api,
@@ -123,50 +122,50 @@ class DeviceTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(3, len(res))
         device_map = {d["device_id"]: d for d in res}
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": user1,
                 "device_id": "xyz",
                 "display_name": "display 0",
                 "last_seen_ip": None,
                 "last_seen_ts": None,
-            },
-            device_map["xyz"],
+            }.items(),
+            device_map["xyz"].items(),
         )
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": user1,
                 "device_id": "fco",
                 "display_name": "display 1",
                 "last_seen_ip": "ip1",
                 "last_seen_ts": 1000000,
-            },
-            device_map["fco"],
+            }.items(),
+            device_map["fco"].items(),
         )
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": user1,
                 "device_id": "abc",
                 "display_name": "display 2",
                 "last_seen_ip": "ip3",
                 "last_seen_ts": 3000000,
-            },
-            device_map["abc"],
+            }.items(),
+            device_map["abc"].items(),
         )
 
     def test_get_device(self) -> None:
         self._record_users()
 
         res = self.get_success(self.handler.get_device(user1, "abc"))
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": user1,
                 "device_id": "abc",
                 "display_name": "display 2",
                 "last_seen_ip": "ip3",
                 "last_seen_ts": 3000000,
-            },
-            res,
+            }.items(),
+            res.items(),
         )
 
     def test_delete_device(self) -> None:
@@ -375,13 +374,11 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         )
 
         # Setup a response.
-        self.appservice_api.query_keys.return_value = make_awaitable(
-            {
-                "device_keys": {
-                    local_user: {device_2: device_key_2b, device_3: device_key_3}
-                }
+        self.appservice_api.query_keys.return_value = {
+            "device_keys": {
+                local_user: {device_2: device_key_2b, device_3: device_key_3}
             }
-        )
+        }
 
         # Request all devices.
         res = self.get_success(
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 90aec484c4..367d94eca3 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Any, Awaitable, Callable, Dict
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -27,14 +27,13 @@ from synapse.types import JsonDict, RoomAlias, create_requester
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class DirectoryTestCase(unittest.HomeserverTestCase):
     """Tests the directory service."""
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        self.mock_federation = Mock()
+        self.mock_federation = AsyncMock()
         self.mock_registry = Mock()
 
         self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
@@ -73,9 +72,10 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
 
     def test_get_remote_association(self) -> None:
-        self.mock_federation.make_query.return_value = make_awaitable(
-            {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
-        )
+        self.mock_federation.make_query.return_value = {
+            "room_id": "!8765qwer:test",
+            "servers": ["test", "remote"],
+        }
 
         result = self.get_success(self.handler.get_association(self.remote_room))
 
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 2eaffe511e..c5556f2844 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -13,7 +13,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Iterable
+from typing import Dict, Iterable
 from unittest import mock
 
 from parameterized import parameterized
@@ -31,13 +31,12 @@ from synapse.types import JsonDict, UserID
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
 
 class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        self.appservice_api = mock.Mock()
+        self.appservice_api = mock.AsyncMock()
         return self.setup_test_homeserver(
             federation_client=mock.Mock(), application_service_api=self.appservice_api
         )
@@ -801,29 +800,27 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
         remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
 
-        self.hs.get_federation_client().query_client_keys = mock.Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(
-                {
-                    "device_keys": {remote_user_id: {}},
-                    "master_keys": {
-                        remote_user_id: {
-                            "user_id": remote_user_id,
-                            "usage": ["master"],
-                            "keys": {"ed25519:" + remote_master_key: remote_master_key},
-                        },
-                    },
-                    "self_signing_keys": {
-                        remote_user_id: {
-                            "user_id": remote_user_id,
-                            "usage": ["self_signing"],
-                            "keys": {
-                                "ed25519:"
-                                + remote_self_signing_key: remote_self_signing_key
-                            },
-                        }
+        self.hs.get_federation_client().query_client_keys = mock.AsyncMock(  # type: ignore[method-assign]
+            return_value={
+                "device_keys": {remote_user_id: {}},
+                "master_keys": {
+                    remote_user_id: {
+                        "user_id": remote_user_id,
+                        "usage": ["master"],
+                        "keys": {"ed25519:" + remote_master_key: remote_master_key},
                     },
-                }
-            )
+                },
+                "self_signing_keys": {
+                    remote_user_id: {
+                        "user_id": remote_user_id,
+                        "usage": ["self_signing"],
+                        "keys": {
+                            "ed25519:"
+                            + remote_self_signing_key: remote_self_signing_key
+                        },
+                    }
+                },
+            }
         )
 
         e2e_handler = self.hs.get_e2e_keys_handler()
@@ -874,34 +871,29 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
         # Pretend we're sharing a room with the user we're querying. If not,
         # `_query_devices_for_destination` will return early.
-        self.store.get_rooms_for_user = mock.Mock(
-            return_value=make_awaitable({"some_room_id"})
-        )
+        self.store.get_rooms_for_user = mock.AsyncMock(return_value={"some_room_id"})
 
         remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
         remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
 
-        self.hs.get_federation_client().query_user_devices = mock.Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(
-                {
+        self.hs.get_federation_client().query_user_devices = mock.AsyncMock(  # type: ignore[method-assign]
+            return_value={
+                "user_id": remote_user_id,
+                "stream_id": 1,
+                "devices": [],
+                "master_key": {
                     "user_id": remote_user_id,
-                    "stream_id": 1,
-                    "devices": [],
-                    "master_key": {
-                        "user_id": remote_user_id,
-                        "usage": ["master"],
-                        "keys": {"ed25519:" + remote_master_key: remote_master_key},
-                    },
-                    "self_signing_key": {
-                        "user_id": remote_user_id,
-                        "usage": ["self_signing"],
-                        "keys": {
-                            "ed25519:"
-                            + remote_self_signing_key: remote_self_signing_key
-                        },
+                    "usage": ["master"],
+                    "keys": {"ed25519:" + remote_master_key: remote_master_key},
+                },
+                "self_signing_key": {
+                    "user_id": remote_user_id,
+                    "usage": ["self_signing"],
+                    "keys": {
+                        "ed25519:" + remote_self_signing_key: remote_self_signing_key
                     },
-                }
-            )
+                },
+            }
         )
 
         e2e_handler = self.hs.get_e2e_keys_handler()
@@ -987,20 +979,20 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         mock_get_rooms = mock.patch.object(
             self.store,
             "get_rooms_for_user",
-            new_callable=mock.MagicMock,
-            return_value=make_awaitable(["some_room_id"]),
+            new_callable=mock.AsyncMock,
+            return_value=["some_room_id"],
         )
         mock_get_users = mock.patch.object(
             self.store,
             "get_users_server_still_shares_room_with",
-            new_callable=mock.MagicMock,
-            return_value=make_awaitable({remote_user_id}),
+            new_callable=mock.AsyncMock,
+            return_value={remote_user_id},
         )
         mock_request = mock.patch.object(
             self.hs.get_federation_client(),
             "query_user_devices",
-            new_callable=mock.MagicMock,
-            return_value=make_awaitable(response_body),
+            new_callable=mock.AsyncMock,
+            return_value=response_body,
         )
 
         with mock_get_rooms, mock_get_users, mock_request as mocked_federation_request:
@@ -1060,8 +1052,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         )
 
         # Setup a response, but only for device 2.
-        self.appservice_api.claim_client_keys.return_value = make_awaitable(
-            ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1", 1)])
+        self.appservice_api.claim_client_keys.return_value = (
+            {local_user: {device_id_2: otk}},
+            [(local_user, device_id_1, "alg1", 1)],
         )
 
         # we shouldn't have any unused fallback keys yet
@@ -1127,9 +1120,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         )
 
         # Setup a response.
-        self.appservice_api.claim_client_keys.return_value = make_awaitable(
-            ({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, [])
-        )
+        response: Dict[str, Dict[str, Dict[str, JsonDict]]] = {
+            local_user: {device_id_1: {**as_otk, **as_fallback_key}}
+        }
+        self.appservice_api.claim_client_keys.return_value = (response, [])
 
         # Claim OTKs, which will ask the appservice and do nothing else.
         claim_res = self.get_success(
@@ -1171,8 +1165,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         self.assertEqual(fallback_res, ["alg1"])
 
         # The appservice will return only the OTK.
-        self.appservice_api.claim_client_keys.return_value = make_awaitable(
-            ({local_user: {device_id_1: as_otk}}, [])
+        self.appservice_api.claim_client_keys.return_value = (
+            {local_user: {device_id_1: as_otk}},
+            [],
         )
 
         # Claim OTKs, which should return the OTK from the appservice and the
@@ -1234,8 +1229,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         self.assertEqual(fallback_res, ["alg1"])
 
         # Finally, return only the fallback key from the appservice.
-        self.appservice_api.claim_client_keys.return_value = make_awaitable(
-            ({local_user: {device_id_1: as_fallback_key}}, [])
+        self.appservice_api.claim_client_keys.return_value = (
+            {local_user: {device_id_1: as_fallback_key}},
+            [],
         )
 
         # Claim OTKs, which will return only the fallback key from the database.
@@ -1350,13 +1346,11 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         )
 
         # Setup a response.
-        self.appservice_api.query_keys.return_value = make_awaitable(
-            {
-                "device_keys": {
-                    local_user: {device_2: device_key_2b, device_3: device_key_3}
-                }
+        self.appservice_api.query_keys.return_value = {
+            "device_keys": {
+                local_user: {device_2: device_key_2b, device_3: device_key_3}
             }
-        )
+        }
 
         # Request all devices.
         res = self.get_success(self.handler.query_local_devices({local_user: None}))
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 5f11d5df11..21d63ab1f2 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -14,7 +14,7 @@
 import logging
 from typing import Collection, Optional, cast
 from unittest import TestCase
-from unittest.mock import Mock, patch
+from unittest.mock import AsyncMock, Mock, patch
 
 from twisted.internet.defer import Deferred
 from twisted.test.proto_helpers import MemoryReactor
@@ -40,7 +40,7 @@ from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
 from tests import unittest
-from tests.test_utils import event_injection, make_awaitable
+from tests.test_utils import event_injection
 
 logger = logging.getLogger(__name__)
 
@@ -370,15 +370,15 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
 
         # We mock out the FederationClient.backfill method, to pretend that a remote
         # server has returned our fake event.
-        federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
-        self.hs.get_federation_client().backfill = federation_client_backfill_mock  # type: ignore[assignment]
+        federation_client_backfill_mock = AsyncMock(return_value=[event])
+        self.hs.get_federation_client().backfill = federation_client_backfill_mock  # type: ignore[method-assign]
 
         # We also mock the persist method with a side effect of itself. This allows us
         # to track when it has been called while preserving its function.
         persist_events_and_notify_mock = Mock(
             side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
         )
-        self.hs.get_federation_event_handler().persist_events_and_notify = (  # type: ignore[assignment]
+        self.hs.get_federation_event_handler().persist_events_and_notify = (  # type: ignore[method-assign]
             persist_events_and_notify_mock
         )
 
@@ -631,33 +631,29 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             },
             RoomVersions.V10,
         )
-        mock_make_membership_event = Mock(
-            return_value=make_awaitable(
-                (
-                    "example.com",
-                    membership_event,
-                    RoomVersions.V10,
-                )
+        mock_make_membership_event = AsyncMock(
+            return_value=(
+                "example.com",
+                membership_event,
+                RoomVersions.V10,
             )
         )
-        mock_send_join = Mock(
-            return_value=make_awaitable(
-                SendJoinResult(
-                    membership_event,
-                    "example.com",
-                    state=[
-                        EVENT_CREATE,
-                        EVENT_CREATOR_MEMBERSHIP,
-                        EVENT_INVITATION_MEMBERSHIP,
-                    ],
-                    auth_chain=[
-                        EVENT_CREATE,
-                        EVENT_CREATOR_MEMBERSHIP,
-                        EVENT_INVITATION_MEMBERSHIP,
-                    ],
-                    partial_state=True,
-                    servers_in_room={"example.com"},
-                )
+        mock_send_join = AsyncMock(
+            return_value=SendJoinResult(
+                membership_event,
+                "example.com",
+                state=[
+                    EVENT_CREATE,
+                    EVENT_CREATOR_MEMBERSHIP,
+                    EVENT_INVITATION_MEMBERSHIP,
+                ],
+                auth_chain=[
+                    EVENT_CREATE,
+                    EVENT_CREATOR_MEMBERSHIP,
+                    EVENT_INVITATION_MEMBERSHIP,
+                ],
+                partial_state=True,
+                servers_in_room={"example.com"},
             )
         )
 
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 23f1b33b2f..70e6a7e142 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -35,7 +35,7 @@ from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import event_injection, make_awaitable
+from tests.test_utils import event_injection
 
 
 class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
@@ -50,6 +50,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
         self.mock_federation_transport_client = mock.Mock(
             spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"]
         )
+        self.mock_federation_transport_client.get_room_state_ids = mock.AsyncMock()
+        self.mock_federation_transport_client.get_room_state = mock.AsyncMock()
+        self.mock_federation_transport_client.get_event = mock.AsyncMock()
+        self.mock_federation_transport_client.backfill = mock.AsyncMock()
         return super().setup_test_homeserver(
             federation_transport_client=self.mock_federation_transport_client
         )
@@ -198,20 +202,14 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
         )
 
         # we expect an outbound request to /state_ids, so stub that out
-        self.mock_federation_transport_client.get_room_state_ids.return_value = (
-            make_awaitable(
-                {
-                    "pdu_ids": [e.event_id for e in state_at_prev_event],
-                    "auth_chain_ids": [],
-                }
-            )
-        )
+        self.mock_federation_transport_client.get_room_state_ids.return_value = {
+            "pdu_ids": [e.event_id for e in state_at_prev_event],
+            "auth_chain_ids": [],
+        }
 
         # we also expect an outbound request to /state
         self.mock_federation_transport_client.get_room_state.return_value = (
-            make_awaitable(
-                StateRequestResponse(auth_events=[], state=state_at_prev_event)
-            )
+            StateRequestResponse(auth_events=[], state=state_at_prev_event)
         )
 
         # we have to bump the clock a bit, to keep the retry logic in
@@ -273,26 +271,23 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
         room_version = self.get_success(main_store.get_room_version(room_id))
 
         # We expect an outbound request to /state_ids, so stub that out
-        self.mock_federation_transport_client.get_room_state_ids.return_value = make_awaitable(
-            {
-                # Mimic the other server not knowing about the state at all.
-                # We want to cause Synapse to throw an error (`Unable to get
-                # missing prev_event $fake_prev_event`) and fail to backfill
-                # the pulled event.
-                "pdu_ids": [],
-                "auth_chain_ids": [],
-            }
-        )
+        self.mock_federation_transport_client.get_room_state_ids.return_value = {
+            # Mimic the other server not knowing about the state at all.
+            # We want to cause Synapse to throw an error (`Unable to get
+            # missing prev_event $fake_prev_event`) and fail to backfill
+            # the pulled event.
+            "pdu_ids": [],
+            "auth_chain_ids": [],
+        }
+
         # We also expect an outbound request to /state
-        self.mock_federation_transport_client.get_room_state.return_value = make_awaitable(
-            StateRequestResponse(
-                # Mimic the other server not knowing about the state at all.
-                # We want to cause Synapse to throw an error (`Unable to get
-                # missing prev_event $fake_prev_event`) and fail to backfill
-                # the pulled event.
-                auth_events=[],
-                state=[],
-            )
+        self.mock_federation_transport_client.get_room_state.return_value = StateRequestResponse(
+            # Mimic the other server not knowing about the state at all.
+            # We want to cause Synapse to throw an error (`Unable to get
+            # missing prev_event $fake_prev_event`) and fail to backfill
+            # the pulled event.
+            auth_events=[],
+            state=[],
         )
 
         pulled_event = make_event_from_dict(
@@ -545,25 +540,23 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
         )
 
         # We expect an outbound request to /backfill, so stub that out
-        self.mock_federation_transport_client.backfill.return_value = make_awaitable(
-            {
-                "origin": self.OTHER_SERVER_NAME,
-                "origin_server_ts": 123,
-                "pdus": [
-                    # This is one of the important aspects of this test: we include
-                    # `pulled_event_without_signatures` so it fails the signature check
-                    # when we filter down the backfill response down to events which
-                    # have valid signatures in
-                    # `_check_sigs_and_hash_for_pulled_events_and_fetch`
-                    pulled_event_without_signatures.get_pdu_json(),
-                    # Then later when we process this valid signature event, when we
-                    # fetch the missing `prev_event`s, we want to make sure that we
-                    # backoff and don't try and fetch `pulled_event_without_signatures`
-                    # again since we know it just had an invalid signature.
-                    pulled_event.get_pdu_json(),
-                ],
-            }
-        )
+        self.mock_federation_transport_client.backfill.return_value = {
+            "origin": self.OTHER_SERVER_NAME,
+            "origin_server_ts": 123,
+            "pdus": [
+                # This is one of the important aspects of this test: we include
+                # `pulled_event_without_signatures` so it fails the signature check
+                # when we filter down the backfill response down to events which
+                # have valid signatures in
+                # `_check_sigs_and_hash_for_pulled_events_and_fetch`
+                pulled_event_without_signatures.get_pdu_json(),
+                # Then later when we process this valid signature event, when we
+                # fetch the missing `prev_event`s, we want to make sure that we
+                # backoff and don't try and fetch `pulled_event_without_signatures`
+                # again since we know it just had an invalid signature.
+                pulled_event.get_pdu_json(),
+            ],
+        }
 
         # Keep track of the count and make sure we don't make any of these requests
         event_endpoint_requested_count = 0
@@ -731,15 +724,13 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
         )
 
         # We expect an outbound request to /backfill, so stub that out
-        self.mock_federation_transport_client.backfill.return_value = make_awaitable(
-            {
-                "origin": self.OTHER_SERVER_NAME,
-                "origin_server_ts": 123,
-                "pdus": [
-                    pulled_event.get_pdu_json(),
-                ],
-            }
-        )
+        self.mock_federation_transport_client.backfill.return_value = {
+            "origin": self.OTHER_SERVER_NAME,
+            "origin_server_ts": 123,
+            "pdus": [
+                pulled_event.get_pdu_json(),
+            ],
+        }
 
         # The function under test: try to backfill and process the pulled event
         with LoggingContext("test"):
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 9691d66b48..1c5897c84e 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -46,18 +46,11 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
         self._persist_event_storage_controller = persistence
 
         self.user_id = self.register_user("tester", "foobar")
-        self.access_token = self.login("tester", "foobar")
-        self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
-
-        info = self.get_success(
-            self.hs.get_datastores().main.get_user_by_access_token(
-                self.access_token,
-            )
-        )
-        assert info is not None
-        self.token_id = info.token_id
+        device_id = "dev-1"
+        access_token = self.login("tester", "foobar", device_id=device_id)
+        self.room_id = self.helper.create_room_as(self.user_id, tok=access_token)
 
-        self.requester = create_requester(self.user_id, access_token_id=self.token_id)
+        self.requester = create_requester(self.user_id, device_id=device_id)
 
     def _create_and_persist_member_event(self) -> Tuple[EventBase, EventContext]:
         # Create a member event we can use as an auth_event
diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py
index 503277cdff..a72ecfdc97 100644
--- a/tests/handlers/test_oauth_delegation.py
+++ b/tests/handlers/test_oauth_delegation.py
@@ -14,7 +14,7 @@
 
 from http import HTTPStatus
 from typing import Any, Dict, Union
-from unittest.mock import ANY, Mock
+from unittest.mock import ANY, AsyncMock, Mock
 from urllib.parse import parse_qs
 
 from signedjson.key import (
@@ -39,7 +39,7 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
 
-from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
+from tests.test_utils import FakeResponse, get_awaitable_result
 from tests.unittest import HomeserverTestCase, skip_unless
 from tests.utils import mock_getRawHeaders
 
@@ -148,7 +148,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_inactive_token(self) -> None:
         """The handler should return a 403 where the token is inactive."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={"active": False},
@@ -167,7 +167,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_no_scope(self) -> None:
         """The handler should return a 403 where no scope is given."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={"active": True},
@@ -186,7 +186,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_user_no_subject(self) -> None:
         """The handler should return a 500 when no subject is present."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])},
@@ -205,7 +205,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_no_user_scope(self) -> None:
         """The handler should return a 500 when no subject is present."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -228,7 +228,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_admin_not_user(self) -> None:
         """The handler should raise when the scope has admin right but not user."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -252,7 +252,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_admin(self) -> None:
         """The handler should return a requester with admin rights."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -282,7 +282,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_admin_highest_privilege(self) -> None:
         """The handler should resolve to the most permissive scope."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -314,7 +314,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_user(self) -> None:
         """The handler should return a requester with normal user rights."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -344,7 +344,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_user_with_device(self) -> None:
         """The handler should return a requester with normal user rights and a device ID."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -374,7 +374,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_multiple_devices(self) -> None:
         """The handler should raise an error if multiple devices are found in the scope."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -399,7 +399,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_guest_not_allowed(self) -> None:
         """The handler should return an insufficient scope error."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -429,7 +429,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_guest_allowed(self) -> None:
         """The handler should return a requester with guest user rights and a device ID."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -465,19 +465,19 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
 
         # The introspection endpoint is returning an error.
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse(code=500, body=b"Internal Server Error")
         )
         error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
         self.assertEqual(error.value.code, 503)
 
         # The introspection endpoint request fails.
-        self.http_client.request = simple_async_mock(raises=Exception())
+        self.http_client.request = AsyncMock(side_effect=Exception())
         error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
         self.assertEqual(error.value.code, 503)
 
         # The introspection endpoint does not return a JSON object.
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200, payload=["this is an array", "not an object"]
             )
@@ -486,7 +486,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
         self.assertEqual(error.value.code, 503)
 
         # The introspection endpoint does not return valid JSON.
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse(code=200, body=b"this is not valid JSON")
         )
         error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
@@ -512,7 +512,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_cross_signing(self) -> None:
         """Try uploading device keys with OAuth delegation enabled."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -666,7 +666,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
 
     def test_admin_token(self) -> None:
         """The handler should return a requester with admin rights when admin_token is used."""
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(code=200, payload={"active": False}),
         )
 
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 0a8bae54fb..e797aaae00 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 import os
 from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple
-from unittest.mock import ANY, Mock, patch
+from unittest.mock import ANY, AsyncMock, Mock, patch
 from urllib.parse import parse_qs, urlparse
 
 import pymacaroons
@@ -28,7 +28,7 @@ from synapse.util import Clock
 from synapse.util.macaroons import get_value_from_macaroon
 from synapse.util.stringutils import random_string
 
-from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
+from tests.test_utils import FakeResponse, get_awaitable_result
 from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
 from tests.unittest import HomeserverTestCase, override_config
 
@@ -157,15 +157,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
         sso_handler = hs.get_sso_handler()
         # Mock the render error method.
         self.render_error = Mock(return_value=None)
-        sso_handler.render_error = self.render_error  # type: ignore[assignment]
+        sso_handler.render_error = self.render_error  # type: ignore[method-assign]
 
         # Reduce the number of attempts when generating MXIDs.
         sso_handler._MAP_USERNAME_RETRIES = 3
 
         auth_handler = hs.get_auth_handler()
         # Mock the complete SSO login method.
-        self.complete_sso_login = simple_async_mock()
-        auth_handler.complete_sso_login = self.complete_sso_login  # type: ignore[assignment]
+        self.complete_sso_login = AsyncMock()
+        auth_handler.complete_sso_login = self.complete_sso_login  # type: ignore[method-assign]
 
         return hs
 
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 394006f5f3..11ec8c7f11 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -16,7 +16,7 @@
 
 from http import HTTPStatus
 from typing import Any, Dict, List, Optional, Type, Union
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -32,7 +32,6 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeChannel
-from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
 # Login flows we expect to appear in the list after the normal ones.
@@ -187,7 +186,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
 
         # check_password must return an awaitable
-        mock_password_provider.check_password.return_value = make_awaitable(True)
+        mock_password_provider.check_password = AsyncMock(return_value=True)
         channel = self._send_password_login("u", "p")
         self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
         self.assertEqual("@u:test", channel.json_body["user_id"])
@@ -209,13 +208,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         """UI Auth should delegate correctly to the password provider"""
 
         # log in twice, to get two devices
-        mock_password_provider.check_password.return_value = make_awaitable(True)
+        mock_password_provider.check_password = AsyncMock(return_value=True)
         tok1 = self.login("u", "p")
         self.login("u", "p", device_id="dev2")
         mock_password_provider.reset_mock()
 
         # have the auth provider deny the request to start with
-        mock_password_provider.check_password.return_value = make_awaitable(False)
+        mock_password_provider.check_password = AsyncMock(return_value=False)
 
         # make the initial request which returns a 401
         session = self._start_delete_device_session(tok1, "dev2")
@@ -229,7 +228,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.reset_mock()
 
         # Finally, check the request goes through when we allow it
-        mock_password_provider.check_password.return_value = make_awaitable(True)
+        mock_password_provider.check_password = AsyncMock(return_value=True)
         channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
         self.assertEqual(channel.code, 200)
         mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@@ -243,7 +242,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.register_user("localuser", "localpass")
 
         # check_password must return an awaitable
-        mock_password_provider.check_password.return_value = make_awaitable(False)
+        mock_password_provider.check_password = AsyncMock(return_value=False)
         channel = self._send_password_login("u", "p")
         self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
 
@@ -260,7 +259,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.register_user("localuser", "localpass")
 
         # have the auth provider deny the request
-        mock_password_provider.check_password.return_value = make_awaitable(False)
+        mock_password_provider.check_password = AsyncMock(return_value=False)
 
         # log in twice, to get two devices
         tok1 = self.login("localuser", "localpass")
@@ -303,7 +302,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.register_user("localuser", "localpass")
 
         # check_password must return an awaitable
-        mock_password_provider.check_password.return_value = make_awaitable(False)
+        mock_password_provider.check_password = AsyncMock(return_value=False)
         channel = self._send_password_login("localuser", "localpass")
         self.assertEqual(channel.code, 403)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -325,7 +324,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.register_user("localuser", "localpass")
 
         # allow login via the auth provider
-        mock_password_provider.check_password.return_value = make_awaitable(True)
+        mock_password_provider.check_password = AsyncMock(return_value=True)
 
         # log in twice, to get two devices
         tok1 = self.login("localuser", "p")
@@ -342,7 +341,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.check_password.assert_not_called()
 
         # now try deleting with the local password
-        mock_password_provider.check_password.return_value = make_awaitable(False)
+        mock_password_provider.check_password = AsyncMock(return_value=False)
         channel = self._authed_delete_device(
             tok1, "dev2", session, "localuser", "localpass"
         )
@@ -396,9 +395,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
         mock_password_provider.check_auth.assert_not_called()
 
-        mock_password_provider.check_auth.return_value = make_awaitable(
-            ("@user:test", None)
-        )
+        mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None))
         channel = self._send_login("test.login_type", "u", test_field="y")
         self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
         self.assertEqual("@user:test", channel.json_body["user_id"])
@@ -447,9 +444,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.reset_mock()
 
         # right params, but authing as the wrong user
-        mock_password_provider.check_auth.return_value = make_awaitable(
-            ("@user:test", None)
-        )
+        mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None))
         body["auth"]["test_field"] = "foo"
         channel = self._delete_device(tok1, "dev2", body)
         self.assertEqual(channel.code, 403)
@@ -460,8 +455,8 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.reset_mock()
 
         # and finally, succeed
-        mock_password_provider.check_auth.return_value = make_awaitable(
-            ("@localuser:test", None)
+        mock_password_provider.check_auth = AsyncMock(
+            return_value=("@localuser:test", None)
         )
         channel = self._delete_device(tok1, "dev2", body)
         self.assertEqual(channel.code, 200)
@@ -478,10 +473,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.custom_auth_provider_callback_test_body()
 
     def custom_auth_provider_callback_test_body(self) -> None:
-        callback = Mock(return_value=make_awaitable(None))
+        callback = AsyncMock(return_value=None)
 
-        mock_password_provider.check_auth.return_value = make_awaitable(
-            ("@user:test", callback)
+        mock_password_provider.check_auth = AsyncMock(
+            return_value=("@user:test", callback)
         )
         channel = self._send_login("test.login_type", "u", test_field="y")
         self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
@@ -616,8 +611,8 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         login is disabled"""
         # register the user and log in twice via the test login type to get two devices,
         self.register_user("localuser", "localpass")
-        mock_password_provider.check_auth.return_value = make_awaitable(
-            ("@localuser:test", None)
+        mock_password_provider.check_auth = AsyncMock(
+            return_value=("@localuser:test", None)
         )
         channel = self._send_login("test.login_type", "localuser", test_field="")
         self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
@@ -835,11 +830,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             username: The username to use for the test.
             registration: Whether to test with registration URLs.
         """
-        self.hs.get_identity_handler().send_threepid_validation = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(0),
+        self.hs.get_identity_handler().send_threepid_validation = AsyncMock(  # type: ignore[method-assign]
+            return_value=0
         )
 
-        m = Mock(return_value=make_awaitable(False))
+        m = AsyncMock(return_value=False)
         self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
 
         self.register_user(username, "password")
@@ -869,7 +864,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
         m.assert_called_once_with("email", "foo@test.com", registration)
 
-        m = Mock(return_value=make_awaitable(True))
+        m = AsyncMock(return_value=True)
         self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
 
         channel = self.make_request(
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 1aebcc16ad..88a16193a3 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -524,6 +524,7 @@ class PresenceHandlerInitTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.user_id = f"@test:{self.hs.config.server.server_name}"
+        self.device_id = "dev-1"
 
         # Move the reactor to the initial time.
         self.reactor.advance(1000)
@@ -608,7 +609,10 @@ class PresenceHandlerInitTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(SYNC_ONLINE_TIMEOUT / 1000 / 2)
         self.get_success(
             presence_handler.user_syncing(
-                self.user_id, sync_state != PresenceState.OFFLINE, sync_state
+                self.user_id,
+                self.device_id,
+                sync_state != PresenceState.OFFLINE,
+                sync_state,
             )
         )
 
@@ -632,6 +636,7 @@ class PresenceHandlerInitTestCase(unittest.HomeserverTestCase):
 class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
     user_id = "@test:server"
     user_id_obj = UserID.from_string(user_id)
+    device_id = "dev-1"
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.presence_handler = hs.get_presence_handler()
@@ -641,13 +646,20 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         """Test that if an external process doesn't update the records for a while
         we time out their syncing users presence.
         """
-        process_id = "1"
 
-        # Notify handler that a user is now syncing.
+        # Create a worker and use it to handle /sync traffic instead.
+        # This is used to test that presence changes get replicated from workers
+        # to the main process correctly.
+        worker_to_sync_against = self.make_worker_hs(
+            "synapse.app.generic_worker", {"worker_name": "synchrotron"}
+        )
+        worker_presence_handler = worker_to_sync_against.get_presence_handler()
+
         self.get_success(
-            self.presence_handler.update_external_syncs_row(
-                process_id, self.user_id, True, self.clock.time_msec()
-            )
+            worker_presence_handler.user_syncing(
+                self.user_id, self.device_id, True, PresenceState.ONLINE
+            ),
+            by=0.1,
         )
 
         # Check that if we wait a while without telling the handler the user has
@@ -701,7 +713,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         # Mark user as offline
         self.get_success(
             self.presence_handler.set_state(
-                self.user_id_obj, {"presence": PresenceState.OFFLINE}
+                self.user_id_obj, self.device_id, {"presence": PresenceState.OFFLINE}
             )
         )
 
@@ -733,7 +745,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         # Mark user as online again
         self.get_success(
             self.presence_handler.set_state(
-                self.user_id_obj, {"presence": PresenceState.ONLINE}
+                self.user_id_obj, self.device_id, {"presence": PresenceState.ONLINE}
             )
         )
 
@@ -762,7 +774,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
 
         self.get_success(
             self.presence_handler.user_syncing(
-                self.user_id, False, PresenceState.ONLINE
+                self.user_id, self.device_id, False, PresenceState.ONLINE
             )
         )
 
@@ -779,7 +791,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg)
 
         self.get_success(
-            self.presence_handler.user_syncing(self.user_id, True, PresenceState.ONLINE)
+            self.presence_handler.user_syncing(
+                self.user_id, self.device_id, True, PresenceState.ONLINE
+            )
         )
 
         state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
@@ -793,7 +807,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg)
 
         self.get_success(
-            self.presence_handler.user_syncing(self.user_id, True, PresenceState.ONLINE)
+            self.presence_handler.user_syncing(
+                self.user_id, self.device_id, True, PresenceState.ONLINE
+            )
         )
 
         state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
@@ -820,7 +836,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
             # This is used to test that presence changes get replicated from workers
             # to the main process correctly.
             worker_to_sync_against = self.make_worker_hs(
-                "synapse.app.generic_worker", {"worker_name": "presence_writer"}
+                "synapse.app.generic_worker", {"worker_name": "synchrotron"}
             )
 
         # Set presence to BUSY
@@ -831,8 +847,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         # /presence/*.
         self.get_success(
             worker_to_sync_against.get_presence_handler().user_syncing(
-                self.user_id, True, PresenceState.ONLINE
-            )
+                self.user_id, self.device_id, True, PresenceState.ONLINE
+            ),
+            by=0.1,
         )
 
         # Check against the main process that the user's presence did not change.
@@ -840,6 +857,21 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         # we should still be busy
         self.assertEqual(state.state, PresenceState.BUSY)
 
+        # Advance such that the device would be discarded if it was not busy,
+        # then pump so _handle_timeouts function to called.
+        self.reactor.advance(IDLE_TIMER / 1000)
+        self.reactor.pump([5])
+
+        # The account should still be busy.
+        state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
+        self.assertEqual(state.state, PresenceState.BUSY)
+
+        # Ensure that a /presence call can set the user *off* busy.
+        self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg)
+
+        state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
+        self.assertEqual(state.state, PresenceState.ONLINE)
+
     def _set_presencestate_with_status_msg(
         self, state: str, status_msg: Optional[str]
     ) -> None:
@@ -852,6 +884,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
         self.get_success(
             self.presence_handler.set_state(
                 self.user_id_obj,
+                self.device_id,
                 {"presence": state, "status_msg": status_msg},
             )
         )
@@ -876,8 +909,14 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
 
         prev_token = self.queue.get_current_token(self.instance_name)
 
-        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
-        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        self.get_success(
+            self.queue.send_presence_to_destinations(
+                (state1, state2), ("dest1", "dest2")
+            )
+        )
+        self.get_success(
+            self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        )
 
         now_token = self.queue.get_current_token(self.instance_name)
 
@@ -913,11 +952,17 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
 
         prev_token = self.queue.get_current_token(self.instance_name)
 
-        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+        self.get_success(
+            self.queue.send_presence_to_destinations(
+                (state1, state2), ("dest1", "dest2")
+            )
+        )
 
         now_token = self.queue.get_current_token(self.instance_name)
 
-        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        self.get_success(
+            self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        )
 
         rows, upto_token, limited = self.get_success(
             self.queue.get_replication_rows("master", prev_token, now_token, 10)
@@ -956,8 +1001,14 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
 
         prev_token = self.queue.get_current_token(self.instance_name)
 
-        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
-        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        self.get_success(
+            self.queue.send_presence_to_destinations(
+                (state1, state2), ("dest1", "dest2")
+            )
+        )
+        self.get_success(
+            self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        )
 
         self.reactor.advance(10 * 60 * 1000)
 
@@ -972,8 +1023,14 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
 
         prev_token = self.queue.get_current_token(self.instance_name)
 
-        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
-        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        self.get_success(
+            self.queue.send_presence_to_destinations(
+                (state1, state2), ("dest1", "dest2")
+            )
+        )
+        self.get_success(
+            self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        )
 
         now_token = self.queue.get_current_token(self.instance_name)
 
@@ -1000,11 +1057,17 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
 
         prev_token = self.queue.get_current_token(self.instance_name)
 
-        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+        self.get_success(
+            self.queue.send_presence_to_destinations(
+                (state1, state2), ("dest1", "dest2")
+            )
+        )
 
         self.reactor.advance(2 * 60 * 1000)
 
-        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        self.get_success(
+            self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        )
 
         self.reactor.advance(4 * 60 * 1000)
 
@@ -1020,8 +1083,14 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
 
         prev_token = self.queue.get_current_token(self.instance_name)
 
-        self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
-        self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        self.get_success(
+            self.queue.send_presence_to_destinations(
+                (state1, state2), ("dest1", "dest2")
+            )
+        )
+        self.get_success(
+            self.queue.send_presence_to_destinations((state3,), ("dest3",))
+        )
 
         now_token = self.queue.get_current_token(self.instance_name)
 
@@ -1093,7 +1162,9 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         # Mark test2 as online, test will be offline with a last_active of 0
         self.get_success(
             self.presence_handler.set_state(
-                UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+                UserID.from_string("@test2:server"),
+                "dev-1",
+                {"presence": PresenceState.ONLINE},
             )
         )
         self.reactor.pump([0])  # Wait for presence updates to be handled
@@ -1140,7 +1211,9 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         # Mark test as online
         self.get_success(
             self.presence_handler.set_state(
-                UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}
+                UserID.from_string("@test:server"),
+                "dev-1",
+                {"presence": PresenceState.ONLINE},
             )
         )
 
@@ -1148,7 +1221,9 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         # Note we don't join them to the room yet
         self.get_success(
             self.presence_handler.set_state(
-                UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+                UserID.from_string("@test2:server"),
+                "dev-1",
+                {"presence": PresenceState.ONLINE},
             )
         )
 
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index ec2f5d30be..f9b292b9ec 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Any, Awaitable, Callable, Dict
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from parameterized import parameterized
 
@@ -26,7 +26,6 @@ from synapse.types import JsonDict, UserID
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class ProfileTestCase(unittest.HomeserverTestCase):
@@ -35,7 +34,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
     servlets = [admin.register_servlets]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        self.mock_federation = Mock()
+        self.mock_federation = AsyncMock()
         self.mock_registry = Mock()
 
         self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
@@ -135,9 +134,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
 
     def test_get_other_name(self) -> None:
-        self.mock_federation.make_query.return_value = make_awaitable(
-            {"displayname": "Alice"}
-        )
+        self.mock_federation.make_query.return_value = {"displayname": "Alice"}
 
         displayname = self.get_success(self.handler.get_displayname(self.alice))
 
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 54eeec228e..e9fbf32c7c 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from typing import Any, Collection, List, Optional, Tuple
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -38,7 +38,6 @@ from synapse.types import (
 )
 from synapse.util import Clock
 
-from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 from tests.utils import mock_getRawHeaders
 
@@ -203,24 +202,22 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": True})
     def test_get_or_create_user_mau_not_blocked(self) -> None:
-        self.store.count_monthly_users = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
+        self.store.count_monthly_users = AsyncMock(  # type: ignore[method-assign]
+            return_value=self.hs.config.server.max_mau_value - 1
         )
         # Ensure does not throw exception
         self.get_success(self.get_or_create_user(self.requester, "c", "User"))
 
     @override_config({"limit_usage_by_mau": True})
     def test_get_or_create_user_mau_blocked(self) -> None:
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.lots_of_users)
-        )
+        self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users)
         self.get_failure(
             self.get_or_create_user(self.requester, "b", "display_name"),
             ResourceLimitError,
         )
 
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.hs.config.server.max_mau_value)
+        self.store.get_monthly_active_count = AsyncMock(
+            return_value=self.hs.config.server.max_mau_value
         )
         self.get_failure(
             self.get_or_create_user(self.requester, "b", "display_name"),
@@ -229,15 +226,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": True})
     def test_register_mau_blocked(self) -> None:
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.lots_of_users)
-        )
+        self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users)
         self.get_failure(
             self.handler.register_user(localpart="local_part"), ResourceLimitError
         )
 
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.hs.config.server.max_mau_value)
+        self.store.get_monthly_active_count = AsyncMock(
+            return_value=self.hs.config.server.max_mau_value
         )
         self.get_failure(
             self.handler.register_user(localpart="local_part"), ResourceLimitError
@@ -292,7 +287,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     @override_config({"auto_join_rooms": ["#room:test"]})
     def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None:
         room_alias_str = "#room:test"
-        self.store.is_real_user = Mock(return_value=make_awaitable(False))
+        self.store.is_real_user = AsyncMock(return_value=False)
         user_id = self.get_success(self.handler.register_user(localpart="support"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
@@ -304,8 +299,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
         room_alias_str = "#room:test"
 
-        self.store.count_real_users = Mock(return_value=make_awaitable(1))  # type: ignore[assignment]
-        self.store.is_real_user = Mock(return_value=make_awaitable(True))
+        self.store.count_real_users = AsyncMock(return_value=1)  # type: ignore[method-assign]
+        self.store.is_real_user = AsyncMock(return_value=True)
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         directory_handler = self.hs.get_directory_handler()
@@ -319,8 +314,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
         self,
     ) -> None:
-        self.store.count_real_users = Mock(return_value=make_awaitable(2))  # type: ignore[assignment]
-        self.store.is_real_user = Mock(return_value=make_awaitable(True))
+        self.store.count_real_users = AsyncMock(return_value=2)  # type: ignore[method-assign]
+        self.store.is_real_user = AsyncMock(return_value=True)
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index 41199ffa29..3e28117e2c 100644
--- a/tests/handlers/test_room_member.py
+++ b/tests/handlers/test_room_member.py
@@ -1,4 +1,4 @@
-from unittest.mock import Mock, patch
+from unittest.mock import AsyncMock, patch
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -16,7 +16,6 @@ from synapse.util import Clock
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.server import make_request
-from tests.test_utils import make_awaitable
 from tests.unittest import (
     FederatingHomeserverTestCase,
     HomeserverTestCase,
@@ -154,25 +153,21 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
             None,
         )
 
-        mock_make_membership_event = Mock(
-            return_value=make_awaitable(
-                (
-                    self.OTHER_SERVER_NAME,
-                    join_event,
-                    self.hs.config.server.default_room_version,
-                )
+        mock_make_membership_event = AsyncMock(
+            return_value=(
+                self.OTHER_SERVER_NAME,
+                join_event,
+                self.hs.config.server.default_room_version,
             )
         )
-        mock_send_join = Mock(
-            return_value=make_awaitable(
-                SendJoinResult(
-                    join_event,
-                    self.OTHER_SERVER_NAME,
-                    state=[create_event],
-                    auth_chain=[create_event],
-                    partial_state=False,
-                    servers_in_room=frozenset(),
-                )
+        mock_send_join = AsyncMock(
+            return_value=SendJoinResult(
+                join_event,
+                self.OTHER_SERVER_NAME,
+                state=[create_event],
+                auth_chain=[create_event],
+                partial_state=False,
+                servers_in_room=frozenset(),
             )
         )
 
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index b5c772a7ae..00f4e181e8 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -13,7 +13,7 @@
 #  limitations under the License.
 
 from typing import Any, Dict, Optional, Set, Tuple
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 import attr
 
@@ -25,7 +25,6 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
 
-from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
 # Check if we have the dependencies to run the tests.
@@ -134,7 +133,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
 
         # send a mocked-up SAML response to the callback
         saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
@@ -164,7 +163,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
 
         # Map a user via SSO.
         saml_response = FakeAuthnResponse(
@@ -206,11 +205,11 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
 
         # mock out the error renderer too
         sso_handler = self.hs.get_sso_handler()
-        sso_handler.render_error = Mock(return_value=None)  # type: ignore[assignment]
+        sso_handler.render_error = Mock(return_value=None)  # type: ignore[method-assign]
 
         saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
         request = _mock_request()
@@ -227,9 +226,9 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler and error renderer
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
         sso_handler = self.hs.get_sso_handler()
-        sso_handler.render_error = Mock(return_value=None)  # type: ignore[assignment]
+        sso_handler.render_error = Mock(return_value=None)  # type: ignore[method-assign]
 
         # register a user to occupy the first-choice MXID
         store = self.hs.get_datastores().main
@@ -312,7 +311,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[method-assign]
 
         # The response doesn't have the proper userGroup or department.
         saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py
index 8b6e4a40b6..a066745d70 100644
--- a/tests/handlers/test_send_email.py
+++ b/tests/handlers/test_send_email.py
@@ -13,19 +13,40 @@
 # limitations under the License.
 
 
-from typing import Callable, List, Tuple
+from typing import Callable, List, Tuple, Type, Union
+from unittest.mock import patch
 
 from zope.interface import implementer
 
 from twisted.internet import defer
-from twisted.internet.address import IPv4Address
+from twisted.internet._sslverify import ClientTLSOptions
+from twisted.internet.address import IPv4Address, IPv6Address
 from twisted.internet.defer import ensureDeferred
+from twisted.internet.interfaces import IProtocolFactory
+from twisted.internet.ssl import ContextFactory
 from twisted.mail import interfaces, smtp
 
 from tests.server import FakeTransport
 from tests.unittest import HomeserverTestCase, override_config
 
 
+def TestingESMTPTLSClientFactory(
+    contextFactory: ContextFactory,
+    _connectWrapped: bool,
+    wrappedProtocol: IProtocolFactory,
+) -> IProtocolFactory:
+    """We use this to pass through in testing without using TLS, but
+    saving the context information to check that it would have happened.
+
+    Note that this is what the MemoryReactor does on connectSSL.
+    It only saves the contextFactory, but starts the connection with the
+    underlying Factory.
+    See: L{twisted.internet.testing.MemoryReactor.connectSSL}"""
+
+    wrappedProtocol._testingContextFactory = contextFactory  # type: ignore[attr-defined]
+    return wrappedProtocol
+
+
 @implementer(interfaces.IMessageDelivery)
 class _DummyMessageDelivery:
     def __init__(self) -> None:
@@ -75,7 +96,13 @@ class _DummyMessage:
         pass
 
 
-class SendEmailHandlerTestCase(HomeserverTestCase):
+class SendEmailHandlerTestCaseIPv4(HomeserverTestCase):
+    ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address
+
+    def setUp(self) -> None:
+        super().setUp()
+        self.reactor.lookups["localhost"] = "127.0.0.1"
+
     def test_send_email(self) -> None:
         """Happy-path test that we can send email to a non-TLS server."""
         h = self.hs.get_send_email_handler()
@@ -89,7 +116,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
         (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[
             0
         ]
-        self.assertEqual(host, "localhost")
+        self.assertEqual(host, self.reactor.lookups["localhost"])
         self.assertEqual(port, 25)
 
         # wire it up to an SMTP server
@@ -105,7 +132,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
             FakeTransport(
                 client_protocol,
                 self.reactor,
-                peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
+                peer_address=self.ip_class(
+                    "TCP", self.reactor.lookups["localhost"], 1234
+                ),
             )
         )
 
@@ -118,6 +147,10 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
         self.assertEqual(str(user), "foo@bar.com")
         self.assertIn(b"Subject: test subject", msg)
 
+    @patch(
+        "synapse.handlers.send_email.TLSMemoryBIOFactory",
+        TestingESMTPTLSClientFactory,
+    )
     @override_config(
         {
             "email": {
@@ -135,17 +168,23 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
             )
         )
         # there should be an attempt to connect to localhost:465
-        self.assertEqual(len(self.reactor.sslClients), 1)
+        self.assertEqual(len(self.reactor.tcpClients), 1)
         (
             host,
             port,
             client_factory,
-            contextFactory,
             _timeout,
             _bindAddress,
-        ) = self.reactor.sslClients[0]
-        self.assertEqual(host, "localhost")
+        ) = self.reactor.tcpClients[0]
+        self.assertEqual(host, self.reactor.lookups["localhost"])
         self.assertEqual(port, 465)
+        # We need to make sure that TLS is happenning
+        self.assertIsInstance(
+            client_factory._wrappedFactory._testingContextFactory,
+            ClientTLSOptions,
+        )
+        # And since we use endpoints, they go through reactor.connectTCP
+        # which works differently to connectSSL on the testing reactor
 
         # wire it up to an SMTP server
         message_delivery = _DummyMessageDelivery()
@@ -160,7 +199,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
             FakeTransport(
                 client_protocol,
                 self.reactor,
-                peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
+                peer_address=self.ip_class(
+                    "TCP", self.reactor.lookups["localhost"], 1234
+                ),
             )
         )
 
@@ -172,3 +213,11 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
         user, msg = message_delivery.messages.pop()
         self.assertEqual(str(user), "foo@bar.com")
         self.assertIn(b"Subject: test subject", msg)
+
+
+class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4):
+    ip_class = IPv6Address
+
+    def setUp(self) -> None:
+        super().setUp()
+        self.reactor.lookups["localhost"] = "::1"
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 9f035a02dc..948d04fc32 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Optional
-from unittest.mock import MagicMock, Mock, patch
+from unittest.mock import AsyncMock, Mock, patch
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -29,7 +29,6 @@ from synapse.util import Clock
 
 import tests.unittest
 import tests.utils
-from tests.test_utils import make_awaitable
 
 
 class SyncTestCase(tests.unittest.HomeserverTestCase):
@@ -253,8 +252,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         mocked_get_prev_events = patch.object(
             self.hs.get_datastores().main,
             "get_prev_events_for_room",
-            new_callable=MagicMock,
-            return_value=make_awaitable([last_room_creation_event_id]),
+            new_callable=AsyncMock,
+            return_value=[last_room_creation_event_id],
         )
         with mocked_get_prev_events:
             self.helper.join(room_id, eve, tok=eve_token)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 5da1d95f0b..95106ec8f3 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -15,7 +15,7 @@
 
 import json
 from typing import Dict, List, Set
-from unittest.mock import ANY, Mock, call
+from unittest.mock import ANY, AsyncMock, Mock, call
 
 from netaddr import IPSet
 
@@ -33,7 +33,6 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.server import ThreadedMemoryReactorClock
-from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
 # Some local users to test with
@@ -74,11 +73,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         # we mock out the keyring so as to skip the authentication check on the
         # federation API call.
         mock_keyring = Mock(spec=["verify_json_for_server"])
-        mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
+        mock_keyring.verify_json_for_server = AsyncMock(return_value=True)
 
         # we mock out the federation client too
-        self.mock_federation_client = Mock(spec=["put_json"])
-        self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
+        self.mock_federation_client = AsyncMock(spec=["put_json"])
+        self.mock_federation_client.put_json.return_value = (200, "OK")
         self.mock_federation_client.agent = MatrixFederationAgent(
             reactor,
             tls_client_options_factory=None,
@@ -121,20 +120,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.datastore = hs.get_datastores().main
 
-        self.datastore.get_destination_retry_timings = Mock(
-            return_value=make_awaitable(None)
+        self.datastore.get_device_updates_by_remote = AsyncMock(  # type: ignore[method-assign]
+            return_value=(0, [])
         )
 
-        self.datastore.get_device_updates_by_remote = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable((0, []))
+        self.datastore.get_destination_last_successful_stream_ordering = AsyncMock(  # type: ignore[method-assign]
+            return_value=None
         )
 
-        self.datastore.get_destination_last_successful_stream_ordering = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None)
-        )
-
-        self.datastore.get_received_txn_response = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None)
+        self.datastore.get_received_txn_response = AsyncMock(  # type: ignore[method-assign]
+            return_value=None
         )
 
         self.room_members: List[UserID] = []
@@ -146,25 +141,25 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
                 raise AuthError(401, "User is not in the room")
             return None
 
-        hs.get_auth().check_user_in_room = Mock(  # type: ignore[assignment]
+        hs.get_auth().check_user_in_room = Mock(  # type: ignore[method-assign]
             side_effect=check_user_in_room
         )
 
         async def check_host_in_room(room_id: str, server_name: str) -> bool:
             return room_id == ROOM_ID
 
-        hs.get_event_auth_handler().is_host_in_room = Mock(  # type: ignore[assignment]
+        hs.get_event_auth_handler().is_host_in_room = Mock(  # type: ignore[method-assign]
             side_effect=check_host_in_room
         )
 
         async def get_current_hosts_in_room(room_id: str) -> Set[str]:
             return {member.domain for member in self.room_members}
 
-        hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(  # type: ignore[assignment]
+        hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(  # type: ignore[method-assign]
             side_effect=get_current_hosts_in_room
         )
 
-        hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock(  # type: ignore[assignment]
+        hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock(  # type: ignore[method-assign]
             side_effect=get_current_hosts_in_room
         )
 
@@ -173,27 +168,25 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.datastore.get_users_in_room = Mock(side_effect=get_users_in_room)
 
-        self.datastore.get_user_directory_stream_pos = Mock(  # type: ignore[assignment]
-            side_effect=(
-                # we deliberately return a non-None stream pos to avoid
-                # doing an initial_sync
-                lambda: make_awaitable(1)
-            )
+        self.datastore.get_user_directory_stream_pos = AsyncMock(  # type: ignore[method-assign]
+            # we deliberately return a non-None stream pos to avoid
+            # doing an initial_sync
+            return_value=1
         )
 
-        self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None))  # type: ignore[assignment]
+        self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None))  # type: ignore[method-assign]
 
-        self.datastore.get_to_device_stream_token = Mock(  # type: ignore[assignment]
-            side_effect=lambda: 0
+        self.datastore.get_to_device_stream_token = Mock(  # type: ignore[method-assign]
+            return_value=0
         )
-        self.datastore.get_new_device_msgs_for_remote = Mock(  # type: ignore[assignment]
-            side_effect=lambda *args, **kargs: make_awaitable(([], 0))
+        self.datastore.get_new_device_msgs_for_remote = AsyncMock(  # type: ignore[method-assign]
+            return_value=([], 0)
         )
-        self.datastore.delete_device_msgs_for_remote = Mock(  # type: ignore[assignment]
-            side_effect=lambda *args, **kargs: make_awaitable(None)
+        self.datastore.delete_device_msgs_for_remote = AsyncMock(  # type: ignore[method-assign]
+            return_value=None
         )
-        self.datastore.set_received_txn_response = Mock(  # type: ignore[assignment]
-            side_effect=lambda *args, **kwargs: make_awaitable(None)
+        self.datastore.set_received_txn_response = AsyncMock(  # type: ignore[method-assign]
+            return_value=None
         )
 
     def test_started_typing_local(self) -> None:
@@ -256,8 +249,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             ),
             json_data_callback=ANY,
             long_retries=True,
-            backoff_on_404=True,
             try_trailing_slash_on_400=True,
+            backoff_on_all_error_codes=True,
         )
 
     def test_started_typing_remote_recv(self) -> None:
@@ -371,7 +364,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             ),
             json_data_callback=ANY,
             long_retries=True,
-            backoff_on_404=True,
+            backoff_on_all_error_codes=True,
             try_trailing_slash_on_400=True,
         )
 
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 430209705e..b5f15aa7d4 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Any, Tuple
-from unittest.mock import Mock, patch
+from unittest.mock import AsyncMock, Mock, patch
 from urllib.parse import quote
 
 from twisted.test.proto_helpers import MemoryReactor
@@ -30,7 +30,7 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.storage.test_user_directory import GetUserDirectoryTables
-from tests.test_utils import event_injection, make_awaitable
+from tests.test_utils import event_injection
 from tests.test_utils.event_injection import inject_member_event
 from tests.unittest import override_config
 
@@ -471,7 +471,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             self.store.register_user(user_id=r_user_id, password_hash=None)
         )
 
-        mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
+        mock_remove_from_user_dir = AsyncMock(return_value=None)
         with patch.object(
             self.store, "remove_from_user_dir", mock_remove_from_user_dir
         ):
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 6a0b5fc0bd..0d17f2fe5b 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -14,8 +14,8 @@
 import base64
 import logging
 import os
-from typing import Any, Awaitable, Callable, Generator, List, Optional, cast
-from unittest.mock import Mock, patch
+from typing import Generator, List, Optional, cast
+from unittest.mock import AsyncMock, patch
 
 import treq
 from netaddr import IPSet
@@ -41,7 +41,7 @@ from twisted.web.iweb import IPolicyForHTTPS, IResponse
 from synapse.config.homeserver import HomeServerConfig
 from synapse.crypto.context_factory import FederationPolicyForHTTPS
 from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
-from synapse.http.federation.srv_resolver import Server
+from synapse.http.federation.srv_resolver import Server, SrvResolver
 from synapse.http.federation.well_known_resolver import (
     WELL_KNOWN_MAX_SIZE,
     WellKnownResolver,
@@ -68,21 +68,11 @@ from tests.utils import checked_cast, default_config
 logger = logging.getLogger(__name__)
 
 
-# Once Async Mocks or lambdas are supported this can go away.
-def generate_resolve_service(
-    result: List[Server],
-) -> Callable[[Any], Awaitable[List[Server]]]:
-    async def resolve_service(_: Any) -> List[Server]:
-        return result
-
-    return resolve_service
-
-
 class MatrixFederationAgentTests(unittest.TestCase):
     def setUp(self) -> None:
         self.reactor = ThreadedMemoryReactorClock()
 
-        self.mock_resolver = Mock()
+        self.mock_resolver = AsyncMock(spec=SrvResolver)
 
         config_dict = default_config("test", parse=False)
         config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()]
@@ -636,7 +626,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
         self.reactor.lookups["testserv1"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix-federation://testserv1/foo/bar")
@@ -722,7 +712,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
@@ -776,7 +766,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """Test the behaviour when the .well-known delegates elsewhere"""
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
         self.reactor.lookups["testserv"] = "1.2.3.4"
         self.reactor.lookups["target-server"] = "1::f"
 
@@ -840,7 +830,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
         self.reactor.lookups["testserv"] = "1.2.3.4"
         self.reactor.lookups["target-server"] = "1::f"
 
@@ -930,7 +920,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
@@ -986,7 +976,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         # the config left to the default, which will not trust it (since the
         # presented cert is signed by a test CA)
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
         config = default_config("test", parse=True)
@@ -1037,9 +1027,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
-            [Server(host=b"srvtarget", port=8443)]
-        )
+        self.mock_resolver.resolve_service.return_value = [
+            Server(host=b"srvtarget", port=8443)
+        ]
         self.reactor.lookups["srvtarget"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
@@ -1094,9 +1084,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 443)
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
-            [Server(host=b"srvtarget", port=8443)]
-        )
+        self.mock_resolver.resolve_service.return_value = [
+            Server(host=b"srvtarget", port=8443)
+        ]
 
         self._handle_well_known_connection(
             client_factory,
@@ -1137,7 +1127,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """test the behaviour when the server name has idna chars in"""
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
+        self.mock_resolver.resolve_service.return_value = []
 
         # the resolver is always called with the IDNA hostname as a native string.
         self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4"
@@ -1201,9 +1191,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """test the behaviour when the target of a SRV record has idna chars"""
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
-            [Server(host=b"xn--trget-3qa.com", port=8443)]  # târget.com
-        )
+        self.mock_resolver.resolve_service.return_value = [
+            Server(host=b"xn--trget-3qa.com", port=8443)
+        ]  # târget.com
         self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
 
         test_d = self._make_get_request(
@@ -1407,12 +1397,10 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """Test that other SRV results are tried if the first one fails."""
         self.agent = self._make_agent()
 
-        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
-            [
-                Server(host=b"target.com", port=8443),
-                Server(host=b"target.com", port=8444),
-            ]
-        )
+        self.mock_resolver.resolve_service.return_value = [
+            Server(host=b"target.com", port=8443),
+            Server(host=b"target.com", port=8444),
+        ]
         self.reactor.lookups["target.com"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index fa27f1279a..c379853e20 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -164,7 +164,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         # Call requestReceived to finish instantiating the object.
         request.content = BytesIO()
         # Partially skip some internal processing of SynapseRequest.
-        request._started_processing = Mock()  # type: ignore[assignment]
+        request._started_processing = Mock()  # type: ignore[method-assign]
         request.request_metrics = Mock(spec=["name"])
         with patch.object(Request, "render"):
             request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1")
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index fe631d7ecb..172fc3a736 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Any, Dict, Optional
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.internet import defer
 from twisted.test.proto_helpers import MemoryReactor
@@ -33,7 +33,6 @@ from synapse.util import Clock
 
 from tests.events.test_presence_router import send_presence_update, sync_presence
 from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.test_utils import simple_async_mock
 from tests.test_utils.event_injection import inject_member_event
 from tests.unittest import HomeserverTestCase, override_config
 
@@ -70,7 +69,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # Mock out the calls over federation.
         self.fed_transport_client = Mock(spec=["send_transaction"])
-        self.fed_transport_client.send_transaction = simple_async_mock({})
+        self.fed_transport_client.send_transaction = AsyncMock(return_value={})
 
         return self.setup_test_homeserver(
             federation_transport_client=self.fed_transport_client,
@@ -234,7 +233,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
     def test_sending_events_into_room(self) -> None:
         """Tests that a module can send events into a room"""
         # Mock out create_and_send_nonmember_event to check whether events are being sent
-        self.event_creation_handler.create_and_send_nonmember_event = Mock(  # type: ignore[assignment]
+        self.event_creation_handler.create_and_send_nonmember_event = Mock(  # type: ignore[method-assign]
             spec=[],
             side_effect=self.event_creation_handler.create_and_send_nonmember_event,
         )
@@ -579,10 +578,8 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
         """Test that the module API can join a remote room."""
         # Necessary to fake a remote join.
         fake_stream_id = 1
-        mocked_remote_join = simple_async_mock(
-            return_value=("fake-event-id", fake_stream_id)
-        )
-        self.hs.get_room_member_handler()._remote_join = mocked_remote_join  # type: ignore[assignment]
+        mocked_remote_join = AsyncMock(return_value=("fake-event-id", fake_stream_id))
+        self.hs.get_room_member_handler()._remote_join = mocked_remote_join  # type: ignore[method-assign]
         fake_remote_host = f"{self.module_api.server_name}-remote"
 
         # Given that the join is to be faked, we expect the relevant join event not to
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index 829b9df83d..7c23b77e0a 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from typing import Any, Optional
-from unittest.mock import patch
+from unittest.mock import AsyncMock, patch
 
 from parameterized import parameterized
 
@@ -28,7 +28,6 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
 
-from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
 
@@ -191,7 +190,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
         # Mock the method which calculates push rules -- we do this instead of
         # e.g. checking the results in the database because we want to ensure
         # that code isn't even running.
-        bulk_evaluator._action_for_event_by_user = simple_async_mock()  # type: ignore[assignment]
+        bulk_evaluator._action_for_event_by_user = AsyncMock()  # type: ignore[method-assign]
 
         # Ensure no actions are generated!
         self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
@@ -382,7 +381,6 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
             )
         )
 
-    @override_config({"experimental_features": {"msc3958_supress_edit_notifs": True}})
     def test_suppress_edits(self) -> None:
         """Under the default push rules, event edits should not generate notifications."""
         bulk_evaluator = BulkPushRuleEvaluator(self.hs)
diff --git a/tests/replication/storage/test_events.py b/tests/replication/storage/test_events.py
index f7c6417a09..af25815fa5 100644
--- a/tests/replication/storage/test_events.py
+++ b/tests/replication/storage/test_events.py
@@ -58,7 +58,7 @@ def patch__eq__(cls: object) -> Callable[[], None]:
 
     def unpatch() -> None:
         if eq is not None:
-            cls.__eq__ = eq  # type: ignore[assignment]
+            cls.__eq__ = eq  # type: ignore[method-assign]
 
     return unpatch
 
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index a324b4d31d..9b28cd474f 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from netaddr import IPSet
 
@@ -26,7 +26,6 @@ from synapse.types import UserID, create_requester
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.server import get_clock
-from tests.test_utils import make_awaitable
 
 logger = logging.getLogger(__name__)
 
@@ -62,7 +61,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         new event.
         """
         mock_client = Mock(spec=["put_json"])
-        mock_client.put_json.return_value = make_awaitable({})
+        mock_client.put_json = AsyncMock(return_value={})
         mock_client.agent = self.matrix_federation_agent
         self.make_worker_hs(
             "synapse.app.generic_worker",
@@ -93,7 +92,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         new events.
         """
         mock_client1 = Mock(spec=["put_json"])
-        mock_client1.put_json.return_value = make_awaitable({})
+        mock_client1.put_json = AsyncMock(return_value={})
         mock_client1.agent = self.matrix_federation_agent
         self.make_worker_hs(
             "synapse.app.generic_worker",
@@ -108,7 +107,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         mock_client2 = Mock(spec=["put_json"])
-        mock_client2.put_json.return_value = make_awaitable({})
+        mock_client2.put_json = AsyncMock(return_value={})
         mock_client2.agent = self.matrix_federation_agent
         self.make_worker_hs(
             "synapse.app.generic_worker",
@@ -162,7 +161,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         new typing EDUs.
         """
         mock_client1 = Mock(spec=["put_json"])
-        mock_client1.put_json.return_value = make_awaitable({})
+        mock_client1.put_json = AsyncMock(return_value={})
         mock_client1.agent = self.matrix_federation_agent
         self.make_worker_hs(
             "synapse.app.generic_worker",
@@ -177,7 +176,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         mock_client2 = Mock(spec=["put_json"])
-        mock_client2.put_json.return_value = make_awaitable({})
+        mock_client2.put_json = AsyncMock(return_value={})
         mock_client2.agent = self.matrix_federation_agent
         self.make_worker_hs(
             "synapse.app.generic_worker",
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index feb81844ae..761871b933 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -18,7 +18,7 @@ import os
 import urllib.parse
 from binascii import unhexlify
 from typing import List, Optional
-from unittest.mock import Mock, patch
+from unittest.mock import AsyncMock, Mock, patch
 
 from parameterized import parameterized, parameterized_class
 
@@ -40,12 +40,13 @@ from synapse.rest.client import (
     user_directory,
 )
 from synapse.server import HomeServer
+from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
 from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeSite, make_request
-from tests.test_utils import SMALL_PNG, make_awaitable
+from tests.test_utils import SMALL_PNG
 from tests.unittest import override_config
 
 
@@ -71,8 +72,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
 
         self.hs.config.registration.registration_shared_secret = "shared"
 
-        self.hs.get_media_repository = Mock()  # type: ignore[assignment]
-        self.hs.get_deactivate_account_handler = Mock()  # type: ignore[assignment]
+        self.hs.get_media_repository = Mock()  # type: ignore[method-assign]
+        self.hs.get_deactivate_account_handler = Mock()  # type: ignore[method-assign]
 
         return self.hs
 
@@ -419,8 +420,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         store = self.hs.get_datastores().main
 
         # Set monthly active users to the limit
-        store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.hs.config.server.max_mau_value)
+        store.get_monthly_active_count = AsyncMock(
+            return_value=self.hs.config.server.max_mau_value
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
@@ -456,6 +457,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
     servlets = [
         synapse.rest.admin.register_servlets,
         login.register_servlets,
+        room.register_servlets,
     ]
     url = "/_synapse/admin/v2/users"
 
@@ -506,6 +508,62 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         # Check that all fields are available
         self._check_fields(channel.json_body["users"])
 
+    def test_last_seen(self) -> None:
+        """
+        Test that last_seen_ts field is properly working.
+        """
+        user1 = self.register_user("u1", "pass")
+        user1_token = self.login("u1", "pass")
+        user2 = self.register_user("u2", "pass")
+        user2_token = self.login("u2", "pass")
+        user3 = self.register_user("u3", "pass")
+        user3_token = self.login("u3", "pass")
+
+        self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+        self.reactor.advance(10)
+        self.helper.create_room_as(user2, tok=user2_token)
+        self.reactor.advance(10)
+        self.helper.create_room_as(user1, tok=user1_token)
+        self.reactor.advance(10)
+        self.helper.create_room_as(user3, tok=user3_token)
+        self.reactor.advance(10)
+
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(4, len(channel.json_body["users"]))
+        self.assertEqual(4, channel.json_body["total"])
+
+        admin_last_seen = channel.json_body["users"][0]["last_seen_ts"]
+        user1_last_seen = channel.json_body["users"][1]["last_seen_ts"]
+        user2_last_seen = channel.json_body["users"][2]["last_seen_ts"]
+        user3_last_seen = channel.json_body["users"][3]["last_seen_ts"]
+        self.assertTrue(admin_last_seen > 0 and admin_last_seen < 10000)
+        self.assertTrue(user2_last_seen > 10000 and user2_last_seen < 20000)
+        self.assertTrue(user1_last_seen > 20000 and user1_last_seen < 30000)
+        self.assertTrue(user3_last_seen > 30000 and user3_last_seen < 40000)
+
+        self._order_test([self.admin_user, user2, user1, user3], "last_seen_ts")
+
+        self.reactor.advance(LAST_SEEN_GRANULARITY / 1000)
+        self.helper.create_room_as(user1, tok=user1_token)
+        self.reactor.advance(10)
+
+        channel = self.make_request(
+            "GET",
+            self.url + "/" + user1,
+            access_token=self.admin_user_tok,
+        )
+        self.assertTrue(
+            channel.json_body["last_seen_ts"] > 40000 + LAST_SEEN_GRANULARITY
+        )
+
+        self._order_test([self.admin_user, user2, user3, user1], "last_seen_ts")
+
     def test_search_term(self) -> None:
         """Test that searching for a users works correctly"""
 
@@ -1135,6 +1193,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
             self.assertIn("displayname", u)
             self.assertIn("avatar_url", u)
             self.assertIn("creation_ts", u)
+            self.assertIn("last_seen_ts", u)
 
     def _create_users(self, number_users: int) -> None:
         """
@@ -1834,8 +1893,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
             )
 
         # Set monthly active users to the limit
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.hs.config.server.max_mau_value)
+        self.store.get_monthly_active_count = AsyncMock(
+            return_value=self.hs.config.server.max_mau_value
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
@@ -1871,8 +1930,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         handler = self.hs.get_registration_handler()
 
         # Set monthly active users to the limit
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.hs.config.server.max_mau_value)
+        self.store.get_monthly_active_count = AsyncMock(
+            return_value=self.hs.config.server.max_mau_value
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
@@ -3035,6 +3094,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertIn("consent_version", content)
         self.assertIn("consent_ts", content)
         self.assertIn("external_ids", content)
+        self.assertIn("last_seen_ts", content)
 
         # This key was removed intentionally. Ensure it is not accidentally re-included.
         self.assertNotIn("password_hash", content)
diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index 6c04e6c56c..4c69d224b8 100644
--- a/tests/rest/admin/test_username_available.py
+++ b/tests/rest/admin/test_username_available.py
@@ -50,7 +50,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
             )
 
         handler = self.hs.get_registration_handler()
-        handler.check_username = check_username  # type: ignore[assignment]
+        handler.check_username = check_username  # type: ignore[method-assign]
 
     def test_username_available(self) -> None:
         """
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index ac19f3c6da..e9f495e206 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -1346,7 +1346,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
                 return {}
 
         # Register a mock that will return the expected result depending on the remote.
-        self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)  # type: ignore[assignment]
+        self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)  # type: ignore[method-assign]
 
         # Check that we've got the correct response from the client-side endpoint.
         self._test_status(
diff --git a/tests/rest/client/test_account_data.py b/tests/rest/client/test_account_data.py
index d5b0640e7a..481db9a687 100644
--- a/tests/rest/client/test_account_data.py
+++ b/tests/rest/client/test_account_data.py
@@ -11,13 +11,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
 
 from synapse.rest import admin
 from synapse.rest.client import account_data, login, room
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class AccountDataTestCase(unittest.HomeserverTestCase):
@@ -32,7 +31,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
         """Tests that the on_account_data_updated module callback is called correctly when
         a user's account data changes.
         """
-        mocked_callback = Mock(return_value=make_awaitable(None))
+        mocked_callback = AsyncMock(return_value=None)
         self.hs.get_account_data_handler()._on_account_data_updated_callbacks.append(
             mocked_callback
         )
diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py
index 54df2a252c..141e0f57a3 100644
--- a/tests/rest/client/test_events.py
+++ b/tests/rest/client/test_events.py
@@ -45,7 +45,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
 
         hs = self.setup_test_homeserver(config=config)
 
-        hs.get_federation_handler = Mock()  # type: ignore[assignment]
+        hs.get_federation_handler = Mock()  # type: ignore[method-assign]
 
         return hs
 
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index a2d5d340be..90a8df147c 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -65,14 +65,14 @@ class FilterTestCase(unittest.HomeserverTestCase):
 
     def test_add_filter_non_local_user(self) -> None:
         _is_mine = self.hs.is_mine
-        self.hs.is_mine = lambda target_user: False  # type: ignore[assignment]
+        self.hs.is_mine = lambda target_user: False  # type: ignore[method-assign]
         channel = self.make_request(
             "POST",
             "/_matrix/client/r0/user/%s/filter" % (self.user_id),
             self.EXAMPLE_FILTER_JSON,
         )
 
-        self.hs.is_mine = _is_mine  # type: ignore[assignment]
+        self.hs.is_mine = _is_mine  # type: ignore[method-assign]
         self.assertEqual(channel.code, 403)
         self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
 
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index ffbc13bb8d..a2a6589564 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -169,7 +169,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 # which sets these values to 10000, but as we're overriding the entire
                 # rc_login dict here, we need to set this manually as well
                 "account": {"per_second": 10000, "burst_count": 10000},
-            }
+            },
+            "experimental_features": {"msc4041_enabled": True},
         }
     )
     def test_POST_ratelimiting_per_address(self) -> None:
@@ -189,12 +190,15 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             if i == 5:
                 self.assertEqual(channel.code, 429, msg=channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
+                retry_header = channel.headers.getRawHeaders("Retry-After")
             else:
                 self.assertEqual(channel.code, 200, msg=channel.result)
 
         # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
         # than 1min.
-        self.assertTrue(retry_after_ms < 6000)
+        self.assertLess(retry_after_ms, 6000)
+        assert retry_header
+        self.assertLessEqual(int(retry_header[0]), 6)
 
         self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
@@ -217,7 +221,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 # which sets these values to 10000, but as we're overriding the entire
                 # rc_login dict here, we need to set this manually as well
                 "address": {"per_second": 10000, "burst_count": 10000},
-            }
+            },
+            "experimental_features": {"msc4041_enabled": True},
         }
     )
     def test_POST_ratelimiting_per_account(self) -> None:
@@ -234,12 +239,15 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             if i == 5:
                 self.assertEqual(channel.code, 429, msg=channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
+                retry_header = channel.headers.getRawHeaders("Retry-After")
             else:
                 self.assertEqual(channel.code, 200, msg=channel.result)
 
         # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
         # than 1min.
-        self.assertTrue(retry_after_ms < 6000)
+        self.assertLess(retry_after_ms, 6000)
+        assert retry_header
+        self.assertLessEqual(int(retry_header[0]), 6)
 
         self.reactor.advance(retry_after_ms / 1000.0)
 
@@ -262,7 +270,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 # rc_login dict here, we need to set this manually as well
                 "address": {"per_second": 10000, "burst_count": 10000},
                 "failed_attempts": {"per_second": 0.17, "burst_count": 5},
-            }
+            },
+            "experimental_features": {"msc4041_enabled": True},
         }
     )
     def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
@@ -279,12 +288,15 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             if i == 5:
                 self.assertEqual(channel.code, 429, msg=channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
+                retry_header = channel.headers.getRawHeaders("Retry-After")
             else:
                 self.assertEqual(channel.code, 403, msg=channel.result)
 
         # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
         # than 1min.
-        self.assertTrue(retry_after_ms < 6000)
+        self.assertLess(retry_after_ms, 6000)
+        assert retry_header
+        self.assertLessEqual(int(retry_header[0]), 6)
 
         self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
@@ -569,8 +581,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             body,
         )
         self.assertEqual(channel.code, 403, channel.result)
-        self.assertDictContainsSubset(
-            {"errcode": Codes.LIMIT_EXCEEDED, "extra": "value"}, channel.json_body
+        self.assertLessEqual(
+            {"errcode": Codes.LIMIT_EXCEEDED, "extra": "value"}.items(),
+            channel.json_body.items(),
         )
 
 
diff --git a/tests/rest/client/test_notifications.py b/tests/rest/client/test_notifications.py
index 700f6587a0..41ceb3db51 100644
--- a/tests/rest/client/test_notifications.py
+++ b/tests/rest/client/test_notifications.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -20,7 +20,6 @@ from synapse.rest.client import login, notifications, receipts, room
 from synapse.server import HomeServer
 from synapse.util import Clock
 
-from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase
 
 
@@ -45,7 +44,7 @@ class HTTPPusherTests(HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # Mock out the calls over federation.
         fed_transport_client = Mock(spec=["send_transaction"])
-        fed_transport_client.send_transaction = simple_async_mock({})
+        fed_transport_client.send_transaction = AsyncMock(return_value={})
 
         return self.setup_test_homeserver(
             federation_transport_client=fed_transport_client,
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index e12098102b..66b387cea3 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from http import HTTPStatus
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -23,7 +23,6 @@ from synapse.types import UserID
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class PresenceTestCase(unittest.HomeserverTestCase):
@@ -36,7 +35,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.presence_handler = Mock(spec=PresenceHandler)
-        self.presence_handler.set_state.return_value = make_awaitable(None)
+        self.presence_handler.set_state = AsyncMock(return_value=None)
 
         hs = self.setup_test_homeserver(
             "red",
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index b228dba861..c33393dc28 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -75,7 +75,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(channel.code, 200, msg=channel.result)
         det_data = {"user_id": user_id, "home_server": self.hs.hostname}
-        self.assertDictContainsSubset(det_data, channel.json_body)
+        self.assertLessEqual(det_data.items(), channel.json_body.items())
 
     def test_POST_appservice_registration_no_type(self) -> None:
         as_token = "i_am_an_app_service"
@@ -136,7 +136,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "device_id": device_id,
         }
         self.assertEqual(channel.code, 200, msg=channel.result)
-        self.assertDictContainsSubset(det_data, channel.json_body)
+        self.assertLessEqual(det_data.items(), channel.json_body.items())
 
     @override_config({"enable_registration": False})
     def test_POST_disabled_registration(self) -> None:
@@ -157,7 +157,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
         self.assertEqual(channel.code, 200, msg=channel.result)
-        self.assertDictContainsSubset(det_data, channel.json_body)
+        self.assertLessEqual(det_data.items(), channel.json_body.items())
 
     def test_POST_disabled_guest_registration(self) -> None:
         self.hs.config.registration.allow_guest_access = False
@@ -267,7 +267,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "device_id": device_id,
         }
         self.assertEqual(channel.code, 200, msg=channel.result)
-        self.assertDictContainsSubset(det_data, channel.json_body)
+        self.assertLessEqual(det_data.items(), channel.json_body.items())
 
         # Check the `completed` counter has been incremented and pending is 0
         res = self.get_success(
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 9bfe913e45..61773fb28c 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -15,7 +15,7 @@
 
 import urllib.parse
 from typing import Any, Callable, Dict, List, Optional, Tuple
-from unittest.mock import patch
+from unittest.mock import AsyncMock, patch
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -28,7 +28,6 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeChannel
-from tests.test_utils import make_awaitable
 from tests.test_utils.event_injection import inject_event
 from tests.unittest import override_config
 
@@ -264,7 +263,8 @@ class RelationsTestCase(BaseRelationsTestCase):
         # Disable the validation to pretend this came over federation.
         with patch(
             "synapse.handlers.message.EventCreationHandler._validate_event_relation",
-            new=lambda self, event: make_awaitable(None),
+            new_callable=AsyncMock,
+            return_value=None,
         ):
             # Generate a various relations from a different room.
             self.get_success(
@@ -570,7 +570,7 @@ class RelationsTestCase(BaseRelationsTestCase):
         )
         self.assertEqual(200, channel.code, channel.json_body)
         event_result = channel.json_body
-        self.assertDictContainsSubset(original_body, event_result["content"])
+        self.assertLessEqual(original_body.items(), event_result["content"].items())
 
         # also check /context, which returns the *edited* event
         channel = self.make_request(
@@ -587,14 +587,14 @@ class RelationsTestCase(BaseRelationsTestCase):
             (context_result, "/context"),
         ):
             # The reference metadata should still be intact.
-            self.assertDictContainsSubset(
+            self.assertLessEqual(
                 {
                     "m.relates_to": {
                         "event_id": self.parent_id,
                         "rel_type": "m.reference",
                     }
-                },
-                result_event_dict["content"],
+                }.items(),
+                result_event_dict["content"].items(),
                 desc,
             )
 
@@ -1300,7 +1300,8 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         # not an event the Client-Server API will allow..
         with patch(
             "synapse.handlers.message.EventCreationHandler._validate_event_relation",
-            new=lambda self, event: make_awaitable(None),
+            new_callable=AsyncMock,
+            return_value=None,
         ):
             # Create a sub-thread off the thread, which is not allowed.
             self._send_relation(
@@ -1371,9 +1372,11 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         latest_event_in_thread = thread_summary["latest_event"]
         # The latest event in the thread should have the edit appear under the
         # bundled aggregations.
-        self.assertDictContainsSubset(
-            {"event_id": edit_event_id, "sender": "@alice:test"},
-            latest_event_in_thread["unsigned"]["m.relations"][RelationTypes.REPLACE],
+        self.assertLessEqual(
+            {"event_id": edit_event_id, "sender": "@alice:test"}.items(),
+            latest_event_in_thread["unsigned"]["m.relations"][
+                RelationTypes.REPLACE
+            ].items(),
         )
 
     def test_aggregation_get_event_for_annotation(self) -> None:
@@ -1636,9 +1639,9 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         ##################################################
         self.assertEqual(self._get_related_events(), list(reversed(thread_replies)))
         relations = self._get_bundled_aggregations()
-        self.assertDictContainsSubset(
-            {"count": 3, "current_user_participated": True},
-            relations[RelationTypes.THREAD],
+        self.assertLessEqual(
+            {"count": 3, "current_user_participated": True}.items(),
+            relations[RelationTypes.THREAD].items(),
         )
         # The latest event is the last sent event.
         self.assertEqual(
@@ -1657,9 +1660,9 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         # The thread should still exist, but the latest event should be updated.
         self.assertEqual(self._get_related_events(), list(reversed(thread_replies)))
         relations = self._get_bundled_aggregations()
-        self.assertDictContainsSubset(
-            {"count": 2, "current_user_participated": True},
-            relations[RelationTypes.THREAD],
+        self.assertLessEqual(
+            {"count": 2, "current_user_participated": True}.items(),
+            relations[RelationTypes.THREAD].items(),
         )
         # And the latest event is the last unredacted event.
         self.assertEqual(
@@ -1676,9 +1679,9 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         # Nothing should have changed (except the thread count).
         self.assertEqual(self._get_related_events(), thread_replies)
         relations = self._get_bundled_aggregations()
-        self.assertDictContainsSubset(
-            {"count": 1, "current_user_participated": True},
-            relations[RelationTypes.THREAD],
+        self.assertLessEqual(
+            {"count": 1, "current_user_participated": True}.items(),
+            relations[RelationTypes.THREAD].items(),
         )
         # And the latest event is the last unredacted event.
         self.assertEqual(
@@ -1773,12 +1776,12 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         event_ids = self._get_related_events()
         relations = self._get_bundled_aggregations()
         self.assertEqual(len(event_ids), 1)
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "count": 1,
                 "current_user_participated": True,
-            },
-            relations[RelationTypes.THREAD],
+            }.items(),
+            relations[RelationTypes.THREAD].items(),
         )
         self.assertEqual(
             relations[RelationTypes.THREAD]["latest_event"]["event_id"],
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 88e579dc39..47c1d38ad7 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -20,7 +20,7 @@
 import json
 from http import HTTPStatus
 from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
-from unittest.mock import Mock, call, patch
+from unittest.mock import AsyncMock, Mock, call, patch
 from urllib import parse as urlparse
 
 from parameterized import param, parameterized
@@ -52,7 +52,6 @@ from synapse.util.stringutils import random_string
 from tests import unittest
 from tests.http.server._base import make_request_with_cancellation_test
 from tests.storage.test_stream import PaginationTestCase
-from tests.test_utils import make_awaitable
 from tests.test_utils.event_injection import create_event
 from tests.unittest import override_config
 
@@ -69,15 +68,15 @@ class RoomBase(unittest.HomeserverTestCase):
             "red",
         )
 
-        self.hs.get_federation_handler = Mock()  # type: ignore[assignment]
-        self.hs.get_federation_handler.return_value.maybe_backfill = Mock(
-            return_value=make_awaitable(None)
+        self.hs.get_federation_handler = Mock()  # type: ignore[method-assign]
+        self.hs.get_federation_handler.return_value.maybe_backfill = AsyncMock(
+            return_value=None
         )
 
         async def _insert_client_ip(*args: Any, **kwargs: Any) -> None:
             return None
 
-        self.hs.get_datastores().main.insert_client_ip = _insert_client_ip  # type: ignore[assignment]
+        self.hs.get_datastores().main.insert_client_ip = _insert_client_ip  # type: ignore[method-assign]
 
         return self.hs
 
@@ -2375,7 +2374,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
     ]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        return self.setup_test_homeserver(federation_client=Mock())
+        return self.setup_test_homeserver(federation_client=AsyncMock())
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.register_user("user", "pass")
@@ -2385,7 +2384,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
 
     def test_simple(self) -> None:
         "Simple test for searching rooms over federation"
-        self.federation_client.get_public_rooms.return_value = make_awaitable({})  # type: ignore[attr-defined]
+        self.federation_client.get_public_rooms.return_value = {}  # type: ignore[attr-defined]
 
         search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
 
@@ -2413,7 +2412,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
         # with a 404, when using search filters.
         self.federation_client.get_public_rooms.side_effect = (  # type: ignore[attr-defined]
             HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""),
-            make_awaitable({}),
+            {},
         )
 
         search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
@@ -3413,17 +3412,17 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
         # Mock a few functions to prevent the test from failing due to failing to talk to
         # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
         # can check its call_count later on during the test.
-        make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
-        self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock  # type: ignore[assignment]
-        self.hs.get_identity_handler().lookup_3pid = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None),
+        make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0))
+        self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock  # type: ignore[method-assign]
+        self.hs.get_identity_handler().lookup_3pid = AsyncMock(  # type: ignore[method-assign]
+            return_value=None,
         )
 
         # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
         # allow everything for now.
         # `spec` argument is needed for this function mock to have `__qualname__`, which
         # is needed for `Measure` metrics buried in SpamChecker.
-        mock = Mock(return_value=make_awaitable(True), spec=lambda *x: None)
+        mock = AsyncMock(return_value=True, spec=lambda *x: None)
         self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append(
             mock
         )
@@ -3451,7 +3450,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
 
         # Now change the return value of the callback to deny any invite and test that
         # we can't send the invite.
-        mock.return_value = make_awaitable(False)
+        mock.return_value = False
         channel = self.make_request(
             method="POST",
             path="/rooms/" + self.room_id + "/invite",
@@ -3477,18 +3476,18 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
         # Mock a few functions to prevent the test from failing due to failing to talk to
         # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
         # can check its call_count later on during the test.
-        make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
-        self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock  # type: ignore[assignment]
-        self.hs.get_identity_handler().lookup_3pid = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None),
+        make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0))
+        self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock  # type: ignore[method-assign]
+        self.hs.get_identity_handler().lookup_3pid = AsyncMock(  # type: ignore[method-assign]
+            return_value=None,
         )
 
         # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
         # allow everything for now.
         # `spec` argument is needed for this function mock to have `__qualname__`, which
         # is needed for `Measure` metrics buried in SpamChecker.
-        mock = Mock(
-            return_value=make_awaitable(synapse.module_api.NOT_SPAM),
+        mock = AsyncMock(
+            return_value=synapse.module_api.NOT_SPAM,
             spec=lambda *x: None,
         )
         self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append(
@@ -3519,7 +3518,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
         # Now change the return value of the callback to deny any invite and test that
         # we can't send the invite. We pick an arbitrary error code to be able to check
         # that the same code has been returned
-        mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN)
+        mock.return_value = Codes.CONSENT_NOT_GIVEN
         channel = self.make_request(
             method="POST",
             path="/rooms/" + self.room_id + "/invite",
@@ -3538,7 +3537,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
         make_invite_mock.assert_called_once()
 
         # Run variant with `Tuple[Codes, dict]`.
-        mock.return_value = make_awaitable((Codes.EXPIRED_ACCOUNT, {"field": "value"}))
+        mock.return_value = (Codes.EXPIRED_ACCOUNT, {"field": "value"})
         channel = self.make_request(
             method="POST",
             path="/rooms/" + self.room_id + "/invite",
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index 8d2cdf8751..9aecf88e41 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -84,7 +84,7 @@ class RoomTestCase(_ShadowBannedBase):
     def test_invite_3pid(self) -> None:
         """Ensure that a 3PID invite does not attempt to contact the identity server."""
         identity_handler = self.hs.get_identity_handler()
-        identity_handler.lookup_3pid = Mock(  # type: ignore[assignment]
+        identity_handler.lookup_3pid = Mock(  # type: ignore[method-assign]
             side_effect=AssertionError("This should not get called")
         )
 
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index e5ba5a9706..57eb713b15 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 import threading
 from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -33,7 +33,6 @@ from synapse.util import Clock
 from synapse.util.frozenutils import unfreeze
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 if TYPE_CHECKING:
     from synapse.module_api import ModuleApi
@@ -118,7 +117,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         async def _check_event_auth(origin: Any, event: Any, context: Any) -> None:
             pass
 
-        hs.get_federation_event_handler()._check_event_auth = _check_event_auth  # type: ignore[assignment]
+        hs.get_federation_event_handler()._check_event_auth = _check_event_auth  # type: ignore[method-assign]
 
         return hs
 
@@ -477,7 +476,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
 
     def test_on_new_event(self) -> None:
         """Test that the on_new_event callback is called on new events"""
-        on_new_event = Mock(make_awaitable(None))
+        on_new_event = AsyncMock(return_value=None)
         self.hs.get_module_api_callbacks().third_party_event_rules._on_new_event_callbacks.append(
             on_new_event
         )
@@ -580,7 +579,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo"
 
         # Register a mock callback.
-        m = Mock(return_value=make_awaitable(None))
+        m = AsyncMock(return_value=None)
         self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
             m
         )
@@ -641,7 +640,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo"
 
         # Register a mock callback.
-        m = Mock(return_value=make_awaitable(None))
+        m = AsyncMock(return_value=None)
         self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
             m
         )
@@ -682,7 +681,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         correctly when processing a user's deactivation.
         """
         # Register a mocked callback.
-        deactivation_mock = Mock(return_value=make_awaitable(None))
+        deactivation_mock = AsyncMock(return_value=None)
         third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._on_user_deactivation_status_changed_callbacks.append(
             deactivation_mock,
@@ -690,7 +689,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         # Also register a mocked callback for profile updates, to check that the
         # deactivation code calls it in a way that let modules know the user is being
         # deactivated.
-        profile_mock = Mock(return_value=make_awaitable(None))
+        profile_mock = AsyncMock(return_value=None)
         self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
             profile_mock,
         )
@@ -740,7 +739,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         well as a reactivation.
         """
         # Register a mock callback.
-        m = Mock(return_value=make_awaitable(None))
+        m = AsyncMock(return_value=None)
         third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._on_user_deactivation_status_changed_callbacks.append(m)
 
@@ -794,7 +793,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         correctly when processing a user's deactivation.
         """
         # Register a mocked callback.
-        deactivation_mock = Mock(return_value=make_awaitable(False))
+        deactivation_mock = AsyncMock(return_value=False)
         third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._check_can_deactivate_user_callbacks.append(
             deactivation_mock,
@@ -840,7 +839,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         correctly when processing a user's deactivation triggered by a server admin.
         """
         # Register a mocked callback.
-        deactivation_mock = Mock(return_value=make_awaitable(False))
+        deactivation_mock = AsyncMock(return_value=False)
         third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._check_can_deactivate_user_callbacks.append(
             deactivation_mock,
@@ -879,7 +878,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         correctly when processing an admin's shutdown room request.
         """
         # Register a mocked callback.
-        shutdown_mock = Mock(return_value=make_awaitable(False))
+        shutdown_mock = AsyncMock(return_value=False)
         third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._check_can_shutdown_room_callbacks.append(
             shutdown_mock,
@@ -915,7 +914,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         associating a 3PID to an account.
         """
         # Register a mocked callback.
-        threepid_bind_mock = Mock(return_value=make_awaitable(None))
+        threepid_bind_mock = AsyncMock(return_value=None)
         third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock)
 
@@ -957,11 +956,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         just before associating and removing a 3PID to/from an account.
         """
         # Pretend to be a Synapse module and register both callbacks as mocks.
-        on_add_user_third_party_identifier_callback_mock = Mock(
-            return_value=make_awaitable(None)
-        )
-        on_remove_user_third_party_identifier_callback_mock = Mock(
-            return_value=make_awaitable(None)
+        on_add_user_third_party_identifier_callback_mock = AsyncMock(return_value=None)
+        on_remove_user_third_party_identifier_callback_mock = AsyncMock(
+            return_value=None
         )
         self.hs.get_module_api().register_third_party_rules_callbacks(
             on_add_user_third_party_identifier=on_add_user_third_party_identifier_callback_mock,
@@ -1021,8 +1018,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         when a user is deactivated and their third-party ID associations are deleted.
         """
         # Pretend to be a Synapse module and register both callbacks as mocks.
-        on_remove_user_third_party_identifier_callback_mock = Mock(
-            return_value=make_awaitable(None)
+        on_remove_user_third_party_identifier_callback_mock = AsyncMock(
+            return_value=None
         )
         self.hs.get_module_api().register_third_party_rules_callbacks(
             on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index d8dc56261a..951a3cbc43 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -14,7 +14,7 @@
 
 from http import HTTPStatus
 from typing import Any, Generator, Tuple, cast
-from unittest.mock import Mock, call
+from unittest.mock import AsyncMock, Mock, call
 
 from twisted.internet import defer, reactor as _reactor
 
@@ -24,7 +24,6 @@ from synapse.types import ISynapseReactor, JsonDict
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 from tests.utils import MockClock
 
 reactor = cast(ISynapseReactor, _reactor)
@@ -53,7 +52,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
     def test_executes_given_function(
         self,
     ) -> Generator["defer.Deferred[Any]", object, None]:
-        cb = Mock(return_value=make_awaitable(self.mock_http_response))
+        cb = AsyncMock(return_value=self.mock_http_response)
         res = yield self.cache.fetch_or_execute_request(
             self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg"
         )
@@ -64,7 +63,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
     def test_deduplicates_based_on_key(
         self,
     ) -> Generator["defer.Deferred[Any]", object, None]:
-        cb = Mock(return_value=make_awaitable(self.mock_http_response))
+        cb = AsyncMock(return_value=self.mock_http_response)
         for i in range(3):  # invoke multiple times
             res = yield self.cache.fetch_or_execute_request(
                 self.mock_request,
@@ -168,7 +167,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
-        cb = Mock(return_value=make_awaitable(self.mock_http_response))
+        cb = AsyncMock(return_value=self.mock_http_response)
         yield self.cache.fetch_or_execute_request(
             self.mock_request, self.mock_requester, cb, "an arg"
         )
diff --git a/tests/server.py b/tests/server.py
index ff03d28864..08633fe640 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import hashlib
+import ipaddress
 import json
 import logging
 import os
@@ -45,7 +46,7 @@ import attr
 from typing_extensions import ParamSpec
 from zope.interface import implementer
 
-from twisted.internet import address, threads, udp
+from twisted.internet import address, tcp, threads, udp
 from twisted.internet._resolver import SimpleResolverComplexifier
 from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
 from twisted.internet.error import DNSLookupError
@@ -567,6 +568,8 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         conn = super().connectTCP(
             host, port, factory, timeout=timeout, bindAddress=None
         )
+        if self.lookups and host in self.lookups:
+            validate_connector(conn, self.lookups[host])
 
         callback = self._tcp_callbacks.get((host, port))
         if callback:
@@ -599,6 +602,55 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
             super().advance(0)
 
 
+def validate_connector(connector: tcp.Connector, expected_ip: str) -> None:
+    """Try to validate the obtained connector as it would happen when
+    synapse is running and the conection will be established.
+
+    This method will raise a useful exception when necessary, else it will
+    just do nothing.
+
+    This is in order to help catch quirks related to reactor.connectTCP,
+    since when called directly, the connector's destination will be of type
+    IPv4Address, with the hostname as the literal host that was given (which
+    could be an IPv6-only host or an IPv6 literal).
+
+    But when called from reactor.connectTCP *through* e.g. an Endpoint, the
+    connector's destination will contain the specific IP address with the
+    correct network stack class.
+
+    Note that testing code paths that use connectTCP directly should not be
+    affected by this check, unless they specifically add a test with a
+    matching reactor.lookups[HOSTNAME] = "IPv6Literal", where reactor is of
+    type ThreadedMemoryReactorClock.
+    For an example of implementing such tests, see test/handlers/send_email.py.
+    """
+    destination = connector.getDestination()
+
+    # We use address.IPv{4,6}Address to check what the reactor thinks it is
+    # is sending but check for validity with ipaddress.IPv{4,6}Address
+    # because they fail with IPs on the wrong network stack.
+    cls_mapping = {
+        address.IPv4Address: ipaddress.IPv4Address,
+        address.IPv6Address: ipaddress.IPv6Address,
+    }
+
+    cls = cls_mapping.get(destination.__class__)
+
+    if cls is not None:
+        try:
+            cls(expected_ip)
+        except Exception as exc:
+            raise ValueError(
+                "Invalid IP type and resolution for %s. Expected %s to be %s"
+                % (destination, expected_ip, cls.__name__)
+            ) from exc
+    else:
+        raise ValueError(
+            "Unknown address type %s for %s"
+            % (destination.__class__.__name__, destination)
+        )
+
+
 class ThreadPool:
     """
     Threadless thread pool.
@@ -670,7 +722,7 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
                 **kwargs,
             )
 
-        pool.runWithConnection = runWithConnection  # type: ignore[assignment]
+        pool.runWithConnection = runWithConnection  # type: ignore[method-assign]
         pool.runInteraction = runInteraction  # type: ignore[assignment]
         # Replace the thread pool with a threadless 'thread' pool
         pool.threadpool = ThreadPool(clock._reactor)
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index d2bfa53eda..17f428bfc5 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Tuple
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -29,7 +29,6 @@ from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 from tests.utils import default_config
 
@@ -69,24 +68,22 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         assert isinstance(rlsn, ResourceLimitsServerNotices)
         self._rlsn = rlsn
 
-        self._rlsn._store.user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(1000)
-        )
-        self._rlsn._server_notices_manager.send_notice = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(Mock())
+        self._rlsn._store.user_last_seen_monthly_active = AsyncMock(return_value=1000)
+        self._rlsn._server_notices_manager.send_notice = AsyncMock(  # type: ignore[method-assign]
+            return_value=Mock()
         )
         self._send_notice = self._rlsn._server_notices_manager.send_notice
 
         self.user_id = "@user_id:test"
 
-        self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
-            return_value=make_awaitable("!something:localhost")
+        self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = (
+            AsyncMock(return_value="!something:localhost")
         )
-        self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = Mock(
-            return_value=make_awaitable("!something:localhost")
+        self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = AsyncMock(
+            return_value="!something:localhost"
         )
-        self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
-        self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))  # type: ignore[assignment]
+        self._rlsn._store.add_tag_to_room = AsyncMock(return_value=None)  # type: ignore[method-assign]
+        self._rlsn._store.get_tags_for_room = AsyncMock(return_value={})  # type: ignore[method-assign]
 
     @override_config({"hs_disabled": True})
     def test_maybe_send_server_notice_disabled_hs(self) -> None:
@@ -103,14 +100,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
     def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None:
         """Test when user has blocked notice, but should have it removed"""
 
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None)
+        self._rlsn._auth_blocking.check_auth_blocking = AsyncMock(  # type: ignore[method-assign]
+            return_value=None
         )
         mock_event = Mock(
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
-        self._rlsn._store.get_events = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable({"123": mock_event})
+        self._rlsn._store.get_events = AsyncMock(  # type: ignore[method-assign]
+            return_value={"123": mock_event}
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
         # Would be better to check the content, but once == remove blocking event
@@ -125,16 +122,16 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test when user has blocked notice, but notice ought to be there (NOOP)
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None),
+        self._rlsn._auth_blocking.check_auth_blocking = AsyncMock(  # type: ignore[method-assign]
+            return_value=None,
             side_effect=ResourceLimitError(403, "foo"),
         )
 
         mock_event = Mock(
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
-        self._rlsn._store.get_events = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable({"123": mock_event})
+        self._rlsn._store.get_events = AsyncMock(  # type: ignore[method-assign]
+            return_value={"123": mock_event}
         )
 
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -145,8 +142,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test when user does not have blocked notice, but should have one
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None),
+        self._rlsn._auth_blocking.check_auth_blocking = AsyncMock(  # type: ignore[method-assign]
+            return_value=None,
             side_effect=ResourceLimitError(403, "foo"),
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -158,8 +155,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test when user does not have blocked notice, nor should they (NOOP)
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None)
+        self._rlsn._auth_blocking.check_auth_blocking = AsyncMock(  # type: ignore[method-assign]
+            return_value=None
         )
 
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -171,12 +168,10 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         Test when user is not part of the MAU cohort - this should not ever
         happen - but ...
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None)
-        )
-        self._rlsn._store.user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(None)
+        self._rlsn._auth_blocking.check_auth_blocking = AsyncMock(  # type: ignore[method-assign]
+            return_value=None
         )
+        self._rlsn._store.user_last_seen_monthly_active = AsyncMock(return_value=None)
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
         self._send_notice.assert_not_called()
@@ -189,8 +184,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         Test that when server is over MAU limit and alerting is suppressed, then
         an alert message is not sent into the room
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None),
+        self._rlsn._auth_blocking.check_auth_blocking = AsyncMock(  # type: ignore[method-assign]
+            return_value=None,
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
             ),
@@ -204,8 +199,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         Test that when a server is disabled, that MAU limit alerting is ignored.
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None),
+        self._rlsn._auth_blocking.check_auth_blocking = AsyncMock(  # type: ignore[method-assign]
+            return_value=None,
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
             ),
@@ -223,22 +218,22 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         When the room is already in a blocked state, test that when alerting
         is suppressed that the room is returned to an unblocked state.
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None),
+        self._rlsn._auth_blocking.check_auth_blocking = AsyncMock(  # type: ignore[method-assign]
+            return_value=None,
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
             ),
         )
 
-        self._rlsn._is_room_currently_blocked = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable((True, []))
+        self._rlsn._is_room_currently_blocked = AsyncMock(  # type: ignore[method-assign]
+            return_value=(True, [])
         )
 
         mock_event = Mock(
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
-        self._rlsn._store.get_events = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable({"123": mock_event})
+        self._rlsn._store.get_events = AsyncMock(  # type: ignore[method-assign]
+            return_value={"123": mock_event}
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
@@ -284,11 +279,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
         self.user_id = "@user_id:test"
 
     def test_server_notice_only_sent_once(self) -> None:
-        self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000))
+        self.store.get_monthly_active_count = AsyncMock(return_value=1000)
 
-        self.store.user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(1000)
-        )
+        self.store.user_last_seen_monthly_active = AsyncMock(return_value=1000)
 
         # Call the function multiple times to ensure we only send the notice once
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -327,7 +320,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
         hasn't been reached (since it's the only user and the limit is 5), so users
         shouldn't receive a server notice.
         """
-        m = Mock(return_value=make_awaitable(None))
+        m = AsyncMock(return_value=None)
         self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = m
 
         user_id = self.register_user("user", "password")
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index f541f1d6be..650b4941ba 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -132,6 +132,7 @@ class LockTestCase(unittest.HomeserverTestCase):
 
         # We simulate the process getting stuck by cancelling the looping call
         # that keeps the lock active.
+        assert lock._looping_call
         lock._looping_call.stop()
 
         # Wait for the lock to timeout.
@@ -403,6 +404,7 @@ class ReadWriteLockTestCase(unittest.HomeserverTestCase):
 
         # We simulate the process getting stuck by cancelling the looping call
         # that keeps the lock active.
+        assert lock._looping_call
         lock._looping_call.stop()
 
         # Wait for the lock to timeout.
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 71302facd1..cbce26a725 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -15,7 +15,7 @@ import json
 import os
 import tempfile
 from typing import List, cast
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 import yaml
 
@@ -35,7 +35,6 @@ from synapse.types import DeviceListUpdates
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
@@ -339,7 +338,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
 
         # we aren't testing store._base stuff here, so mock this out
         # (ignore needed because Mypy won't allow us to assign to a method otherwise)
-        self.store.get_events_as_list = Mock(return_value=make_awaitable(events))  # type: ignore[assignment]
+        self.store.get_events_as_list = AsyncMock(return_value=events)  # type: ignore[method-assign]
 
         self.get_success(self._insert_txn(self.as_list[1]["id"], 9, other_events))
         self.get_success(self._insert_txn(service.id, 10, events))
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index a4a823a252..abf7d0564d 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -11,8 +11,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-from unittest.mock import Mock
+import logging
+from unittest.mock import AsyncMock, Mock
 
 import yaml
 
@@ -32,7 +32,6 @@ from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable, simple_async_mock
 from tests.unittest import override_config
 
 
@@ -331,6 +330,28 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
         self.update_handler.side_effect = update_short
         self.get_success(self.updates.do_next_background_update(False))
 
+    def test_failed_update_logs_exception_details(self) -> None:
+        needle = "RUH ROH RAGGY"
+
+        def failing_update(progress: JsonDict, count: int) -> int:
+            raise Exception(needle)
+
+        self.update_handler.side_effect = failing_update
+        self.update_handler.reset_mock()
+
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                values={"update_name": "test_update", "progress_json": "{}"},
+            )
+        )
+
+        with self.assertLogs(level=logging.ERROR) as logs:
+            # Expect a back-to-back RuntimeError to be raised
+            self.get_failure(self.updates.run_background_updates(False), RuntimeError)
+
+        self.assertTrue(any(needle in log for log in logs.output), logs.output)
+
 
 class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -348,8 +369,8 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
 
         # Mock out the AsyncContextManager
         class MockCM:
-            __aenter__ = simple_async_mock(return_value=None)
-            __aexit__ = simple_async_mock(return_value=None)
+            __aenter__ = AsyncMock(return_value=None)
+            __aexit__ = AsyncMock(return_value=None)
 
         self._update_ctx_manager = MockCM
 
@@ -363,9 +384,9 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
         # Register the callbacks with more mocks
         self.hs.get_module_api().register_background_update_controller_callbacks(
             on_update=self._on_update,
-            min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)),
-            default_batch_size=Mock(
-                return_value=make_awaitable(self._default_batch_size),
+            min_batch_size=AsyncMock(return_value=self._default_batch_size),
+            default_batch_size=AsyncMock(
+                return_value=self._default_batch_size,
             ),
         )
 
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 209d68b40b..6b9692c486 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from typing import Any, Dict
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
 
 from parameterized import parameterized
 
@@ -30,7 +30,6 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.server import make_request
-from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
 
@@ -66,15 +65,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         r = result[(user_id, device_id)]
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": user_id,
                 "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 12345678000,
-            },
-            r,
+            }.items(),
+            r.items(),
         )
 
     def test_insert_new_client_ip_none_device_id(self) -> None:
@@ -443,9 +442,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         lots_of_users = 100
         user_id = "@user:server"
 
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(lots_of_users)
-        )
+        self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users)
         self.get_success(
             self.store.insert_client_ip(
                 user_id, "access_token", "ip", "user_agent", "device_id"
@@ -529,15 +526,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         r = result[(user_id, device_id)]
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": user_id,
                 "device_id": device_id,
                 "ip": None,
                 "user_agent": None,
                 "last_seen": None,
-            },
-            r,
+            }.items(),
+            r.items(),
         )
 
         # Register the background update to run again.
@@ -564,15 +561,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         r = result[(user_id, device_id)]
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": user_id,
                 "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 0,
-            },
-            r,
+            }.items(),
+            r.items(),
         )
 
     def test_old_user_ips_pruned(self) -> None:
@@ -643,15 +640,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         r = result2[(user_id, device_id)]
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": user_id,
                 "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 0,
-            },
-            r,
+            }.items(),
+            r.items(),
         )
 
     def test_invalid_user_agents_are_ignored(self) -> None:
@@ -780,13 +777,13 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
             self.store.get_last_client_ip_by_device(self.user_id, device_id)
         )
         r = result[(self.user_id, device_id)]
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": self.user_id,
                 "device_id": device_id,
                 "ip": expected_ip,
                 "user_agent": "Mozzila pizza",
                 "last_seen": 123456100,
-            },
-            r,
+            }.items(),
+            r.items(),
         )
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index f03807c8f9..58ab41cf26 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -58,13 +58,13 @@ class DeviceStoreTestCase(HomeserverTestCase):
 
         res = self.get_success(self.store.get_device("user_id", "device_id"))
         assert res is not None
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": "user_id",
                 "device_id": "device_id",
                 "display_name": "display_name",
-            },
-            res,
+            }.items(),
+            res.items(),
         )
 
     def test_get_devices_by_user(self) -> None:
@@ -80,21 +80,21 @@ class DeviceStoreTestCase(HomeserverTestCase):
 
         res = self.get_success(self.store.get_devices_by_user("user_id"))
         self.assertEqual(2, len(res.keys()))
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": "user_id",
                 "device_id": "device1",
                 "display_name": "display_name 1",
-            },
-            res["device1"],
+            }.items(),
+            res["device1"].items(),
         )
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "user_id": "user_id",
                 "device_id": "device2",
                 "display_name": "display_name 2",
-            },
-            res["device2"],
+            }.items(),
+            res["device2"].items(),
         )
 
     def test_count_devices_by_users(self) -> None:
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 5fde3b9c78..2033377b52 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -38,7 +38,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
         self.assertIn("user", res)
         self.assertIn("device", res["user"])
         dev = res["user"]["device"]
-        self.assertDictContainsSubset(json, dev)
+        self.assertLessEqual(json.items(), dev.items())
 
     def test_reupload_key(self) -> None:
         now = 1470174257070
@@ -71,8 +71,12 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
         self.assertIn("user", res)
         self.assertIn("device", res["user"])
         dev = res["user"]["device"]
-        self.assertDictContainsSubset(
-            {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
+        self.assertLessEqual(
+            {
+                "key": "value",
+                "unsigned": {"device_display_name": "display_name"},
+            }.items(),
+            dev.items(),
         )
 
     def test_multiple_devices(self) -> None:
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 2827738379..49366440ce 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Any, Dict, List
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -21,7 +21,6 @@ from synapse.server import HomeServer
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 from tests.unittest import default_config, override_config
 
 FORTY_DAYS = 40 * 24 * 60 * 60
@@ -253,7 +252,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         )
         self.get_success(d)
 
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
+        self.store.upsert_monthly_active_user = AsyncMock(return_value=None)  # type: ignore[method-assign]
 
         d = self.store.populate_monthly_active_users(user_id)
         self.get_success(d)
@@ -261,24 +260,22 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.upsert_monthly_active_user.assert_not_called()
 
     def test_populate_monthly_users_should_update(self) -> None:
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
+        self.store.upsert_monthly_active_user = AsyncMock(return_value=None)  # type: ignore[method-assign]
 
-        self.store.is_trial_user = Mock(return_value=make_awaitable(False))  # type: ignore[assignment]
+        self.store.is_trial_user = AsyncMock(return_value=False)  # type: ignore[method-assign]
 
-        self.store.user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(None)
-        )
+        self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
         d = self.store.populate_monthly_active_users("user_id")
         self.get_success(d)
 
         self.store.upsert_monthly_active_user.assert_called_once()
 
     def test_populate_monthly_users_should_not_update(self) -> None:
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
+        self.store.upsert_monthly_active_user = AsyncMock(return_value=None)  # type: ignore[method-assign]
 
-        self.store.is_trial_user = Mock(return_value=make_awaitable(False))  # type: ignore[assignment]
-        self.store.user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(self.hs.get_clock().time_msec())
+        self.store.is_trial_user = AsyncMock(return_value=False)  # type: ignore[method-assign]
+        self.store.user_last_seen_monthly_active = AsyncMock(
+            return_value=self.hs.get_clock().time_msec()
         )
 
         d = self.store.populate_monthly_active_users("user_id")
@@ -359,7 +356,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
     def test_no_users_when_not_tracking(self) -> None:
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
+        self.store.upsert_monthly_active_user = AsyncMock(return_value=None)  # type: ignore[method-assign]
 
         self.get_success(self.store.populate_monthly_active_users("@user:sever"))
 
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index ba41459d08..95c9792d54 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -51,6 +51,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
                 "locked": 0,
                 "shadow_banned": 0,
                 "approved": 1,
+                "last_seen_ts": None,
             },
             (self.get_success(self.store.get_user_by_id(self.user_id))),
         )
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 71ec74eadc..1e27f2c275 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -44,13 +44,13 @@ class RoomStoreTestCase(HomeserverTestCase):
     def test_get_room(self) -> None:
         res = self.get_success(self.store.get_room(self.room.to_string()))
         assert res is not None
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "room_id": self.room.to_string(),
                 "creator": self.u_creator.to_string(),
                 "is_public": True,
-            },
-            res,
+            }.items(),
+            res.items(),
         )
 
     def test_get_room_unknown_room(self) -> None:
@@ -59,13 +59,13 @@ class RoomStoreTestCase(HomeserverTestCase):
     def test_get_room_with_stats(self) -> None:
         res = self.get_success(self.store.get_room_with_stats(self.room.to_string()))
         assert res is not None
-        self.assertDictContainsSubset(
+        self.assertLessEqual(
             {
                 "room_id": self.room.to_string(),
                 "creator": self.u_creator.to_string(),
                 "public": True,
-            },
-            res,
+            }.items(),
+            res.items(),
         )
 
     def test_get_room_with_stats_unknown_room(self) -> None:
diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py
index 0e3fc2a77f..29be8cdbd0 100644
--- a/tests/storage/util/test_partial_state_events_tracker.py
+++ b/tests/storage/util/test_partial_state_events_tracker.py
@@ -22,7 +22,6 @@ from synapse.storage.util.partial_state_events_tracker import (
     PartialStateEventsTracker,
 )
 
-from tests.test_utils import make_awaitable
 from tests.unittest import TestCase
 
 
@@ -124,16 +123,17 @@ class PartialStateEventsTrackerTestCase(TestCase):
 class PartialCurrentStateTrackerTestCase(TestCase):
     def setUp(self) -> None:
         self.mock_store = mock.Mock(spec_set=["is_partial_state_room"])
+        self.mock_store.is_partial_state_room = mock.AsyncMock()
 
         self.tracker = PartialCurrentStateTracker(self.mock_store)
 
     def test_does_not_block_for_full_state_rooms(self) -> None:
-        self.mock_store.is_partial_state_room.return_value = make_awaitable(False)
+        self.mock_store.is_partial_state_room.return_value = False
 
         self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
 
     def test_blocks_for_partial_room_state(self) -> None:
-        self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
+        self.mock_store.is_partial_state_room.return_value = True
 
         d = ensureDeferred(self.tracker.await_full_state("room_id"))
 
@@ -156,7 +156,7 @@ class PartialCurrentStateTrackerTestCase(TestCase):
         self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
 
     def test_cancellation(self) -> None:
-        self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
+        self.mock_store.is_partial_state_room.return_value = True
 
         d1 = ensureDeferred(self.tracker.await_full_state("room_id"))
         self.assertNoResult(d1)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 6d15ac7597..f8ade6da38 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from typing import Collection, List, Optional, Union
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -31,7 +31,6 @@ from synapse.util import Clock
 from synapse.util.retryutils import NotRetryingDestination
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class MessageAcceptTests(unittest.HomeserverTestCase):
@@ -81,7 +80,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         ) -> None:
             pass
 
-        federation_event_handler._check_event_auth = _check_event_auth  # type: ignore[assignment]
+        federation_event_handler._check_event_auth = _check_event_auth  # type: ignore[method-assign]
         self.client = self.hs.get_federation_client()
 
         async def _check_sigs_and_hash_for_pulled_events_and_fetch(
@@ -191,12 +190,12 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         # Register the mock on the federation client.
         federation_client = self.hs.get_federation_client()
-        federation_client.query_user_devices = Mock(side_effect=query_user_devices)  # type: ignore[assignment]
+        federation_client.query_user_devices = Mock(side_effect=query_user_devices)  # type: ignore[method-assign]
 
         # Register a mock on the store so that the incoming update doesn't fail because
         # we don't share a room with the user.
         store = self.hs.get_datastores().main
-        store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
+        store.get_rooms_for_user = AsyncMock(return_value=["!someroom:test"])
 
         # Manually inject a fake device list update. We need this update to include at
         # least one prev_id so that the user's device list will need to be retried.
@@ -241,27 +240,24 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         # Register mock device list retrieval on the federation client.
         federation_client = self.hs.get_federation_client()
-        federation_client.query_user_devices = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(
-                {
+        federation_client.query_user_devices = AsyncMock(  # type: ignore[method-assign]
+            return_value={
+                "user_id": remote_user_id,
+                "stream_id": 1,
+                "devices": [],
+                "master_key": {
                     "user_id": remote_user_id,
-                    "stream_id": 1,
-                    "devices": [],
-                    "master_key": {
-                        "user_id": remote_user_id,
-                        "usage": ["master"],
-                        "keys": {"ed25519:" + remote_master_key: remote_master_key},
-                    },
-                    "self_signing_key": {
-                        "user_id": remote_user_id,
-                        "usage": ["self_signing"],
-                        "keys": {
-                            "ed25519:"
-                            + remote_self_signing_key: remote_self_signing_key
-                        },
+                    "usage": ["master"],
+                    "keys": {"ed25519:" + remote_master_key: remote_master_key},
+                },
+                "self_signing_key": {
+                    "user_id": remote_user_id,
+                    "usage": ["self_signing"],
+                    "keys": {
+                        "ed25519:" + remote_self_signing_key: remote_self_signing_key
                     },
-                }
-            )
+                },
+            }
         )
 
         # Resync the device list.
diff --git a/tests/test_state.py b/tests/test_state.py
index eded38c766..9c8679cc1d 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -714,7 +714,7 @@ class StateTestCase(unittest.TestCase):
         store = _DummyStore()
         store.register_events(old_state_1)
         store.register_events(old_state_2)
-        self.dummy_store.get_events = store.get_events  # type: ignore[assignment]
+        self.dummy_store.get_events = store.get_events  # type: ignore[method-assign]
 
         context: EventContext
         context = yield self._get_context(
@@ -773,7 +773,7 @@ class StateTestCase(unittest.TestCase):
         store = _DummyStore()
         store.register_events(old_state_1)
         store.register_events(old_state_2)
-        self.dummy_store.get_events = store.get_events  # type: ignore[assignment]
+        self.dummy_store.get_events = store.get_events  # type: ignore[method-assign]
 
         context: EventContext
         context = yield self._get_context(
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 52424aa087..64a49488c6 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -85,7 +85,9 @@ class TermsTestCase(unittest.HomeserverTestCase):
             }
         }
         self.assertIsInstance(channel.json_body["params"], dict)
-        self.assertDictContainsSubset(channel.json_body["params"], expected_params)
+        self.assertLessEqual(
+            channel.json_body["params"].items(), expected_params.items()
+        )
 
         # We have to complete the dummy auth stage before completing the terms stage
         request_data = {
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index c8cc841d95..fa731426cd 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -18,10 +18,8 @@ Utilities for running the unit tests
 import json
 import sys
 import warnings
-from asyncio import Future
 from binascii import unhexlify
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
-from unittest.mock import Mock
+from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, TypeVar
 
 import attr
 import zope.interface
@@ -57,27 +55,12 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
     raise Exception("awaitable has not yet completed")
 
 
-def make_awaitable(result: TV) -> Awaitable[TV]:
-    """
-    Makes an awaitable, suitable for mocking an `async` function.
-    This uses Futures as they can be awaited multiple times so can be returned
-    to multiple callers.
-    """
-    future: Future[TV] = Future()
-    future.set_result(result)
-    return future
-
-
 def setup_awaitable_errors() -> Callable[[], None]:
     """
     Convert warnings from a non-awaited coroutines into errors.
     """
     warnings.simplefilter("error", RuntimeWarning)
 
-    # unraisablehook was added in Python 3.8.
-    if not hasattr(sys, "unraisablehook"):
-        return lambda: None
-
     # State shared between unraisablehook and check_for_unraisable_exceptions.
     unraisable_exceptions = []
     orig_unraisablehook = sys.unraisablehook
@@ -100,18 +83,6 @@ def setup_awaitable_errors() -> Callable[[], None]:
     return cleanup
 
 
-def simple_async_mock(
-    return_value: Optional[TV] = None, raises: Optional[Exception] = None
-) -> Mock:
-    # AsyncMock is not available in python3.5, this mimics part of its behaviour
-    async def cb(*args: Any, **kwargs: Any) -> Optional[TV]:
-        if raises:
-            raise raises
-        return return_value
-
-    return Mock(side_effect=cb)
-
-
 # Type ignore: it does not fully implement IResponse, but is good enough for tests
 @zope.interface.implementer(IResponse)
 @attr.s(slots=True, frozen=True, auto_attribs=True)
diff --git a/tests/unittest.py b/tests/unittest.py
index b0721e060c..5d3640d8ac 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -313,7 +313,7 @@ class HomeserverTestCase(TestCase):
         servlets: List of servlet registration function.
         user_id (str): The user ID to assume if auth is hijacked.
         hijack_auth: Whether to hijack auth to return the user specified
-        in user_id.
+           in user_id.
     """
 
     hijack_auth: ClassVar[bool] = True
@@ -395,9 +395,9 @@ class HomeserverTestCase(TestCase):
                     )
 
                 # Type ignore: mypy doesn't like us assigning to methods.
-                self.hs.get_auth().get_user_by_req = get_requester  # type: ignore[assignment]
-                self.hs.get_auth().get_user_by_access_token = get_requester  # type: ignore[assignment]
-                self.hs.get_auth().get_access_token_from_request = Mock(return_value=token)  # type: ignore[assignment]
+                self.hs.get_auth().get_user_by_req = get_requester  # type: ignore[method-assign]
+                self.hs.get_auth().get_user_by_access_token = get_requester  # type: ignore[method-assign]
+                self.hs.get_auth().get_access_token_from_request = Mock(return_value=token)  # type: ignore[method-assign]
 
         if self.needs_threadpool:
             self.reactor.threadpool = ThreadPool()  # type: ignore[assignment]
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index 91cac9822a..05983ed434 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -60,11 +60,9 @@ class ObservableDeferredTest(TestCase):
         observer1.addBoth(check_called_first)
 
         # store the results
-        results: List[Optional[ObservableDeferred[int]]] = [None, None]
+        results: List[Optional[int]] = [None, None]
 
-        def check_val(
-            res: ObservableDeferred[int], idx: int
-        ) -> ObservableDeferred[int]:
+        def check_val(res: int, idx: int) -> int:
             results[idx] = res
             return res
 
@@ -93,14 +91,14 @@ class ObservableDeferredTest(TestCase):
         observer1.addBoth(check_called_first)
 
         # store the results
-        results: List[Optional[ObservableDeferred[str]]] = [None, None]
+        results: List[Optional[Failure]] = [None, None]
 
-        def check_val(res: ObservableDeferred[str], idx: int) -> None:
+        def check_failure(res: Failure, idx: int) -> None:
             results[idx] = res
             return None
 
-        observer1.addErrback(check_val, 0)
-        observer2.addErrback(check_val, 1)
+        observer1.addErrback(check_failure, 0)
+        observer2.addErrback(check_failure, 1)
 
         try:
             raise Exception("gah!")
diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py
index 3a97559bf0..8665aeb50c 100644
--- a/tests/util/test_task_scheduler.py
+++ b/tests/util/test_task_scheduler.py
@@ -22,10 +22,11 @@ from synapse.types import JsonMapping, ScheduledTask, TaskStatus
 from synapse.util import Clock
 from synapse.util.task_scheduler import TaskScheduler
 
-from tests import unittest
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.unittest import HomeserverTestCase, override_config
 
 
-class TestTaskScheduler(unittest.HomeserverTestCase):
+class TestTaskScheduler(HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.task_scheduler = hs.get_task_scheduler()
         self.task_scheduler.register_action(self._test_task, "_test_task")
@@ -34,7 +35,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
         self.task_scheduler.register_action(self._resumable_task, "_resumable_task")
 
     async def _test_task(
-        self, task: ScheduledTask, first_launch: bool
+        self, task: ScheduledTask
     ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
         # This test task will copy the parameters to the result
         result = None
@@ -77,7 +78,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
         self.assertIsNone(task)
 
     async def _sleeping_task(
-        self, task: ScheduledTask, first_launch: bool
+        self, task: ScheduledTask
     ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
         # Sleep for a second
         await deferLater(self.reactor, 1, lambda: None)
@@ -85,24 +86,18 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
 
     def test_schedule_lot_of_tasks(self) -> None:
         """Schedule more than `TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS` tasks and check the behavior."""
-        timestamp = self.clock.time_msec() + 30 * 1000
         task_ids = []
         for i in range(TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS + 1):
             task_ids.append(
                 self.get_success(
                     self.task_scheduler.schedule_task(
                         "_sleeping_task",
-                        timestamp=timestamp,
                         params={"val": i},
                     )
                 )
             )
 
-        # The timestamp being 30s after now the task should been executed
-        # after the first scheduling loop is run
-        self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
-
-        # This is to give the time to the sleeping tasks to finish
+        # This is to give the time to the active tasks to finish
         self.reactor.advance(1)
 
         # Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one
@@ -120,10 +115,11 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
         )
 
         scheduled_tasks = [
-            t for t in tasks if t is not None and t.status == TaskStatus.SCHEDULED
+            t for t in tasks if t is not None and t.status == TaskStatus.ACTIVE
         ]
         self.assertEquals(len(scheduled_tasks), 1)
 
+        # We need to wait for the next run of the scheduler loop
         self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
         self.reactor.advance(1)
 
@@ -138,7 +134,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
         )
 
     async def _raising_task(
-        self, task: ScheduledTask, first_launch: bool
+        self, task: ScheduledTask
     ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
         raise Exception("raising")
 
@@ -146,15 +142,13 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
         """Schedule a task raising an exception and check it runs to failure and report exception content."""
         task_id = self.get_success(self.task_scheduler.schedule_task("_raising_task"))
 
-        self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
-
         task = self.get_success(self.task_scheduler.get_task(task_id))
         assert task is not None
         self.assertEqual(task.status, TaskStatus.FAILED)
         self.assertEqual(task.error, "raising")
 
     async def _resumable_task(
-        self, task: ScheduledTask, first_launch: bool
+        self, task: ScheduledTask
     ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
         if task.result and "in_progress" in task.result:
             return TaskStatus.COMPLETE, {"success": True}, None
@@ -169,8 +163,6 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
         """Schedule a resumable task and check that it gets properly resumed and complete after simulating a synapse restart."""
         task_id = self.get_success(self.task_scheduler.schedule_task("_resumable_task"))
 
-        self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
-
         task = self.get_success(self.task_scheduler.get_task(task_id))
         assert task is not None
         self.assertEqual(task.status, TaskStatus.ACTIVE)
@@ -184,3 +176,33 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
         self.assertEqual(task.status, TaskStatus.COMPLETE)
         assert task.result is not None
         self.assertTrue(task.result.get("success"))
+
+
+class TestTaskSchedulerWithBackgroundWorker(BaseMultiWorkerStreamTestCase):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.task_scheduler = hs.get_task_scheduler()
+        self.task_scheduler.register_action(self._test_task, "_test_task")
+
+    async def _test_task(
+        self, task: ScheduledTask
+    ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+        return (TaskStatus.COMPLETE, None, None)
+
+    @override_config({"run_background_tasks_on": "worker1"})
+    def test_schedule_task(self) -> None:
+        """Check that a task scheduled to run now is launch right away on the background worker."""
+        bg_worker_hs = self.make_worker_hs(
+            "synapse.app.generic_worker",
+            extra_config={"worker_name": "worker1"},
+        )
+        bg_worker_hs.get_task_scheduler().register_action(self._test_task, "_test_task")
+
+        task_id = self.get_success(
+            self.task_scheduler.schedule_task(
+                "_test_task",
+            )
+        )
+
+        task = self.get_success(self.task_scheduler.get_task(task_id))
+        assert task is not None
+        self.assertEqual(task.status, TaskStatus.COMPLETE)