summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py8
-rw-r--r--tests/events/test_presence_router.py4
-rw-r--r--tests/federation/test_federation_sender.py17
-rw-r--r--tests/handlers/test_deactivate_account.py48
-rw-r--r--tests/handlers/test_password_providers.py11
-rw-r--r--tests/handlers/test_register.py7
-rw-r--r--tests/handlers/test_room_member.py4
-rw-r--r--tests/handlers/test_typing.py8
-rw-r--r--tests/module_api/test_api.py7
-rw-r--r--tests/replication/_base.py90
-rw-r--r--tests/replication/tcp/test_handler.py4
-rw-r--r--tests/replication/test_sharded_event_persister.py7
-rw-r--r--tests/rest/admin/test_event_reports.py27
-rw-r--r--tests/rest/admin/test_room.py1
-rw-r--r--tests/rest/admin/test_server_notice.py56
-rw-r--r--tests/rest/admin/test_user.py92
-rw-r--r--tests/rest/client/test_account.py10
-rw-r--r--tests/rest/client/test_models.py53
-rw-r--r--tests/rest/client/test_register.py2
-rw-r--r--tests/rest/client/test_relations.py12
-rw-r--r--tests/rest/client/test_retention.py4
-rw-r--r--tests/rest/client/test_shadow_banned.py6
-rw-r--r--tests/server.py14
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py9
-rw-r--r--tests/storage/test_roommember.py70
-rw-r--r--tests/test_metrics.py36
-rw-r--r--tests/unittest.py15
27 files changed, 477 insertions, 145 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index dfcfaf79b6..e0f363555b 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -284,10 +284,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
             TokenLookupResult(
                 user_id="@baldrick:matrix.org",
                 device_id="device",
+                token_id=5,
                 token_owner="@admin:matrix.org",
+                token_used=True,
             )
         )
         self.store.insert_client_ip = simple_async_mock(None)
+        self.store.mark_access_token_as_used = simple_async_mock(None)
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
@@ -301,10 +304,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
             TokenLookupResult(
                 user_id="@baldrick:matrix.org",
                 device_id="device",
+                token_id=5,
                 token_owner="@admin:matrix.org",
+                token_used=True,
             )
         )
         self.store.insert_client_ip = simple_async_mock(None)
+        self.store.mark_access_token_as_used = simple_async_mock(None)
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
@@ -347,7 +353,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         serialized = macaroon.serialize()
 
         user_info = self.get_success(self.auth.get_user_by_access_token(serialized))
-        self.assertEqual(user_id, user_info.user_id)
+        self.assertEqual(user_id, user_info.user.to_string())
         self.assertTrue(user_info.is_guest)
         self.store.get_user_by_id.assert_called_with(user_id)
 
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index ffc3012a86..685a9a6d52 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -141,10 +141,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
         hs = self.setup_test_homeserver(
             federation_transport_client=fed_transport_client,
         )
-        # Load the modules into the homeserver
-        module_api = hs.get_module_api()
-        for module, config in hs.config.modules.loaded_modules:
-            module(config=config, api=module_api)
 
         load_legacy_presence_router(hs)
 
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 01a1db6115..a5aa500ef8 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -173,17 +173,24 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         return c
 
     def prepare(self, reactor, clock, hs):
-        # stub out `get_rooms_for_user` and `get_users_in_room` so that the
+        test_room_id = "!room:host1"
+
+        # stub out `get_rooms_for_user` and `get_current_hosts_in_room` so that the
         # server thinks the user shares a room with `@user2:host2`
         def get_rooms_for_user(user_id):
-            return defer.succeed({"!room:host1"})
+            return defer.succeed({test_room_id})
 
         hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user
 
-        def get_users_in_room(room_id):
-            return defer.succeed({"@user2:host2"})
+        async def get_current_hosts_in_room(room_id):
+            if room_id == test_room_id:
+                return ["host2"]
+
+            # TODO: We should fail the test when we encounter an unxpected room ID.
+            # We can't just use `self.fail(...)` here because the app code is greedy
+            # with `Exception` and will catch it before the test can see it.
 
-        hs.get_datastores().main.get_users_in_room = get_users_in_room
+        hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room
 
         # whenever send_transaction is called, record the edu data
         self.edus = []
diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
index ff9f2e8edb..7b9b711521 100644
--- a/tests/handlers/test_deactivate_account.py
+++ b/tests/handlers/test_deactivate_account.py
@@ -11,11 +11,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Dict
 
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import AccountDataTypes
+from synapse.push.baserules import PushRule
 from synapse.push.rulekinds import PRIORITY_CLASS_MAP
 from synapse.rest import admin
 from synapse.rest.client import account, login
@@ -130,12 +130,12 @@ class DeactivateAccountTestCase(HomeserverTestCase):
             ),
         )
 
-    def _is_custom_rule(self, push_rule: Dict[str, Any]) -> bool:
+    def _is_custom_rule(self, push_rule: PushRule) -> bool:
         """
         Default rules start with a dot: such as .m.rule and .im.vector.
         This function returns true iff a rule is custom (not default).
         """
-        return "/." not in push_rule["rule_id"]
+        return "/." not in push_rule.rule_id
 
     def test_push_rules_deleted_upon_account_deactivation(self) -> None:
         """
@@ -157,22 +157,21 @@ class DeactivateAccountTestCase(HomeserverTestCase):
         )
 
         # Test the rule exists
-        push_rules = self.get_success(self._store.get_push_rules_for_user(self.user))
+        filtered_push_rules = self.get_success(
+            self._store.get_push_rules_for_user(self.user)
+        )
         # Filter out default rules; we don't care
-        push_rules = list(filter(self._is_custom_rule, push_rules))
+        push_rules = [r for r, _ in filtered_push_rules if self._is_custom_rule(r)]
         # Check our rule made it
         self.assertEqual(
             push_rules,
             [
-                {
-                    "user_name": "@user:test",
-                    "rule_id": "personal.override.rule1",
-                    "priority_class": 5,
-                    "priority": 0,
-                    "conditions": [],
-                    "actions": [],
-                    "default": False,
-                }
+                PushRule(
+                    rule_id="personal.override.rule1",
+                    priority_class=5,
+                    conditions=[],
+                    actions=[],
+                )
             ],
             push_rules,
         )
@@ -180,9 +179,11 @@ class DeactivateAccountTestCase(HomeserverTestCase):
         # Request the deactivation of our account
         self._deactivate_my_account()
 
-        push_rules = self.get_success(self._store.get_push_rules_for_user(self.user))
+        filtered_push_rules = self.get_success(
+            self._store.get_push_rules_for_user(self.user)
+        )
         # Filter out default rules; we don't care
-        push_rules = list(filter(self._is_custom_rule, push_rules))
+        push_rules = [r for r, _ in filtered_push_rules if self._is_custom_rule(r)]
         # Check our rule no longer exists
         self.assertEqual(push_rules, [], push_rules)
 
@@ -321,3 +322,18 @@ class DeactivateAccountTestCase(HomeserverTestCase):
                 )
             ),
         )
+
+    def test_deactivate_account_needs_auth(self) -> None:
+        """
+        Tests that making a request to /deactivate with an empty body
+        succeeds in starting the user-interactive auth flow.
+        """
+        req = self.make_request(
+            "POST",
+            "account/deactivate",
+            {},
+            access_token=self.token,
+        )
+
+        self.assertEqual(req.code, 401, req)
+        self.assertEqual(req.json_body["flows"], [{"stages": ["m.login.password"]}])
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 4c62449c89..75934b1707 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -21,7 +21,6 @@ from unittest.mock import Mock
 import synapse
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes
-from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.module_api import ModuleApi
 from synapse.rest.client import account, devices, login, logout, register
 from synapse.types import JsonDict, UserID
@@ -167,16 +166,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.reset_mock()
         super().setUp()
 
-    def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver()
-        # Load the modules into the homeserver
-        module_api = hs.get_module_api()
-        for module, config in hs.config.modules.loaded_modules:
-            module(config=config, api=module_api)
-        load_legacy_password_auth_providers(hs)
-
-        return hs
-
     @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
     def test_password_only_auth_progiver_login_legacy(self):
         self.password_only_auth_provider_login_test_body()
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 23f35d5bf5..86b3d51975 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -22,7 +22,6 @@ from synapse.api.errors import (
     ResourceLimitError,
     SynapseError,
 )
-from synapse.events.spamcheck import load_legacy_spam_checkers
 from synapse.spam_checker_api import RegistrationBehaviour
 from synapse.types import RoomAlias, RoomID, UserID, create_requester
 
@@ -144,12 +143,6 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             config=hs_config, federation_client=self.mock_federation_client
         )
 
-        load_legacy_spam_checkers(hs)
-
-        module_api = hs.get_module_api()
-        for module, config in hs.config.modules.loaded_modules:
-            module(config=config, api=module_api)
-
         return hs
 
     def prepare(self, reactor, clock, hs):
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index b4e1405aee..1d13ed1e88 100644
--- a/tests/handlers/test_room_member.py
+++ b/tests/handlers/test_room_member.py
@@ -14,7 +14,7 @@ from synapse.server import HomeServer
 from synapse.types import UserID, create_requester
 from synapse.util import Clock
 
-from tests.replication._base import RedisMultiWorkerStreamTestCase
+from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.server import make_request
 from tests.test_utils import make_awaitable
 from tests.unittest import FederatingHomeserverTestCase, override_config
@@ -216,7 +216,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
     #   - trying to remote-join again.
 
 
-class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestCase):
+class TestReplicatedJoinsLimitedByPerRoomRateLimiter(BaseMultiWorkerStreamTestCase):
     servlets = [
         synapse.rest.admin.register_servlets,
         synapse.rest.client.login.register_servlets,
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 7af1333126..8adba29d7f 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -25,7 +25,7 @@ from synapse.api.constants import EduTypes
 from synapse.api.errors import AuthError
 from synapse.federation.transport.server import TransportLayerServer
 from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID, create_requester
+from synapse.types import JsonDict, Requester, UserID, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -117,8 +117,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.room_members = []
 
-        async def check_user_in_room(room_id: str, user_id: str) -> None:
-            if user_id not in [u.to_string() for u in self.room_members]:
+        async def check_user_in_room(room_id: str, requester: Requester) -> None:
+            if requester.user.to_string() not in [
+                u.to_string() for u in self.room_members
+            ]:
                 raise AuthError(401, "User is not in the room")
             return None
 
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 106159fa65..02cef6f876 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -30,7 +30,6 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.test_utils import simple_async_mock
 from tests.test_utils.event_injection import inject_member_event
 from tests.unittest import HomeserverTestCase, override_config
-from tests.utils import USE_POSTGRES_FOR_TESTS
 
 
 class ModuleApiTestCase(HomeserverTestCase):
@@ -738,11 +737,6 @@ class ModuleApiTestCase(HomeserverTestCase):
 class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
     """For testing ModuleApi functionality in a multi-worker setup"""
 
-    # Testing stream ID replication from the main to worker processes requires postgres
-    # (due to needing `MultiWriterIdGenerator`).
-    if not USE_POSTGRES_FOR_TESTS:
-        skip = "Requires Postgres"
-
     servlets = [
         admin.register_servlets,
         login.register_servlets,
@@ -752,7 +746,6 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
 
     def default_config(self):
         conf = super().default_config()
-        conf["redis"] = {"enabled": "true"}
         conf["stream_writers"] = {"presence": ["presence_writer"]}
         conf["instance_map"] = {
             "presence_writer": {"host": "testserv", "port": 1001},
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 970d5e533b..ce53f808db 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -24,11 +24,11 @@ from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.replication.http import ReplicationRestResource
 from synapse.replication.tcp.client import ReplicationDataHandler
 from synapse.replication.tcp.handler import ReplicationCommandHandler
-from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import (
-    ReplicationStreamProtocolFactory,
+from synapse.replication.tcp.protocol import (
+    ClientReplicationStreamProtocol,
     ServerReplicationStreamProtocol,
 )
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
 from synapse.server import HomeServer
 
 from tests import unittest
@@ -220,15 +220,34 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
 class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
     """Base class for tests running multiple workers.
 
+    Enables Redis, providing a fake Redis server.
+
     Automatically handle HTTP replication requests from workers to master,
     unlike `BaseStreamTestCase`.
     """
 
+    if not hiredis:
+        skip = "Requires hiredis"
+
+    if not USE_POSTGRES_FOR_TESTS:
+        # Redis replication only takes place on Postgres
+        skip = "Requires Postgres"
+
+    def default_config(self) -> Dict[str, Any]:
+        """
+        Overrides the default config to enable Redis.
+        Even if the test only uses make_worker_hs, the main process needs Redis
+        enabled otherwise it won't create a Fake Redis server to listen on the
+        Redis port and accept fake TCP connections.
+        """
+        base = super().default_config()
+        base["redis"] = {"enabled": True}
+        return base
+
     def setUp(self):
         super().setUp()
 
         # build a replication server
-        self.server_factory = ReplicationStreamProtocolFactory(self.hs)
         self.streamer = self.hs.get_replication_streamer()
 
         # Fake in memory Redis server that servers can connect to.
@@ -247,15 +266,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # handling inbound HTTP requests to that instance.
         self._hs_to_site = {self.hs: self.site}
 
-        if self.hs.config.redis.redis_enabled:
-            # Handle attempts to connect to fake redis server.
-            self.reactor.add_tcp_client_callback(
-                "localhost",
-                6379,
-                self.connect_any_redis_attempts,
-            )
+        # Handle attempts to connect to fake redis server.
+        self.reactor.add_tcp_client_callback(
+            "localhost",
+            6379,
+            self.connect_any_redis_attempts,
+        )
 
-            self.hs.get_replication_command_handler().start_replication(self.hs)
+        self.hs.get_replication_command_handler().start_replication(self.hs)
 
         # When we see a connection attempt to the master replication listener we
         # automatically set up the connection. This is so that tests don't
@@ -339,27 +357,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         store = worker_hs.get_datastores().main
         store.db_pool._db_pool = self.database_pool._db_pool
 
-        # Set up TCP replication between master and the new worker if we don't
-        # have Redis support enabled.
-        if not worker_hs.config.redis.redis_enabled:
-            repl_handler = ReplicationCommandHandler(worker_hs)
-            client = ClientReplicationStreamProtocol(
-                worker_hs,
-                "client",
-                "test",
-                self.clock,
-                repl_handler,
-            )
-            server = self.server_factory.buildProtocol(
-                IPv4Address("TCP", "127.0.0.1", 0)
-            )
-
-            client_transport = FakeTransport(server, self.reactor)
-            client.makeConnection(client_transport)
-
-            server_transport = FakeTransport(client, self.reactor)
-            server.makeConnection(server_transport)
-
         # Set up a resource for the worker
         resource = ReplicationRestResource(worker_hs)
 
@@ -378,8 +375,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             reactor=self.reactor,
         )
 
-        if worker_hs.config.redis.redis_enabled:
-            worker_hs.get_replication_command_handler().start_replication(worker_hs)
+        worker_hs.get_replication_command_handler().start_replication(worker_hs)
 
         return worker_hs
 
@@ -582,27 +578,3 @@ class FakeRedisPubSubProtocol(Protocol):
 
     def connectionLost(self, reason):
         self._server.remove_subscriber(self)
-
-
-class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
-    """
-    A test case that enables Redis, providing a fake Redis server.
-    """
-
-    if not hiredis:
-        skip = "Requires hiredis"
-
-    if not USE_POSTGRES_FOR_TESTS:
-        # Redis replication only takes place on Postgres
-        skip = "Requires Postgres"
-
-    def default_config(self) -> Dict[str, Any]:
-        """
-        Overrides the default config to enable Redis.
-        Even if the test only uses make_worker_hs, the main process needs Redis
-        enabled otherwise it won't create a Fake Redis server to listen on the
-        Redis port and accept fake TCP connections.
-        """
-        base = super().default_config()
-        base["redis"] = {"enabled": True}
-        return base
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index e6a19eafd5..1e299d2d67 100644
--- a/tests/replication/tcp/test_handler.py
+++ b/tests/replication/tcp/test_handler.py
@@ -12,10 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from tests.replication._base import RedisMultiWorkerStreamTestCase
+from tests.replication._base import BaseMultiWorkerStreamTestCase
 
 
-class ChannelsTestCase(RedisMultiWorkerStreamTestCase):
+class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
     def test_subscribed_to_enough_redis_channels(self) -> None:
         # The default main process is subscribed to the USER_IP channel.
         self.assertCountEqual(
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index a7ca68069e..541d390286 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -20,7 +20,6 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.server import make_request
-from tests.utils import USE_POSTGRES_FOR_TESTS
 
 logger = logging.getLogger(__name__)
 
@@ -28,11 +27,6 @@ logger = logging.getLogger(__name__)
 class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
     """Checks event persisting sharding works"""
 
-    # Event persister sharding requires postgres (due to needing
-    # `MultiWriterIdGenerator`).
-    if not USE_POSTGRES_FOR_TESTS:
-        skip = "Requires Postgres"
-
     servlets = [
         admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
@@ -50,7 +44,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
 
     def default_config(self):
         conf = super().default_config()
-        conf["redis"] = {"enabled": "true"}
         conf["stream_writers"] = {"events": ["worker1", "worker2"]}
         conf["instance_map"] = {
             "worker1": {"host": "testserv", "port": 1001},
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index fbc490f46d..8a4e5c3f77 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -410,6 +410,33 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
             self.assertIn("score", c)
             self.assertIn("reason", c)
 
+    def test_count_correct_despite_table_deletions(self) -> None:
+        """
+        Tests that the count matches the number of rows, even if rows in joined tables
+        are missing.
+        """
+
+        # Delete rows from room_stats_state for one of our rooms.
+        self.get_success(
+            self.hs.get_datastores().main.db_pool.simple_delete(
+                "room_stats_state", {"room_id": self.room_id1}, desc="_"
+            )
+        )
+
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        # The 'total' field is 10 because only 10 reports will actually
+        # be retrievable since we deleted the rows in the room_stats_state
+        # table.
+        self.assertEqual(channel.json_body["total"], 10)
+        # This is consistent with the number of rows actually returned.
+        self.assertEqual(len(channel.json_body["event_reports"]), 10)
+
 
 class EventReportDetailTestCase(unittest.HomeserverTestCase):
     servlets = [
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index dd5000679a..fd6da557c1 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1633,6 +1633,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
         self.assertIn("history_visibility", channel.json_body)
         self.assertIn("state_events", channel.json_body)
         self.assertIn("room_type", channel.json_body)
+        self.assertIn("forgotten", channel.json_body)
         self.assertEqual(room_id_1, channel.json_body["room_id"])
 
     def test_single_room_devices(self) -> None:
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index 81e125e27d..a2f347f666 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -159,6 +159,62 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
         self.assertEqual("'msgtype' not in content", channel.json_body["error"])
 
+    @override_config(
+        {
+            "server_notices": {
+                "system_mxid_localpart": "notices",
+                "system_mxid_avatar_url": "somthingwrong",
+            },
+            "max_avatar_size": "10M",
+        }
+    )
+    def test_invalid_avatar_url(self) -> None:
+        """If avatar url in homeserver.yaml is invalid and
+        "check avatar size and mime type" is set, an error is returned.
+        TODO: Should be checked when reading the configuration."""
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={
+                "user_id": self.other_user,
+                "content": {"msgtype": "m.text", "body": "test msg"},
+            },
+        )
+
+        self.assertEqual(500, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+    @override_config(
+        {
+            "server_notices": {
+                "system_mxid_localpart": "notices",
+                "system_mxid_display_name": "test display name",
+                "system_mxid_avatar_url": None,
+            },
+            "max_avatar_size": "10M",
+        }
+    )
+    def test_displayname_is_set_avatar_is_none(self) -> None:
+        """
+        Tests that sending a server notices is successfully,
+        if a display_name is set, avatar_url is `None` and
+        "check avatar size and mime type" is set.
+        """
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={
+                "user_id": self.other_user,
+                "content": {"msgtype": "m.text", "body": "test msg"},
+            },
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+
+        # user has one invite
+        self._check_invite_and_join_status(self.other_user, 1, 0)
+
     def test_server_notice_disabled(self) -> None:
         """Tests that server returns error if server notice is disabled"""
         channel = self.make_request(
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 411e4ec005..1afd082707 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1,4 +1,4 @@
-# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2018-2022 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -904,6 +904,96 @@ class UsersListTestCase(unittest.HomeserverTestCase):
             )
 
 
+class UserDevicesTestCase(unittest.HomeserverTestCase):
+    """
+    Tests user device management-related Admin APIs.
+    """
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        sync.register_servlets,
+    ]
+
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
+        # Set up an Admin user to query the Admin API with.
+        self.admin_user_id = self.register_user("admin", "pass", admin=True)
+        self.admin_user_token = self.login("admin", "pass")
+
+        # Set up a test user to query the devices of.
+        self.other_user_device_id = "TESTDEVICEID"
+        self.other_user_device_display_name = "My Test Device"
+        self.other_user_client_ip = "1.2.3.4"
+        self.other_user_user_agent = "EquestriaTechnology/123.0"
+
+        self.other_user_id = self.register_user("user", "pass", displayname="User1")
+        self.other_user_token = self.login(
+            "user",
+            "pass",
+            device_id=self.other_user_device_id,
+            additional_request_fields={
+                "initial_device_display_name": self.other_user_device_display_name,
+            },
+        )
+
+        # Have the "other user" make a request so that the "last_seen_*" fields are
+        # populated in the tests below.
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v3/sync",
+            access_token=self.other_user_token,
+            client_ip=self.other_user_client_ip,
+            custom_headers=[
+                ("User-Agent", self.other_user_user_agent),
+            ],
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+
+    def test_list_user_devices(self) -> None:
+        """
+        Tests that a user's devices and attributes are listed correctly via the Admin API.
+        """
+        # Request all devices of "other user"
+        channel = self.make_request(
+            "GET",
+            f"/_synapse/admin/v2/users/{self.other_user_id}/devices",
+            access_token=self.admin_user_token,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+
+        # Double-check we got the single device expected
+        user_devices = channel.json_body["devices"]
+        self.assertEqual(len(user_devices), 1)
+        self.assertEqual(channel.json_body["total"], 1)
+
+        # Check that all the attributes of the device reported are as expected.
+        self._validate_attributes_of_device_response(user_devices[0])
+
+        # Request just a single device for "other user" by its ID
+        channel = self.make_request(
+            "GET",
+            f"/_synapse/admin/v2/users/{self.other_user_id}/devices/"
+            f"{self.other_user_device_id}",
+            access_token=self.admin_user_token,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+
+        # Check that all the attributes of the device reported are as expected.
+        self._validate_attributes_of_device_response(channel.json_body)
+
+    def _validate_attributes_of_device_response(self, response: JsonDict) -> None:
+        # Check that all device expected attributes are present
+        self.assertEqual(response["user_id"], self.other_user_id)
+        self.assertEqual(response["device_id"], self.other_user_device_id)
+        self.assertEqual(response["display_name"], self.other_user_device_display_name)
+        self.assertEqual(response["last_seen_ip"], self.other_user_client_ip)
+        self.assertEqual(response["last_seen_user_agent"], self.other_user_user_agent)
+        self.assertIsInstance(response["last_seen_ts"], int)
+        self.assertGreater(response["last_seen_ts"], 0)
+
+
 class DeactivateAccountTestCase(unittest.HomeserverTestCase):
 
     servlets = [
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 7ae926dc9c..c1a7fb2f8a 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -488,7 +488,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST", "account/deactivate", request_data, access_token=tok
         )
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, 200, channel.json_body)
 
 
 class WhoamiTestCase(unittest.HomeserverTestCase):
@@ -641,21 +641,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
     def test_add_email_no_at(self) -> None:
         self._request_token_invalid_email(
             "address-without-at.bar",
-            expected_errcode=Codes.UNKNOWN,
+            expected_errcode=Codes.BAD_JSON,
             expected_error="Unable to parse email address",
         )
 
     def test_add_email_two_at(self) -> None:
         self._request_token_invalid_email(
             "foo@foo@test.bar",
-            expected_errcode=Codes.UNKNOWN,
+            expected_errcode=Codes.BAD_JSON,
             expected_error="Unable to parse email address",
         )
 
     def test_add_email_bad_format(self) -> None:
         self._request_token_invalid_email(
             "user@bad.example.net@good.example.com",
-            expected_errcode=Codes.UNKNOWN,
+            expected_errcode=Codes.BAD_JSON,
             expected_error="Unable to parse email address",
         )
 
@@ -1001,7 +1001,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
         )
         self.assertEqual(expected_errcode, channel.json_body["errcode"])
-        self.assertEqual(expected_error, channel.json_body["error"])
+        self.assertIn(expected_error, channel.json_body["error"])
 
     def _validate_token(self, link: str) -> None:
         # Remove the host
diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py
new file mode 100644
index 0000000000..a9da00665e
--- /dev/null
+++ b/tests/rest/client/test_models.py
@@ -0,0 +1,53 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+
+from pydantic import ValidationError
+
+from synapse.rest.client.models import EmailRequestTokenBody
+
+
+class EmailRequestTokenBodyTestCase(unittest.TestCase):
+    base_request = {
+        "client_secret": "hunter2",
+        "email": "alice@wonderland.com",
+        "send_attempt": 1,
+    }
+
+    def test_token_required_if_id_server_provided(self) -> None:
+        with self.assertRaises(ValidationError):
+            EmailRequestTokenBody.parse_obj(
+                {
+                    **self.base_request,
+                    "id_server": "identity.wonderland.com",
+                }
+            )
+        with self.assertRaises(ValidationError):
+            EmailRequestTokenBody.parse_obj(
+                {
+                    **self.base_request,
+                    "id_server": "identity.wonderland.com",
+                    "id_access_token": None,
+                }
+            )
+
+    def test_token_typechecked_when_id_server_provided(self) -> None:
+        with self.assertRaises(ValidationError):
+            EmailRequestTokenBody.parse_obj(
+                {
+                    **self.base_request,
+                    "id_server": "identity.wonderland.com",
+                    "id_access_token": 1337,
+                }
+            )
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index ab4277dd31..b781875d52 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -586,9 +586,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
                 "require_at_registration": True,
             },
             "account_threepid_delegates": {
-                "email": "https://id_server",
                 "msisdn": "https://id_server",
             },
+            "email": {"notif_from": "Synapse <synapse@example.com>"},
         }
     )
     def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index d589f07314..651f4f415d 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -999,7 +999,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                 bundled_aggregations,
             )
 
-        self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 6)
+        self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
 
     def test_annotation_to_annotation(self) -> None:
         """Any relation to an annotation should be ignored."""
@@ -1035,7 +1035,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                 bundled_aggregations,
             )
 
-        self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6)
+        self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
 
     def test_thread(self) -> None:
         """
@@ -1080,21 +1080,21 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
 
         # The "user" sent the root event and is making queries for the bundled
         # aggregations: they have participated.
-        self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
+        self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9)
         # The "user2" sent replies in the thread and is making queries for the
         # bundled aggregations: they have participated.
         #
         # Note that this re-uses some cached values, so the total number of
         # queries is much smaller.
         self._test_bundled_aggregations(
-            RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
+            RelationTypes.THREAD, _gen_assert(True), 3, access_token=self.user2_token
         )
 
         # A user with no interactions with the thread: they have not participated.
         user3_id, user3_token = self._create_user("charlie")
         self.helper.join(self.room, user=user3_id, tok=user3_token)
         self._test_bundled_aggregations(
-            RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
+            RelationTypes.THREAD, _gen_assert(False), 3, access_token=user3_token
         )
 
     def test_thread_with_bundled_aggregations_for_latest(self) -> None:
@@ -1142,7 +1142,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                 bundled_aggregations["latest_event"].get("unsigned"),
             )
 
-        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
+        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
 
     def test_nested_thread(self) -> None:
         """
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index ac9c113354..9c8c1889d3 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
 from synapse.visibility import filter_events_for_client
 
@@ -188,7 +188,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         message_handler = self.hs.get_message_handler()
         create_event = self.get_success(
             message_handler.get_room_data(
-                self.user_id, room_id, EventTypes.Create, state_key=""
+                create_requester(self.user_id), room_id, EventTypes.Create, state_key=""
             )
         )
 
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index d9bd8c4a28..c50f034b34 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -26,7 +26,7 @@ from synapse.rest.client import (
     room_upgrade_rest_servlet,
 )
 from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -275,7 +275,7 @@ class ProfileTestCase(_ShadowBannedBase):
         message_handler = self.hs.get_message_handler()
         event = self.get_success(
             message_handler.get_room_data(
-                self.banned_user_id,
+                create_requester(self.banned_user_id),
                 room_id,
                 "m.room.member",
                 self.banned_user_id,
@@ -310,7 +310,7 @@ class ProfileTestCase(_ShadowBannedBase):
         message_handler = self.hs.get_message_handler()
         event = self.get_success(
             message_handler.get_room_data(
-                self.banned_user_id,
+                create_requester(self.banned_user_id),
                 room_id,
                 "m.room.member",
                 self.banned_user_id,
diff --git a/tests/server.py b/tests/server.py
index 9689e6a0cd..c447d5e4c4 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -61,6 +61,10 @@ from twisted.web.resource import IResource
 from twisted.web.server import Request, Site
 
 from synapse.config.database import DatabaseConnectionConfig
+from synapse.events.presence_router import load_legacy_presence_router
+from synapse.events.spamcheck import load_legacy_spam_checkers
+from synapse.events.third_party_rules import load_legacy_third_party_event_rules
+from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import ContextResourceUsage
 from synapse.server import HomeServer
@@ -913,4 +917,14 @@ def setup_test_homeserver(
     # Make the threadpool and database transactions synchronous for testing.
     _make_test_homeserver_synchronous(hs)
 
+    # Load any configured modules into the homeserver
+    module_api = hs.get_module_api()
+    for module, config in hs.config.modules.loaded_modules:
+        module(config=config, api=module_api)
+
+    load_legacy_spam_checkers(hs)
+    load_legacy_third_party_event_rules(hs)
+    load_legacy_presence_router(hs)
+    load_legacy_password_auth_providers(hs)
+
     return hs
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index e07ae78fc4..bf403045e9 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -11,16 +11,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 from unittest.mock import Mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType
 from synapse.api.errors import ResourceLimitError
 from synapse.rest import admin
 from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
 from synapse.server_notices.resource_limits_server_notices import (
     ResourceLimitsServerNotices,
 )
+from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable
@@ -52,7 +55,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
 
         return config
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.server_notices_sender = self.hs.get_server_notices_sender()
 
         # relying on [1] is far from ideal, but the only case where
@@ -251,7 +254,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
         c["admin_contact"] = "mailto:user@test.com"
         return c
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = self.hs.get_datastores().main
         self.server_notices_sender = self.hs.get_server_notices_sender()
         self.server_notices_manager = self.hs.get_server_notices_manager()
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 240b02cb9f..ceec690285 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -23,6 +23,7 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.server import TestHomeServer
+from tests.test_utils import event_injection
 
 
 class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
@@ -157,6 +158,75 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         # Check that alice's display name is now None
         self.assertEqual(row[0]["display_name"], None)
 
+    def test_room_is_locally_forgotten(self):
+        """Test that when the last local user has forgotten a room it is known as forgotten."""
+        # join two local and one remote user
+        self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+        self.get_success(
+            event_injection.inject_member_event(self.hs, self.room, self.u_bob, "join")
+        )
+        self.get_success(
+            event_injection.inject_member_event(
+                self.hs, self.room, self.u_charlie.to_string(), "join"
+            )
+        )
+        self.assertFalse(
+            self.get_success(self.store.is_locally_forgotten_room(self.room))
+        )
+
+        # local users leave the room and the room is not forgotten
+        self.get_success(
+            event_injection.inject_member_event(
+                self.hs, self.room, self.u_alice, "leave"
+            )
+        )
+        self.get_success(
+            event_injection.inject_member_event(self.hs, self.room, self.u_bob, "leave")
+        )
+        self.assertFalse(
+            self.get_success(self.store.is_locally_forgotten_room(self.room))
+        )
+
+        # first user forgets the room, room is not forgotten
+        self.get_success(self.store.forget(self.u_alice, self.room))
+        self.assertFalse(
+            self.get_success(self.store.is_locally_forgotten_room(self.room))
+        )
+
+        # second (last local) user forgets the room and the room is forgotten
+        self.get_success(self.store.forget(self.u_bob, self.room))
+        self.assertTrue(
+            self.get_success(self.store.is_locally_forgotten_room(self.room))
+        )
+
+    def test_join_locally_forgotten_room(self):
+        """Tests if a user joins a forgotten room the room is not forgotten anymore."""
+        self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+        self.assertFalse(
+            self.get_success(self.store.is_locally_forgotten_room(self.room))
+        )
+
+        # after leaving and forget the room, it is forgotten
+        self.get_success(
+            event_injection.inject_member_event(
+                self.hs, self.room, self.u_alice, "leave"
+            )
+        )
+        self.get_success(self.store.forget(self.u_alice, self.room))
+        self.assertTrue(
+            self.get_success(self.store.is_locally_forgotten_room(self.room))
+        )
+
+        # after rejoin the room is not forgotten anymore
+        self.get_success(
+            event_injection.inject_member_event(
+                self.hs, self.room, self.u_alice, "join"
+            )
+        )
+        self.assertFalse(
+            self.get_success(self.store.is_locally_forgotten_room(self.room))
+        )
+
 
 class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index b4574b2ffe..1a70eddc9b 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -12,7 +12,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+try:
+    from importlib import metadata
+except ImportError:
+    import importlib_metadata as metadata  # type: ignore[no-redef]
 
+from unittest.mock import patch
+
+from pkg_resources import parse_version
+
+from synapse.app._base import _set_prometheus_client_use_created_metrics
 from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
 from synapse.util.caches.deferred_cache import DeferredCache
 
@@ -162,3 +171,30 @@ class CacheMetricsTests(unittest.HomeserverTestCase):
 
         self.assertEqual(items["synapse_util_caches_cache_size"], "1.0")
         self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0")
+
+
+class PrometheusMetricsHackTestCase(unittest.HomeserverTestCase):
+    if parse_version(metadata.version("prometheus_client")) < parse_version("0.14.0"):
+        skip = "prometheus-client too old"
+
+    def test_created_metrics_disabled(self) -> None:
+        """
+        Tests that a brittle hack, to disable `_created` metrics, works.
+        This involves poking at the internals of prometheus-client.
+        It's not the end of the world if this doesn't work.
+
+        This test gives us a way to notice if prometheus-client changes
+        their internals.
+        """
+        import prometheus_client.metrics
+
+        PRIVATE_FLAG_NAME = "_use_created"
+
+        # By default, the pesky `_created` metrics are enabled.
+        # Check this assumption is still valid.
+        self.assertTrue(getattr(prometheus_client.metrics, PRIVATE_FLAG_NAME))
+
+        with patch("prometheus_client.metrics") as mock:
+            setattr(mock, PRIVATE_FLAG_NAME, True)
+            _set_prometheus_client_use_created_metrics(False)
+            self.assertFalse(getattr(mock, PRIVATE_FLAG_NAME, False))
diff --git a/tests/unittest.py b/tests/unittest.py
index bec4a3d023..975b0a23a7 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -677,14 +677,29 @@ class HomeserverTestCase(TestCase):
         username: str,
         password: str,
         device_id: Optional[str] = None,
+        additional_request_fields: Optional[Dict[str, str]] = None,
         custom_headers: Optional[Iterable[CustomHeaderType]] = None,
     ) -> str:
         """
         Log in a user, and get an access token. Requires the Login API be registered.
+
+        Args:
+            username: The localpart to assign to the new user.
+            password: The password to assign to the new user.
+            device_id: An optional device ID to assign to the new device created during
+                login.
+            additional_request_fields: A dictionary containing any additional /login
+                request fields and their values.
+            custom_headers: Custom HTTP headers and values to add to the /login request.
+
+        Returns:
+            The newly registered user's Matrix ID.
         """
         body = {"type": "m.login.password", "user": username, "password": password}
         if device_id:
             body["device_id"] = device_id
+        if additional_request_fields:
+            body.update(additional_request_fields)
 
         channel = self.make_request(
             "POST",