summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/app/test_phone_stats_home.py395
-rw-r--r--tests/config/test_load.py4
-rw-r--r--tests/events/test_presence_router.py14
-rw-r--r--tests/handlers/test_e2e_keys.py20
-rw-r--r--tests/handlers/test_profile.py2
-rw-r--r--tests/handlers/test_space_summary.py391
-rw-r--r--tests/handlers/test_stats.py203
-rw-r--r--tests/handlers/test_typing.py37
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py2
-rw-r--r--tests/http/test_fedclient.py8
-rw-r--r--tests/http/test_proxyagent.py65
-rw-r--r--tests/module_api/test_api.py16
-rw-r--r--tests/replication/_base.py18
-rw-r--r--tests/replication/tcp/streams/test_events.py14
-rw-r--r--tests/replication/tcp/streams/test_receipts.py4
-rw-r--r--tests/replication/tcp/streams/test_typing.py4
-rw-r--r--tests/replication/test_multi_media_repo.py6
-rw-r--r--tests/replication/test_sharded_event_persister.py6
-rw-r--r--tests/rest/admin/test_admin.py6
-rw-r--r--tests/rest/admin/test_room.py21
-rw-r--r--tests/rest/client/test_third_party_rules.py136
-rw-r--r--tests/rest/client/v1/test_login.py14
-rw-r--r--tests/rest/client/v1/test_rooms.py14
-rw-r--r--tests/rest/client/v1/utils.py30
-rw-r--r--tests/rest/client/v2_alpha/test_relations.py2
-rw-r--r--tests/rest/client/v2_alpha/test_report_event.py2
-rw-r--r--tests/rest/media/v1/test_media_storage.py2
-rw-r--r--tests/server.py8
-rw-r--r--tests/storage/databases/main/test_lock.py13
-rw-r--r--tests/storage/test_background_update.py4
-rw-r--r--tests/storage/test_directory.py2
-rw-r--r--tests/storage/test_id_generators.py6
-rw-r--r--tests/storage/test_profile.py12
-rw-r--r--tests/storage/test_purge.py2
-rw-r--r--tests/storage/test_room.py2
-rw-r--r--tests/test_event_auth.py23
-rw-r--r--tests/test_state.py3
-rw-r--r--tests/test_types.py4
-rw-r--r--tests/test_utils/html_parsers.py6
-rw-r--r--tests/unittest.py19
-rw-r--r--tests/util/caches/test_descriptors.py2
-rw-r--r--tests/util/test_itertools.py18
42 files changed, 1174 insertions, 386 deletions
diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py
new file mode 100644
index 0000000000..5527e278db
--- /dev/null
+++ b/tests/app/test_phone_stats_home.py
@@ -0,0 +1,395 @@
+import synapse
+from synapse.app.phone_stats_home import start_phone_stats_home
+from synapse.rest.client.v1 import login, room
+
+from tests import unittest
+from tests.unittest import HomeserverTestCase
+
+FIVE_MINUTES_IN_SECONDS = 300
+ONE_DAY_IN_SECONDS = 86400
+
+
+class PhoneHomeTestCase(HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    # Override the retention time for the user_ips table because otherwise it
+    # gets pruned too aggressively for our R30 test.
+    @unittest.override_config({"user_ips_max_age": "365d"})
+    def test_r30_minimum_usage(self):
+        """
+        Tests the minimum amount of interaction necessary for the R30 metric
+        to consider a user 'retained'.
+        """
+
+        # Register a user, log it in, create a room and send a message
+        user_id = self.register_user("u1", "secret!")
+        access_token = self.login("u1", "secret!")
+        room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token)
+        self.helper.send(room_id, "message", tok=access_token)
+
+        # Check the R30 results do not count that user.
+        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        self.assertEqual(r30_results, {"all": 0})
+
+        # Advance 30 days (+ 1 second, because strict inequality causes issues if we are
+        # bang on 30 days later).
+        self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1)
+
+        # (Make sure the user isn't somehow counted by this point.)
+        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        self.assertEqual(r30_results, {"all": 0})
+
+        # Send a message (this counts as activity)
+        self.helper.send(room_id, "message2", tok=access_token)
+
+        # We have to wait some time for _update_client_ips_batch to get
+        # called and update the user_ips table.
+        self.reactor.advance(2 * 60 * 60)
+
+        # *Now* the user is counted.
+        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        self.assertEqual(r30_results, {"all": 1, "unknown": 1})
+
+        # Advance 29 days. The user has now not posted for 29 days.
+        self.reactor.advance(29 * ONE_DAY_IN_SECONDS)
+
+        # The user is still counted.
+        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        self.assertEqual(r30_results, {"all": 1, "unknown": 1})
+
+        # Advance another day. The user has now not posted for 30 days.
+        self.reactor.advance(ONE_DAY_IN_SECONDS)
+
+        # The user is now no longer counted in R30.
+        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        self.assertEqual(r30_results, {"all": 0})
+
+    def test_r30_minimum_usage_using_default_config(self):
+        """
+        Tests the minimum amount of interaction necessary for the R30 metric
+        to consider a user 'retained'.
+
+        N.B. This test does not override the `user_ips_max_age` config setting,
+        which defaults to 28 days.
+        """
+
+        # Register a user, log it in, create a room and send a message
+        user_id = self.register_user("u1", "secret!")
+        access_token = self.login("u1", "secret!")
+        room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token)
+        self.helper.send(room_id, "message", tok=access_token)
+
+        # Check the R30 results do not count that user.
+        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        self.assertEqual(r30_results, {"all": 0})
+
+        # Advance 30 days (+ 1 second, because strict inequality causes issues if we are
+        # bang on 30 days later).
+        self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1)
+
+        # (Make sure the user isn't somehow counted by this point.)
+        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        self.assertEqual(r30_results, {"all": 0})
+
+        # Send a message (this counts as activity)
+        self.helper.send(room_id, "message2", tok=access_token)
+
+        # We have to wait some time for _update_client_ips_batch to get
+        # called and update the user_ips table.
+        self.reactor.advance(2 * 60 * 60)
+
+        # *Now* the user is counted.
+        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        self.assertEqual(r30_results, {"all": 1, "unknown": 1})
+
+        # Advance 27 days. The user has now not posted for 27 days.
+        self.reactor.advance(27 * ONE_DAY_IN_SECONDS)
+
+        # The user is still counted.
+        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        self.assertEqual(r30_results, {"all": 1, "unknown": 1})
+
+        # Advance another day. The user has now not posted for 28 days.
+        self.reactor.advance(ONE_DAY_IN_SECONDS)
+
+        # The user is now no longer counted in R30.
+        # (This is because the user_ips table has been pruned, which by default
+        # only preserves the last 28 days of entries.)
+        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        self.assertEqual(r30_results, {"all": 0})
+
+    def test_r30_user_must_be_retained_for_at_least_a_month(self):
+        """
+        Tests that a newly-registered user must be retained for a whole month
+        before appearing in the R30 statistic, even if they post every day
+        during that time!
+        """
+        # Register a user and send a message
+        user_id = self.register_user("u1", "secret!")
+        access_token = self.login("u1", "secret!")
+        room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token)
+        self.helper.send(room_id, "message", tok=access_token)
+
+        # Check the user does not contribute to R30 yet.
+        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        self.assertEqual(r30_results, {"all": 0})
+
+        for _ in range(30):
+            # This loop posts a message every day for 30 days
+            self.reactor.advance(ONE_DAY_IN_SECONDS)
+            self.helper.send(room_id, "I'm still here", tok=access_token)
+
+            # Notice that the user *still* does not contribute to R30!
+            r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+            self.assertEqual(r30_results, {"all": 0})
+
+        self.reactor.advance(ONE_DAY_IN_SECONDS)
+        self.helper.send(room_id, "Still here!", tok=access_token)
+
+        # *Now* the user appears in R30.
+        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        self.assertEqual(r30_results, {"all": 1, "unknown": 1})
+
+
+class PhoneHomeR30V2TestCase(HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def _advance_to(self, desired_time_secs: float):
+        now = self.hs.get_clock().time()
+        assert now < desired_time_secs
+        self.reactor.advance(desired_time_secs - now)
+
+    def make_homeserver(self, reactor, clock):
+        hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock)
+
+        # We don't want our tests to actually report statistics, so check
+        # that it's not enabled
+        assert not hs.config.report_stats
+
+        # This starts the needed data collection that we rely on to calculate
+        # R30v2 metrics.
+        start_phone_stats_home(hs)
+        return hs
+
+    def test_r30v2_minimum_usage(self):
+        """
+        Tests the minimum amount of interaction necessary for the R30v2 metric
+        to consider a user 'retained'.
+        """
+
+        # Register a user, log it in, create a room and send a message
+        user_id = self.register_user("u1", "secret!")
+        access_token = self.login("u1", "secret!")
+        room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token)
+        self.helper.send(room_id, "message", tok=access_token)
+        first_post_at = self.hs.get_clock().time()
+
+        # Give time for user_daily_visits table to be updated.
+        # (user_daily_visits is updated every 5 minutes using a looping call.)
+        self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+        store = self.hs.get_datastore()
+
+        # Check the R30 results do not count that user.
+        r30_results = self.get_success(store.count_r30v2_users())
+        self.assertEqual(
+            r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+        )
+
+        # Advance 31 days.
+        # (R30v2 includes users with **more** than 30 days between the two visits,
+        #  and user_daily_visits records the timestamp as the start of the day.)
+        self.reactor.advance(31 * ONE_DAY_IN_SECONDS)
+        # Also advance 5 minutes to let another user_daily_visits update occur
+        self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+        # (Make sure the user isn't somehow counted by this point.)
+        r30_results = self.get_success(store.count_r30v2_users())
+        self.assertEqual(
+            r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+        )
+
+        # Send a message (this counts as activity)
+        self.helper.send(room_id, "message2", tok=access_token)
+
+        # We have to wait a few minutes for the user_daily_visits table to
+        # be updated by a background process.
+        self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+        # *Now* the user is counted.
+        r30_results = self.get_success(store.count_r30v2_users())
+        self.assertEqual(
+            r30_results, {"all": 1, "android": 0, "electron": 0, "ios": 0, "web": 0}
+        )
+
+        # Advance to JUST under 60 days after the user's first post
+        self._advance_to(first_post_at + 60 * ONE_DAY_IN_SECONDS - 5)
+
+        # Check the user is still counted.
+        r30_results = self.get_success(store.count_r30v2_users())
+        self.assertEqual(
+            r30_results, {"all": 1, "android": 0, "electron": 0, "ios": 0, "web": 0}
+        )
+
+        # Advance into the next day. The user's first activity is now more than 60 days old.
+        self._advance_to(first_post_at + 60 * ONE_DAY_IN_SECONDS + 5)
+
+        # Check the user is now no longer counted in R30.
+        r30_results = self.get_success(store.count_r30v2_users())
+        self.assertEqual(
+            r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+        )
+
+    def test_r30v2_user_must_be_retained_for_at_least_a_month(self):
+        """
+        Tests that a newly-registered user must be retained for a whole month
+        before appearing in the R30v2 statistic, even if they post every day
+        during that time!
+        """
+
+        # set a custom user-agent to impersonate Element/Android.
+        headers = (
+            (
+                "User-Agent",
+                "Element/1.1 (Linux; U; Android 9; MatrixAndroidSDK_X 0.0.1)",
+            ),
+        )
+
+        # Register a user and send a message
+        user_id = self.register_user("u1", "secret!")
+        access_token = self.login("u1", "secret!", custom_headers=headers)
+        room_id = self.helper.create_room_as(
+            room_creator=user_id, tok=access_token, custom_headers=headers
+        )
+        self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
+
+        # Give time for user_daily_visits table to be updated.
+        # (user_daily_visits is updated every 5 minutes using a looping call.)
+        self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+        store = self.hs.get_datastore()
+
+        # Check the user does not contribute to R30 yet.
+        r30_results = self.get_success(store.count_r30v2_users())
+        self.assertEqual(
+            r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+        )
+
+        for _ in range(30):
+            # This loop posts a message every day for 30 days
+            self.reactor.advance(ONE_DAY_IN_SECONDS - FIVE_MINUTES_IN_SECONDS)
+            self.helper.send(
+                room_id, "I'm still here", tok=access_token, custom_headers=headers
+            )
+
+            # give time for user_daily_visits to update
+            self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+            # Notice that the user *still* does not contribute to R30!
+            r30_results = self.get_success(store.count_r30v2_users())
+            self.assertEqual(
+                r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+            )
+
+        # advance yet another day with more activity
+        self.reactor.advance(ONE_DAY_IN_SECONDS)
+        self.helper.send(
+            room_id, "Still here!", tok=access_token, custom_headers=headers
+        )
+
+        # give time for user_daily_visits to update
+        self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+        # *Now* the user appears in R30.
+        r30_results = self.get_success(store.count_r30v2_users())
+        self.assertEqual(
+            r30_results, {"all": 1, "android": 1, "electron": 0, "ios": 0, "web": 0}
+        )
+
+    def test_r30v2_returning_dormant_users_not_counted(self):
+        """
+        Tests that dormant users (users inactive for a long time) do not
+        contribute to R30v2 when they return for just a single day.
+        This is a key difference between R30 and R30v2.
+        """
+
+        # set a custom user-agent to impersonate Element/iOS.
+        headers = (
+            (
+                "User-Agent",
+                "Riot/1.4 (iPhone; iOS 13; Scale/4.00)",
+            ),
+        )
+
+        # Register a user and send a message
+        user_id = self.register_user("u1", "secret!")
+        access_token = self.login("u1", "secret!", custom_headers=headers)
+        room_id = self.helper.create_room_as(
+            room_creator=user_id, tok=access_token, custom_headers=headers
+        )
+        self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
+
+        # the user goes inactive for 2 months
+        self.reactor.advance(60 * ONE_DAY_IN_SECONDS)
+
+        # the user returns for one day, perhaps just to check out a new feature
+        self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
+
+        # Give time for user_daily_visits table to be updated.
+        # (user_daily_visits is updated every 5 minutes using a looping call.)
+        self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+        store = self.hs.get_datastore()
+
+        # Check that the user does not contribute to R30v2, even though it's been
+        # more than 30 days since registration.
+        r30_results = self.get_success(store.count_r30v2_users())
+        self.assertEqual(
+            r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+        )
+
+        # Check that this is a situation where old R30 differs:
+        # old R30 DOES count this as 'retained'.
+        r30_results = self.get_success(store.count_r30_users())
+        self.assertEqual(r30_results, {"all": 1, "ios": 1})
+
+        # Now we want to check that the user will still be able to appear in
+        # R30v2 as long as the user performs some other activity between
+        # 30 and 60 days later.
+        self.reactor.advance(32 * ONE_DAY_IN_SECONDS)
+        self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
+
+        # (give time for tables to update)
+        self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+        # Check the user now satisfies the requirements to appear in R30v2.
+        r30_results = self.get_success(store.count_r30v2_users())
+        self.assertEqual(
+            r30_results, {"all": 1, "ios": 1, "android": 0, "electron": 0, "web": 0}
+        )
+
+        # Advance to 59.5 days after the user's first R30v2-eligible activity.
+        self.reactor.advance(27.5 * ONE_DAY_IN_SECONDS)
+
+        # Check the user still appears in R30v2.
+        r30_results = self.get_success(store.count_r30v2_users())
+        self.assertEqual(
+            r30_results, {"all": 1, "ios": 1, "android": 0, "electron": 0, "web": 0}
+        )
+
+        # Advance to 60.5 days after the user's first R30v2-eligible activity.
+        self.reactor.advance(ONE_DAY_IN_SECONDS)
+
+        # Check the user no longer appears in R30v2.
+        r30_results = self.get_success(store.count_r30v2_users())
+        self.assertEqual(
+            r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+        )
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index ebe2c05165..903c69127d 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -43,7 +43,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
     def test_generates_and_loads_macaroon_secret_key(self):
         self.generate_config()
 
-        with open(self.file, "r") as f:
+        with open(self.file) as f:
             raw = yaml.safe_load(f)
         self.assertIn("macaroon_secret_key", raw)
 
@@ -120,7 +120,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
     def generate_config_and_remove_lines_containing(self, needle):
         self.generate_config()
 
-        with open(self.file, "r") as f:
+        with open(self.file) as f:
             contents = f.readlines()
         contents = [line for line in contents if needle not in line]
         with open(self.file, "w") as f:
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index 875b0d0a11..3f41e99950 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -152,7 +152,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
         )
         self.assertEqual(len(presence_updates), 1)
 
-        presence_update = presence_updates[0]  # type: UserPresenceState
+        presence_update: UserPresenceState = presence_updates[0]
         self.assertEqual(presence_update.user_id, self.other_user_one_id)
         self.assertEqual(presence_update.state, "online")
         self.assertEqual(presence_update.status_msg, "boop")
@@ -274,7 +274,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
         presence_updates, _ = sync_presence(self, self.other_user_id)
         self.assertEqual(len(presence_updates), 1)
 
-        presence_update = presence_updates[0]  # type: UserPresenceState
+        presence_update: UserPresenceState = presence_updates[0]
         self.assertEqual(presence_update.user_id, self.other_user_id)
         self.assertEqual(presence_update.state, "online")
         self.assertEqual(presence_update.status_msg, "I'm online!")
@@ -285,6 +285,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
         presence_updates, _ = sync_presence(self, self.presence_receiving_user_two_id)
         self.assertEqual(len(presence_updates), 3)
 
+        # We stagger sending of presence, so we need to wait a bit for them to
+        # get sent out.
+        self.reactor.advance(60)
+
         # Test that sending to a remote user works
         remote_user_id = "@far_away_person:island"
 
@@ -301,6 +305,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
             self.module_api.send_local_online_presence_to([remote_user_id])
         )
 
+        # We stagger sending of presence, so we need to wait a bit for them to
+        # get sent out.
+        self.reactor.advance(60)
+
         # Check that the expected presence updates were sent
         # We explicitly compare using sets as we expect that calling
         # module_api.send_local_online_presence_to will create a presence
@@ -320,7 +328,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
         )
         for call in calls:
             call_args = call[0]
-            federation_transaction = call_args[0]  # type: Transaction
+            federation_transaction: Transaction = call_args[0]
 
             # Get the sent EDUs in this transaction
             edus = federation_transaction.get_dict()["edus"]
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index e0a24824cc..39e7b1ab25 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -47,12 +47,16 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             "alg2:k3": {"key": "key3"},
         }
 
+        # Note that "signed_curve25519" is always returned in key count responses. This is necessary until
+        # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
         res = self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": keys}
             )
         )
-        self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
+        self.assertDictEqual(
+            res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
+        )
 
         # we should be able to change the signature without a problem
         keys["alg2:k2"]["signatures"]["k1"] = "sig2"
@@ -61,7 +65,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
                 local_user, device_id, {"one_time_keys": keys}
             )
         )
-        self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
+        self.assertDictEqual(
+            res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
+        )
 
     def test_change_one_time_keys(self):
         """attempts to change one-time-keys should be rejected"""
@@ -79,7 +85,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
                 local_user, device_id, {"one_time_keys": keys}
             )
         )
-        self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
+        self.assertDictEqual(
+            res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
+        )
 
         # Error when changing string key
         self.get_failure(
@@ -89,7 +97,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             SynapseError,
         )
 
-        # Error when replacing dict key with strin
+        # Error when replacing dict key with string
         self.get_failure(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
@@ -131,7 +139,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
                 local_user, device_id, {"one_time_keys": keys}
             )
         )
-        self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
+        self.assertDictEqual(
+            res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}}
+        )
 
         res2 = self.get_success(
             self.handler.claim_one_time_keys(
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index cdb41101b3..2928c4f48c 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -103,7 +103,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertIsNone(
-            (self.get_success(self.store.get_profile_displayname(self.frank.localpart)))
+            self.get_success(self.store.get_profile_displayname(self.frank.localpart))
         )
 
     def test_set_my_name_if_disabled(self):
diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py
index 9771d3fb3b..3f73ad7f94 100644
--- a/tests/handlers/test_space_summary.py
+++ b/tests/handlers/test_space_summary.py
@@ -14,8 +14,18 @@
 from typing import Any, Iterable, Optional, Tuple
 from unittest import mock
 
-from synapse.api.constants import EventContentFields, RoomTypes
+from synapse.api.constants import (
+    EventContentFields,
+    EventTypes,
+    HistoryVisibility,
+    JoinRules,
+    Membership,
+    RestrictedJoinRuleTypes,
+    RoomTypes,
+)
 from synapse.api.errors import AuthError
+from synapse.api.room_versions import RoomVersions
+from synapse.events import make_event_from_dict
 from synapse.handlers.space_summary import _child_events_comparison_key
 from synapse.rest import admin
 from synapse.rest.client.v1 import login, room
@@ -117,7 +127,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
         """Add a child room to a space."""
         self.helper.send_state(
             space_id,
-            event_type="m.space.child",
+            event_type=EventTypes.SpaceChild,
             body={"via": [self.hs.hostname]},
             tok=token,
             state_key=room_id,
@@ -155,26 +165,379 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
         # The user cannot see the space.
         self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
 
-        # Joining the room causes it to be visible.
-        self.helper.join(self.space, user2, tok=token2)
+        # If the space is made world-readable it should return a result.
+        self.helper.send_state(
+            self.space,
+            event_type=EventTypes.RoomHistoryVisibility,
+            body={"history_visibility": HistoryVisibility.WORLD_READABLE},
+            tok=self.token,
+        )
         result = self.get_success(self.handler.get_space_summary(user2, self.space))
-
-        # The result should only have the space, but includes the link to the room.
-        self._assert_rooms(result, [self.space])
+        self._assert_rooms(result, [self.space, self.room])
         self._assert_events(result, [(self.space, self.room)])
 
-    def test_world_readable(self):
-        """A world-readable room is visible to everyone."""
+        # Make it not world-readable again and confirm it results in an error.
         self.helper.send_state(
             self.space,
-            event_type="m.room.history_visibility",
-            body={"history_visibility": "world_readable"},
+            event_type=EventTypes.RoomHistoryVisibility,
+            body={"history_visibility": HistoryVisibility.JOINED},
+            tok=self.token,
+        )
+        self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
+
+        # Join the space and results should be returned.
+        self.helper.join(self.space, user2, tok=token2)
+        result = self.get_success(self.handler.get_space_summary(user2, self.space))
+        self._assert_rooms(result, [self.space, self.room])
+        self._assert_events(result, [(self.space, self.room)])
+
+    def _create_room_with_join_rule(
+        self, join_rule: str, room_version: Optional[str] = None, **extra_content
+    ) -> str:
+        """Create a room with the given join rule and add it to the space."""
+        room_id = self.helper.create_room_as(
+            self.user,
+            room_version=room_version,
             tok=self.token,
+            extra_content={
+                "initial_state": [
+                    {
+                        "type": EventTypes.JoinRules,
+                        "state_key": "",
+                        "content": {
+                            "join_rule": join_rule,
+                            **extra_content,
+                        },
+                    }
+                ]
+            },
         )
+        self._add_child(self.space, room_id, self.token)
+        return room_id
 
+    def test_filtering(self):
+        """
+        Rooms should be properly filtered to only include rooms the user has access to.
+        """
         user2 = self.register_user("user2", "pass")
+        token2 = self.login("user2", "pass")
 
-        # The space should be visible, as well as the link to the room.
+        # Create a few rooms which will have different properties.
+        public_room = self._create_room_with_join_rule(JoinRules.PUBLIC)
+        knock_room = self._create_room_with_join_rule(
+            JoinRules.KNOCK, room_version=RoomVersions.V7.identifier
+        )
+        not_invited_room = self._create_room_with_join_rule(JoinRules.INVITE)
+        invited_room = self._create_room_with_join_rule(JoinRules.INVITE)
+        self.helper.invite(invited_room, targ=user2, tok=self.token)
+        restricted_room = self._create_room_with_join_rule(
+            JoinRules.MSC3083_RESTRICTED,
+            room_version=RoomVersions.MSC3083.identifier,
+            allow=[],
+        )
+        restricted_accessible_room = self._create_room_with_join_rule(
+            JoinRules.MSC3083_RESTRICTED,
+            room_version=RoomVersions.MSC3083.identifier,
+            allow=[
+                {
+                    "type": RestrictedJoinRuleTypes.ROOM_MEMBERSHIP,
+                    "room_id": self.space,
+                    "via": [self.hs.hostname],
+                }
+            ],
+        )
+        world_readable_room = self._create_room_with_join_rule(JoinRules.INVITE)
+        self.helper.send_state(
+            world_readable_room,
+            event_type=EventTypes.RoomHistoryVisibility,
+            body={"history_visibility": HistoryVisibility.WORLD_READABLE},
+            tok=self.token,
+        )
+        joined_room = self._create_room_with_join_rule(JoinRules.INVITE)
+        self.helper.invite(joined_room, targ=user2, tok=self.token)
+        self.helper.join(joined_room, user2, tok=token2)
+
+        # Join the space.
+        self.helper.join(self.space, user2, tok=token2)
         result = self.get_success(self.handler.get_space_summary(user2, self.space))
-        self._assert_rooms(result, [self.space])
-        self._assert_events(result, [(self.space, self.room)])
+
+        self._assert_rooms(
+            result,
+            [
+                self.space,
+                self.room,
+                public_room,
+                knock_room,
+                invited_room,
+                restricted_accessible_room,
+                world_readable_room,
+                joined_room,
+            ],
+        )
+        self._assert_events(
+            result,
+            [
+                (self.space, self.room),
+                (self.space, public_room),
+                (self.space, knock_room),
+                (self.space, not_invited_room),
+                (self.space, invited_room),
+                (self.space, restricted_room),
+                (self.space, restricted_accessible_room),
+                (self.space, world_readable_room),
+                (self.space, joined_room),
+            ],
+        )
+
+    def test_complex_space(self):
+        """
+        Create a "complex" space to see how it handles things like loops and subspaces.
+        """
+        # Create an inaccessible room.
+        user2 = self.register_user("user2", "pass")
+        token2 = self.login("user2", "pass")
+        room2 = self.helper.create_room_as(user2, is_public=False, tok=token2)
+        # This is a bit odd as "user" is adding a room they don't know about, but
+        # it works for the tests.
+        self._add_child(self.space, room2, self.token)
+
+        # Create a subspace under the space with an additional room in it.
+        subspace = self.helper.create_room_as(
+            self.user,
+            tok=self.token,
+            extra_content={
+                "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+            },
+        )
+        subroom = self.helper.create_room_as(self.user, tok=self.token)
+        self._add_child(self.space, subspace, token=self.token)
+        self._add_child(subspace, subroom, token=self.token)
+        # Also add the two rooms from the space into this subspace (causing loops).
+        self._add_child(subspace, self.room, token=self.token)
+        self._add_child(subspace, room2, self.token)
+
+        result = self.get_success(self.handler.get_space_summary(self.user, self.space))
+
+        # The result should include each room a single time and each link.
+        self._assert_rooms(result, [self.space, self.room, subspace, subroom])
+        self._assert_events(
+            result,
+            [
+                (self.space, self.room),
+                (self.space, room2),
+                (self.space, subspace),
+                (subspace, subroom),
+                (subspace, self.room),
+                (subspace, room2),
+            ],
+        )
+
+    def test_fed_complex(self):
+        """
+        Return data over federation and ensure that it is handled properly.
+        """
+        fed_hostname = self.hs.hostname + "2"
+        subspace = "#subspace:" + fed_hostname
+        subroom = "#subroom:" + fed_hostname
+
+        async def summarize_remote_room(
+            _self, room, suggested_only, max_children, exclude_rooms
+        ):
+            # Return some good data, and some bad data:
+            #
+            # * Event *back* to the root room.
+            # * Unrelated events / rooms
+            # * Multiple levels of events (in a not-useful order, e.g. grandchild
+            #   events before child events).
+
+            # Note that these entries are brief, but should contain enough info.
+            rooms = [
+                {
+                    "room_id": subspace,
+                    "world_readable": True,
+                    "room_type": RoomTypes.SPACE,
+                },
+                {
+                    "room_id": subroom,
+                    "world_readable": True,
+                },
+            ]
+            event_content = {"via": [fed_hostname]}
+            events = [
+                {
+                    "room_id": subspace,
+                    "state_key": subroom,
+                    "content": event_content,
+                },
+            ]
+            return rooms, events
+
+        # Add a room to the space which is on another server.
+        self._add_child(self.space, subspace, self.token)
+
+        with mock.patch(
+            "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room",
+            new=summarize_remote_room,
+        ):
+            result = self.get_success(
+                self.handler.get_space_summary(self.user, self.space)
+            )
+
+        self._assert_rooms(result, [self.space, self.room, subspace, subroom])
+        self._assert_events(
+            result,
+            [
+                (self.space, self.room),
+                (self.space, subspace),
+                (subspace, subroom),
+            ],
+        )
+
+    def test_fed_filtering(self):
+        """
+        Rooms returned over federation should be properly filtered to only include
+        rooms the user has access to.
+        """
+        fed_hostname = self.hs.hostname + "2"
+        subspace = "#subspace:" + fed_hostname
+
+        # Create a few rooms which will have different properties.
+        public_room = "#public:" + fed_hostname
+        knock_room = "#knock:" + fed_hostname
+        not_invited_room = "#not_invited:" + fed_hostname
+        invited_room = "#invited:" + fed_hostname
+        restricted_room = "#restricted:" + fed_hostname
+        restricted_accessible_room = "#restricted_accessible:" + fed_hostname
+        world_readable_room = "#world_readable:" + fed_hostname
+        joined_room = self.helper.create_room_as(self.user, tok=self.token)
+
+        # Poke an invite over federation into the database.
+        fed_handler = self.hs.get_federation_handler()
+        event = make_event_from_dict(
+            {
+                "room_id": invited_room,
+                "event_id": "!abcd:" + fed_hostname,
+                "type": EventTypes.Member,
+                "sender": "@remote:" + fed_hostname,
+                "state_key": self.user,
+                "content": {"membership": Membership.INVITE},
+                "prev_events": [],
+                "auth_events": [],
+                "depth": 1,
+                "origin_server_ts": 1234,
+            }
+        )
+        self.get_success(
+            fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6)
+        )
+
+        async def summarize_remote_room(
+            _self, room, suggested_only, max_children, exclude_rooms
+        ):
+            # Note that these entries are brief, but should contain enough info.
+            rooms = [
+                {
+                    "room_id": public_room,
+                    "world_readable": False,
+                    "join_rules": JoinRules.PUBLIC,
+                },
+                {
+                    "room_id": knock_room,
+                    "world_readable": False,
+                    "join_rules": JoinRules.KNOCK,
+                },
+                {
+                    "room_id": not_invited_room,
+                    "world_readable": False,
+                    "join_rules": JoinRules.INVITE,
+                },
+                {
+                    "room_id": invited_room,
+                    "world_readable": False,
+                    "join_rules": JoinRules.INVITE,
+                },
+                {
+                    "room_id": restricted_room,
+                    "world_readable": False,
+                    "join_rules": JoinRules.MSC3083_RESTRICTED,
+                    "allowed_spaces": [],
+                },
+                {
+                    "room_id": restricted_accessible_room,
+                    "world_readable": False,
+                    "join_rules": JoinRules.MSC3083_RESTRICTED,
+                    "allowed_spaces": [self.room],
+                },
+                {
+                    "room_id": world_readable_room,
+                    "world_readable": True,
+                    "join_rules": JoinRules.INVITE,
+                },
+                {
+                    "room_id": joined_room,
+                    "world_readable": False,
+                    "join_rules": JoinRules.INVITE,
+                },
+            ]
+
+            # Place each room in the sub-space.
+            event_content = {"via": [fed_hostname]}
+            events = [
+                {
+                    "room_id": subspace,
+                    "state_key": room["room_id"],
+                    "content": event_content,
+                }
+                for room in rooms
+            ]
+
+            # Also include the subspace.
+            rooms.insert(
+                0,
+                {
+                    "room_id": subspace,
+                    "world_readable": True,
+                },
+            )
+            return rooms, events
+
+        # Add a room to the space which is on another server.
+        self._add_child(self.space, subspace, self.token)
+
+        with mock.patch(
+            "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room",
+            new=summarize_remote_room,
+        ):
+            result = self.get_success(
+                self.handler.get_space_summary(self.user, self.space)
+            )
+
+        self._assert_rooms(
+            result,
+            [
+                self.space,
+                self.room,
+                subspace,
+                public_room,
+                knock_room,
+                invited_room,
+                restricted_accessible_room,
+                world_readable_room,
+                joined_room,
+            ],
+        )
+        self._assert_events(
+            result,
+            [
+                (self.space, self.room),
+                (self.space, subspace),
+                (subspace, public_room),
+                (subspace, knock_room),
+                (subspace, not_invited_room),
+                (subspace, invited_room),
+                (subspace, restricted_room),
+                (subspace, restricted_accessible_room),
+                (subspace, world_readable_room),
+                (subspace, joined_room),
+            ],
+        )
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index c9d4fd9336..e4059acda3 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -88,16 +88,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
     def _get_current_stats(self, stats_type, stat_id):
         table, id_col = stats.TYPE_TO_TABLE[stats_type]
 
-        cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list(
-            stats.PER_SLICE_FIELDS[stats_type]
-        )
-
-        end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000)
+        cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
 
         return self.get_success(
             self.store.db_pool.simple_select_one(
-                table + "_historical",
-                {id_col: stat_id, end_ts: end_ts},
+                table + "_current",
+                {id_col: stat_id},
                 cols,
                 allow_none=True,
             )
@@ -156,115 +152,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         self.assertEqual(len(r), 1)
         self.assertEqual(r[0]["topic"], "foo")
 
-    def test_initial_earliest_token(self):
-        """
-        Ingestion via notify_new_event will ignore tokens that the background
-        update have already processed.
-        """
-
-        self.reactor.advance(86401)
-
-        self.hs.config.stats_enabled = False
-        self.handler.stats_enabled = False
-
-        u1 = self.register_user("u1", "pass")
-        u1_token = self.login("u1", "pass")
-
-        u2 = self.register_user("u2", "pass")
-        u2_token = self.login("u2", "pass")
-
-        u3 = self.register_user("u3", "pass")
-        u3_token = self.login("u3", "pass")
-
-        room_1 = self.helper.create_room_as(u1, tok=u1_token)
-        self.helper.send_state(
-            room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token
-        )
-
-        # Begin the ingestion by creating the temp tables. This will also store
-        # the position that the deltas should begin at, once they take over.
-        self.hs.config.stats_enabled = True
-        self.handler.stats_enabled = True
-        self.store.db_pool.updates._all_done = False
-        self.get_success(
-            self.store.db_pool.simple_update_one(
-                table="stats_incremental_position",
-                keyvalues={},
-                updatevalues={"stream_id": 0},
-            )
-        )
-
-        self.get_success(
-            self.store.db_pool.simple_insert(
-                "background_updates",
-                {"update_name": "populate_stats_prepare", "progress_json": "{}"},
-            )
-        )
-
-        while not self.get_success(
-            self.store.db_pool.updates.has_completed_background_updates()
-        ):
-            self.get_success(
-                self.store.db_pool.updates.do_next_background_update(100), by=0.1
-            )
-
-        # Now, before the table is actually ingested, add some more events.
-        self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token)
-        self.helper.join(room=room_1, user=u2, tok=u2_token)
-
-        # orig_delta_processor = self.store.
-
-        # Now do the initial ingestion.
-        self.get_success(
-            self.store.db_pool.simple_insert(
-                "background_updates",
-                {"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
-            )
-        )
-        self.get_success(
-            self.store.db_pool.simple_insert(
-                "background_updates",
-                {
-                    "update_name": "populate_stats_cleanup",
-                    "progress_json": "{}",
-                    "depends_on": "populate_stats_process_rooms",
-                },
-            )
-        )
-
-        self.store.db_pool.updates._all_done = False
-        while not self.get_success(
-            self.store.db_pool.updates.has_completed_background_updates()
-        ):
-            self.get_success(
-                self.store.db_pool.updates.do_next_background_update(100), by=0.1
-            )
-
-        self.reactor.advance(86401)
-
-        # Now add some more events, triggering ingestion. Because of the stream
-        # position being set to before the events sent in the middle, a simpler
-        # implementation would reprocess those events, and say there were four
-        # users, not three.
-        self.helper.invite(room=room_1, src=u1, targ=u3, tok=u1_token)
-        self.helper.join(room=room_1, user=u3, tok=u3_token)
-
-        # self.handler.notify_new_event()
-
-        # We need to let the delta processor advance

-        self.reactor.advance(10 * 60)
-
-        # Get the slices! There should be two -- day 1, and day 2.
-        r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0))
-
-        self.assertEqual(len(r), 2)
-
-        # The oldest has 2 joined members
-        self.assertEqual(r[-1]["joined_members"], 2)
-
-        # The newest has 3
-        self.assertEqual(r[0]["joined_members"], 3)
-
     def test_create_user(self):
         """
         When we create a user, it should have statistics already ready.
@@ -296,22 +183,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         self.assertIsNotNone(r1stats)
         self.assertIsNotNone(r2stats)
 
-        # contains the default things you'd expect in a fresh room
-        self.assertEqual(
-            r1stats["total_events"],
-            EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM,
-            "Wrong number of total_events in new room's stats!"
-            " You may need to update this if more state events are added to"
-            " the room creation process.",
-        )
-        self.assertEqual(
-            r2stats["total_events"],
-            EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM,
-            "Wrong number of total_events in new room's stats!"
-            " You may need to update this if more state events are added to"
-            " the room creation process.",
-        )
-
         self.assertEqual(
             r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
         )
@@ -327,24 +198,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         self.assertEqual(r2stats["invited_members"], 0)
         self.assertEqual(r2stats["banned_members"], 0)
 
-    def test_send_message_increments_total_events(self):
-        """
-        When we send a message, it increments total_events.
-        """
-
-        self._perform_background_initial_update()
-
-        u1 = self.register_user("u1", "pass")
-        u1token = self.login("u1", "pass")
-        r1 = self.helper.create_room_as(u1, tok=u1token)
-        r1stats_ante = self._get_current_stats("room", r1)
-
-        self.helper.send(r1, "hiss", tok=u1token)
-
-        r1stats_post = self._get_current_stats("room", r1)
-
-        self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
-
     def test_updating_profile_information_does_not_increase_joined_members_count(self):
         """
         Check that the joined_members count does not increase when a user changes their
@@ -378,7 +231,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
     def test_send_state_event_nonoverwriting(self):
         """
-        When we send a non-overwriting state event, it increments total_events AND current_state_events
+        When we send a non-overwriting state event, it increments current_state_events
         """
 
         self._perform_background_initial_update()
@@ -399,44 +252,14 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         r1stats_post = self._get_current_stats("room", r1)
 
-        self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
         self.assertEqual(
             r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
             1,
         )
 
-    def test_send_state_event_overwriting(self):
-        """
-        When we send an overwriting state event, it increments total_events ONLY
-        """
-
-        self._perform_background_initial_update()
-
-        u1 = self.register_user("u1", "pass")
-        u1token = self.login("u1", "pass")
-        r1 = self.helper.create_room_as(u1, tok=u1token)
-
-        self.helper.send_state(
-            r1, "cat.hissing", {"value": True}, tok=u1token, state_key="tabby"
-        )
-
-        r1stats_ante = self._get_current_stats("room", r1)
-
-        self.helper.send_state(
-            r1, "cat.hissing", {"value": False}, tok=u1token, state_key="tabby"
-        )
-
-        r1stats_post = self._get_current_stats("room", r1)
-
-        self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
-        self.assertEqual(
-            r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
-            0,
-        )
-
     def test_join_first_time(self):
         """
-        When a user joins a room for the first time, total_events, current_state_events and
+        When a user joins a room for the first time, current_state_events and
         joined_members should increase by exactly 1.
         """
 
@@ -455,7 +278,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         r1stats_post = self._get_current_stats("room", r1)
 
-        self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
         self.assertEqual(
             r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
             1,
@@ -466,7 +288,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
     def test_join_after_leave(self):
         """
-        When a user joins a room after being previously left, total_events and
+        When a user joins a room after being previously left,
         joined_members should increase by exactly 1.
         current_state_events should not increase.
         left_members should decrease by exactly 1.
@@ -490,7 +312,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         r1stats_post = self._get_current_stats("room", r1)
 
-        self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
         self.assertEqual(
             r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
             0,
@@ -504,7 +325,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
     def test_invited(self):
         """
-        When a user invites another user, current_state_events, total_events and
+        When a user invites another user, current_state_events and
         invited_members should increase by exactly 1.
         """
 
@@ -522,7 +343,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         r1stats_post = self._get_current_stats("room", r1)
 
-        self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
         self.assertEqual(
             r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
             1,
@@ -533,7 +353,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
     def test_join_after_invite(self):
         """
-        When a user joins a room after being invited, total_events and
+        When a user joins a room after being invited and
         joined_members should increase by exactly 1.
         current_state_events should not increase.
         invited_members should decrease by exactly 1.
@@ -556,7 +376,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         r1stats_post = self._get_current_stats("room", r1)
 
-        self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
         self.assertEqual(
             r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
             0,
@@ -570,7 +389,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
     def test_left(self):
         """
-        When a user leaves a room after joining, total_events and
+        When a user leaves a room after joining and
         left_members should increase by exactly 1.
         current_state_events should not increase.
         joined_members should decrease by exactly 1.
@@ -593,7 +412,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         r1stats_post = self._get_current_stats("room", r1)
 
-        self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
         self.assertEqual(
             r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
             0,
@@ -607,7 +425,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
     def test_banned(self):
         """
-        When a user is banned from a room after joining, total_events and
+        When a user is banned from a room after joining and
         left_members should increase by exactly 1.
         current_state_events should not increase.
         banned_members should decrease by exactly 1.
@@ -630,7 +448,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         r1stats_post = self._get_current_stats("room", r1)
 
-        self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
         self.assertEqual(
             r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
             0,
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index f58afbc244..fa3cff598e 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -38,6 +38,9 @@ U_ONION = UserID.from_string("@onion:farm")
 # Test room id
 ROOM_ID = "a-room"
 
+# Room we're not in
+OTHER_ROOM_ID = "another-room"
+
 
 def _expect_edu_transaction(edu_type, content, origin="test"):
     return {
@@ -115,6 +118,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         hs.get_auth().check_user_in_room = check_user_in_room
 
+        async def check_host_in_room(room_id, server_name):
+            return room_id == ROOM_ID
+
+        hs.get_event_auth_handler().check_host_in_room = check_host_in_room
+
         def get_joined_hosts_for_room(room_id):
             return {member.domain for member in self.room_members}
 
@@ -244,6 +252,35 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             ],
         )
 
+    def test_started_typing_remote_recv_not_in_room(self):
+        self.room_members = [U_APPLE, U_ONION]
+
+        self.assertEquals(self.event_source.get_current_key(), 0)
+
+        channel = self.make_request(
+            "PUT",
+            "/_matrix/federation/v1/send/1000000",
+            _make_edu_transaction_json(
+                "m.typing",
+                content={
+                    "room_id": OTHER_ROOM_ID,
+                    "user_id": U_ONION.to_string(),
+                    "typing": True,
+                },
+            ),
+            federation_auth_origin=b"farm",
+        )
+        self.assertEqual(channel.code, 200)
+
+        self.on_new_event.assert_not_called()
+
+        self.assertEquals(self.event_source.get_current_key(), 0)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[OTHER_ROOM_ID], from_key=0)
+        )
+        self.assertEquals(events[0], [])
+        self.assertEquals(events[1], 0)
+
     @override_config({"send_federation": True})
     def test_stopped_typing(self):
         self.room_members = [U_APPLE, U_BANANA, U_ONION]
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index e45980316b..a37bce08c3 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -273,7 +273,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.assertEqual(response.code, 200)
 
         # Send the body
-        request.write('{ "a": 1 }'.encode("ascii"))
+        request.write(b'{ "a": 1 }')
         request.finish()
 
         self.reactor.pump((0.1,))
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index ed9a884d76..d9a8b077d3 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -102,7 +102,7 @@ class FederationClientTests(HomeserverTestCase):
         self.assertNoResult(test_d)
 
         # Send it the HTTP response
-        res_json = '{ "a": 1 }'.encode("ascii")
+        res_json = b'{ "a": 1 }'
         protocol.dataReceived(
             b"HTTP/1.1 200 OK\r\n"
             b"Server: Fake\r\n"
@@ -339,10 +339,8 @@ class FederationClientTests(HomeserverTestCase):
 
         # Send it the HTTP response
         client.dataReceived(
-            (
-                b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n"
-                b"Server: Fake\r\n\r\n"
-            )
+            b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n"
+            b"Server: Fake\r\n\r\n"
         )
 
         # Push by enough to time it out
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index fefc8099c9..437113929a 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -205,6 +205,41 @@ class MatrixFederationAgentTests(TestCase):
 
     @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"})
     def test_http_request_via_proxy(self):
+        """
+        Tests that requests can be made through a proxy.
+        """
+        self._do_http_request_via_proxy(auth_credentials=None)
+
+    @patch.dict(
+        os.environ,
+        {"http_proxy": "bob:pinkponies@proxy.com:8888", "no_proxy": "unused.com"},
+    )
+    def test_http_request_via_proxy_with_auth(self):
+        """
+        Tests that authenticated requests can be made through a proxy.
+        """
+        self._do_http_request_via_proxy(auth_credentials="bob:pinkponies")
+
+    @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
+    def test_https_request_via_proxy(self):
+        """Tests that TLS-encrypted requests can be made through a proxy"""
+        self._do_https_request_via_proxy(auth_credentials=None)
+
+    @patch.dict(
+        os.environ,
+        {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
+    )
+    def test_https_request_via_proxy_with_auth(self):
+        """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
+        self._do_https_request_via_proxy(auth_credentials="bob:pinkponies")
+
+    def _do_http_request_via_proxy(
+        self,
+        auth_credentials: Optional[str] = None,
+    ):
+        """
+        Tests that requests can be made through a proxy.
+        """
         agent = ProxyAgent(self.reactor, use_proxy=True)
 
         self.reactor.lookups["proxy.com"] = "1.2.3.5"
@@ -229,6 +264,23 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(len(http_server.requests), 1)
 
         request = http_server.requests[0]
+
+        # Check whether auth credentials have been supplied to the proxy
+        proxy_auth_header_values = request.requestHeaders.getRawHeaders(
+            b"Proxy-Authorization"
+        )
+
+        if auth_credentials is not None:
+            # Compute the correct header value for Proxy-Authorization
+            encoded_credentials = base64.b64encode(b"bob:pinkponies")
+            expected_header_value = b"Basic " + encoded_credentials
+
+            # Validate the header's value
+            self.assertIn(expected_header_value, proxy_auth_header_values)
+        else:
+            # Check that the Proxy-Authorization header has not been supplied to the proxy
+            self.assertIsNone(proxy_auth_header_values)
+
         self.assertEqual(request.method, b"GET")
         self.assertEqual(request.path, b"http://test.com")
         self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
@@ -241,19 +293,6 @@ class MatrixFederationAgentTests(TestCase):
         body = self.successResultOf(treq.content(resp))
         self.assertEqual(body, b"result")
 
-    @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
-    def test_https_request_via_proxy(self):
-        """Tests that TLS-encrypted requests can be made through a proxy"""
-        self._do_https_request_via_proxy(auth_credentials=None)
-
-    @patch.dict(
-        os.environ,
-        {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
-    )
-    def test_https_request_via_proxy_with_auth(self):
-        """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
-        self._do_https_request_via_proxy(auth_credentials="bob:pinkponies")
-
     def _do_https_request_via_proxy(
         self,
         auth_credentials: Optional[str] = None,
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 2c68b9a13c..81d9e2f484 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -100,9 +100,9 @@ class ModuleApiTestCase(HomeserverTestCase):
             "content": content,
             "sender": user_id,
         }
-        event = self.get_success(
+        event: EventBase = self.get_success(
             self.module_api.create_and_send_event_into_room(event_dict)
-        )  # type: EventBase
+        )
         self.assertEqual(event.sender, user_id)
         self.assertEqual(event.type, "m.room.message")
         self.assertEqual(event.room_id, room_id)
@@ -136,9 +136,9 @@ class ModuleApiTestCase(HomeserverTestCase):
             "sender": user_id,
             "state_key": "",
         }
-        event = self.get_success(
+        event: EventBase = self.get_success(
             self.module_api.create_and_send_event_into_room(event_dict)
-        )  # type: EventBase
+        )
         self.assertEqual(event.sender, user_id)
         self.assertEqual(event.type, "m.room.power_levels")
         self.assertEqual(event.room_id, room_id)
@@ -281,7 +281,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         )
         for call in calls:
             call_args = call[0]
-            federation_transaction = call_args[0]  # type: Transaction
+            federation_transaction: Transaction = call_args[0]
 
             # Get the sent EDUs in this transaction
             edus = federation_transaction.get_dict()["edus"]
@@ -390,7 +390,7 @@ def _test_sending_local_online_presence_to_local_user(
     )
     test_case.assertEqual(len(presence_updates), 1)
 
-    presence_update = presence_updates[0]  # type: UserPresenceState
+    presence_update: UserPresenceState = presence_updates[0]
     test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
     test_case.assertEqual(presence_update.state, "online")
 
@@ -443,7 +443,7 @@ def _test_sending_local_online_presence_to_local_user(
     )
     test_case.assertEqual(len(presence_updates), 1)
 
-    presence_update = presence_updates[0]  # type: UserPresenceState
+    presence_update: UserPresenceState = presence_updates[0]
     test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
     test_case.assertEqual(presence_update.state, "online")
 
@@ -454,7 +454,7 @@ def _test_sending_local_online_presence_to_local_user(
     )
     test_case.assertEqual(len(presence_updates), 1)
 
-    presence_update = presence_updates[0]  # type: UserPresenceState
+    presence_update: UserPresenceState = presence_updates[0]
     test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
     test_case.assertEqual(presence_update.state, "online")
 
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 624bd1b927..e9fd991718 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -53,9 +53,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(hs)
         self.streamer = hs.get_replication_streamer()
-        self.server = server_factory.buildProtocol(
+        self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol(
             None
-        )  # type: ServerReplicationStreamProtocol
+        )
 
         # Make a new HomeServer object for the worker
         self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -195,7 +195,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         fetching updates for given stream.
         """
 
-        path = request.path  # type: bytes  # type: ignore
+        path: bytes = request.path  # type: ignore
         self.assertRegex(
             path,
             br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
@@ -212,7 +212,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
     unlike `BaseStreamTestCase`.
     """
 
-    servlets = []  # type: List[Callable[[HomeServer, JsonResource], None]]
+    servlets: List[Callable[[HomeServer, JsonResource], None]] = []
 
     def setUp(self):
         super().setUp()
@@ -448,7 +448,7 @@ class TestReplicationDataHandler(ReplicationDataHandler):
         super().__init__(hs)
 
         # list of received (stream_name, token, row) tuples
-        self.received_rdata_rows = []  # type: List[Tuple[str, int, Any]]
+        self.received_rdata_rows: List[Tuple[str, int, Any]] = []
 
     async def on_rdata(self, stream_name, instance_name, token, rows):
         await super().on_rdata(stream_name, instance_name, token, rows)
@@ -484,7 +484,7 @@ class FakeRedisPubSubServer:
 class FakeRedisPubSubProtocol(Protocol):
     """A connection from a client talking to the fake Redis server."""
 
-    transport = None  # type: Optional[FakeTransport]
+    transport: Optional[FakeTransport] = None
 
     def __init__(self, server: FakeRedisPubSubServer):
         self._server = server
@@ -550,12 +550,12 @@ class FakeRedisPubSubProtocol(Protocol):
         if obj is None:
             return "$-1\r\n"
         if isinstance(obj, str):
-            return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
+            return f"${len(obj)}\r\n{obj}\r\n"
         if isinstance(obj, int):
-            return ":{val}\r\n".format(val=obj)
+            return f":{obj}\r\n"
         if isinstance(obj, (list, tuple)):
             items = "".join(self.encode(a) for a in obj)
-            return "*{len}\r\n{items}".format(len=len(obj), items=items)
+            return f"*{len(obj)}\r\n{items}"
 
         raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
 
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index f51fa0a79e..666008425a 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -135,9 +135,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
 
         # this is the point in the DAG where we make a fork
-        fork_point = self.get_success(
+        fork_point: List[str] = self.get_success(
             self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
-        )  # type: List[str]
+        )
 
         events = [
             self._inject_state_event(sender=OTHER_USER)
@@ -238,7 +238,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.assertEqual(row.data.event_id, pl_event.event_id)
 
         # the state rows are unsorted
-        state_rows = []  # type: List[EventsStreamCurrentStateRow]
+        state_rows: List[EventsStreamCurrentStateRow] = []
         for stream_name, _, row in received_rows:
             self.assertEqual("events", stream_name)
             self.assertIsInstance(row, EventsStreamRow)
@@ -290,11 +290,11 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
 
         # this is the point in the DAG where we make a fork
-        fork_point = self.get_success(
+        fork_point: List[str] = self.get_success(
             self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
-        )  # type: List[str]
+        )
 
-        events = []  # type: List[EventBase]
+        events: List[EventBase] = []
         for user in user_ids:
             events.extend(
                 self._inject_state_event(sender=user) for _ in range(STATES_PER_USER)
@@ -355,7 +355,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
             self.assertEqual(row.data.event_id, pl_events[i].event_id)
 
             # the state rows are unsorted
-            state_rows = []  # type: List[EventsStreamCurrentStateRow]
+            state_rows: List[EventsStreamCurrentStateRow] = []
             for _ in range(STATES_PER_USER + 1):
                 stream_name, token, row = received_rows.pop(0)
                 self.assertEqual("events", stream_name)
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index 7f5d932f0b..38e292c1ab 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -43,7 +43,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
         stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "receipts")
         self.assertEqual(1, len(rdata_rows))
-        row = rdata_rows[0]  # type: ReceiptsStream.ReceiptsStreamRow
+        row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0]
         self.assertEqual("!room:blue", row.room_id)
         self.assertEqual("m.read", row.receipt_type)
         self.assertEqual(USER_ID, row.user_id)
@@ -75,7 +75,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
         self.assertEqual(token, 3)
         self.assertEqual(1, len(rdata_rows))
 
-        row = rdata_rows[0]  # type: ReceiptsStream.ReceiptsStreamRow
+        row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0]
         self.assertEqual("!room2:blue", row.room_id)
         self.assertEqual("m.read", row.receipt_type)
         self.assertEqual(USER_ID, row.user_id)
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index ecd360c2d0..3ff5afc6e5 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -47,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
-        row = rdata_rows[0]  # type: TypingStream.TypingStreamRow
+        row: TypingStream.TypingStreamRow = rdata_rows[0]
         self.assertEqual(ROOM_ID, row.room_id)
         self.assertEqual([USER_ID], row.user_ids)
 
@@ -102,7 +102,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
-        row = rdata_rows[0]  # type: TypingStream.TypingStreamRow
+        row: TypingStream.TypingStreamRow = rdata_rows[0]
         self.assertEqual(ROOM_ID, row.room_id)
         self.assertEqual([USER_ID], row.user_ids)
 
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 76e6644353..ffa425328f 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -31,7 +31,7 @@ from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
 
 logger = logging.getLogger(__name__)
 
-test_server_connection_factory = None  # type: Optional[TestServerTLSConnectionFactory]
+test_server_connection_factory: Optional[TestServerTLSConnectionFactory] = None
 
 
 class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
@@ -70,7 +70,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
             self.reactor,
             FakeSite(resource),
             "GET",
-            "/{}/{}".format(target, media_id),
+            f"/{target}/{media_id}",
             shorthand=False,
             access_token=self.access_token,
             await_result=False,
@@ -113,7 +113,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         self.assertEqual(request.method, b"GET")
         self.assertEqual(
             request.path,
-            "/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"),
+            f"/_matrix/media/r0/download/{target}/{media_id}".encode("utf-8"),
         )
         self.assertEqual(
             request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 5eca5c165d..f3615af97e 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -211,7 +211,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
             self.reactor,
             sync_hs_site,
             "GET",
-            "/sync?since={}".format(next_batch),
+            f"/sync?since={next_batch}",
             access_token=access_token,
         )
 
@@ -241,7 +241,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
             self.reactor,
             sync_hs_site,
             "GET",
-            "/sync?since={}".format(vector_clock_token),
+            f"/sync?since={vector_clock_token}",
             access_token=access_token,
         )
 
@@ -266,7 +266,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
             self.reactor,
             sync_hs_site,
             "GET",
-            "/sync?since={}".format(next_batch),
+            f"/sync?since={next_batch}",
             access_token=access_token,
         )
 
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 2f7090e554..a7c6e595b9 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -66,7 +66,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
         # Create a new group
         channel = self.make_request(
             "POST",
-            "/create_group".encode("ascii"),
+            b"/create_group",
             access_token=self.admin_user_tok,
             content={"localpart": "test"},
         )
@@ -129,9 +129,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
 
     def _get_groups_user_is_in(self, access_token):
         """Returns the list of groups the user is in (given their access token)"""
-        channel = self.make_request(
-            "GET", "/joined_groups".encode("ascii"), access_token=access_token
-        )
+        channel = self.make_request("GET", b"/joined_groups", access_token=access_token)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index ee071c2477..17ec8bfd3b 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -535,7 +535,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
                 )
             )
 
-            self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+            self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
 
     def _assert_peek(self, room_id, expect_code):
         """Assert that the admin user can (or cannot) peek into the room."""
@@ -599,7 +599,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
                 )
             )
 
-            self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+            self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
 
 
 class RoomTestCase(unittest.HomeserverTestCase):
@@ -1280,7 +1280,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         self.public_room_id = self.helper.create_room_as(
             self.creator, tok=self.creator_tok, is_public=True
         )
-        self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id)
+        self.url = f"/_synapse/admin/v1/join/{self.public_room_id}"
 
     def test_requester_is_no_admin(self):
         """
@@ -1420,7 +1420,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         private_room_id = self.helper.create_room_as(
             self.creator, tok=self.creator_tok, is_public=False
         )
-        url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+        url = f"/_synapse/admin/v1/join/{private_room_id}"
         body = json.dumps({"user_id": self.second_user_id})
 
         channel = self.make_request(
@@ -1463,7 +1463,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
 
         # Join user to room.
 
-        url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+        url = f"/_synapse/admin/v1/join/{private_room_id}"
         body = json.dumps({"user_id": self.second_user_id})
 
         channel = self.make_request(
@@ -1493,7 +1493,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         private_room_id = self.helper.create_room_as(
             self.admin_user, tok=self.admin_user_tok, is_public=False
         )
-        url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+        url = f"/_synapse/admin/v1/join/{private_room_id}"
         body = json.dumps({"user_id": self.second_user_id})
 
         channel = self.make_request(
@@ -1633,7 +1633,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "POST",
-            "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+            f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin",
             content={},
             access_token=self.admin_user_tok,
         )
@@ -1660,7 +1660,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "POST",
-            "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+            f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin",
             content={},
             access_token=self.admin_user_tok,
         )
@@ -1686,7 +1686,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "POST",
-            "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+            f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin",
             content={"user_id": self.second_user_id},
             access_token=self.admin_user_tok,
         )
@@ -1720,7 +1720,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "POST",
-            "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+            f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin",
             content={},
             access_token=self.admin_user_tok,
         )
@@ -1753,7 +1753,6 @@ PURGE_TABLES = [
     "room_memberships",
     "room_stats_state",
     "room_stats_current",
-    "room_stats_historical",
     "room_stats_earliest_token",
     "rooms",
     "stream_ordering_to_exterm",
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index e1fe72fc5d..28dd47a28b 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -16,17 +16,19 @@ from typing import Dict
 from unittest.mock import Mock
 
 from synapse.events import EventBase
+from synapse.events.third_party_rules import load_legacy_third_party_event_rules
 from synapse.module_api import ModuleApi
 from synapse.rest import admin
 from synapse.rest.client.v1 import login, room
 from synapse.types import Requester, StateMap
+from synapse.util.frozenutils import unfreeze
 
 from tests import unittest
 
 thread_local = threading.local()
 
 
-class ThirdPartyRulesTestModule:
+class LegacyThirdPartyRulesTestModule:
     def __init__(self, config: Dict, module_api: ModuleApi):
         # keep a record of the "current" rules module, so that the test can patch
         # it if desired.
@@ -46,8 +48,26 @@ class ThirdPartyRulesTestModule:
         return config
 
 
-def current_rules_module() -> ThirdPartyRulesTestModule:
-    return thread_local.rules_module
+class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
+    def __init__(self, config: Dict, module_api: ModuleApi):
+        super().__init__(config, module_api)
+
+    def on_create_room(
+        self, requester: Requester, config: dict, is_requester_admin: bool
+    ):
+        return False
+
+
+class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
+    def __init__(self, config: Dict, module_api: ModuleApi):
+        super().__init__(config, module_api)
+
+    async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+        d = event.get_dict()
+        content = unfreeze(event.content)
+        content["foo"] = "bar"
+        d["content"] = content
+        return d
 
 
 class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
@@ -57,20 +77,23 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def default_config(self):
-        config = super().default_config()
-        config["third_party_event_rules"] = {
-            "module": __name__ + ".ThirdPartyRulesTestModule",
-            "config": {},
-        }
-        return config
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver()
+
+        load_legacy_third_party_event_rules(hs)
+
+        return hs
 
     def prepare(self, reactor, clock, homeserver):
         # Create a user and room to play with during the tests
         self.user_id = self.register_user("kermit", "monkey")
         self.tok = self.login("kermit", "monkey")
 
-        self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+        # Some tests might prevent room creation on purpose.
+        try:
+            self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+        except Exception:
+            pass
 
     def test_third_party_rules(self):
         """Tests that a forbidden event is forbidden from being sent, but an allowed one
@@ -79,10 +102,12 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         # patch the rules module with a Mock which will return False for some event
         # types
         async def check(ev, state):
-            return ev.type != "foo.bar.forbidden"
+            return ev.type != "foo.bar.forbidden", None
 
         callback = Mock(spec=[], side_effect=check)
-        current_rules_module().check_event_allowed = callback
+        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [
+            callback
+        ]
 
         channel = self.make_request(
             "PUT",
@@ -116,9 +141,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         # first patch the event checker so that it will try to modify the event
         async def check(ev: EventBase, state):
             ev.content = {"x": "y"}
-            return True
+            return True, None
 
-        current_rules_module().check_event_allowed = check
+        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
 
         # now send the event
         channel = self.make_request(
@@ -127,7 +152,19 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
             {"x": "x"},
             access_token=self.tok,
         )
-        self.assertEqual(channel.result["code"], b"500", channel.result)
+        # check_event_allowed has some error handling, so it shouldn't 500 just because a
+        # module did something bad.
+        self.assertEqual(channel.code, 200, channel.result)
+        event_id = channel.json_body["event_id"]
+
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        ev = channel.json_body
+        self.assertEqual(ev["content"]["x"], "x")
 
     def test_modify_event(self):
         """The module can return a modified version of the event"""
@@ -135,9 +172,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         async def check(ev: EventBase, state):
             d = ev.get_dict()
             d["content"] = {"x": "y"}
-            return d
+            return True, d
 
-        current_rules_module().check_event_allowed = check
+        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
 
         # now send the event
         channel = self.make_request(
@@ -168,9 +205,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
                 "msgtype": "m.text",
                 "body": d["content"]["body"].upper(),
             }
-            return d
+            return True, d
 
-        current_rules_module().check_event_allowed = check
+        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
 
         # Send an event, then edit it.
         channel = self.make_request(
@@ -222,7 +259,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         self.assertEqual(ev["content"]["body"], "EDITED BODY")
 
     def test_send_event(self):
-        """Tests that the module can send an event into a room via the module api"""
+        """Tests that a module can send an event into a room via the module api"""
         content = {
             "msgtype": "m.text",
             "body": "Hello!",
@@ -233,13 +270,60 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
             "content": content,
             "sender": self.user_id,
         }
-        event = self.get_success(
-            current_rules_module().module_api.create_and_send_event_into_room(
-                event_dict
-            )
-        )  # type: EventBase
+        event: EventBase = self.get_success(
+            self.hs.get_module_api().create_and_send_event_into_room(event_dict)
+        )
 
         self.assertEquals(event.sender, self.user_id)
         self.assertEquals(event.room_id, self.room_id)
         self.assertEquals(event.type, "m.room.message")
         self.assertEquals(event.content, content)
+
+    @unittest.override_config(
+        {
+            "third_party_event_rules": {
+                "module": __name__ + ".LegacyChangeEvents",
+                "config": {},
+            }
+        }
+    )
+    def test_legacy_check_event_allowed(self):
+        """Tests that the wrapper for legacy check_event_allowed callbacks works
+        correctly.
+        """
+        channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/send/m.room.message/1" % self.room_id,
+            {
+                "msgtype": "m.text",
+                "body": "Original body",
+            },
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+
+        event_id = channel.json_body["event_id"]
+
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+
+        self.assertIn("foo", channel.json_body["content"].keys())
+        self.assertEqual(channel.json_body["content"]["foo"], "bar")
+
+    @unittest.override_config(
+        {
+            "third_party_event_rules": {
+                "module": __name__ + ".LegacyDenyNewRooms",
+                "config": {},
+            }
+        }
+    )
+    def test_legacy_on_create_room(self):
+        """Tests that the wrapper for legacy on_create_room callbacks works
+        correctly.
+        """
+        self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 605b952316..7eba69642a 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -453,7 +453,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
 
         # stick the flows results in a dict by type
-        flow_results = {}  # type: Dict[str, Any]
+        flow_results: Dict[str, Any] = {}
         for f in channel.json_body["flows"]:
             flow_type = f["type"]
             self.assertNotIn(
@@ -501,7 +501,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         p.close()
 
         # there should be a link for each href
-        returned_idps = []  # type: List[str]
+        returned_idps: List[str] = []
         for link in p.links:
             path, query = link.split("?", 1)
             self.assertEqual(path, "pick_idp")
@@ -582,7 +582,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         # ... and should have set a cookie including the redirect url
         cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
         assert cookie_headers
-        cookies = {}  # type: Dict[str, str]
+        cookies: Dict[str, str] = {}
         for h in cookie_headers:
             key, value = h.split(";")[0].split("=", maxsplit=1)
             cookies[key] = value
@@ -874,9 +874,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
     def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
         # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
-        result = jwt.encode(
-            payload, secret, self.jwt_algorithm
-        )  # type: Union[str, bytes]
+        result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm)
         if isinstance(result, bytes):
             return result.decode("ascii")
         return result
@@ -1084,7 +1082,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
 
     def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
         # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
-        result = jwt.encode(payload, secret, "RS256")  # type: Union[bytes,str]
+        result: Union[bytes, str] = jwt.encode(payload, secret, "RS256")
         if isinstance(result, bytes):
             return result.decode("ascii")
         return result
@@ -1272,7 +1270,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
         self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
 
         # ... with a username_mapping_session cookie
-        cookies = {}  # type: Dict[str,str]
+        cookies: Dict[str, str] = {}
         channel.extract_cookies(cookies)
         self.assertIn("username_mapping_session", cookies)
         session_id = cookies["username_mapping_session"]
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index e94566ffd7..3df070c936 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -1206,7 +1206,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
         reason = "hello"
         channel = self.make_request(
             "POST",
-            "/_matrix/client/r0/rooms/{}/join".format(self.room_id),
+            f"/_matrix/client/r0/rooms/{self.room_id}/join",
             content={"reason": reason},
             access_token=self.second_tok,
         )
@@ -1220,7 +1220,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
         reason = "hello"
         channel = self.make_request(
             "POST",
-            "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
+            f"/_matrix/client/r0/rooms/{self.room_id}/leave",
             content={"reason": reason},
             access_token=self.second_tok,
         )
@@ -1234,7 +1234,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
         reason = "hello"
         channel = self.make_request(
             "POST",
-            "/_matrix/client/r0/rooms/{}/kick".format(self.room_id),
+            f"/_matrix/client/r0/rooms/{self.room_id}/kick",
             content={"reason": reason, "user_id": self.second_user_id},
             access_token=self.second_tok,
         )
@@ -1248,7 +1248,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
         reason = "hello"
         channel = self.make_request(
             "POST",
-            "/_matrix/client/r0/rooms/{}/ban".format(self.room_id),
+            f"/_matrix/client/r0/rooms/{self.room_id}/ban",
             content={"reason": reason, "user_id": self.second_user_id},
             access_token=self.creator_tok,
         )
@@ -1260,7 +1260,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
         reason = "hello"
         channel = self.make_request(
             "POST",
-            "/_matrix/client/r0/rooms/{}/unban".format(self.room_id),
+            f"/_matrix/client/r0/rooms/{self.room_id}/unban",
             content={"reason": reason, "user_id": self.second_user_id},
             access_token=self.creator_tok,
         )
@@ -1272,7 +1272,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
         reason = "hello"
         channel = self.make_request(
             "POST",
-            "/_matrix/client/r0/rooms/{}/invite".format(self.room_id),
+            f"/_matrix/client/r0/rooms/{self.room_id}/invite",
             content={"reason": reason, "user_id": self.second_user_id},
             access_token=self.creator_tok,
         )
@@ -1291,7 +1291,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
         reason = "hello"
         channel = self.make_request(
             "POST",
-            "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
+            f"/_matrix/client/r0/rooms/{self.room_id}/leave",
             content={"reason": reason},
             access_token=self.second_tok,
         )
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 69798e95c3..fc2d35596e 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -19,7 +19,7 @@ import json
 import re
 import time
 import urllib.parse
-from typing import Any, Dict, Mapping, MutableMapping, Optional
+from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
 from unittest.mock import patch
 
 import attr
@@ -53,6 +53,9 @@ class RestHelper:
         tok: str = None,
         expect_code: int = 200,
         extra_content: Optional[Dict] = None,
+        custom_headers: Optional[
+            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+        ] = None,
     ) -> str:
         """
         Create a room.
@@ -87,6 +90,7 @@ class RestHelper:
             "POST",
             path,
             json.dumps(content).encode("utf8"),
+            custom_headers=custom_headers,
         )
 
         assert channel.result["code"] == b"%d" % expect_code, channel.result
@@ -175,14 +179,30 @@ class RestHelper:
 
         self.auth_user_id = temp_id
 
-    def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
+    def send(
+        self,
+        room_id,
+        body=None,
+        txn_id=None,
+        tok=None,
+        expect_code=200,
+        custom_headers: Optional[
+            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+        ] = None,
+    ):
         if body is None:
             body = "body_text_here"
 
         content = {"msgtype": "m.text", "body": body}
 
         return self.send_event(
-            room_id, "m.room.message", content, txn_id, tok, expect_code
+            room_id,
+            "m.room.message",
+            content,
+            txn_id,
+            tok,
+            expect_code,
+            custom_headers=custom_headers,
         )
 
     def send_event(
@@ -193,6 +213,9 @@ class RestHelper:
         txn_id=None,
         tok=None,
         expect_code=200,
+        custom_headers: Optional[
+            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+        ] = None,
     ):
         if txn_id is None:
             txn_id = "m%s" % (str(time.time()))
@@ -207,6 +230,7 @@ class RestHelper:
             "PUT",
             path,
             json.dumps(content or {}).encode("utf8"),
+            custom_headers=custom_headers,
         )
 
         assert (
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 856aa8682f..2e2f94742e 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -273,7 +273,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         prev_token = None
         found_event_ids = []
-        encoded_key = urllib.parse.quote_plus("👍".encode("utf-8"))
+        encoded_key = urllib.parse.quote_plus("👍".encode())
         for _ in range(20):
             from_token = ""
             if prev_token:
diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/v2_alpha/test_report_event.py
index 1ec6b05e5b..a76a6fef1e 100644
--- a/tests/rest/client/v2_alpha/test_report_event.py
+++ b/tests/rest/client/v2_alpha/test_report_event.py
@@ -41,7 +41,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
         self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok)
         resp = self.helper.send(self.room_id, tok=self.admin_user_tok)
         self.event_id = resp["event_id"]
-        self.report_path = "rooms/{}/report/{}".format(self.room_id, self.event_id)
+        self.report_path = f"rooms/{self.room_id}/report/{self.event_id}"
 
     def test_reason_str_and_score_int(self):
         data = {"reason": "this makes me sad", "score": -100}
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 95e7075841..2d6b49692e 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -310,7 +310,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         correctly decode it as the UTF-8 string, and use filename* in the
         response.
         """
-        filename = parse.quote("\u2603".encode("utf8")).encode("ascii")
+        filename = parse.quote("\u2603".encode()).encode("ascii")
         channel = self._req(
             b"inline; filename*=utf-8''" + filename + self.test_image.extension
         )
diff --git a/tests/server.py b/tests/server.py
index f32d8dc375..6fddd3b305 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -52,7 +52,7 @@ class FakeChannel:
     _reactor = attr.ib()
     result = attr.ib(type=dict, default=attr.Factory(dict))
     _ip = attr.ib(type=str, default="127.0.0.1")
-    _producer = None  # type: Optional[Union[IPullProducer, IPushProducer]]
+    _producer: Optional[Union[IPullProducer, IPushProducer]] = None
 
     @property
     def json_body(self):
@@ -316,8 +316,10 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
 
         self._tcp_callbacks = {}
         self._udp = []
-        lookups = self.lookups = {}  # type: Dict[str, str]
-        self._thread_callbacks = deque()  # type: Deque[Callable[[], None]]
+        self.lookups: Dict[str, str] = {}
+        self._thread_callbacks: Deque[Callable[[], None]] = deque()
+
+        lookups = self.lookups
 
         @implementer(IResolverSimple)
         class FakeResolver:
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 9ca70e7367..d326a1d6a6 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -98,3 +98,16 @@ class LockTestCase(unittest.HomeserverTestCase):
 
         lock2 = self.get_success(self.store.try_acquire_lock("name", "key"))
         self.assertIsNotNone(lock2)
+
+    def test_shutdown(self):
+        """Test that shutting down Synapse releases the locks"""
+        # Acquire two locks
+        lock = self.get_success(self.store.try_acquire_lock("name", "key1"))
+        self.assertIsNotNone(lock)
+        lock2 = self.get_success(self.store.try_acquire_lock("name", "key2"))
+        self.assertIsNotNone(lock2)
+
+        # Now call the shutdown code
+        self.get_success(self.store._on_shutdown())
+
+        self.assertEqual(self.store._live_tokens, {})
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 069db0edc4..0da42b5ac5 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -7,9 +7,7 @@ from tests import unittest
 
 class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, homeserver):
-        self.updates = (
-            self.hs.get_datastore().db_pool.updates
-        )  # type: BackgroundUpdater
+        self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates
         # the base test class should have run the real bg updates for us
         self.assertTrue(
             self.get_success(self.updates.has_completed_background_updates())
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index 41bef62ca8..43628ce44f 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -59,5 +59,5 @@ class DirectoryStoreTestCase(HomeserverTestCase):
         self.assertEqual(self.room.to_string(), room_id)
 
         self.assertIsNone(
-            (self.get_success(self.store.get_association_from_room_alias(self.alias)))
+            self.get_success(self.store.get_association_from_room_alias(self.alias))
         )
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 792b1c44c1..7486078284 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -27,7 +27,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
-        self.db_pool = self.store.db_pool  # type: DatabasePool
+        self.db_pool: DatabasePool = self.store.db_pool
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
@@ -460,7 +460,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
-        self.db_pool = self.store.db_pool  # type: DatabasePool
+        self.db_pool: DatabasePool = self.store.db_pool
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
@@ -586,7 +586,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
-        self.db_pool = self.store.db_pool  # type: DatabasePool
+        self.db_pool: DatabasePool = self.store.db_pool
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 8a446da848..a1ba99ff14 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -45,11 +45,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertIsNone(
-            (
-                self.get_success(
-                    self.store.get_profile_displayname(self.u_frank.localpart)
-                )
-            )
+            self.get_success(self.store.get_profile_displayname(self.u_frank.localpart))
         )
 
     def test_avatar_url(self):
@@ -76,9 +72,5 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertIsNone(
-            (
-                self.get_success(
-                    self.store.get_profile_avatar_url(self.u_frank.localpart)
-                )
-            )
+            self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart))
         )
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 54c5b470c7..e5574063f1 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -75,7 +75,7 @@ class PurgeTests(HomeserverTestCase):
         token = self.get_success(
             self.store.get_topological_token_for_event(last["event_id"])
         )
-        event = "t{}-{}".format(token.topological + 1, token.stream + 1)
+        event = f"t{token.topological + 1}-{token.stream + 1}"
 
         # Purge everything before this topological token
         f = self.get_failure(
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 70257bf210..31ce7f6252 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -49,7 +49,7 @@ class RoomStoreTestCase(HomeserverTestCase):
         )
 
     def test_get_room_unknown_room(self):
-        self.assertIsNone((self.get_success(self.store.get_room("!uknown:test"))))
+        self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))
 
     def test_get_room_with_stats(self):
         self.assertDictContainsSubset(
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 88888319cc..f73306ecc4 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -13,12 +13,13 @@
 # limitations under the License.
 
 import unittest
+from typing import Optional
 
 from synapse import event_auth
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import RoomVersions
-from synapse.events import make_event_from_dict
-from synapse.types import get_domain_from_id
+from synapse.events import EventBase, make_event_from_dict
+from synapse.types import JsonDict, get_domain_from_id
 
 
 class EventAuthTestCase(unittest.TestCase):
@@ -432,7 +433,7 @@ class EventAuthTestCase(unittest.TestCase):
 TEST_ROOM_ID = "!test:room"
 
 
-def _create_event(user_id):
+def _create_event(user_id: str) -> EventBase:
     return make_event_from_dict(
         {
             "room_id": TEST_ROOM_ID,
@@ -444,7 +445,9 @@ def _create_event(user_id):
     )
 
 
-def _member_event(user_id, membership, sender=None):
+def _member_event(
+    user_id: str, membership: str, sender: Optional[str] = None
+) -> EventBase:
     return make_event_from_dict(
         {
             "room_id": TEST_ROOM_ID,
@@ -458,11 +461,11 @@ def _member_event(user_id, membership, sender=None):
     )
 
 
-def _join_event(user_id):
+def _join_event(user_id: str) -> EventBase:
     return _member_event(user_id, "join")
 
 
-def _power_levels_event(sender, content):
+def _power_levels_event(sender: str, content: JsonDict) -> EventBase:
     return make_event_from_dict(
         {
             "room_id": TEST_ROOM_ID,
@@ -475,7 +478,7 @@ def _power_levels_event(sender, content):
     )
 
 
-def _alias_event(sender, **kwargs):
+def _alias_event(sender: str, **kwargs) -> EventBase:
     data = {
         "room_id": TEST_ROOM_ID,
         "event_id": _get_event_id(),
@@ -488,7 +491,7 @@ def _alias_event(sender, **kwargs):
     return make_event_from_dict(data)
 
 
-def _random_state_event(sender):
+def _random_state_event(sender: str) -> EventBase:
     return make_event_from_dict(
         {
             "room_id": TEST_ROOM_ID,
@@ -501,7 +504,7 @@ def _random_state_event(sender):
     )
 
 
-def _join_rules_event(sender, join_rule):
+def _join_rules_event(sender: str, join_rule: str) -> EventBase:
     return make_event_from_dict(
         {
             "room_id": TEST_ROOM_ID,
@@ -519,7 +522,7 @@ def _join_rules_event(sender, join_rule):
 event_count = 0
 
 
-def _get_event_id():
+def _get_event_id() -> str:
     global event_count
     c = event_count
     event_count += 1
diff --git a/tests/test_state.py b/tests/test_state.py
index 62f7095873..e5488df1ac 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -168,6 +168,7 @@ class StateTestCase(unittest.TestCase):
                 "get_state_handler",
                 "get_clock",
                 "get_state_resolution_handler",
+                "get_account_validity_handler",
                 "hostname",
             ]
         )
@@ -199,7 +200,7 @@ class StateTestCase(unittest.TestCase):
 
         self.store.register_events(graph.walk())
 
-        context_store = {}  # type: dict[str, EventContext]
+        context_store: dict[str, EventContext] = {}
 
         for event in graph.walk():
             context = yield defer.ensureDeferred(
diff --git a/tests/test_types.py b/tests/test_types.py
index d7881021d3..0d0c00d97a 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -103,6 +103,4 @@ class MapUsernameTestCase(unittest.TestCase):
     def testNonAscii(self):
         # this should work with either a unicode or a bytes
         self.assertEqual(map_username_to_mxid_localpart("tĂȘst"), "t=c3=aast")
-        self.assertEqual(
-            map_username_to_mxid_localpart("tĂȘst".encode("utf-8")), "t=c3=aast"
-        )
+        self.assertEqual(map_username_to_mxid_localpart("tĂȘst".encode()), "t=c3=aast")
diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
index 1fbb38f4be..e878af5f12 100644
--- a/tests/test_utils/html_parsers.py
+++ b/tests/test_utils/html_parsers.py
@@ -23,13 +23,13 @@ class TestHtmlParser(HTMLParser):
         super().__init__()
 
         # a list of links found in the doc
-        self.links = []  # type: List[str]
+        self.links: List[str] = []
 
         # the values of any hidden <input>s: map from name to value
-        self.hiddens = {}  # type: Dict[str, Optional[str]]
+        self.hiddens: Dict[str, Optional[str]] = {}
 
         # the values of any radio buttons: map from name to list of values
-        self.radios = {}  # type: Dict[str, List[Optional[str]]]
+        self.radios: Dict[str, List[Optional[str]]] = {}
 
     def handle_starttag(
         self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
diff --git a/tests/unittest.py b/tests/unittest.py
index 74db7c08f1..3eec9c4d5b 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -140,7 +140,7 @@ class TestCase(unittest.TestCase):
             try:
                 self.assertEquals(attrs[key], getattr(obj, key))
             except AssertionError as e:
-                raise (type(e))("Assert error for '.{}':".format(key)) from e
+                raise (type(e))(f"Assert error for '.{key}':") from e
 
     def assert_dict(self, required, actual):
         """Does a partial assert of a dict.
@@ -520,7 +520,7 @@ class HomeserverTestCase(TestCase):
         if not isinstance(deferred, Deferred):
             return d
 
-        results = []  # type: list
+        results: list = []
         deferred.addBoth(results.append)
 
         self.pump(by=by)
@@ -594,7 +594,15 @@ class HomeserverTestCase(TestCase):
         user_id = channel.json_body["user_id"]
         return user_id
 
-    def login(self, username, password, device_id=None):
+    def login(
+        self,
+        username,
+        password,
+        device_id=None,
+        custom_headers: Optional[
+            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+        ] = None,
+    ):
         """
         Log in a user, and get an access token. Requires the Login API be
         registered.
@@ -605,7 +613,10 @@ class HomeserverTestCase(TestCase):
             body["device_id"] = device_id
 
         channel = self.make_request(
-            "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
+            "POST",
+            "/_matrix/client/r0/login",
+            json.dumps(body).encode("utf8"),
+            custom_headers=custom_headers,
         )
         self.assertEqual(channel.code, 200, channel.result)
 
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 0277998cbe..39947a166b 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -174,7 +174,7 @@ class DescriptorTestCase(unittest.TestCase):
                 return self.result
 
         obj = Cls()
-        callbacks = set()  # type: Set[str]
+        callbacks: Set[str] = set()
 
         # set off an asynchronous request
         obj.result = origin_d = defer.Deferred()
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index e712eb42ea..3c0ddd4f18 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -44,7 +44,7 @@ class ChunkSeqTests(TestCase):
         )
 
     def test_empty_input(self):
-        parts = chunk_seq([], 5)  # type: Iterable[Sequence]
+        parts: Iterable[Sequence] = chunk_seq([], 5)
 
         self.assertEqual(
             list(parts),
@@ -56,13 +56,13 @@ class SortTopologically(TestCase):
     def test_empty(self):
         "Test that an empty graph works correctly"
 
-        graph = {}  # type: Dict[int, List[int]]
+        graph: Dict[int, List[int]] = {}
         self.assertEqual(list(sorted_topologically([], graph)), [])
 
     def test_handle_empty_graph(self):
         "Test that a graph where a node doesn't have an entry is treated as empty"
 
-        graph = {}  # type: Dict[int, List[int]]
+        graph: Dict[int, List[int]] = {}
 
         # For disconnected nodes the output is simply sorted.
         self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
@@ -70,7 +70,7 @@ class SortTopologically(TestCase):
     def test_disconnected(self):
         "Test that a graph with no edges work"
 
-        graph = {1: [], 2: []}  # type: Dict[int, List[int]]
+        graph: Dict[int, List[int]] = {1: [], 2: []}
 
         # For disconnected nodes the output is simply sorted.
         self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
@@ -78,19 +78,19 @@ class SortTopologically(TestCase):
     def test_linear(self):
         "Test that a simple `4 -> 3 -> 2 -> 1` graph works"
 
-        graph = {1: [], 2: [1], 3: [2], 4: [3]}  # type: Dict[int, List[int]]
+        graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
 
         self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
 
     def test_subset(self):
         "Test that only sorting a subset of the graph works"
-        graph = {1: [], 2: [1], 3: [2], 4: [3]}  # type: Dict[int, List[int]]
+        graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
 
         self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
 
     def test_fork(self):
         "Test that a forked graph works"
-        graph = {1: [], 2: [1], 3: [1], 4: [2, 3]}  # type: Dict[int, List[int]]
+        graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]}
 
         # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should
         # always get the same one.
@@ -98,12 +98,12 @@ class SortTopologically(TestCase):
 
     def test_duplicates(self):
         "Test that a graph with duplicate edges work"
-        graph = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]}  # type: Dict[int, List[int]]
+        graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]}
 
         self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
 
     def test_multiple_paths(self):
         "Test that a graph with multiple paths between two nodes work"
-        graph = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]}  # type: Dict[int, List[int]]
+        graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]}
 
         self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])