summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-07-29 11:08:49 +0100
committerErik Johnston <erik@matrix.org>2021-07-29 11:08:49 +0100
commitc36c2777900284cf94e93e60e34c3b856bb31551 (patch)
tree5079c397821dab6f70dd0200a4c435c1b1d91db7 /synapse
parentMerge tag 'v1.38.1' (diff)
parentFixup changelog (diff)
downloadsynapse-c36c2777900284cf94e93e60e34c3b856bb31551.tar.xz
Merge tag 'v1.39.0rc3'
Synapse 1.39.0rc3 (2021-07-28)
==============================

Bugfixes
--------

- Fix a bug introduced in Synapse 1.38 which caused an exception at startup when SAML authentication was enabled. ([\#10477](https://github.com/matrix-org/synapse/issues/10477))
- Fix a long-standing bug where Synapse would not inform clients that a device had exhausted its one-time-key pool, potentially causing problems decrypting events. ([\#10485](https://github.com/matrix-org/synapse/issues/10485))
- Fix reporting old R30 stats as R30v2 stats. Introduced in v1.39.0rc1. ([\#10486](https://github.com/matrix-org/synapse/issues/10486))

Internal Changes
----------------

- Fix an error which prevented the Github Actions workflow to build the docker images from running. ([\#10461](https://github.com/matrix-org/synapse/issues/10461))
- Fix release script to correctly version debian changelog when doing RCs. ([\#10465](https://github.com/matrix-org/synapse/issues/10465))
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py58
-rw-r--r--synapse/api/constants.py8
-rw-r--r--synapse/api/errors.py4
-rw-r--r--synapse/api/filtering.py2
-rw-r--r--synapse/api/ratelimiting.py4
-rw-r--r--synapse/api/room_versions.py4
-rw-r--r--synapse/app/_base.py2
-rw-r--r--synapse/app/generic_worker.py8
-rw-r--r--synapse/app/homeserver.py6
-rw-r--r--synapse/app/phone_stats_home.py36
-rw-r--r--synapse/appservice/api.py4
-rw-r--r--synapse/config/account_validity.py102
-rw-r--r--synapse/config/appservice.py6
-rw-r--r--synapse/config/cache.py4
-rw-r--r--synapse/config/emailconfig.py4
-rw-r--r--synapse/config/experimental.py6
-rw-r--r--synapse/config/federation.py2
-rw-r--r--synapse/config/oidc.py2
-rw-r--r--synapse/config/password_auth_providers.py2
-rw-r--r--synapse/config/repository.py4
-rw-r--r--synapse/config/server.py16
-rw-r--r--synapse/config/spam_checker.py2
-rw-r--r--synapse/config/sso.py2
-rw-r--r--synapse/config/stats.py9
-rw-r--r--synapse/config/third_party_event_rules.py15
-rw-r--r--synapse/config/tls.py12
-rw-r--r--synapse/crypto/keyring.py20
-rw-r--r--synapse/event_auth.py11
-rw-r--r--synapse/events/__init__.py40
-rw-r--r--synapse/events/builder.py16
-rw-r--r--synapse/events/spamcheck.py4
-rw-r--r--synapse/events/third_party_rules.py245
-rw-r--r--synapse/federation/federation_client.py10
-rw-r--r--synapse/federation/federation_server.py43
-rw-r--r--synapse/federation/send_queue.py26
-rw-r--r--synapse/federation/sender/__init__.py110
-rw-r--r--synapse/federation/sender/per_destination_queue.py34
-rw-r--r--synapse/federation/transport/client.py8
-rw-r--r--synapse/federation/transport/server.py120
-rw-r--r--synapse/groups/groups_server.py12
-rw-r--r--synapse/handlers/_base.py10
-rw-r--r--synapse/handlers/account_validity.py128
-rw-r--r--synapse/handlers/admin.py4
-rw-r--r--synapse/handlers/appservice.py6
-rw-r--r--synapse/handlers/auth.py16
-rw-r--r--synapse/handlers/cas.py4
-rw-r--r--synapse/handlers/device.py14
-rw-r--r--synapse/handlers/devicemessage.py2
-rw-r--r--synapse/handlers/directory.py11
-rw-r--r--synapse/handlers/e2e_keys.py40
-rw-r--r--synapse/handlers/events.py6
-rw-r--r--synapse/handlers/federation.py53
-rw-r--r--synapse/handlers/groups_local.py4
-rw-r--r--synapse/handlers/identity.py4
-rw-r--r--synapse/handlers/initial_sync.py14
-rw-r--r--synapse/handlers/message.py44
-rw-r--r--synapse/handlers/oidc.py56
-rw-r--r--synapse/handlers/pagination.py4
-rw-r--r--synapse/handlers/presence.py28
-rw-r--r--synapse/handlers/profile.py4
-rw-r--r--synapse/handlers/receipts.py19
-rw-r--r--synapse/handlers/register.py20
-rw-r--r--synapse/handlers/room.py26
-rw-r--r--synapse/handlers/room_list.py50
-rw-r--r--synapse/handlers/saml.py8
-rw-r--r--synapse/handlers/search.py8
-rw-r--r--synapse/handlers/space_summary.py84
-rw-r--r--synapse/handlers/sso.py12
-rw-r--r--synapse/handlers/stats.py37
-rw-r--r--synapse/handlers/sync.py38
-rw-r--r--synapse/handlers/typing.py28
-rw-r--r--synapse/handlers/user_directory.py2
-rw-r--r--synapse/http/__init__.py2
-rw-r--r--synapse/http/client.py18
-rw-r--r--synapse/http/federation/well_known_resolver.py13
-rw-r--r--synapse/http/matrixfederationclient.py40
-rw-r--r--synapse/http/proxyagent.py14
-rw-r--r--synapse/http/server.py8
-rw-r--r--synapse/http/servlet.py2
-rw-r--r--synapse/http/site.py16
-rw-r--r--synapse/logging/_remote.py14
-rw-r--r--synapse/logging/_structured.py2
-rw-r--r--synapse/logging/context.py16
-rw-r--r--synapse/logging/opentracing.py12
-rw-r--r--synapse/metrics/__init__.py6
-rw-r--r--synapse/metrics/_exposition.py28
-rw-r--r--synapse/metrics/background_process_metrics.py7
-rw-r--r--synapse/module_api/__init__.py227
-rw-r--r--synapse/module_api/errors.py6
-rw-r--r--synapse/notifier.py18
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py4
-rw-r--r--synapse/push/clientformat.py4
-rw-r--r--synapse/push/emailpusher.py6
-rw-r--r--synapse/push/httppusher.py2
-rw-r--r--synapse/push/mailer.py12
-rw-r--r--synapse/push/presentable_names.py2
-rw-r--r--synapse/push/push_rule_evaluator.py4
-rw-r--r--synapse/push/pusher.py6
-rw-r--r--synapse/push/pusherpool.py26
-rw-r--r--synapse/python_dependencies.py4
-rw-r--r--synapse/replication/http/_base.py10
-rw-r--r--synapse/replication/slave/storage/_base.py6
-rw-r--r--synapse/replication/slave/storage/client_ips.py4
-rw-r--r--synapse/replication/tcp/client.py10
-rw-r--r--synapse/replication/tcp/commands.py6
-rw-r--r--synapse/replication/tcp/handler.py16
-rw-r--r--synapse/replication/tcp/protocol.py14
-rw-r--r--synapse/replication/tcp/redis.py8
-rw-r--r--synapse/replication/tcp/streams/_base.py14
-rw-r--r--synapse/replication/tcp/streams/events.py28
-rw-r--r--synapse/replication/tcp/streams/federation.py6
-rw-r--r--synapse/rest/admin/rooms.py17
-rw-r--r--synapse/rest/admin/users.py26
-rw-r--r--synapse/rest/client/v1/login.py33
-rw-r--r--synapse/rest/client/v1/room.py167
-rw-r--r--synapse/rest/client/v2_alpha/account_validity.py7
-rw-r--r--synapse/rest/client/v2_alpha/sendtodevice.py2
-rw-r--r--synapse/rest/consent/consent_resource.py4
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py4
-rw-r--r--synapse/rest/media/v1/__init__.py4
-rw-r--r--synapse/rest/media/v1/_base.py2
-rw-r--r--synapse/rest/media/v1/media_repository.py10
-rw-r--r--synapse/rest/media/v1/media_storage.py4
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py8
-rw-r--r--synapse/rest/media/v1/upload_resource.py6
-rw-r--r--synapse/rest/synapse/client/pick_username.py4
-rw-r--r--synapse/server.py6
-rw-r--r--synapse/server_notices/consent_server_notices.py2
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py2
-rw-r--r--synapse/server_notices/server_notices_sender.py6
-rw-r--r--synapse/state/__init__.py20
-rw-r--r--synapse/state/v1.py2
-rw-r--r--synapse/state/v2.py18
-rw-r--r--synapse/storage/background_updates.py16
-rw-r--r--synapse/storage/database.py16
-rw-r--r--synapse/storage/databases/main/appservice.py4
-rw-r--r--synapse/storage/databases/main/deviceinbox.py4
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py11
-rw-r--r--synapse/storage/databases/main/event_federation.py30
-rw-r--r--synapse/storage/databases/main/event_push_actions.py2
-rw-r--r--synapse/storage/databases/main/events.py44
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py8
-rw-r--r--synapse/storage/databases/main/events_worker.py6
-rw-r--r--synapse/storage/databases/main/group_server.py6
-rw-r--r--synapse/storage/databases/main/lock.py6
-rw-r--r--synapse/storage/databases/main/metrics.py133
-rw-r--r--synapse/storage/databases/main/purge_events.py4
-rw-r--r--synapse/storage/databases/main/push_rule.py6
-rw-r--r--synapse/storage/databases/main/registration.py2
-rw-r--r--synapse/storage/databases/main/room.py104
-rw-r--r--synapse/storage/databases/main/roommember.py15
-rw-r--r--synapse/storage/databases/main/stats.py299
-rw-r--r--synapse/storage/databases/main/stream.py6
-rw-r--r--synapse/storage/databases/main/tags.py2
-rw-r--r--synapse/storage/databases/main/ui_auth.py4
-rw-r--r--synapse/storage/persist_events.py16
-rw-r--r--synapse/storage/prepare_database.py8
-rw-r--r--synapse/storage/schema/__init__.py6
-rw-r--r--synapse/storage/schema/main/delta/61/01change_appservices_txns.sql.postgres23
-rw-r--r--synapse/storage/schema/main/delta/61/02drop_redundant_room_depth_index.sql18
-rw-r--r--synapse/storage/schema/main/delta/61/03recreate_min_depth.py70
-rw-r--r--synapse/storage/schema/state/delta/61/02state_groups_state_n_distinct.sql.postgres34
-rw-r--r--synapse/storage/state.py4
-rw-r--r--synapse/storage/util/id_generators.py12
-rw-r--r--synapse/storage/util/sequence.py6
-rw-r--r--synapse/streams/events.py4
-rw-r--r--synapse/types.py14
-rw-r--r--synapse/util/async_helpers.py8
-rw-r--r--synapse/util/batching_queue.py8
-rw-r--r--synapse/util/caches/__init__.py4
-rw-r--r--synapse/util/caches/cached_call.py6
-rw-r--r--synapse/util/caches/deferred_cache.py12
-rw-r--r--synapse/util/caches/descriptors.py36
-rw-r--r--synapse/util/caches/dictionary_cache.py6
-rw-r--r--synapse/util/caches/expiringcache.py4
-rw-r--r--synapse/util/caches/lrucache.py11
-rw-r--r--synapse/util/caches/response_cache.py2
-rw-r--r--synapse/util/caches/stream_change_cache.py6
-rw-r--r--synapse/util/caches/treecache.py3
-rw-r--r--synapse/util/caches/ttlcache.py6
-rw-r--r--synapse/util/daemonize.py8
-rw-r--r--synapse/util/iterutils.py2
-rw-r--r--synapse/util/macaroons.py2
-rw-r--r--synapse/util/metrics.py2
-rw-r--r--synapse/util/patch_inline_callbacks.py4
-rw-r--r--synapse/visibility.py6
187 files changed, 2379 insertions, 1504 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 7ea5a790db..c9a445c8fe 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.38.1"
+__version__ = "1.39.0rc3"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 307f5f9a94..05699714ee 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -62,16 +62,14 @@ class Auth:
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
+        self._account_validity_handler = hs.get_account_validity_handler()
 
-        self.token_cache = LruCache(
+        self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
             10000, "token_cache"
-        )  # type: LruCache[str, Tuple[str, bool]]
+        )
 
         self._auth_blocking = AuthBlocking(self.hs)
 
-        self._account_validity_enabled = (
-            hs.config.account_validity.account_validity_enabled
-        )
         self._track_appservice_user_ips = hs.config.track_appservice_user_ips
         self._macaroon_secret_key = hs.config.macaroon_secret_key
         self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
@@ -187,12 +185,17 @@ class Auth:
             shadow_banned = user_info.shadow_banned
 
             # Deny the request if the user account has expired.
-            if self._account_validity_enabled and not allow_expired:
-                if await self.store.is_account_expired(
-                    user_info.user_id, self.clock.time_msec()
+            if not allow_expired:
+                if await self._account_validity_handler.is_user_expired(
+                    user_info.user_id
                 ):
+                    # Raise the error if either an account validity module has determined
+                    # the account has expired, or the legacy account validity
+                    # implementation is enabled and determined the account has expired
                     raise AuthError(
-                        403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
+                        403,
+                        "User account has expired",
+                        errcode=Codes.EXPIRED_ACCOUNT,
                     )
 
             device_id = user_info.device_id
@@ -240,6 +243,37 @@ class Auth:
         except KeyError:
             raise MissingClientTokenError()
 
+    async def validate_appservice_can_control_user_id(
+        self, app_service: ApplicationService, user_id: str
+    ):
+        """Validates that the app service is allowed to control
+        the given user.
+
+        Args:
+            app_service: The app service that controls the user
+            user_id: The author MXID that the app service is controlling
+
+        Raises:
+            AuthError: If the application service is not allowed to control the user
+                (user namespace regex does not match, wrong homeserver, etc)
+                or if the user has not been registered yet.
+        """
+
+        # It's ok if the app service is trying to use the sender from their registration
+        if app_service.sender == user_id:
+            pass
+        # Check to make sure the app service is allowed to control the user
+        elif not app_service.is_interested_in_user(user_id):
+            raise AuthError(
+                403,
+                "Application service cannot masquerade as this user (%s)." % user_id,
+            )
+        # Check to make sure the user is already registered on the homeserver
+        elif not (await self.store.get_user_by_id(user_id)):
+            raise AuthError(
+                403, "Application service has not registered this user (%s)" % user_id
+            )
+
     async def _get_appservice_user_id(
         self, request: Request
     ) -> Tuple[Optional[str], Optional[ApplicationService]]:
@@ -261,13 +295,11 @@ class Auth:
             return app_service.sender, app_service
 
         user_id = request.args[b"user_id"][0].decode("utf8")
+        await self.validate_appservice_can_control_user_id(app_service, user_id)
+
         if app_service.sender == user_id:
             return app_service.sender, app_service
 
-        if not app_service.is_interested_in_user(user_id):
-            raise AuthError(403, "Application service cannot masquerade as this user.")
-        if not (await self.store.get_user_by_id(user_id)):
-            raise AuthError(403, "Application service has not registered this user")
         return user_id, app_service
 
     async def get_user_by_access_token(
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 8363c2bb0f..8c7ad2a407 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -127,6 +127,14 @@ class ToDeviceEventTypes:
     RoomKeyRequest = "m.room_key_request"
 
 
+class DeviceKeyAlgorithms:
+    """Spec'd algorithms for the generation of per-device keys"""
+
+    ED25519 = "ed25519"
+    CURVE25519 = "curve25519"
+    SIGNED_CURVE25519 = "signed_curve25519"
+
+
 class EduTypes:
     Presence = "m.presence"
 
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 4cb8bbaf70..054ab14ab6 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -118,7 +118,7 @@ class RedirectException(CodeMessageException):
         super().__init__(code=http_code, msg=msg)
         self.location = location
 
-        self.cookies = []  # type: List[bytes]
+        self.cookies: List[bytes] = []
 
 
 class SynapseError(CodeMessageException):
@@ -160,7 +160,7 @@ class ProxiedRequestError(SynapseError):
     ):
         super().__init__(code, msg, errcode)
         if additional_fields is None:
-            self._additional_fields = {}  # type: Dict
+            self._additional_fields: Dict = {}
         else:
             self._additional_fields = dict(additional_fields)
 
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index ce49a0ad58..ad1ff6a9df 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -289,7 +289,7 @@ class Filter:
             room_id = None
             ev_type = "m.presence"
             contains_url = False
-            labels = []  # type: List[str]
+            labels: List[str] = []
         else:
             sender = event.get("sender", None)
             if not sender:
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index b9a10283f4..3e3d09bbd2 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -46,9 +46,7 @@ class Ratelimiter:
         #   * How many times an action has occurred since a point in time
         #   * The point in time
         #   * The rate_hz of this particular entry. This can vary per request
-        self.actions = (
-            OrderedDict()
-        )  # type: OrderedDict[Hashable, Tuple[float, int, float]]
+        self.actions: OrderedDict[Hashable, Tuple[float, int, float]] = OrderedDict()
 
     async def can_do_action(
         self,
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index f6c1c97b40..a20abc5a65 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -195,7 +195,7 @@ class RoomVersions:
     )
 
 
-KNOWN_ROOM_VERSIONS = {
+KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
     v.identifier: v
     for v in (
         RoomVersions.V1,
@@ -209,4 +209,4 @@ KNOWN_ROOM_VERSIONS = {
         RoomVersions.V7,
     )
     # Note that we do not include MSC2043 here unless it is enabled in the config.
-}  # type: Dict[str, RoomVersion]
+}
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index b30571fe49..50a02f51f5 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -38,6 +38,7 @@ from synapse.app.phone_stats_home import start_phone_stats_home
 from synapse.config.homeserver import HomeServerConfig
 from synapse.crypto import context_factory
 from synapse.events.spamcheck import load_legacy_spam_checkers
+from synapse.events.third_party_rules import load_legacy_third_party_event_rules
 from synapse.logging.context import PreserveLoggingContext
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.metrics.jemalloc import setup_jemalloc_stats
@@ -368,6 +369,7 @@ async def start(hs: "HomeServer"):
         module(config=config, api=module_api)
 
     load_legacy_spam_checkers(hs)
+    load_legacy_third_party_event_rules(hs)
 
     # If we've configured an expiry time for caches, start the background job now.
     setup_expire_lru_cache_entries(hs)
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 5b041fcaad..c3d4992518 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -270,7 +270,7 @@ class GenericWorkerServer(HomeServer):
             site_tag = port
 
         # We always include a health resource.
-        resources = {"/health": HealthResource()}  # type: Dict[str, IResource]
+        resources: Dict[str, IResource] = {"/health": HealthResource()}
 
         for res in listener_config.http_options.resources:
             for name in res.names:
@@ -395,10 +395,8 @@ class GenericWorkerServer(HomeServer):
             elif listener.type == "metrics":
                 if not self.config.enable_metrics:
                     logger.warning(
-                        (
-                            "Metrics listener configured, but "
-                            "enable_metrics is not True!"
-                        )
+                        "Metrics listener configured, but "
+                        "enable_metrics is not True!"
                     )
                 else:
                     _base.listen_metrics(listener.bind_addresses, listener.port)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 7af56ac136..920b34d97b 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -305,10 +305,8 @@ class SynapseHomeServer(HomeServer):
             elif listener.type == "metrics":
                 if not self.config.enable_metrics:
                     logger.warning(
-                        (
-                            "Metrics listener configured, but "
-                            "enable_metrics is not True!"
-                        )
+                        "Metrics listener configured, but "
+                        "enable_metrics is not True!"
                     )
                 else:
                     _base.listen_metrics(listener.bind_addresses, listener.port)
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index 8f86cecb76..86ad7337a9 100644
--- a/synapse/app/phone_stats_home.py
+++ b/synapse/app/phone_stats_home.py
@@ -71,6 +71,8 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
     # General statistics
     #
 
+    store = hs.get_datastore()
+
     stats["homeserver"] = hs.config.server_name
     stats["server_context"] = hs.config.server_context
     stats["timestamp"] = now
@@ -79,34 +81,38 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
     stats["python_version"] = "{}.{}.{}".format(
         version.major, version.minor, version.micro
     )
-    stats["total_users"] = await hs.get_datastore().count_all_users()
+    stats["total_users"] = await store.count_all_users()
 
-    total_nonbridged_users = await hs.get_datastore().count_nonbridged_users()
+    total_nonbridged_users = await store.count_nonbridged_users()
     stats["total_nonbridged_users"] = total_nonbridged_users
 
-    daily_user_type_results = await hs.get_datastore().count_daily_user_type()
+    daily_user_type_results = await store.count_daily_user_type()
     for name, count in daily_user_type_results.items():
         stats["daily_user_type_" + name] = count
 
-    room_count = await hs.get_datastore().get_room_count()
+    room_count = await store.get_room_count()
     stats["total_room_count"] = room_count
 
-    stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
-    stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
-    daily_active_e2ee_rooms = await hs.get_datastore().count_daily_active_e2ee_rooms()
+    stats["daily_active_users"] = await store.count_daily_users()
+    stats["monthly_active_users"] = await store.count_monthly_users()
+    daily_active_e2ee_rooms = await store.count_daily_active_e2ee_rooms()
     stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms
-    stats["daily_e2ee_messages"] = await hs.get_datastore().count_daily_e2ee_messages()
-    daily_sent_e2ee_messages = await hs.get_datastore().count_daily_sent_e2ee_messages()
+    stats["daily_e2ee_messages"] = await store.count_daily_e2ee_messages()
+    daily_sent_e2ee_messages = await store.count_daily_sent_e2ee_messages()
     stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages
-    stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
-    stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
-    daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
+    stats["daily_active_rooms"] = await store.count_daily_active_rooms()
+    stats["daily_messages"] = await store.count_daily_messages()
+    daily_sent_messages = await store.count_daily_sent_messages()
     stats["daily_sent_messages"] = daily_sent_messages
 
-    r30_results = await hs.get_datastore().count_r30_users()
+    r30_results = await store.count_r30_users()
     for name, count in r30_results.items():
         stats["r30_users_" + name] = count
 
+    r30v2_results = await store.count_r30v2_users()
+    for name, count in r30v2_results.items():
+        stats["r30v2_users_" + name] = count
+
     stats["cache_factor"] = hs.config.caches.global_factor
     stats["event_cache_size"] = hs.config.caches.event_cache_size
 
@@ -115,8 +121,8 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
     #
 
     # This only reports info about the *main* database.
-    stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__
-    stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version
+    stats["database_engine"] = store.db_pool.engine.module.__name__
+    stats["database_server_version"] = store.db_pool.engine.server_version
 
     #
     # Logging configuration
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 61152b2c46..935f24263c 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -88,9 +88,9 @@ class ApplicationServiceApi(SimpleHttpClient):
         super().__init__(hs)
         self.clock = hs.get_clock()
 
-        self.protocol_meta_cache = ResponseCache(
+        self.protocol_meta_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
             hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
-        )  # type: ResponseCache[Tuple[str, str]]
+        )
 
     async def query_user(self, service, user_id):
         if service.url is None:
diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py
index 957de7f3a6..6be4eafe55 100644
--- a/synapse/config/account_validity.py
+++ b/synapse/config/account_validity.py
@@ -18,6 +18,21 @@ class AccountValidityConfig(Config):
     section = "account_validity"
 
     def read_config(self, config, **kwargs):
+        """Parses the old account validity config. The config format looks like this:
+
+        account_validity:
+            enabled: true
+            period: 6w
+            renew_at: 1w
+            renew_email_subject: "Renew your %(app)s account"
+            template_dir: "res/templates"
+            account_renewed_html_path: "account_renewed.html"
+            invalid_token_html_path: "invalid_token.html"
+
+        We expect admins to use modules for this feature (which is why it doesn't appear
+        in the sample config file), but we want to keep support for it around for a bit
+        for backwards compatibility.
+        """
         account_validity_config = config.get("account_validity") or {}
         self.account_validity_enabled = account_validity_config.get("enabled", False)
         self.account_validity_renew_by_email_enabled = (
@@ -75,90 +90,3 @@ class AccountValidityConfig(Config):
             ],
             account_validity_template_dir,
         )
-
-    def generate_config_section(self, **kwargs):
-        return """\
-        ## Account Validity ##
-
-        # Optional account validity configuration. This allows for accounts to be denied
-        # any request after a given period.
-        #
-        # Once this feature is enabled, Synapse will look for registered users without an
-        # expiration date at startup and will add one to every account it found using the
-        # current settings at that time.
-        # This means that, if a validity period is set, and Synapse is restarted (it will
-        # then derive an expiration date from the current validity period), and some time
-        # after that the validity period changes and Synapse is restarted, the users'
-        # expiration dates won't be updated unless their account is manually renewed. This
-        # date will be randomly selected within a range [now + period - d ; now + period],
-        # where d is equal to 10% of the validity period.
-        #
-        account_validity:
-          # The account validity feature is disabled by default. Uncomment the
-          # following line to enable it.
-          #
-          #enabled: true
-
-          # The period after which an account is valid after its registration. When
-          # renewing the account, its validity period will be extended by this amount
-          # of time. This parameter is required when using the account validity
-          # feature.
-          #
-          #period: 6w
-
-          # The amount of time before an account's expiry date at which Synapse will
-          # send an email to the account's email address with a renewal link. By
-          # default, no such emails are sent.
-          #
-          # If you enable this setting, you will also need to fill out the 'email' and
-          # 'public_baseurl' configuration sections.
-          #
-          #renew_at: 1w
-
-          # The subject of the email sent out with the renewal link. '%(app)s' can be
-          # used as a placeholder for the 'app_name' parameter from the 'email'
-          # section.
-          #
-          # Note that the placeholder must be written '%(app)s', including the
-          # trailing 's'.
-          #
-          # If this is not set, a default value is used.
-          #
-          #renew_email_subject: "Renew your %(app)s account"
-
-          # Directory in which Synapse will try to find templates for the HTML files to
-          # serve to the user when trying to renew an account. If not set, default
-          # templates from within the Synapse package will be used.
-          #
-          # The currently available templates are:
-          #
-          # * account_renewed.html: Displayed to the user after they have successfully
-          #       renewed their account.
-          #
-          # * account_previously_renewed.html: Displayed to the user if they attempt to
-          #       renew their account with a token that is valid, but that has already
-          #       been used. In this case the account is not renewed again.
-          #
-          # * invalid_token.html: Displayed to the user when they try to renew an account
-          #       with an unknown or invalid renewal token.
-          #
-          # See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for
-          # default template contents.
-          #
-          # The file name of some of these templates can be configured below for legacy
-          # reasons.
-          #
-          #template_dir: "res/templates"
-
-          # A custom file name for the 'account_renewed.html' template.
-          #
-          # If not set, the file is assumed to be named "account_renewed.html".
-          #
-          #account_renewed_html_path: "account_renewed.html"
-
-          # A custom file name for the 'invalid_token.html' template.
-          #
-          # If not set, the file is assumed to be named "invalid_token.html".
-          #
-          #invalid_token_html_path: "invalid_token.html"
-        """
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 746fc3cc02..1ebea88db2 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -57,14 +57,14 @@ def load_appservices(hostname, config_files):
         return []
 
     # Dicts of value -> filename
-    seen_as_tokens = {}  # type: Dict[str, str]
-    seen_ids = {}  # type: Dict[str, str]
+    seen_as_tokens: Dict[str, str] = {}
+    seen_ids: Dict[str, str] = {}
 
     appservices = []
 
     for config_file in config_files:
         try:
-            with open(config_file, "r") as f:
+            with open(config_file) as f:
                 appservice = _load_appservice(hostname, yaml.safe_load(f), config_file)
                 if appservice.id in seen_ids:
                     raise ConfigError(
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 7789b40323..8d5f38b5d9 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -25,7 +25,7 @@ from ._base import Config, ConfigError
 _CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR"
 
 # Map from canonicalised cache name to cache.
-_CACHES = {}  # type: Dict[str, Callable[[float], None]]
+_CACHES: Dict[str, Callable[[float], None]] = {}
 
 # a lock on the contents of _CACHES
 _CACHES_LOCK = threading.Lock()
@@ -157,7 +157,7 @@ class CacheConfig(Config):
         self.event_cache_size = self.parse_size(
             config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
         )
-        self.cache_factors = {}  # type: Dict[str, float]
+        self.cache_factors: Dict[str, float] = {}
 
         cache_config = config.get("caches") or {}
         self.global_factor = cache_config.get(
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 5564d7d097..bcecbfec03 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -134,9 +134,9 @@ class EmailConfig(Config):
 
                 # trusted_third_party_id_servers does not contain a scheme whereas
                 # account_threepid_delegate_email is expected to. Presume https
-                self.account_threepid_delegate_email = (
+                self.account_threepid_delegate_email: Optional[str] = (
                     "https://" + first_trusted_identity_server
-                )  # type: Optional[str]
+                )
                 self.using_identity_server_from_trusted_list = True
             else:
                 raise ConfigError(
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 7fb1f7021f..e25ccba9ac 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -25,10 +25,10 @@ class ExperimentalConfig(Config):
         experimental = config.get("experimental_features") or {}
 
         # MSC2858 (multiple SSO identity providers)
-        self.msc2858_enabled = experimental.get("msc2858_enabled", False)  # type: bool
+        self.msc2858_enabled: bool = experimental.get("msc2858_enabled", False)
 
         # MSC3026 (busy presence state)
-        self.msc3026_enabled = experimental.get("msc3026_enabled", False)  # type: bool
+        self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)
 
         # MSC2716 (backfill existing history)
-        self.msc2716_enabled = experimental.get("msc2716_enabled", False)  # type: bool
+        self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False)
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index cdd7a1ef05..7d64993e22 100644
--- a/synapse/config/federation.py
+++ b/synapse/config/federation.py
@@ -22,7 +22,7 @@ class FederationConfig(Config):
 
     def read_config(self, config, **kwargs):
         # FIXME: federation_domain_whitelist needs sytests
-        self.federation_domain_whitelist = None  # type: Optional[dict]
+        self.federation_domain_whitelist: Optional[dict] = None
         federation_domain_whitelist = config.get("federation_domain_whitelist", None)
 
         if federation_domain_whitelist is not None:
diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index 942e2672a9..ba89d11cf0 100644
--- a/synapse/config/oidc.py
+++ b/synapse/config/oidc.py
@@ -460,7 +460,7 @@ def _parse_oidc_config_dict(
             ) from e
 
     client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key")
-    client_secret_jwt_key = None  # type: Optional[OidcProviderClientSecretJwtKey]
+    client_secret_jwt_key: Optional[OidcProviderClientSecretJwtKey] = None
     if client_secret_jwt_key_config is not None:
         keyfile = client_secret_jwt_key_config.get("key_file")
         if keyfile:
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index fd90b79772..0f5b2b3977 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -25,7 +25,7 @@ class PasswordAuthProviderConfig(Config):
     section = "authproviders"
 
     def read_config(self, config, **kwargs):
-        self.password_providers = []  # type: List[Any]
+        self.password_providers: List[Any] = []
         providers = []
 
         # We want to be backwards compatible with the old `ldap_config`
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index a7a82742ac..0dfb3a227a 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -62,7 +62,7 @@ def parse_thumbnail_requirements(thumbnail_sizes):
         Dictionary mapping from media type string to list of
         ThumbnailRequirement tuples.
     """
-    requirements = {}  # type: Dict[str, List]
+    requirements: Dict[str, List] = {}
     for size in thumbnail_sizes:
         width = size["width"]
         height = size["height"]
@@ -141,7 +141,7 @@ class ContentRepositoryConfig(Config):
         #
         # We don't create the storage providers here as not all workers need
         # them to be started.
-        self.media_storage_providers = []  # type: List[tuple]
+        self.media_storage_providers: List[tuple] = []
 
         for i, provider_config in enumerate(storage_providers):
             # We special case the module "file_system" so as not to need to
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 6bff715230..b9e0c0b300 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -505,7 +505,7 @@ class ServerConfig(Config):
                 " greater than 'allowed_lifetime_max'"
             )
 
-        self.retention_purge_jobs = []  # type: List[Dict[str, Optional[int]]]
+        self.retention_purge_jobs: List[Dict[str, Optional[int]]] = []
         for purge_job_config in retention_config.get("purge_jobs", []):
             interval_config = purge_job_config.get("interval")
 
@@ -688,23 +688,21 @@ class ServerConfig(Config):
         # not included in the sample configuration file on purpose as it's a temporary
         # hack, so that some users can trial the new defaults without impacting every
         # user on the homeserver.
-        users_new_default_push_rules = (
+        users_new_default_push_rules: list = (
             config.get("users_new_default_push_rules") or []
-        )  # type: list
+        )
         if not isinstance(users_new_default_push_rules, list):
             raise ConfigError("'users_new_default_push_rules' must be a list")
 
         # Turn the list into a set to improve lookup speed.
-        self.users_new_default_push_rules = set(
-            users_new_default_push_rules
-        )  # type: set
+        self.users_new_default_push_rules: set = set(users_new_default_push_rules)
 
         # Whitelist of domain names that given next_link parameters must have
-        next_link_domain_whitelist = config.get(
+        next_link_domain_whitelist: Optional[List[str]] = config.get(
             "next_link_domain_whitelist"
-        )  # type: Optional[List[str]]
+        )
 
-        self.next_link_domain_whitelist = None  # type: Optional[Set[str]]
+        self.next_link_domain_whitelist: Optional[Set[str]] = None
         if next_link_domain_whitelist is not None:
             if not isinstance(next_link_domain_whitelist, list):
                 raise ConfigError("'next_link_domain_whitelist' must be a list")
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index cb7716c837..a233a9ce03 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -34,7 +34,7 @@ class SpamCheckerConfig(Config):
     section = "spamchecker"
 
     def read_config(self, config, **kwargs):
-        self.spam_checkers = []  # type: List[Tuple[Any, Dict]]
+        self.spam_checkers: List[Tuple[Any, Dict]] = []
 
         spam_checkers = config.get("spam_checker") or []
         if isinstance(spam_checkers, dict):
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index e4346e02aa..d0f04cf8e6 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -39,7 +39,7 @@ class SSOConfig(Config):
     section = "sso"
 
     def read_config(self, config, **kwargs):
-        sso_config = config.get("sso") or {}  # type: Dict[str, Any]
+        sso_config: Dict[str, Any] = config.get("sso") or {}
 
         # The sso-specific template_dir
         self.sso_template_dir = sso_config.get("template_dir")
diff --git a/synapse/config/stats.py b/synapse/config/stats.py
index 78f61fe9da..6f253e00c0 100644
--- a/synapse/config/stats.py
+++ b/synapse/config/stats.py
@@ -38,13 +38,9 @@ class StatsConfig(Config):
 
     def read_config(self, config, **kwargs):
         self.stats_enabled = True
-        self.stats_bucket_size = 86400 * 1000
         stats_config = config.get("stats", None)
         if stats_config:
             self.stats_enabled = stats_config.get("enabled", self.stats_enabled)
-            self.stats_bucket_size = self.parse_duration(
-                stats_config.get("bucket_size", "1d")
-            )
         if not self.stats_enabled:
             logger.warning(ROOM_STATS_DISABLED_WARN)
 
@@ -59,9 +55,4 @@ class StatsConfig(Config):
           # correctly.
           #
           #enabled: false
-
-          # The size of each timeslice in the room_stats_historical and
-          # user_stats_historical tables, as a time period. Defaults to "1d".
-          #
-          #bucket_size: 1h
         """
diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py
index f502ff539e..a3fae02420 100644
--- a/synapse/config/third_party_event_rules.py
+++ b/synapse/config/third_party_event_rules.py
@@ -28,18 +28,3 @@ class ThirdPartyRulesConfig(Config):
             self.third_party_event_rules = load_module(
                 provider, ("third_party_event_rules",)
             )
-
-    def generate_config_section(self, **kwargs):
-        return """\
-        # Server admins can define a Python module that implements extra rules for
-        # allowing or denying incoming events. In order to work, this module needs to
-        # override the methods defined in synapse/events/third_party_rules.py.
-        #
-        # This feature is designed to be used in closed federations only, where each
-        # participating server enforces the same rules.
-        #
-        #third_party_event_rules:
-        #  module: "my_custom_project.SuperRulesSet"
-        #  config:
-        #    example_option: 'things'
-        """
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 9a16a8fbae..5679f05e42 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -66,10 +66,8 @@ class TlsConfig(Config):
         if self.federation_client_minimum_tls_version == "1.3":
             if getattr(SSL, "OP_NO_TLSv1_3", None) is None:
                 raise ConfigError(
-                    (
-                        "federation_client_minimum_tls_version cannot be 1.3, "
-                        "your OpenSSL does not support it"
-                    )
+                    "federation_client_minimum_tls_version cannot be 1.3, "
+                    "your OpenSSL does not support it"
                 )
 
         # Whitelist of domains to not verify certificates for
@@ -80,7 +78,7 @@ class TlsConfig(Config):
             fed_whitelist_entries = []
 
         # Support globs (*) in whitelist values
-        self.federation_certificate_verification_whitelist = []  # type: List[Pattern]
+        self.federation_certificate_verification_whitelist: List[Pattern] = []
         for entry in fed_whitelist_entries:
             try:
                 entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii"))
@@ -132,8 +130,8 @@ class TlsConfig(Config):
             "use_insecure_ssl_client_just_for_testing_do_not_use"
         )
 
-        self.tls_certificate = None  # type: Optional[crypto.X509]
-        self.tls_private_key = None  # type: Optional[crypto.PKey]
+        self.tls_certificate: Optional[crypto.X509] = None
+        self.tls_private_key: Optional[crypto.PKey] = None
 
     def is_disk_cert_valid(self, allow_self_signed=True):
         """
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index e5a4685ed4..9e9b1c1c86 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -170,11 +170,13 @@ class Keyring:
             )
         self._key_fetchers = key_fetchers
 
-        self._server_queue = BatchingQueue(
+        self._server_queue: BatchingQueue[
+            _FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]
+        ] = BatchingQueue(
             "keyring_server",
             clock=hs.get_clock(),
             process_batch_callback=self._inner_fetch_key_requests,
-        )  # type: BatchingQueue[_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]]
+        )
 
     async def verify_json_for_server(
         self,
@@ -330,7 +332,7 @@ class Keyring:
         # First we need to deduplicate requests for the same key. We do this by
         # taking the *maximum* requested `minimum_valid_until_ts` for each pair
         # of server name/key ID.
-        server_to_key_to_ts = {}  # type: Dict[str, Dict[str, int]]
+        server_to_key_to_ts: Dict[str, Dict[str, int]] = {}
         for request in requests:
             by_server = server_to_key_to_ts.setdefault(request.server_name, {})
             for key_id in request.key_ids:
@@ -355,7 +357,7 @@ class Keyring:
 
         # We now convert the returned list of results into a map from server
         # name to key ID to FetchKeyResult, to return.
-        to_return = {}  # type: Dict[str, Dict[str, FetchKeyResult]]
+        to_return: Dict[str, Dict[str, FetchKeyResult]] = {}
         for (request, results) in zip(deduped_requests, results_per_request):
             to_return_by_server = to_return.setdefault(request.server_name, {})
             for key_id, key_result in results.items():
@@ -455,7 +457,7 @@ class StoreKeyFetcher(KeyFetcher):
         )
 
         res = await self.store.get_server_verify_keys(key_ids_to_fetch)
-        keys = {}  # type: Dict[str, Dict[str, FetchKeyResult]]
+        keys: Dict[str, Dict[str, FetchKeyResult]] = {}
         for (server_name, key_id), key in res.items():
             keys.setdefault(server_name, {})[key_id] = key
         return keys
@@ -603,7 +605,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
             ).addErrback(unwrapFirstError)
         )
 
-        union_of_keys = {}  # type: Dict[str, Dict[str, FetchKeyResult]]
+        union_of_keys: Dict[str, Dict[str, FetchKeyResult]] = {}
         for result in results:
             for server_name, keys in result.items():
                 union_of_keys.setdefault(server_name, {}).update(keys)
@@ -656,8 +658,8 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         except HttpResponseException as e:
             raise KeyLookupError("Remote server returned an error: %s" % (e,))
 
-        keys = {}  # type: Dict[str, Dict[str, FetchKeyResult]]
-        added_keys = []  # type: List[Tuple[str, str, FetchKeyResult]]
+        keys: Dict[str, Dict[str, FetchKeyResult]] = {}
+        added_keys: List[Tuple[str, str, FetchKeyResult]] = []
 
         time_now_ms = self.clock.time_msec()
 
@@ -805,7 +807,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
         Raises:
             KeyLookupError if there was a problem making the lookup
         """
-        keys = {}  # type: Dict[str, FetchKeyResult]
+        keys: Dict[str, FetchKeyResult] = {}
 
         for requested_key_id in key_ids:
             # we may have found this key as a side-effect of asking for another.
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 89bcf81515..137dff2513 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -48,6 +48,9 @@ def check(
         room_version_obj: the version of the room
         event: the event being checked.
         auth_events: the existing room state.
+        do_sig_check: True if it should be verified that the sending server
+            signed the event.
+        do_size_check: True if the size of the event fields should be verified.
 
     Raises:
         AuthError if the checks fail
@@ -528,7 +531,7 @@ def _check_power_levels(
     user_level = get_user_power_level(event.user_id, auth_events)
 
     # Check other levels:
-    levels_to_check = [
+    levels_to_check: List[Tuple[str, Optional[str]]] = [
         ("users_default", None),
         ("events_default", None),
         ("state_default", None),
@@ -536,7 +539,7 @@ def _check_power_levels(
         ("redact", None),
         ("kick", None),
         ("invite", None),
-    ]  # type: List[Tuple[str, Optional[str]]]
+    ]
 
     old_list = current_state.content.get("users", {})
     for user in set(list(old_list) + list(user_list)):
@@ -566,12 +569,12 @@ def _check_power_levels(
             new_loc = new_loc.get(dir, {})
 
         if level_to_check in old_loc:
-            old_level = int(old_loc[level_to_check])  # type: Optional[int]
+            old_level: Optional[int] = int(old_loc[level_to_check])
         else:
             old_level = None
 
         if level_to_check in new_loc:
-            new_level = int(new_loc[level_to_check])  # type: Optional[int]
+            new_level: Optional[int] = int(new_loc[level_to_check])
         else:
             new_level = None
 
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 6286ad999a..0298af4c02 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -105,28 +105,28 @@ class _EventInternalMetadata:
         self._dict = dict(internal_metadata_dict)
 
         # the stream ordering of this event. None, until it has been persisted.
-        self.stream_ordering = None  # type: Optional[int]
+        self.stream_ordering: Optional[int] = None
 
         # whether this event is an outlier (ie, whether we have the state at that point
         # in the DAG)
         self.outlier = False
 
-    out_of_band_membership = DictProperty("out_of_band_membership")  # type: bool
-    send_on_behalf_of = DictProperty("send_on_behalf_of")  # type: str
-    recheck_redaction = DictProperty("recheck_redaction")  # type: bool
-    soft_failed = DictProperty("soft_failed")  # type: bool
-    proactively_send = DictProperty("proactively_send")  # type: bool
-    redacted = DictProperty("redacted")  # type: bool
-    txn_id = DictProperty("txn_id")  # type: str
-    token_id = DictProperty("token_id")  # type: int
-    historical = DictProperty("historical")  # type: bool
+    out_of_band_membership: bool = DictProperty("out_of_band_membership")
+    send_on_behalf_of: str = DictProperty("send_on_behalf_of")
+    recheck_redaction: bool = DictProperty("recheck_redaction")
+    soft_failed: bool = DictProperty("soft_failed")
+    proactively_send: bool = DictProperty("proactively_send")
+    redacted: bool = DictProperty("redacted")
+    txn_id: str = DictProperty("txn_id")
+    token_id: int = DictProperty("token_id")
+    historical: bool = DictProperty("historical")
 
     # XXX: These are set by StreamWorkerStore._set_before_and_after.
     # I'm pretty sure that these are never persisted to the database, so shouldn't
     # be here
-    before = DictProperty("before")  # type: RoomStreamToken
-    after = DictProperty("after")  # type: RoomStreamToken
-    order = DictProperty("order")  # type: Tuple[int, int]
+    before: RoomStreamToken = DictProperty("before")
+    after: RoomStreamToken = DictProperty("after")
+    order: Tuple[int, int] = DictProperty("order")
 
     def get_dict(self) -> JsonDict:
         return dict(self._dict)
@@ -291,6 +291,20 @@ class EventBase(metaclass=abc.ABCMeta):
 
         return pdu_json
 
+    def get_templated_pdu_json(self) -> JsonDict:
+        """
+        Return a JSON object suitable for a templated event, as used in the
+        make_{join,leave,knock} workflow.
+        """
+        # By using _dict directly we don't pull in signatures/unsigned.
+        template_json = dict(self._dict)
+        # The hashes (similar to the signature) need to be recalculated by the
+        # joining/leaving/knocking server after (potentially) modifying the
+        # event.
+        template_json.pop("hashes")
+
+        return template_json
+
     def __set__(self, instance, value):
         raise AttributeError("Unrecognized attribute %s" % (instance,))
 
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 26e3950859..87e2bb123b 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -132,12 +132,12 @@ class EventBuilder:
         format_version = self.room_version.event_format
         if format_version == EventFormatVersions.V1:
             # The types of auth/prev events changes between event versions.
-            auth_events = await self._store.add_event_hashes(
-                auth_event_ids
-            )  # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
-            prev_events = await self._store.add_event_hashes(
-                prev_event_ids
-            )  # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
+            auth_events: Union[
+                List[str], List[Tuple[str, Dict[str, str]]]
+            ] = await self._store.add_event_hashes(auth_event_ids)
+            prev_events: Union[
+                List[str], List[Tuple[str, Dict[str, str]]]
+            ] = await self._store.add_event_hashes(prev_event_ids)
         else:
             auth_events = auth_event_ids
             prev_events = prev_event_ids
@@ -156,7 +156,7 @@ class EventBuilder:
         # the db)
         depth = min(depth, MAX_DEPTH)
 
-        event_dict = {
+        event_dict: Dict[str, Any] = {
             "auth_events": auth_events,
             "prev_events": prev_events,
             "type": self.type,
@@ -166,7 +166,7 @@ class EventBuilder:
             "unsigned": self.unsigned,
             "depth": depth,
             "prev_state": [],
-        }  # type: Dict[str, Any]
+        }
 
         if self.is_state():
             event_dict["state_key"] = self._state_key
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index efec16c226..57f1d53fa8 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -76,7 +76,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
     """Wrapper that loads spam checkers configured using the old configuration, and
     registers the spam checker hooks they implement.
     """
-    spam_checkers = []  # type: List[Any]
+    spam_checkers: List[Any] = []
     api = hs.get_module_api()
     for module, config in hs.config.spam_checkers:
         # Older spam checkers don't accept the `api` argument, so we
@@ -239,7 +239,7 @@ class SpamChecker:
             will be used as the error message returned to the user.
         """
         for callback in self._check_event_for_spam_callbacks:
-            res = await callback(event)  # type: Union[bool, str]
+            res: Union[bool, str] = await callback(event)
             if res:
                 return res
 
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index f7944fd834..7a6eb3e516 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -11,16 +11,124 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
 
-from typing import TYPE_CHECKING, Union
-
+from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.types import Requester, StateMap
+from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
+logger = logging.getLogger(__name__)
+
+
+CHECK_EVENT_ALLOWED_CALLBACK = Callable[
+    [EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]]
+]
+ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable]
+CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[
+    [str, str, StateMap[EventBase]], Awaitable[bool]
+]
+CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
+    [str, StateMap[EventBase], str], Awaitable[bool]
+]
+
+
+def load_legacy_third_party_event_rules(hs: "HomeServer"):
+    """Wrapper that loads a third party event rules module configured using the old
+    configuration, and registers the hooks they implement.
+    """
+    if hs.config.third_party_event_rules is None:
+        return
+
+    module, config = hs.config.third_party_event_rules
+
+    api = hs.get_module_api()
+    third_party_rules = module(config=config, module_api=api)
+
+    # The known hooks. If a module implements a method which name appears in this set,
+    # we'll want to register it.
+    third_party_event_rules_methods = {
+        "check_event_allowed",
+        "on_create_room",
+        "check_threepid_can_be_invited",
+        "check_visibility_can_be_modified",
+    }
+
+    def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
+        # f might be None if the callback isn't implemented by the module. In this
+        # case we don't want to register a callback at all so we return None.
+        if f is None:
+            return None
+
+        # We return a separate wrapper for these methods because, in order to wrap them
+        # correctly, we need to await its result. Therefore it doesn't make a lot of
+        # sense to make it go through the run() wrapper.
+        if f.__name__ == "check_event_allowed":
+
+            # We need to wrap check_event_allowed because its old form would return either
+            # a boolean or a dict, but now we want to return the dict separately from the
+            # boolean.
+            async def wrap_check_event_allowed(
+                event: EventBase,
+                state_events: StateMap[EventBase],
+            ) -> Tuple[bool, Optional[dict]]:
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                res = await f(event, state_events)
+                if isinstance(res, dict):
+                    return True, res
+                else:
+                    return res, None
+
+            return wrap_check_event_allowed
+
+        if f.__name__ == "on_create_room":
+
+            # We need to wrap on_create_room because its old form would return a boolean
+            # if the room creation is denied, but now we just want it to raise an
+            # exception.
+            async def wrap_on_create_room(
+                requester: Requester, config: dict, is_requester_admin: bool
+            ) -> None:
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                res = await f(requester, config, is_requester_admin)
+                if res is False:
+                    raise SynapseError(
+                        403,
+                        "Room creation forbidden with these parameters",
+                    )
+
+            return wrap_on_create_room
+
+        def run(*args, **kwargs):
+            # mypy doesn't do well across function boundaries so we need to tell it
+            # f is definitely not None.
+            assert f is not None
+
+            return maybe_awaitable(f(*args, **kwargs))
+
+        return run
+
+    # Register the hooks through the module API.
+    hooks = {
+        hook: async_wrapper(getattr(third_party_rules, hook, None))
+        for hook in third_party_event_rules_methods
+    }
+
+    api.register_third_party_rules_callbacks(**hooks)
+
 
 class ThirdPartyEventRules:
     """Allows server admins to provide a Python module implementing an extra
@@ -35,36 +143,65 @@ class ThirdPartyEventRules:
 
         self.store = hs.get_datastore()
 
-        module = None
-        config = None
-        if hs.config.third_party_event_rules:
-            module, config = hs.config.third_party_event_rules
+        self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
+        self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
+        self._check_threepid_can_be_invited_callbacks: List[
+            CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
+        ] = []
+        self._check_visibility_can_be_modified_callbacks: List[
+            CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
+        ] = []
+
+    def register_third_party_rules_callbacks(
+        self,
+        check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None,
+        on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None,
+        check_threepid_can_be_invited: Optional[
+            CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
+        ] = None,
+        check_visibility_can_be_modified: Optional[
+            CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
+        ] = None,
+    ):
+        """Register callbacks from modules for each hook."""
+        if check_event_allowed is not None:
+            self._check_event_allowed_callbacks.append(check_event_allowed)
+
+        if on_create_room is not None:
+            self._on_create_room_callbacks.append(on_create_room)
+
+        if check_threepid_can_be_invited is not None:
+            self._check_threepid_can_be_invited_callbacks.append(
+                check_threepid_can_be_invited,
+            )
 
-        if module is not None:
-            self.third_party_rules = module(
-                config=config,
-                module_api=hs.get_module_api(),
+        if check_visibility_can_be_modified is not None:
+            self._check_visibility_can_be_modified_callbacks.append(
+                check_visibility_can_be_modified,
             )
 
     async def check_event_allowed(
         self, event: EventBase, context: EventContext
-    ) -> Union[bool, dict]:
+    ) -> Tuple[bool, Optional[dict]]:
         """Check if a provided event should be allowed in the given context.
 
         The module can return:
             * True: the event is allowed.
             * False: the event is not allowed, and should be rejected with M_FORBIDDEN.
-            * a dict: replacement event data.
+
+        If the event is allowed, the module can also return a dictionary to use as a
+        replacement for the event.
 
         Args:
             event: The event to be checked.
             context: The context of the event.
 
         Returns:
-            The result from the ThirdPartyRules module, as above
+            The result from the ThirdPartyRules module, as above.
         """
-        if self.third_party_rules is None:
-            return True
+        # Bail out early without hitting the store if we don't have any callbacks to run.
+        if len(self._check_event_allowed_callbacks) == 0:
+            return True, None
 
         prev_state_ids = await context.get_prev_state_ids()
 
@@ -77,29 +214,46 @@ class ThirdPartyEventRules:
         # the hashes and signatures.
         event.freeze()
 
-        return await self.third_party_rules.check_event_allowed(event, state_events)
+        for callback in self._check_event_allowed_callbacks:
+            try:
+                res, replacement_data = await callback(event, state_events)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue
+
+            # Return if the event shouldn't be allowed or if the module came up with a
+            # replacement dict for the event.
+            if res is False:
+                return res, None
+            elif isinstance(replacement_data, dict):
+                return True, replacement_data
+
+        return True, None
 
     async def on_create_room(
         self, requester: Requester, config: dict, is_requester_admin: bool
-    ) -> bool:
-        """Intercept requests to create room to allow, deny or update the
-        request config.
+    ) -> None:
+        """Intercept requests to create room to maybe deny it (via an exception) or
+        update the request config.
 
         Args:
             requester
             config: The creation config from the client.
             is_requester_admin: If the requester is an admin
-
-        Returns:
-            Whether room creation is allowed or denied.
         """
-
-        if self.third_party_rules is None:
-            return True
-
-        return await self.third_party_rules.on_create_room(
-            requester, config, is_requester_admin
-        )
+        for callback in self._on_create_room_callbacks:
+            try:
+                await callback(requester, config, is_requester_admin)
+            except Exception as e:
+                # Don't silence the errors raised by this callback since we expect it to
+                # raise an exception to deny the creation of the room; instead make sure
+                # it's a SynapseError we can send to clients.
+                if not isinstance(e, SynapseError):
+                    e = SynapseError(
+                        403, "Room creation forbidden with these parameters"
+                    )
+
+                raise e
 
     async def check_threepid_can_be_invited(
         self, medium: str, address: str, room_id: str
@@ -114,15 +268,20 @@ class ThirdPartyEventRules:
         Returns:
             True if the 3PID can be invited, False if not.
         """
-
-        if self.third_party_rules is None:
+        # Bail out early without hitting the store if we don't have any callbacks to run.
+        if len(self._check_threepid_can_be_invited_callbacks) == 0:
             return True
 
         state_events = await self._get_state_map_for_room(room_id)
 
-        return await self.third_party_rules.check_threepid_can_be_invited(
-            medium, address, state_events
-        )
+        for callback in self._check_threepid_can_be_invited_callbacks:
+            try:
+                if await callback(medium, address, state_events) is False:
+                    return False
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+
+        return True
 
     async def check_visibility_can_be_modified(
         self, room_id: str, new_visibility: str
@@ -137,18 +296,20 @@ class ThirdPartyEventRules:
         Returns:
             True if the room's visibility can be modified, False if not.
         """
-        if self.third_party_rules is None:
-            return True
-
-        check_func = getattr(
-            self.third_party_rules, "check_visibility_can_be_modified", None
-        )
-        if not check_func or not callable(check_func):
+        # Bail out early without hitting the store if we don't have any callback
+        if len(self._check_visibility_can_be_modified_callbacks) == 0:
             return True
 
         state_events = await self._get_state_map_for_room(room_id)
 
-        return await check_func(room_id, state_events, new_visibility)
+        for callback in self._check_visibility_can_be_modified_callbacks:
+            try:
+                if await callback(room_id, state_events, new_visibility) is False:
+                    return False
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+
+        return True
 
     async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
         """Given a room ID, return the state events of that room.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index ed09c6af1f..c767d30627 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -86,7 +86,7 @@ class FederationClient(FederationBase):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.pdu_destination_tried = {}  # type: Dict[str, Dict[str, int]]
+        self.pdu_destination_tried: Dict[str, Dict[str, int]] = {}
         self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
         self.state = hs.get_state_handler()
         self.transport_layer = hs.get_federation_transport_client()
@@ -94,13 +94,13 @@ class FederationClient(FederationBase):
         self.hostname = hs.hostname
         self.signing_key = hs.signing_key
 
-        self._get_pdu_cache = ExpiringCache(
+        self._get_pdu_cache: ExpiringCache[str, EventBase] = ExpiringCache(
             cache_name="get_pdu_cache",
             clock=self._clock,
             max_len=1000,
             expiry_ms=120 * 1000,
             reset_expiry_on_get=False,
-        )  # type: ExpiringCache[str, EventBase]
+        )
 
     def _clear_tried_cache(self):
         """Clear pdu_destination_tried cache"""
@@ -293,10 +293,10 @@ class FederationClient(FederationBase):
                     transaction_data,
                 )
 
-                pdu_list = [
+                pdu_list: List[EventBase] = [
                     event_from_pdu_json(p, room_version, outlier=outlier)
                     for p in transaction_data["pdus"]
-                ]  # type: List[EventBase]
+                ]
 
                 if pdu_list and pdu_list[0]:
                     pdu = pdu_list[0]
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index ac0f2ccfb3..29619aeeb8 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -122,12 +122,12 @@ class FederationServer(FederationBase):
 
         # origins that we are currently processing a transaction from.
         # a dict from origin to txn id.
-        self._active_transactions = {}  # type: Dict[str, str]
+        self._active_transactions: Dict[str, str] = {}
 
         # We cache results for transaction with the same ID
-        self._transaction_resp_cache = ResponseCache(
+        self._transaction_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
             hs.get_clock(), "fed_txn_handler", timeout_ms=30000
-        )  # type: ResponseCache[Tuple[str, str]]
+        )
 
         self.transaction_actions = TransactionActions(self.store)
 
@@ -135,12 +135,12 @@ class FederationServer(FederationBase):
 
         # We cache responses to state queries, as they take a while and often
         # come in waves.
-        self._state_resp_cache = ResponseCache(
-            hs.get_clock(), "state_resp", timeout_ms=30000
-        )  # type: ResponseCache[Tuple[str, Optional[str]]]
-        self._state_ids_resp_cache = ResponseCache(
+        self._state_resp_cache: ResponseCache[
+            Tuple[str, Optional[str]]
+        ] = ResponseCache(hs.get_clock(), "state_resp", timeout_ms=30000)
+        self._state_ids_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
             hs.get_clock(), "state_ids_resp", timeout_ms=30000
-        )  # type: ResponseCache[Tuple[str, str]]
+        )
 
         self._federation_metrics_domains = (
             hs.config.federation.federation_metrics_domains
@@ -337,7 +337,7 @@ class FederationServer(FederationBase):
 
         origin_host, _ = parse_server_name(origin)
 
-        pdus_by_room = {}  # type: Dict[str, List[EventBase]]
+        pdus_by_room: Dict[str, List[EventBase]] = {}
 
         newest_pdu_ts = 0
 
@@ -516,9 +516,9 @@ class FederationServer(FederationBase):
         self, room_id: str, event_id: Optional[str]
     ) -> Dict[str, list]:
         if event_id:
-            pdus = await self.handler.get_state_for_pdu(
+            pdus: Iterable[EventBase] = await self.handler.get_state_for_pdu(
                 room_id, event_id
-            )  # type: Iterable[EventBase]
+            )
         else:
             pdus = (await self.state.get_current_state(room_id)).values()
 
@@ -562,8 +562,7 @@ class FederationServer(FederationBase):
             raise IncompatibleRoomVersionError(room_version=room_version)
 
         pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
-        time_now = self._clock.time_msec()
-        return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
+        return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
 
     async def on_invite_request(
         self, origin: str, content: JsonDict, room_version_id: str
@@ -611,8 +610,7 @@ class FederationServer(FederationBase):
 
         room_version = await self.store.get_room_version_id(room_id)
 
-        time_now = self._clock.time_msec()
-        return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
+        return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
 
     async def on_send_leave_request(
         self, origin: str, content: JsonDict, room_id: str
@@ -659,9 +657,8 @@ class FederationServer(FederationBase):
             )
 
         pdu = await self.handler.on_make_knock_request(origin, room_id, user_id)
-        time_now = self._clock.time_msec()
         return {
-            "event": pdu.get_pdu_json(time_now),
+            "event": pdu.get_templated_pdu_json(),
             "room_version": room_version.identifier,
         }
 
@@ -791,7 +788,7 @@ class FederationServer(FederationBase):
         log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
         results = await self.store.claim_e2e_one_time_keys(query)
 
-        json_result = {}  # type: Dict[str, Dict[str, dict]]
+        json_result: Dict[str, Dict[str, dict]] = {}
         for user_id, device_keys in results.items():
             for device_id, keys in device_keys.items():
                 for key_id, json_str in keys.items():
@@ -1119,17 +1116,13 @@ class FederationHandlerRegistry:
         self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
         self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
 
-        self.edu_handlers = (
-            {}
-        )  # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
-        self.query_handlers = (
-            {}
-        )  # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]
+        self.edu_handlers: Dict[str, Callable[[str, dict], Awaitable[None]]] = {}
+        self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
 
         # Map from type to instance names that we should route EDU handling to.
         # We randomly choose one instance from the list to route to for each new
         # EDU received.
-        self._edu_type_to_instance = {}  # type: Dict[str, List[str]]
+        self._edu_type_to_instance: Dict[str, List[str]] = {}
 
     def register_edu_handler(
         self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 65d76ea974..1fbf325fdc 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -71,34 +71,32 @@ class FederationRemoteSendQueue(AbstractFederationSender):
         # We may have multiple federation sender instances, so we need to track
         # their positions separately.
         self._sender_instances = hs.config.worker.federation_shard_config.instances
-        self._sender_positions = {}  # type: Dict[str, int]
+        self._sender_positions: Dict[str, int] = {}
 
         # Pending presence map user_id -> UserPresenceState
-        self.presence_map = {}  # type: Dict[str, UserPresenceState]
+        self.presence_map: Dict[str, UserPresenceState] = {}
 
         # Stores the destinations we need to explicitly send presence to about a
         # given user.
         # Stream position -> (user_id, destinations)
-        self.presence_destinations = (
-            SortedDict()
-        )  # type: SortedDict[int, Tuple[str, Iterable[str]]]
+        self.presence_destinations: SortedDict[
+            int, Tuple[str, Iterable[str]]
+        ] = SortedDict()
 
         # (destination, key) -> EDU
-        self.keyed_edu = {}  # type: Dict[Tuple[str, tuple], Edu]
+        self.keyed_edu: Dict[Tuple[str, tuple], Edu] = {}
 
         # stream position -> (destination, key)
-        self.keyed_edu_changed = (
-            SortedDict()
-        )  # type: SortedDict[int, Tuple[str, tuple]]
+        self.keyed_edu_changed: SortedDict[int, Tuple[str, tuple]] = SortedDict()
 
-        self.edus = SortedDict()  # type: SortedDict[int, Edu]
+        self.edus: SortedDict[int, Edu] = SortedDict()
 
         # stream ID for the next entry into keyed_edu_changed/edus.
         self.pos = 1
 
         # map from stream ID to the time that stream entry was generated, so that we
         # can clear out entries after a while
-        self.pos_time = SortedDict()  # type: SortedDict[int, int]
+        self.pos_time: SortedDict[int, int] = SortedDict()
 
         # EVERYTHING IS SAD. In particular, python only makes new scopes when
         # we make a new function, so we need to make a new function so the inner
@@ -291,7 +289,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
 
         # list of tuple(int, BaseFederationRow), where the first is the position
         # of the federation stream.
-        rows = []  # type: List[Tuple[int, BaseFederationRow]]
+        rows: List[Tuple[int, BaseFederationRow]] = []
 
         # Fetch presence to send to destinations
         i = self.presence_destinations.bisect_right(from_token)
@@ -445,11 +443,11 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))):  # Edu
         buff.edus.setdefault(self.edu.destination, []).append(self.edu)
 
 
-_rowtypes = (
+_rowtypes: Tuple[Type[BaseFederationRow], ...] = (
     PresenceDestinationsRow,
     KeyedEduRow,
     EduRow,
-)  # type: Tuple[Type[BaseFederationRow], ...]
+)
 
 TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
 
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index deb40f4610..d980e0d986 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -14,9 +14,12 @@
 
 import abc
 import logging
+from collections import OrderedDict
 from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple
 
+import attr
 from prometheus_client import Counter
+from typing_extensions import Literal
 
 from twisted.internet import defer
 
@@ -33,8 +36,12 @@ from synapse.metrics import (
     event_processing_loop_room_count,
     events_processed_counter,
 )
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+    run_as_background_process,
+    wrap_as_background_process,
+)
 from synapse.types import JsonDict, ReadReceipt, RoomStreamToken
+from synapse.util import Clock
 from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
@@ -137,6 +144,84 @@ class AbstractFederationSender(metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
 
+@attr.s
+class _PresenceQueue:
+    """A queue of destinations that need to be woken up due to new presence
+    updates.
+
+    Staggers waking up of per destination queues to ensure that we don't attempt
+    to start TLS connections with many hosts all at once, leading to pinned CPU.
+    """
+
+    # The maximum duration in seconds between queuing up a destination and it
+    # being woken up.
+    _MAX_TIME_IN_QUEUE = 30.0
+
+    # The maximum duration in seconds between waking up consecutive destination
+    # queues.
+    _MAX_DELAY = 0.1
+
+    sender: "FederationSender" = attr.ib()
+    clock: Clock = attr.ib()
+    queue: "OrderedDict[str, Literal[None]]" = attr.ib(factory=OrderedDict)
+    processing: bool = attr.ib(default=False)
+
+    def add_to_queue(self, destination: str) -> None:
+        """Add a destination to the queue to be woken up."""
+
+        self.queue[destination] = None
+
+        if not self.processing:
+            self._handle()
+
+    @wrap_as_background_process("_PresenceQueue.handle")
+    async def _handle(self) -> None:
+        """Background process to drain the queue."""
+
+        if not self.queue:
+            return
+
+        assert not self.processing
+        self.processing = True
+
+        try:
+            # We start with a delay that should drain the queue quickly enough that
+            # we process all destinations in the queue in _MAX_TIME_IN_QUEUE
+            # seconds.
+            #
+            # We also add an upper bound to the delay, to gracefully handle the
+            # case where the queue only has a few entries in it.
+            current_sleep_seconds = min(
+                self._MAX_DELAY, self._MAX_TIME_IN_QUEUE / len(self.queue)
+            )
+
+            while self.queue:
+                destination, _ = self.queue.popitem(last=False)
+
+                queue = self.sender._get_per_destination_queue(destination)
+
+                if not queue._new_data_to_send:
+                    # The per destination queue has already been woken up.
+                    continue
+
+                queue.attempt_new_transaction()
+
+                await self.clock.sleep(current_sleep_seconds)
+
+                if not self.queue:
+                    break
+
+                # More destinations may have been added to the queue, so we may
+                # need to reduce the delay to ensure everything gets processed
+                # within _MAX_TIME_IN_QUEUE seconds.
+                current_sleep_seconds = min(
+                    current_sleep_seconds, self._MAX_TIME_IN_QUEUE / len(self.queue)
+                )
+
+        finally:
+            self.processing = False
+
+
 class FederationSender(AbstractFederationSender):
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
@@ -148,14 +233,14 @@ class FederationSender(AbstractFederationSender):
         self.clock = hs.get_clock()
         self.is_mine_id = hs.is_mine_id
 
-        self._presence_router = None  # type: Optional[PresenceRouter]
+        self._presence_router: Optional["PresenceRouter"] = None
         self._transaction_manager = TransactionManager(hs)
 
         self._instance_name = hs.get_instance_name()
         self._federation_shard_config = hs.config.worker.federation_shard_config
 
         # map from destination to PerDestinationQueue
-        self._per_destination_queues = {}  # type: Dict[str, PerDestinationQueue]
+        self._per_destination_queues: Dict[str, PerDestinationQueue] = {}
 
         LaterGauge(
             "synapse_federation_transaction_queue_pending_destinations",
@@ -192,9 +277,7 @@ class FederationSender(AbstractFederationSender):
         # awaiting a call to flush_read_receipts_for_room. The presence of an entry
         # here for a given room means that we are rate-limiting RR flushes to that room,
         # and that there is a pending call to _flush_rrs_for_room in the system.
-        self._queues_awaiting_rr_flush_by_room = (
-            {}
-        )  # type: Dict[str, Set[PerDestinationQueue]]
+        self._queues_awaiting_rr_flush_by_room: Dict[str, Set[PerDestinationQueue]] = {}
 
         self._rr_txn_interval_per_room_ms = (
             1000.0 / hs.config.federation_rr_transactions_per_room_per_second
@@ -210,6 +293,8 @@ class FederationSender(AbstractFederationSender):
 
         self._external_cache = hs.get_external_cache()
 
+        self._presence_queue = _PresenceQueue(self, self.clock)
+
     def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
         """Get or create a PerDestinationQueue for the given destination
 
@@ -265,7 +350,7 @@ class FederationSender(AbstractFederationSender):
                     if not event.internal_metadata.should_proactively_send():
                         return
 
-                    destinations = None  # type: Optional[Set[str]]
+                    destinations: Optional[Set[str]] = None
                     if not event.prev_event_ids():
                         # If there are no prev event IDs then the state is empty
                         # and so no remote servers in the room
@@ -331,7 +416,7 @@ class FederationSender(AbstractFederationSender):
                         for event in events:
                             await handle_event(event)
 
-                events_by_room = {}  # type: Dict[str, List[EventBase]]
+                events_by_room: Dict[str, List[EventBase]] = {}
                 for event in events:
                     events_by_room.setdefault(event.room_id, []).append(event)
 
@@ -519,7 +604,12 @@ class FederationSender(AbstractFederationSender):
                 self._instance_name, destination
             ):
                 continue
-            self._get_per_destination_queue(destination).send_presence(states)
+
+            self._get_per_destination_queue(destination).send_presence(
+                states, start_loop=False
+            )
+
+            self._presence_queue.add_to_queue(destination)
 
     def build_and_send_edu(
         self,
@@ -628,7 +718,7 @@ class FederationSender(AbstractFederationSender):
         In order to reduce load spikes, adds a delay between each destination.
         """
 
-        last_processed = None  # type: Optional[str]
+        last_processed: Optional[str] = None
 
         while True:
             destinations_to_wake = (
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 3a2efd56ee..c11d1f6d31 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -105,34 +105,34 @@ class PerDestinationQueue:
         # catch-up at startup.
         # New events will only be sent once this is finished, at which point
         # _catching_up is flipped to False.
-        self._catching_up = True  # type: bool
+        self._catching_up: bool = True
 
         # The stream_ordering of the most recent PDU that was discarded due to
         # being in catch-up mode.
-        self._catchup_last_skipped = 0  # type: int
+        self._catchup_last_skipped: int = 0
 
         # Cache of the last successfully-transmitted stream ordering for this
         # destination (we are the only updater so this is safe)
-        self._last_successful_stream_ordering = None  # type: Optional[int]
+        self._last_successful_stream_ordering: Optional[int] = None
 
         # a queue of pending PDUs
-        self._pending_pdus = []  # type: List[EventBase]
+        self._pending_pdus: List[EventBase] = []
 
         # XXX this is never actually used: see
         # https://github.com/matrix-org/synapse/issues/7549
-        self._pending_edus = []  # type: List[Edu]
+        self._pending_edus: List[Edu] = []
 
         # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
         # based on their key (e.g. typing events by room_id)
         # Map of (edu_type, key) -> Edu
-        self._pending_edus_keyed = {}  # type: Dict[Tuple[str, Hashable], Edu]
+        self._pending_edus_keyed: Dict[Tuple[str, Hashable], Edu] = {}
 
         # Map of user_id -> UserPresenceState of pending presence to be sent to this
         # destination
-        self._pending_presence = {}  # type: Dict[str, UserPresenceState]
+        self._pending_presence: Dict[str, UserPresenceState] = {}
 
         # room_id -> receipt_type -> user_id -> receipt_dict
-        self._pending_rrs = {}  # type: Dict[str, Dict[str, Dict[str, dict]]]
+        self._pending_rrs: Dict[str, Dict[str, Dict[str, dict]]] = {}
         self._rrs_pending_flush = False
 
         # stream_id of last successfully sent to-device message.
@@ -171,14 +171,24 @@ class PerDestinationQueue:
 
         self.attempt_new_transaction()
 
-    def send_presence(self, states: Iterable[UserPresenceState]) -> None:
-        """Add presence updates to the queue. Start the transmission loop if necessary.
+    def send_presence(
+        self, states: Iterable[UserPresenceState], start_loop: bool = True
+    ) -> None:
+        """Add presence updates to the queue.
+
+        Args:
+            states: Presence updates to send
+            start_loop: Whether to start the transmission loop if not already
+                running.
 
         Args:
             states: presence to send
         """
         self._pending_presence.update({state.user_id: state for state in states})
-        self.attempt_new_transaction()
+        self._new_data_to_send = True
+
+        if start_loop:
+            self.attempt_new_transaction()
 
     def queue_read_receipt(self, receipt: ReadReceipt) -> None:
         """Add a RR to the list to be sent. Doesn't start the transmission loop yet
@@ -243,7 +253,7 @@ class PerDestinationQueue:
         )
 
     async def _transaction_transmission_loop(self) -> None:
-        pending_pdus = []  # type: List[EventBase]
+        pending_pdus: List[EventBase] = []
         try:
             self.transmission_loop_running = True
 
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index c9e7c57461..98b1bf77fd 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -395,9 +395,9 @@ class TransportLayerClient:
             # this uses MSC2197 (Search Filtering over Federation)
             path = _create_v1_path("/publicRooms")
 
-            data = {
+            data: Dict[str, Any] = {
                 "include_all_networks": "true" if include_all_networks else "false"
-            }  # type: Dict[str, Any]
+            }
             if third_party_instance_id:
                 data["third_party_instance_id"] = third_party_instance_id
             if limit:
@@ -423,9 +423,9 @@ class TransportLayerClient:
         else:
             path = _create_v1_path("/publicRooms")
 
-            args = {
+            args: Dict[str, Any] = {
                 "include_all_networks": "true" if include_all_networks else "false"
-            }  # type: Dict[str, Any]
+            }
             if third_party_instance_id:
                 args["third_party_instance_id"] = (third_party_instance_id,)
             if limit:
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index d37d9565fc..2974d4d0cc 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1013,7 +1013,7 @@ class PublicRoomList(BaseFederationServlet):
         if not self.allow_access:
             raise FederationDeniedError(origin)
 
-        limit = int(content.get("limit", 100))  # type: Optional[int]
+        limit: Optional[int] = int(content.get("limit", 100))
         since_token = content.get("since", None)
         search_filter = content.get("filter", None)
 
@@ -1095,7 +1095,9 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet):
         query: Dict[bytes, List[bytes]],
         group_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1110,7 +1112,9 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet):
         query: Dict[bytes, List[bytes]],
         group_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1131,7 +1135,9 @@ class FederationGroupsSummaryServlet(BaseGroupsServerServlet):
         query: Dict[bytes, List[bytes]],
         group_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1152,7 +1158,9 @@ class FederationGroupsRoomsServlet(BaseGroupsServerServlet):
         query: Dict[bytes, List[bytes]],
         group_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1174,7 +1182,9 @@ class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet):
         group_id: str,
         room_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1192,7 +1202,9 @@ class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet):
         group_id: str,
         room_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1220,7 +1232,9 @@ class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet):
         room_id: str,
         config_key: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1243,7 +1257,9 @@ class FederationGroupsUsersServlet(BaseGroupsServerServlet):
         query: Dict[bytes, List[bytes]],
         group_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1264,7 +1280,9 @@ class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet):
         query: Dict[bytes, List[bytes]],
         group_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1288,7 +1306,9 @@ class FederationGroupsInviteServlet(BaseGroupsServerServlet):
         group_id: str,
         user_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1354,7 +1374,9 @@ class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet):
         group_id: str,
         user_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1487,7 +1509,9 @@ class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet):
         category_id: str,
         room_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1523,7 +1547,9 @@ class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet):
         category_id: str,
         room_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1549,7 +1575,9 @@ class FederationGroupsCategoriesServlet(BaseGroupsServerServlet):
         query: Dict[bytes, List[bytes]],
         group_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1571,7 +1599,9 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
         group_id: str,
         category_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1589,7 +1619,9 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
         group_id: str,
         category_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1618,7 +1650,9 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
         group_id: str,
         category_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1644,7 +1678,9 @@ class FederationGroupsRolesServlet(BaseGroupsServerServlet):
         query: Dict[bytes, List[bytes]],
         group_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1666,7 +1702,9 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet):
         group_id: str,
         role_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1682,7 +1720,9 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet):
         group_id: str,
         role_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1713,7 +1753,9 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet):
         group_id: str,
         role_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1750,7 +1792,9 @@ class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet):
         role_id: str,
         user_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1784,7 +1828,9 @@ class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet):
         role_id: str,
         user_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1825,7 +1871,9 @@ class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet):
         query: Dict[bytes, List[bytes]],
         group_id: str,
     ) -> Tuple[int, JsonDict]:
-        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        requester_user_id = parse_string_from_args(
+            query, "requester_user_id", required=True
+        )
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
@@ -1943,7 +1991,7 @@ class RoomComplexityServlet(BaseFederationServlet):
         return 200, complexity
 
 
-FEDERATION_SERVLET_CLASSES = (
+FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
     FederationSendServlet,
     FederationEventServlet,
     FederationStateV1Servlet,
@@ -1971,15 +2019,13 @@ FEDERATION_SERVLET_CLASSES = (
     FederationSpaceSummaryServlet,
     FederationV1SendKnockServlet,
     FederationMakeKnockServlet,
-)  # type: Tuple[Type[BaseFederationServlet], ...]
+)
 
-OPENID_SERVLET_CLASSES = (
-    OpenIdUserInfo,
-)  # type: Tuple[Type[BaseFederationServlet], ...]
+OPENID_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (OpenIdUserInfo,)
 
-ROOM_LIST_CLASSES = (PublicRoomList,)  # type: Tuple[Type[PublicRoomList], ...]
+ROOM_LIST_CLASSES: Tuple[Type[PublicRoomList], ...] = (PublicRoomList,)
 
-GROUP_SERVER_SERVLET_CLASSES = (
+GROUP_SERVER_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
     FederationGroupsProfileServlet,
     FederationGroupsSummaryServlet,
     FederationGroupsRoomsServlet,
@@ -1998,19 +2044,19 @@ GROUP_SERVER_SERVLET_CLASSES = (
     FederationGroupsAddRoomsServlet,
     FederationGroupsAddRoomsConfigServlet,
     FederationGroupsSettingJoinPolicyServlet,
-)  # type: Tuple[Type[BaseFederationServlet], ...]
+)
 
 
-GROUP_LOCAL_SERVLET_CLASSES = (
+GROUP_LOCAL_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
     FederationGroupsLocalInviteServlet,
     FederationGroupsRemoveLocalUserServlet,
     FederationGroupsBulkPublicisedServlet,
-)  # type: Tuple[Type[BaseFederationServlet], ...]
+)
 
 
-GROUP_ATTESTATION_SERVLET_CLASSES = (
+GROUP_ATTESTATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
     FederationGroupsRenewAttestaionServlet,
-)  # type: Tuple[Type[BaseFederationServlet], ...]
+)
 
 
 DEFAULT_SERVLET_GROUPS = (
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index a06d060ebf..3dc55ab861 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -707,9 +707,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         See accept_invite, join_group.
         """
         if not self.hs.is_mine_id(user_id):
-            local_attestation = self.attestations.create_attestation(
-                group_id, user_id
-            )  # type: Optional[JsonDict]
+            local_attestation: Optional[
+                JsonDict
+            ] = self.attestations.create_attestation(group_id, user_id)
 
             remote_attestation = content["attestation"]
 
@@ -868,9 +868,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
                 remote_attestation, user_id=requester_user_id, group_id=group_id
             )
 
-            local_attestation = self.attestations.create_attestation(
-                group_id, requester_user_id
-            )  # type: Optional[JsonDict]
+            local_attestation: Optional[
+                JsonDict
+            ] = self.attestations.create_attestation(group_id, requester_user_id)
         else:
             local_attestation = None
             remote_attestation = None
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index d800e16912..6a05a65305 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -15,8 +15,6 @@
 import logging
 from typing import TYPE_CHECKING, Optional
 
-import synapse.state
-import synapse.storage
 import synapse.types
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.ratelimiting import Ratelimiter
@@ -38,10 +36,10 @@ class BaseHandler:
     """
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()  # type: synapse.storage.DataStore
+        self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.notifier = hs.get_notifier()
-        self.state_handler = hs.get_state_handler()  # type: synapse.state.StateHandler
+        self.state_handler = hs.get_state_handler()
         self.distributor = hs.get_distributor()
         self.clock = hs.get_clock()
         self.hs = hs
@@ -55,12 +53,12 @@ class BaseHandler:
         # Check whether ratelimiting room admin message redaction is enabled
         # by the presence of rate limits in the config
         if self.hs.config.rc_admin_redaction:
-            self.admin_redaction_ratelimiter = Ratelimiter(
+            self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
                 store=self.store,
                 clock=self.clock,
                 rate_hz=self.hs.config.rc_admin_redaction.per_second,
                 burst_count=self.hs.config.rc_admin_redaction.burst_count,
-            )  # type: Optional[Ratelimiter]
+            )
         else:
             self.admin_redaction_ratelimiter = None
 
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index d752cf34f0..078accd634 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -15,9 +15,11 @@
 import email.mime.multipart
 import email.utils
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
 
-from synapse.api.errors import StoreError, SynapseError
+from twisted.web.http import Request
+
+from synapse.api.errors import AuthError, StoreError, SynapseError
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.types import UserID
 from synapse.util import stringutils
@@ -27,6 +29,15 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+# Types for callbacks to be registered via the module api
+IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]]
+ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable]
+# Temporary hooks to allow for a transition from `/_matrix/client` endpoints
+# to `/_synapse/client/account_validity`. See `register_account_validity_callbacks`.
+ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable]
+ON_LEGACY_RENEW_CALLBACK = Callable[[str], Awaitable[Tuple[bool, bool, int]]]
+ON_LEGACY_ADMIN_REQUEST = Callable[[Request], Awaitable]
+
 
 class AccountValidityHandler:
     def __init__(self, hs: "HomeServer"):
@@ -70,6 +81,99 @@ class AccountValidityHandler:
             if hs.config.run_background_tasks:
                 self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
 
+        self._is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = []
+        self._on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = []
+        self._on_legacy_send_mail_callback: Optional[
+            ON_LEGACY_SEND_MAIL_CALLBACK
+        ] = None
+        self._on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None
+
+        # The legacy admin requests callback isn't a protected attribute because we need
+        # to access it from the admin servlet, which is outside of this handler.
+        self.on_legacy_admin_request_callback: Optional[ON_LEGACY_ADMIN_REQUEST] = None
+
+    def register_account_validity_callbacks(
+        self,
+        is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
+        on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
+        on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
+        on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
+        on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
+    ):
+        """Register callbacks from module for each hook."""
+        if is_user_expired is not None:
+            self._is_user_expired_callbacks.append(is_user_expired)
+
+        if on_user_registration is not None:
+            self._on_user_registration_callbacks.append(on_user_registration)
+
+        # The builtin account validity feature exposes 3 endpoints (send_mail, renew, and
+        # an admin one). As part of moving the feature into a module, we need to change
+        # the path from /_matrix/client/unstable/account_validity/... to
+        # /_synapse/client/account_validity, because:
+        #
+        #   * the feature isn't part of the Matrix spec thus shouldn't live under /_matrix
+        #   * the way we register servlets means that modules can't register resources
+        #     under /_matrix/client
+        #
+        # We need to allow for a transition period between the old and new endpoints
+        # in order to allow for clients to update (and for emails to be processed).
+        #
+        # Once the email-account-validity module is loaded, it will take control of account
+        # validity by moving the rows from our `account_validity` table into its own table.
+        #
+        # Therefore, we need to allow modules (in practice just the one implementing the
+        # email-based account validity) to temporarily hook into the legacy endpoints so we
+        # can route the traffic coming into the old endpoints into the module, which is
+        # why we have the following three temporary hooks.
+        if on_legacy_send_mail is not None:
+            if self._on_legacy_send_mail_callback is not None:
+                raise RuntimeError("Tried to register on_legacy_send_mail twice")
+
+            self._on_legacy_send_mail_callback = on_legacy_send_mail
+
+        if on_legacy_renew is not None:
+            if self._on_legacy_renew_callback is not None:
+                raise RuntimeError("Tried to register on_legacy_renew twice")
+
+            self._on_legacy_renew_callback = on_legacy_renew
+
+        if on_legacy_admin_request is not None:
+            if self.on_legacy_admin_request_callback is not None:
+                raise RuntimeError("Tried to register on_legacy_admin_request twice")
+
+            self.on_legacy_admin_request_callback = on_legacy_admin_request
+
+    async def is_user_expired(self, user_id: str) -> bool:
+        """Checks if a user has expired against third-party modules.
+
+        Args:
+            user_id: The user to check the expiry of.
+
+        Returns:
+            Whether the user has expired.
+        """
+        for callback in self._is_user_expired_callbacks:
+            expired = await callback(user_id)
+            if expired is not None:
+                return expired
+
+        if self._account_validity_enabled:
+            # If no module could determine whether the user has expired and the legacy
+            # configuration is enabled, fall back to it.
+            return await self.store.is_account_expired(user_id, self.clock.time_msec())
+
+        return False
+
+    async def on_user_registration(self, user_id: str):
+        """Tell third-party modules about a user's registration.
+
+        Args:
+            user_id: The ID of the newly registered user.
+        """
+        for callback in self._on_user_registration_callbacks:
+            await callback(user_id)
+
     @wrap_as_background_process("send_renewals")
     async def _send_renewal_emails(self) -> None:
         """Gets the list of users whose account is expiring in the amount of time
@@ -95,6 +199,17 @@ class AccountValidityHandler:
         Raises:
             SynapseError if the user is not set to renew.
         """
+        # If a module supports sending a renewal email from here, do that, otherwise do
+        # the legacy dance.
+        if self._on_legacy_send_mail_callback is not None:
+            await self._on_legacy_send_mail_callback(user_id)
+            return
+
+        if not self._account_validity_renew_by_email_enabled:
+            raise AuthError(
+                403, "Account renewal via email is disabled on this server."
+            )
+
         expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
 
         # If this user isn't set to be expired, raise an error.
@@ -209,6 +324,10 @@ class AccountValidityHandler:
         token is considered stale. A token is stale if the 'token_used_ts_ms' db column
         is non-null.
 
+        This method exists to support handling the legacy account validity /renew
+        endpoint. If a module implements the on_legacy_renew callback, then this process
+        is delegated to the module instead.
+
         Args:
             renewal_token: Token sent with the renewal request.
         Returns:
@@ -218,6 +337,11 @@ class AccountValidityHandler:
               * An int representing the user's expiry timestamp as milliseconds since the
                 epoch, or 0 if the token was invalid.
         """
+        # If a module supports triggering a renew from here, do that, otherwise do the
+        # legacy dance.
+        if self._on_legacy_renew_callback is not None:
+            return await self._on_legacy_renew_callback(renewal_token)
+
         try:
             (
                 user_id,
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index d75a8b15c3..bfa7f2c545 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -139,7 +139,7 @@ class AdminHandler(BaseHandler):
             to_key = RoomStreamToken(None, stream_ordering)
 
             # Events that we've processed in this room
-            written_events = set()  # type: Set[str]
+            written_events: Set[str] = set()
 
             # We need to track gaps in the events stream so that we can then
             # write out the state at those events. We do this by keeping track
@@ -152,7 +152,7 @@ class AdminHandler(BaseHandler):
             # The reverse mapping to above, i.e. map from unseen event to events
             # that have the unseen event in their prev_events, i.e. the unseen
             # events "children".
-            unseen_to_child_events = {}  # type: Dict[str, Set[str]]
+            unseen_to_child_events: Dict[str, Set[str]] = {}
 
             # We fetch events in the room the user could see by fetching *all*
             # events that we have and then filtering, this isn't the most
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 862638cc4f..21a17cd2e8 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -96,7 +96,7 @@ class ApplicationServicesHandler:
                         self.current_max, limit
                     )
 
-                    events_by_room = {}  # type: Dict[str, List[EventBase]]
+                    events_by_room: Dict[str, List[EventBase]] = {}
                     for event in events:
                         events_by_room.setdefault(event.room_id, []).append(event)
 
@@ -275,7 +275,7 @@ class ApplicationServicesHandler:
     async def _handle_presence(
         self, service: ApplicationService, users: Collection[Union[str, UserID]]
     ) -> List[JsonDict]:
-        events = []  # type: List[JsonDict]
+        events: List[JsonDict] = []
         presence_source = self.event_sources.sources["presence"]
         from_key = await self.store.get_type_stream_id_for_appservice(
             service, "presence"
@@ -375,7 +375,7 @@ class ApplicationServicesHandler:
         self, only_protocol: Optional[str] = None
     ) -> Dict[str, JsonDict]:
         services = self.store.get_app_services()
-        protocols = {}  # type: Dict[str, List[JsonDict]]
+        protocols: Dict[str, List[JsonDict]] = {}
 
         # Collect up all the individual protocol responses out of the ASes
         for s in services:
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index e2ac595a62..22a8552241 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -191,7 +191,7 @@ class AuthHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.checkers = {}  # type: Dict[str, UserInteractiveAuthChecker]
+        self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
         for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
             inst = auth_checker_class(hs)
             if inst.is_enabled():
@@ -296,7 +296,7 @@ class AuthHandler(BaseHandler):
 
         # A mapping of user ID to extra attributes to include in the login
         # response.
-        self._extra_attributes = {}  # type: Dict[str, SsoLoginExtraAttributes]
+        self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {}
 
     async def validate_user_via_ui_auth(
         self,
@@ -500,7 +500,7 @@ class AuthHandler(BaseHandler):
                 all the stages in any of the permitted flows.
         """
 
-        sid = None  # type: Optional[str]
+        sid: Optional[str] = None
         authdict = clientdict.pop("auth", {})
         if "session" in authdict:
             sid = authdict["session"]
@@ -588,9 +588,9 @@ class AuthHandler(BaseHandler):
             )
 
         # check auth type currently being presented
-        errordict = {}  # type: Dict[str, Any]
+        errordict: Dict[str, Any] = {}
         if "type" in authdict:
-            login_type = authdict["type"]  # type: str
+            login_type: str = authdict["type"]
             try:
                 result = await self._check_auth_dict(authdict, clientip)
                 if result:
@@ -766,7 +766,7 @@ class AuthHandler(BaseHandler):
             LoginType.TERMS: self._get_params_terms,
         }
 
-        params = {}  # type: Dict[str, Any]
+        params: Dict[str, Any] = {}
 
         for f in public_flows:
             for stage in f:
@@ -1530,9 +1530,9 @@ class AuthHandler(BaseHandler):
         except StoreError:
             raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
 
-        user_id_to_verify = await self.get_session_data(
+        user_id_to_verify: str = await self.get_session_data(
             session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
-        )  # type: str
+        )
 
         idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
             user_id_to_verify
diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py
index 7346ccfe93..0325f86e20 100644
--- a/synapse/handlers/cas.py
+++ b/synapse/handlers/cas.py
@@ -40,7 +40,7 @@ class CasError(Exception):
 
     def __str__(self):
         if self.error_description:
-            return "{}: {}".format(self.error, self.error_description)
+            return f"{self.error}: {self.error_description}"
         return self.error
 
 
@@ -171,7 +171,7 @@ class CasHandler:
 
         # Iterate through the nodes and pull out the user and any extra attributes.
         user = None
-        attributes = {}  # type: Dict[str, List[Optional[str]]]
+        attributes: Dict[str, List[Optional[str]]] = {}
         for child in root[0]:
             if child.tag.endswith("user"):
                 user = child.text
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 95bdc5902a..46ee834407 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -452,7 +452,7 @@ class DeviceHandler(DeviceWorkerHandler):
             user_id
         )
 
-        hosts = set()  # type: Set[str]
+        hosts: Set[str] = set()
         if self.hs.is_mine_id(user_id):
             hosts.update(get_domain_from_id(u) for u in users_who_share_room)
             hosts.discard(self.server_name)
@@ -613,20 +613,20 @@ class DeviceListUpdater:
         self._remote_edu_linearizer = Linearizer(name="remote_device_list")
 
         # user_id -> list of updates waiting to be handled.
-        self._pending_updates = (
-            {}
-        )  # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]]
+        self._pending_updates: Dict[
+            str, List[Tuple[str, str, Iterable[str], JsonDict]]
+        ] = {}
 
         # Recently seen stream ids. We don't bother keeping these in the DB,
         # but they're useful to have them about to reduce the number of spurious
         # resyncs.
-        self._seen_updates = ExpiringCache(
+        self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache(
             cache_name="device_update_edu",
             clock=self.clock,
             max_len=10000,
             expiry_ms=30 * 60 * 1000,
             iterable=True,
-        )  # type: ExpiringCache[str, Set[str]]
+        )
 
         # Attempt to resync out of sync device lists every 30s.
         self._resync_retry_in_progress = False
@@ -755,7 +755,7 @@ class DeviceListUpdater:
         """Given a list of updates for a user figure out if we need to do a full
         resync, or whether we have enough data that we can just apply the delta.
         """
-        seen_updates = self._seen_updates.get(user_id, set())  # type: Set[str]
+        seen_updates: Set[str] = self._seen_updates.get(user_id, set())
 
         extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
 
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 580b941595..679b47f081 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -203,7 +203,7 @@ class DeviceMessageHandler:
         log_kv({"number_of_to_device_messages": len(messages)})
         set_tag("sender", sender_user_id)
         local_messages = {}
-        remote_messages = {}  # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
+        remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
         for user_id, by_device in messages.items():
             # Ratelimit local cross-user key requests by the sending device.
             if (
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 4064a2b859..d487fee627 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -22,6 +22,7 @@ from synapse.api.errors import (
     CodeMessageException,
     Codes,
     NotFoundError,
+    RequestSendFailed,
     ShadowBanError,
     StoreError,
     SynapseError,
@@ -236,9 +237,9 @@ class DirectoryHandler(BaseHandler):
     async def get_association(self, room_alias: RoomAlias) -> JsonDict:
         room_id = None
         if self.hs.is_mine(room_alias):
-            result = await self.get_association_from_room_alias(
-                room_alias
-            )  # type: Optional[RoomAliasMapping]
+            result: Optional[
+                RoomAliasMapping
+            ] = await self.get_association_from_room_alias(room_alias)
 
             if result:
                 room_id = result.room_id
@@ -252,12 +253,14 @@ class DirectoryHandler(BaseHandler):
                     retry_on_dns_fail=False,
                     ignore_backoff=True,
                 )
+            except RequestSendFailed:
+                raise SynapseError(502, "Failed to fetch alias")
             except CodeMessageException as e:
                 logging.warning("Error retrieving alias")
                 if e.code == 404:
                     fed_result = None
                 else:
-                    raise
+                    raise SynapseError(502, "Failed to fetch alias")
 
             if fed_result and "room_id" in fed_result and "servers" in fed_result:
                 room_id = fed_result["room_id"]
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 3972849d4d..d92370859f 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -115,9 +115,9 @@ class E2eKeysHandler:
                 the number of in-flight queries at a time.
         """
         with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
-            device_keys_query = query_body.get(
+            device_keys_query: Dict[str, Iterable[str]] = query_body.get(
                 "device_keys", {}
-            )  # type: Dict[str, Iterable[str]]
+            )
 
             # separate users by domain.
             # make a map from domain to user_id to device_ids
@@ -136,7 +136,7 @@ class E2eKeysHandler:
 
             # First get local devices.
             # A map of destination -> failure response.
-            failures = {}  # type: Dict[str, JsonDict]
+            failures: Dict[str, JsonDict] = {}
             results = {}
             if local_query:
                 local_result = await self.query_local_devices(local_query)
@@ -151,11 +151,9 @@ class E2eKeysHandler:
 
             # Now attempt to get any remote devices from our local cache.
             # A map of destination -> user ID -> device IDs.
-            remote_queries_not_in_cache = (
-                {}
-            )  # type: Dict[str, Dict[str, Iterable[str]]]
+            remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
             if remote_queries:
-                query_list = []  # type: List[Tuple[str, Optional[str]]]
+                query_list: List[Tuple[str, Optional[str]]] = []
                 for user_id, device_ids in remote_queries.items():
                     if device_ids:
                         query_list.extend(
@@ -362,9 +360,9 @@ class E2eKeysHandler:
             A map from user_id -> device_id -> device details
         """
         set_tag("local_query", query)
-        local_query = []  # type: List[Tuple[str, Optional[str]]]
+        local_query: List[Tuple[str, Optional[str]]] = []
 
-        result_dict = {}  # type: Dict[str, Dict[str, dict]]
+        result_dict: Dict[str, Dict[str, dict]] = {}
         for user_id, device_ids in query.items():
             # we use UserID.from_string to catch invalid user ids
             if not self.is_mine(UserID.from_string(user_id)):
@@ -402,9 +400,9 @@ class E2eKeysHandler:
         self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
     ) -> JsonDict:
         """Handle a device key query from a federated server"""
-        device_keys_query = query_body.get(
+        device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
             "device_keys", {}
-        )  # type: Dict[str, Optional[List[str]]]
+        )
         res = await self.query_local_devices(device_keys_query)
         ret = {"device_keys": res}
 
@@ -421,8 +419,8 @@ class E2eKeysHandler:
     async def claim_one_time_keys(
         self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
     ) -> JsonDict:
-        local_query = []  # type: List[Tuple[str, str, str]]
-        remote_queries = {}  # type: Dict[str, Dict[str, Dict[str, str]]]
+        local_query: List[Tuple[str, str, str]] = []
+        remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
 
         for user_id, one_time_keys in query.get("one_time_keys", {}).items():
             # we use UserID.from_string to catch invalid user ids
@@ -439,8 +437,8 @@ class E2eKeysHandler:
         results = await self.store.claim_e2e_one_time_keys(local_query)
 
         # A map of user ID -> device ID -> key ID -> key.
-        json_result = {}  # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
-        failures = {}  # type: Dict[str, JsonDict]
+        json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+        failures: Dict[str, JsonDict] = {}
         for user_id, device_keys in results.items():
             for device_id, keys in device_keys.items():
                 for key_id, json_str in keys.items():
@@ -768,8 +766,8 @@ class E2eKeysHandler:
         Raises:
             SynapseError: if the input is malformed
         """
-        signature_list = []  # type: List[SignatureListItem]
-        failures = {}  # type: Dict[str, Dict[str, JsonDict]]
+        signature_list: List["SignatureListItem"] = []
+        failures: Dict[str, Dict[str, JsonDict]] = {}
         if not signatures:
             return signature_list, failures
 
@@ -930,8 +928,8 @@ class E2eKeysHandler:
         Raises:
             SynapseError: if the input is malformed
         """
-        signature_list = []  # type: List[SignatureListItem]
-        failures = {}  # type: Dict[str, Dict[str, JsonDict]]
+        signature_list: List["SignatureListItem"] = []
+        failures: Dict[str, Dict[str, JsonDict]] = {}
         if not signatures:
             return signature_list, failures
 
@@ -1300,7 +1298,7 @@ class SigningKeyEduUpdater:
         self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
 
         # user_id -> list of updates waiting to be handled.
-        self._pending_updates = {}  # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
+        self._pending_updates: Dict[str, List[Tuple[JsonDict, JsonDict]]] = {}
 
     async def incoming_signing_key_update(
         self, origin: str, edu_content: JsonDict
@@ -1349,7 +1347,7 @@ class SigningKeyEduUpdater:
                 # This can happen since we batch updates
                 return
 
-            device_ids = []  # type: List[str]
+            device_ids: List[str] = []
 
             logger.info("pending updates: %r", pending_updates)
 
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index f134f1e234..4b3f037072 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -93,7 +93,7 @@ class EventStreamHandler(BaseHandler):
 
             # When the user joins a new room, or another user joins a currently
             # joined room, we need to send down presence for those users.
-            to_add = []  # type: List[JsonDict]
+            to_add: List[JsonDict] = []
             for event in events:
                 if not isinstance(event, EventBase):
                     continue
@@ -103,9 +103,9 @@ class EventStreamHandler(BaseHandler):
                     # Send down presence.
                     if event.state_key == auth_user_id:
                         # Send down presence for everyone in the room.
-                        users = await self.store.get_users_in_room(
+                        users: Iterable[str] = await self.store.get_users_in_room(
                             event.room_id
-                        )  # type: Iterable[str]
+                        )
                     else:
                         users = [event.state_key]
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 991ec9919a..5728719909 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -181,7 +181,7 @@ class FederationHandler(BaseHandler):
 
         # When joining a room we need to queue any events for that room up.
         # For each room, a list of (pdu, origin) tuples.
-        self.room_queues = {}  # type: Dict[str, List[Tuple[EventBase, str]]]
+        self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {}
         self._room_pdu_linearizer = Linearizer("fed_room_pdu")
 
         self._room_backfill = Linearizer("room_backfill")
@@ -368,7 +368,7 @@ class FederationHandler(BaseHandler):
                     ours = await self.state_store.get_state_groups_ids(room_id, seen)
 
                     # state_maps is a list of mappings from (type, state_key) to event_id
-                    state_maps = list(ours.values())  # type: List[StateMap[str]]
+                    state_maps: List[StateMap[str]] = list(ours.values())
 
                     # we don't need this any more, let's delete it.
                     del ours
@@ -735,7 +735,7 @@ class FederationHandler(BaseHandler):
         # we need to make sure we re-load from the database to get the rejected
         # state correct.
         fetched_events.update(
-            (await self.store.get_events(missing_desired_events, allow_rejected=True))
+            await self.store.get_events(missing_desired_events, allow_rejected=True)
         )
 
         # check for events which were in the wrong room.
@@ -845,7 +845,7 @@ class FederationHandler(BaseHandler):
                 # exact key to expect. Otherwise check it matches any key we
                 # have for that device.
 
-                current_keys = []  # type: Container[str]
+                current_keys: Container[str] = []
 
                 if device:
                     keys = device.get("keys", {}).get("keys", {})
@@ -1185,7 +1185,7 @@ class FederationHandler(BaseHandler):
                 if e_type == EventTypes.Member and event.membership == Membership.JOIN
             ]
 
-            joined_domains = {}  # type: Dict[str, int]
+            joined_domains: Dict[str, int] = {}
             for u, d in joined_users:
                 try:
                     dom = get_domain_from_id(u)
@@ -1314,7 +1314,7 @@ class FederationHandler(BaseHandler):
 
         room_version = await self.store.get_room_version(room_id)
 
-        event_map = {}  # type: Dict[str, EventBase]
+        event_map: Dict[str, EventBase] = {}
 
         async def get_event(event_id: str):
             with nested_logging_context(event_id):
@@ -1414,12 +1414,15 @@ class FederationHandler(BaseHandler):
 
         Invites must be signed by the invitee's server before distribution.
         """
-        pdu = await self.federation_client.send_invite(
-            destination=target_host,
-            room_id=event.room_id,
-            event_id=event.event_id,
-            pdu=event,
-        )
+        try:
+            pdu = await self.federation_client.send_invite(
+                destination=target_host,
+                room_id=event.room_id,
+                event_id=event.event_id,
+                pdu=event,
+            )
+        except RequestSendFailed:
+            raise SynapseError(502, f"Can't connect to server {target_host}")
 
         return pdu
 
@@ -1593,7 +1596,7 @@ class FederationHandler(BaseHandler):
 
         # Ask the remote server to create a valid knock event for us. Once received,
         # we sign the event
-        params = {"ver": supported_room_versions}  # type: Dict[str, Iterable[str]]
+        params: Dict[str, Iterable[str]] = {"ver": supported_room_versions}
         origin, event, event_format_version = await self._make_and_verify_event(
             target_hosts, room_id, knockee, Membership.KNOCK, content, params=params
         )
@@ -1931,7 +1934,7 @@ class FederationHandler(BaseHandler):
             builder=builder
         )
 
-        event_allowed = await self.third_party_event_rules.check_event_allowed(
+        event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
             event, context
         )
         if not event_allowed:
@@ -2023,7 +2026,7 @@ class FederationHandler(BaseHandler):
         # for knock events, we run the third-party event rules. It's not entirely clear
         # why we don't do this for other sorts of membership events.
         if event.membership == Membership.KNOCK:
-            event_allowed = await self.third_party_event_rules.check_event_allowed(
+            event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
                 event, context
             )
             if not event_allowed:
@@ -2450,14 +2453,14 @@ class FederationHandler(BaseHandler):
             state_sets_d = await self.state_store.get_state_groups(
                 event.room_id, extrem_ids
             )
-            state_sets = list(state_sets_d.values())  # type: List[Iterable[EventBase]]
+            state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
             state_sets.append(state)
             current_states = await self.state_handler.resolve_events(
                 room_version, state_sets, event
             )
-            current_state_ids = {
+            current_state_ids: StateMap[str] = {
                 k: e.event_id for k, e in current_states.items()
-            }  # type: StateMap[str]
+            }
         else:
             current_state_ids = await self.state_handler.get_current_state_ids(
                 event.room_id, latest_event_ids=extrem_ids
@@ -2814,7 +2817,7 @@ class FederationHandler(BaseHandler):
         """
         # exclude the state key of the new event from the current_state in the context.
         if event.is_state():
-            event_key = (event.type, event.state_key)  # type: Optional[Tuple[str, str]]
+            event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
         else:
             event_key = None
         state_updates = {
@@ -3031,9 +3034,13 @@ class FederationHandler(BaseHandler):
             await member_handler.send_membership_event(None, event, context)
         else:
             destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
-            await self.federation_client.forward_third_party_invite(
-                destinations, room_id, event_dict
-            )
+
+            try:
+                await self.federation_client.forward_third_party_invite(
+                    destinations, room_id, event_dict
+                )
+            except (RequestSendFailed, HttpResponseException):
+                raise SynapseError(502, "Failed to forward third party invite")
 
     async def on_exchange_third_party_invite_request(
         self, event_dict: JsonDict
@@ -3149,7 +3156,7 @@ class FederationHandler(BaseHandler):
 
         logger.debug("Checking auth on event %r", event.content)
 
-        last_exception = None  # type: Optional[Exception]
+        last_exception: Optional[Exception] = None
 
         # for each public key in the 3pid invite event
         for public_key_object in event_auth.get_public_keys(invite_event):
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 157f2ff218..1a6c5c64a2 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -214,7 +214,7 @@ class GroupsLocalWorkerHandler:
     async def bulk_get_publicised_groups(
         self, user_ids: Iterable[str], proxy: bool = True
     ) -> JsonDict:
-        destinations = {}  # type: Dict[str, Set[str]]
+        destinations: Dict[str, Set[str]] = {}
         local_users = set()
 
         for user_id in user_ids:
@@ -227,7 +227,7 @@ class GroupsLocalWorkerHandler:
             raise SynapseError(400, "Some user_ids are not local")
 
         results = {}
-        failed_results = []  # type: List[str]
+        failed_results: List[str] = []
         for destination, dest_user_ids in destinations.items():
             try:
                 r = await self.transport_client.bulk_get_publicised_groups(
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 33d16fbf9c..0961dec5ab 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -302,7 +302,7 @@ class IdentityHandler(BaseHandler):
             )
 
         url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
-        url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
+        url_bytes = b"/_matrix/identity/api/v1/3pid/unbind"
 
         content = {
             "mxid": mxid,
@@ -695,7 +695,7 @@ class IdentityHandler(BaseHandler):
                 return data["mxid"]
         except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
-        except IOError as e:
+        except OSError as e:
             logger.warning("Error from v1 identity server lookup: %s" % (e,))
 
         return None
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 76242865ae..5d49640760 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -46,9 +46,17 @@ class InitialSyncHandler(BaseHandler):
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
         self.validator = EventValidator()
-        self.snapshot_cache = ResponseCache(
-            hs.get_clock(), "initial_sync_cache"
-        )  # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
+        self.snapshot_cache: ResponseCache[
+            Tuple[
+                str,
+                Optional[StreamToken],
+                Optional[StreamToken],
+                str,
+                Optional[int],
+                bool,
+                bool,
+            ]
+        ] = ResponseCache(hs.get_clock(), "initial_sync_cache")
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 66e40a915d..8a0024ce84 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -81,7 +81,7 @@ class MessageHandler:
 
         # The scheduled call to self._expire_event. None if no call is currently
         # scheduled.
-        self._scheduled_expiry = None  # type: Optional[IDelayedCall]
+        self._scheduled_expiry: Optional[IDelayedCall] = None
 
         if not hs.config.worker_app:
             run_as_background_process(
@@ -196,9 +196,7 @@ class MessageHandler:
                 room_state_events = await self.state_store.get_state_for_events(
                     [event.event_id], state_filter=state_filter
                 )
-                room_state = room_state_events[
-                    event.event_id
-                ]  # type: Mapping[Any, EventBase]
+                room_state: Mapping[Any, EventBase] = room_state_events[event.event_id]
             else:
                 raise AuthError(
                     403,
@@ -421,9 +419,9 @@ class EventCreationHandler:
         self.action_generator = hs.get_action_generator()
 
         self.spam_checker = hs.get_spam_checker()
-        self.third_party_event_rules = (
+        self.third_party_event_rules: "ThirdPartyEventRules" = (
             self.hs.get_third_party_event_rules()
-        )  # type: ThirdPartyEventRules
+        )
 
         self._block_events_without_consent_error = (
             self.config.block_events_without_consent_error
@@ -440,7 +438,7 @@ class EventCreationHandler:
         #
         # map from room id to time-of-last-attempt.
         #
-        self._rooms_to_exclude_from_dummy_event_insertion = {}  # type: Dict[str, int]
+        self._rooms_to_exclude_from_dummy_event_insertion: Dict[str, int] = {}
         # The number of forward extremeities before a dummy event is sent.
         self._dummy_events_threshold = hs.config.dummy_events_threshold
 
@@ -465,9 +463,7 @@ class EventCreationHandler:
         # Stores the state groups we've recently added to the joined hosts
         # external cache. Note that the timeout must be significantly less than
         # the TTL on the external cache.
-        self._external_cache_joined_hosts_updates = (
-            None
-        )  # type: Optional[ExpiringCache]
+        self._external_cache_joined_hosts_updates: Optional[ExpiringCache] = None
         if self._external_cache.is_enabled():
             self._external_cache_joined_hosts_updates = ExpiringCache(
                 "_external_cache_joined_hosts_updates",
@@ -518,6 +514,9 @@ class EventCreationHandler:
             outlier: Indicates whether the event is an `outlier`, i.e. if
                 it's from an arbitrary point and floating in the DAG as
                 opposed to being inline with the current DAG.
+            historical: Indicates whether the message is being inserted
+                back in time around some existing events. This is used to skip
+                a few checks and mark the event as backfilled.
             depth: Override the depth used to order the event in the DAG.
                 Should normally be set to None, which will cause the depth to be calculated
                 based on the prev_events.
@@ -772,6 +771,7 @@ class EventCreationHandler:
         txn_id: Optional[str] = None,
         ignore_shadow_ban: bool = False,
         outlier: bool = False,
+        historical: bool = False,
         depth: Optional[int] = None,
     ) -> Tuple[EventBase, int]:
         """
@@ -799,6 +799,9 @@ class EventCreationHandler:
             outlier: Indicates whether the event is an `outlier`, i.e. if
                 it's from an arbitrary point and floating in the DAG as
                 opposed to being inline with the current DAG.
+            historical: Indicates whether the message is being inserted
+                back in time around some existing events. This is used to skip
+                a few checks and mark the event as backfilled.
             depth: Override the depth used to order the event in the DAG.
                 Should normally be set to None, which will cause the depth to be calculated
                 based on the prev_events.
@@ -847,6 +850,7 @@ class EventCreationHandler:
                 prev_event_ids=prev_event_ids,
                 auth_event_ids=auth_event_ids,
                 outlier=outlier,
+                historical=historical,
                 depth=depth,
             )
 
@@ -945,10 +949,10 @@ class EventCreationHandler:
         if requester:
             context.app_service = requester.app_service
 
-        third_party_result = await self.third_party_event_rules.check_event_allowed(
+        res, new_content = await self.third_party_event_rules.check_event_allowed(
             event, context
         )
-        if not third_party_result:
+        if res is False:
             logger.info(
                 "Event %s forbidden by third-party rules",
                 event,
@@ -956,11 +960,11 @@ class EventCreationHandler:
             raise SynapseError(
                 403, "This event is not allowed in this context", Codes.FORBIDDEN
             )
-        elif isinstance(third_party_result, dict):
+        elif new_content is not None:
             # the third-party rules want to replace the event. We'll need to build a new
             # event.
             event, context = await self._rebuild_event_after_third_party_rules(
-                third_party_result, event
+                new_content, event
             )
 
         self.validator.validate_new(event, self.config)
@@ -1291,7 +1295,7 @@ class EventCreationHandler:
             # Validate a newly added alias or newly added alt_aliases.
 
             original_alias = None
-            original_alt_aliases = []  # type: List[str]
+            original_alt_aliases: List[str] = []
 
             original_event_id = event.unsigned.get("replaces_state")
             if original_event_id:
@@ -1594,11 +1598,13 @@ class EventCreationHandler:
         for k, v in original_event.internal_metadata.get_dict().items():
             setattr(builder.internal_metadata, k, v)
 
-        # the event type hasn't changed, so there's no point in re-calculating the
-        # auth events.
+        # modules can send new state events, so we re-calculate the auth events just in
+        # case.
+        prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
+
         event = await builder.build(
-            prev_event_ids=original_event.prev_event_ids(),
-            auth_event_ids=original_event.auth_event_ids(),
+            prev_event_ids=prev_event_ids,
+            auth_event_ids=None,
         )
 
         # we rebuild the event context, to be on the safe side. If nothing else,
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index ee6e41c0e4..eca8f16040 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -72,26 +72,26 @@ _SESSION_COOKIES = [
     (b"oidc_session_no_samesite", b"HttpOnly"),
 ]
 
+
 #: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
 #: OpenID.Core sec 3.1.3.3.
-Token = TypedDict(
-    "Token",
-    {
-        "access_token": str,
-        "token_type": str,
-        "id_token": Optional[str],
-        "refresh_token": Optional[str],
-        "expires_in": int,
-        "scope": Optional[str],
-    },
-)
+class Token(TypedDict):
+    access_token: str
+    token_type: str
+    id_token: Optional[str]
+    refresh_token: Optional[str]
+    expires_in: int
+    scope: Optional[str]
+
 
 #: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but
 #: there is no real point of doing this in our case.
 JWK = Dict[str, str]
 
+
 #: A JWK Set, as per RFC7517 sec 5.
-JWKS = TypedDict("JWKS", {"keys": List[JWK]})
+class JWKS(TypedDict):
+    keys: List[JWK]
 
 
 class OidcHandler:
@@ -105,9 +105,9 @@ class OidcHandler:
         assert provider_confs
 
         self._token_generator = OidcSessionTokenGenerator(hs)
-        self._providers = {
+        self._providers: Dict[str, "OidcProvider"] = {
             p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
-        }  # type: Dict[str, OidcProvider]
+        }
 
     async def load_metadata(self) -> None:
         """Validate the config and load the metadata from the remote endpoint.
@@ -178,7 +178,7 @@ class OidcHandler:
         # are two.
 
         for cookie_name, _ in _SESSION_COOKIES:
-            session = request.getCookie(cookie_name)  # type: Optional[bytes]
+            session: Optional[bytes] = request.getCookie(cookie_name)
             if session is not None:
                 break
         else:
@@ -255,7 +255,7 @@ class OidcError(Exception):
 
     def __str__(self):
         if self.error_description:
-            return "{}: {}".format(self.error, self.error_description)
+            return f"{self.error}: {self.error_description}"
         return self.error
 
 
@@ -277,7 +277,7 @@ class OidcProvider:
         self._token_generator = token_generator
 
         self._config = provider
-        self._callback_url = hs.config.oidc_callback_url  # type: str
+        self._callback_url: str = hs.config.oidc_callback_url
 
         # Calculate the prefix for OIDC callback paths based on the public_baseurl.
         # We'll insert this into the Path= parameter of any session cookies we set.
@@ -290,7 +290,7 @@ class OidcProvider:
         self._scopes = provider.scopes
         self._user_profile_method = provider.user_profile_method
 
-        client_secret = None  # type: Union[None, str, JwtClientSecret]
+        client_secret: Optional[Union[str, JwtClientSecret]] = None
         if provider.client_secret:
             client_secret = provider.client_secret
         elif provider.client_secret_jwt_key:
@@ -305,7 +305,7 @@ class OidcProvider:
             provider.client_id,
             client_secret,
             provider.client_auth_method,
-        )  # type: ClientAuth
+        )
         self._client_auth_method = provider.client_auth_method
 
         # cache of metadata for the identity provider (endpoint uris, mostly). This is
@@ -324,7 +324,7 @@ class OidcProvider:
         self._allow_existing_users = provider.allow_existing_users
 
         self._http_client = hs.get_proxied_http_client()
-        self._server_name = hs.config.server_name  # type: str
+        self._server_name: str = hs.config.server_name
 
         # identifier for the external_ids table
         self.idp_id = provider.idp_id
@@ -639,7 +639,7 @@ class OidcProvider:
             )
             logger.warning(description)
             # Body was still valid JSON. Might be useful to log it for debugging.
-            logger.warning("Code exchange response: {resp!r}".format(resp=resp))
+            logger.warning("Code exchange response: %r", resp)
             raise OidcError("server_error", description)
 
         return resp
@@ -1217,10 +1217,12 @@ class OidcSessionData:
     ui_auth_session_id = attr.ib(type=str)
 
 
-UserAttributeDict = TypedDict(
-    "UserAttributeDict",
-    {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]},
-)
+class UserAttributeDict(TypedDict):
+    localpart: Optional[str]
+    display_name: Optional[str]
+    emails: List[str]
+
+
 C = TypeVar("C")
 
 
@@ -1381,7 +1383,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
         if display_name == "":
             display_name = None
 
-        emails = []  # type: List[str]
+        emails: List[str] = []
         email = render_template_field(self._config.email_template)
         if email:
             emails.append(email)
@@ -1391,7 +1393,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
         )
 
     async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
-        extras = {}  # type: Dict[str, str]
+        extras: Dict[str, str] = {}
         for key, template in self._config.extra_attributes.items():
             try:
                 extras[key] = template.render(user=userinfo).strip()
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 1e1186c29e..1dbafd253d 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -81,9 +81,9 @@ class PaginationHandler:
         self._server_name = hs.hostname
 
         self.pagination_lock = ReadWriteLock()
-        self._purges_in_progress_by_room = set()  # type: Set[str]
+        self._purges_in_progress_by_room: Set[str] = set()
         # map from purge id to PurgeStatus
-        self._purges_by_id = {}  # type: Dict[str, PurgeStatus]
+        self._purges_by_id: Dict[str, PurgeStatus] = {}
         self._event_serializer = hs.get_event_client_serializer()
 
         self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 44ed7a0712..016c5df2ca 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -378,14 +378,14 @@ class WorkerPresenceHandler(BasePresenceHandler):
 
         # The number of ongoing syncs on this process, by user id.
         # Empty if _presence_enabled is false.
-        self._user_to_num_current_syncs = {}  # type: Dict[str, int]
+        self._user_to_num_current_syncs: Dict[str, int] = {}
 
         self.notifier = hs.get_notifier()
         self.instance_id = hs.get_instance_id()
 
         # user_id -> last_sync_ms. Lists the users that have stopped syncing but
         # we haven't notified the presence writer of that yet
-        self.users_going_offline = {}  # type: Dict[str, int]
+        self.users_going_offline: Dict[str, int] = {}
 
         self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
         self._set_state_client = ReplicationPresenceSetState.make_client(hs)
@@ -650,7 +650,7 @@ class PresenceHandler(BasePresenceHandler):
 
         # Set of users who have presence in the `user_to_current_state` that
         # have not yet been persisted
-        self.unpersisted_users_changes = set()  # type: Set[str]
+        self.unpersisted_users_changes: Set[str] = set()
 
         hs.get_reactor().addSystemEventTrigger(
             "before",
@@ -664,7 +664,7 @@ class PresenceHandler(BasePresenceHandler):
 
         # Keeps track of the number of *ongoing* syncs on this process. While
         # this is non zero a user will never go offline.
-        self.user_to_num_current_syncs = {}  # type: Dict[str, int]
+        self.user_to_num_current_syncs: Dict[str, int] = {}
 
         # Keeps track of the number of *ongoing* syncs on other processes.
         # While any sync is ongoing on another process the user will never
@@ -674,8 +674,8 @@ class PresenceHandler(BasePresenceHandler):
         # we assume that all the sync requests on that process have stopped.
         # Stored as a dict from process_id to set of user_id, and a dict of
         # process_id to millisecond timestamp last updated.
-        self.external_process_to_current_syncs = {}  # type: Dict[str, Set[str]]
-        self.external_process_last_updated_ms = {}  # type: Dict[str, int]
+        self.external_process_to_current_syncs: Dict[str, Set[str]] = {}
+        self.external_process_last_updated_ms: Dict[str, int] = {}
 
         self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
 
@@ -1581,9 +1581,7 @@ class PresenceEventSource:
 
             # The set of users that we're interested in and that have had a presence update.
             # We'll actually pull the presence updates for these users at the end.
-            interested_and_updated_users = (
-                set()
-            )  # type: Union[Set[str], FrozenSet[str]]
+            interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
 
             if from_key:
                 # First get all users that have had a presence update
@@ -1950,8 +1948,8 @@ async def get_interested_parties(
         A 2-tuple of `(room_ids_to_states, users_to_states)`,
         with each item being a dict of `entity_name` -> `[UserPresenceState]`
     """
-    room_ids_to_states = {}  # type: Dict[str, List[UserPresenceState]]
-    users_to_states = {}  # type: Dict[str, List[UserPresenceState]]
+    room_ids_to_states: Dict[str, List[UserPresenceState]] = {}
+    users_to_states: Dict[str, List[UserPresenceState]] = {}
     for state in states:
         room_ids = await store.get_rooms_for_user(state.user_id)
         for room_id in room_ids:
@@ -2063,12 +2061,12 @@ class PresenceFederationQueue:
         # stream_id, destinations, user_ids)`. We don't store the full states
         # for efficiency, and remote workers will already have the full states
         # cached.
-        self._queue = []  # type: List[Tuple[int, int, Collection[str], Set[str]]]
+        self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = []
 
         self._next_id = 1
 
         # Map from instance name to current token
-        self._current_tokens = {}  # type: Dict[str, int]
+        self._current_tokens: Dict[str, int] = {}
 
         if self._queue_presence_updates:
             self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
@@ -2168,7 +2166,7 @@ class PresenceFederationQueue:
         # handle the case where `from_token` stream ID has already been dropped.
         start_idx = max(from_token + 1 - self._next_id, -len(self._queue))
 
-        to_send = []  # type: List[Tuple[int, Tuple[str, str]]]
+        to_send: List[Tuple[int, Tuple[str, str]]] = []
         limited = False
         new_id = upto_token
         for _, stream_id, destinations, user_ids in self._queue[start_idx:]:
@@ -2216,7 +2214,7 @@ class PresenceFederationQueue:
         if not self._federation:
             return
 
-        hosts_to_users = {}  # type: Dict[str, Set[str]]
+        hosts_to_users: Dict[str, Set[str]] = {}
         for row in rows:
             hosts_to_users.setdefault(row.destination, set()).add(row.user_id)
 
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 05b4a97b59..20a033d0ba 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -197,7 +197,7 @@ class ProfileHandler(BaseHandler):
                 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
             )
 
-        displayname_to_set = new_displayname  # type: Optional[str]
+        displayname_to_set: Optional[str] = new_displayname
         if new_displayname == "":
             displayname_to_set = None
 
@@ -286,7 +286,7 @@ class ProfileHandler(BaseHandler):
                 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
             )
 
-        avatar_url_to_set = new_avatar_url  # type: Optional[str]
+        avatar_url_to_set: Optional[str] = new_avatar_url
         if new_avatar_url == "":
             avatar_url_to_set = None
 
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index f782d9db32..283483fc2c 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -30,6 +30,8 @@ class ReceiptsHandler(BaseHandler):
 
         self.server_name = hs.config.server_name
         self.store = hs.get_datastore()
+        self.event_auth_handler = hs.get_event_auth_handler()
+
         self.hs = hs
 
         # We only need to poke the federation sender explicitly if its on the
@@ -59,6 +61,19 @@ class ReceiptsHandler(BaseHandler):
         """Called when we receive an EDU of type m.receipt from a remote HS."""
         receipts = []
         for room_id, room_values in content.items():
+            # If we're not in the room just ditch the event entirely. This is
+            # probably an old server that has come back and thinks we're still in
+            # the room (or we've been rejoined to the room by a state reset).
+            is_in_room = await self.event_auth_handler.check_host_in_room(
+                room_id, self.server_name
+            )
+            if not is_in_room:
+                logger.info(
+                    "Ignoring receipt from %s as we're not in the room",
+                    origin,
+                )
+                continue
+
             for receipt_type, users in room_values.items():
                 for user_id, user_values in users.items():
                     if get_domain_from_id(user_id) != origin:
@@ -83,8 +98,8 @@ class ReceiptsHandler(BaseHandler):
 
     async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
         """Takes a list of receipts, stores them and informs the notifier."""
-        min_batch_id = None  # type: Optional[int]
-        max_batch_id = None  # type: Optional[int]
+        min_batch_id: Optional[int] = None
+        max_batch_id: Optional[int] = None
 
         for receipt in receipts:
             res = await self.store.insert_receipt(
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 26ef016179..8cf614136e 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -55,15 +55,12 @@ login_counter = Counter(
     ["guest", "auth_provider"],
 )
 
-LoginDict = TypedDict(
-    "LoginDict",
-    {
-        "device_id": str,
-        "access_token": str,
-        "valid_until_ms": Optional[int],
-        "refresh_token": Optional[str],
-    },
-)
+
+class LoginDict(TypedDict):
+    device_id: str
+    access_token: str
+    valid_until_ms: Optional[int]
+    refresh_token: Optional[str]
 
 
 class RegistrationHandler(BaseHandler):
@@ -77,6 +74,7 @@ class RegistrationHandler(BaseHandler):
         self.identity_handler = self.hs.get_identity_handler()
         self.ratelimiter = hs.get_registration_ratelimiter()
         self.macaroon_gen = hs.get_macaroon_generator()
+        self._account_validity_handler = hs.get_account_validity_handler()
         self._server_notices_mxid = hs.config.server_notices_mxid
         self._server_name = hs.hostname
 
@@ -700,6 +698,10 @@ class RegistrationHandler(BaseHandler):
                 shadow_banned=shadow_banned,
             )
 
+            # Only call the account validity module(s) on the main process, to avoid
+            # repeating e.g. database writes on all of the workers.
+            await self._account_validity_handler.on_user_registration(user_id)
+
     async def register_device(
         self,
         user_id: str,
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 579b1b93c5..370561e549 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -87,7 +87,7 @@ class RoomCreationHandler(BaseHandler):
         self.config = hs.config
 
         # Room state based off defined presets
-        self._presets_dict = {
+        self._presets_dict: Dict[str, Dict[str, Any]] = {
             RoomCreationPreset.PRIVATE_CHAT: {
                 "join_rules": JoinRules.INVITE,
                 "history_visibility": HistoryVisibility.SHARED,
@@ -109,7 +109,7 @@ class RoomCreationHandler(BaseHandler):
                 "guest_can_join": False,
                 "power_level_content_override": {},
             },
-        }  # type: Dict[str, Dict[str, Any]]
+        }
 
         # Modify presets to selectively enable encryption by default per homeserver config
         for preset_name, preset_config in self._presets_dict.items():
@@ -127,9 +127,9 @@ class RoomCreationHandler(BaseHandler):
         # If a user tries to update the same room multiple times in quick
         # succession, only process the first attempt and return its result to
         # subsequent requests
-        self._upgrade_response_cache = ResponseCache(
+        self._upgrade_response_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
             hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
-        )  # type: ResponseCache[Tuple[str, str]]
+        )
         self._server_notices_mxid = hs.config.server_notices_mxid
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
@@ -377,10 +377,10 @@ class RoomCreationHandler(BaseHandler):
         if not await self.spam_checker.user_may_create_room(user_id):
             raise SynapseError(403, "You are not permitted to create rooms")
 
-        creation_content = {
+        creation_content: JsonDict = {
             "room_version": new_room_version.identifier,
             "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
-        }  # type: JsonDict
+        }
 
         # Check if old room was non-federatable
 
@@ -618,15 +618,11 @@ class RoomCreationHandler(BaseHandler):
         else:
             is_requester_admin = await self.auth.is_server_admin(requester.user)
 
-        # Check whether the third party rules allows/changes the room create
-        # request.
-        event_allowed = await self.third_party_event_rules.on_create_room(
+        # Let the third party rules modify the room creation config if needed, or abort
+        # the room creation entirely with an exception.
+        await self.third_party_event_rules.on_create_room(
             requester, config, is_requester_admin=is_requester_admin
         )
-        if not event_allowed:
-            raise SynapseError(
-                403, "You are not permitted to create rooms", Codes.FORBIDDEN
-            )
 
         if not is_requester_admin and not await self.spam_checker.user_may_create_room(
             user_id
@@ -936,7 +932,7 @@ class RoomCreationHandler(BaseHandler):
                 etype=EventTypes.PowerLevels, content=pl_content
             )
         else:
-            power_level_content = {
+            power_level_content: JsonDict = {
                 "users": {creator_id: 100},
                 "users_default": 0,
                 "events": {
@@ -955,7 +951,7 @@ class RoomCreationHandler(BaseHandler):
                 "kick": 50,
                 "redact": 50,
                 "invite": 50,
-            }  # type: JsonDict
+            }
 
             if config["original_invitees_have_ops"]:
                 for invitee in invite_list:
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 5e3ef7ce3a..fae2c098e3 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -20,7 +20,12 @@ import msgpack
 from unpaddedbase64 import decode_base64, encode_base64
 
 from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
-from synapse.api.errors import Codes, HttpResponseException
+from synapse.api.errors import (
+    Codes,
+    HttpResponseException,
+    RequestSendFailed,
+    SynapseError,
+)
 from synapse.types import JsonDict, ThirdPartyInstanceID
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.response_cache import ResponseCache
@@ -42,12 +47,12 @@ class RoomListHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.enable_room_list_search = hs.config.enable_room_list_search
-        self.response_cache = ResponseCache(
-            hs.get_clock(), "room_list"
-        )  # type: ResponseCache[Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]]
-        self.remote_response_cache = ResponseCache(
-            hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
-        )  # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
+        self.response_cache: ResponseCache[
+            Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]
+        ] = ResponseCache(hs.get_clock(), "room_list")
+        self.remote_response_cache: ResponseCache[
+            Tuple[str, Optional[int], Optional[str], bool, Optional[str]]
+        ] = ResponseCache(hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000)
 
     async def get_local_public_room_list(
         self,
@@ -134,10 +139,10 @@ class RoomListHandler(BaseHandler):
         if since_token:
             batch_token = RoomListNextBatch.from_token(since_token)
 
-            bounds = (
+            bounds: Optional[Tuple[int, str]] = (
                 batch_token.last_joined_members,
                 batch_token.last_room_id,
-            )  # type: Optional[Tuple[int, str]]
+            )
             forwards = batch_token.direction_is_forward
             has_batch_token = True
         else:
@@ -177,7 +182,7 @@ class RoomListHandler(BaseHandler):
 
         results = [build_room_entry(r) for r in results]
 
-        response = {}  # type: JsonDict
+        response: JsonDict = {}
         num_results = len(results)
         if limit is not None:
             more_to_come = num_results == probing_limit
@@ -378,7 +383,11 @@ class RoomListHandler(BaseHandler):
                 ):
                     logger.debug("Falling back to locally-filtered /publicRooms")
                 else:
-                    raise  # Not an error that should trigger a fallback.
+                    # Not an error that should trigger a fallback.
+                    raise SynapseError(502, "Failed to fetch room list")
+            except RequestSendFailed:
+                # Not an error that should trigger a fallback.
+                raise SynapseError(502, "Failed to fetch room list")
 
             # if we reach this point, then we fall back to the situation where
             # we currently don't support searching across federation, so we have
@@ -417,14 +426,17 @@ class RoomListHandler(BaseHandler):
         repl_layer = self.hs.get_federation_client()
         if search_filter:
             # We can't cache when asking for search
-            return await repl_layer.get_public_rooms(
-                server_name,
-                limit=limit,
-                since_token=since_token,
-                search_filter=search_filter,
-                include_all_networks=include_all_networks,
-                third_party_instance_id=third_party_instance_id,
-            )
+            try:
+                return await repl_layer.get_public_rooms(
+                    server_name,
+                    limit=limit,
+                    since_token=since_token,
+                    search_filter=search_filter,
+                    include_all_networks=include_all_networks,
+                    third_party_instance_id=third_party_instance_id,
+                )
+            except (RequestSendFailed, HttpResponseException):
+                raise SynapseError(502, "Failed to fetch room list")
 
         key = (
             server_name,
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
index 80ba65b9e0..e6e71e9729 100644
--- a/synapse/handlers/saml.py
+++ b/synapse/handlers/saml.py
@@ -83,7 +83,7 @@ class SamlHandler(BaseHandler):
         self.unstable_idp_brand = None
 
         # a map from saml session id to Saml2SessionData object
-        self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
+        self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {}
 
         self._sso_handler = hs.get_sso_handler()
         self._sso_handler.register_identity_provider(self)
@@ -372,7 +372,7 @@ class SamlHandler(BaseHandler):
 
 
 DOT_REPLACE_PATTERN = re.compile(
-    ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+    "[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)
 )
 
 
@@ -386,10 +386,10 @@ def dot_replace_for_mxid(username: str) -> str:
     return username
 
 
-MXID_MAPPER_MAP = {
+MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = {
     "hexencode": map_username_to_mxid_localpart,
     "dotreplace": dot_replace_for_mxid,
-}  # type: Dict[str, Callable[[str], str]]
+}
 
 
 @attr.s
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 4e718d3f63..8226d6f5a1 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -192,7 +192,7 @@ class SearchHandler(BaseHandler):
         # If doing a subset of all rooms seearch, check if any of the rooms
         # are from an upgraded room, and search their contents as well
         if search_filter.rooms:
-            historical_room_ids = []  # type: List[str]
+            historical_room_ids: List[str] = []
             for room_id in search_filter.rooms:
                 # Add any previous rooms to the search if they exist
                 ids = await self.get_old_rooms_from_upgraded_room(room_id)
@@ -216,9 +216,9 @@ class SearchHandler(BaseHandler):
         rank_map = {}  # event_id -> rank of event
         allowed_events = []
         # Holds result of grouping by room, if applicable
-        room_groups = {}  # type: Dict[str, JsonDict]
+        room_groups: Dict[str, JsonDict] = {}
         # Holds result of grouping by sender, if applicable
-        sender_group = {}  # type: Dict[str, JsonDict]
+        sender_group: Dict[str, JsonDict] = {}
 
         # Holds the next_batch for the entire result set if one of those exists
         global_next_batch = None
@@ -262,7 +262,7 @@ class SearchHandler(BaseHandler):
                 s["results"].append(e.event_id)
 
         elif order_by == "recent":
-            room_events = []  # type: List[EventBase]
+            room_events: List[EventBase] = []
             i = 0
 
             pagination_token = batch_token
diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
index b585057ec3..5f7d4602bd 100644
--- a/synapse/handlers/space_summary.py
+++ b/synapse/handlers/space_summary.py
@@ -24,6 +24,7 @@ from synapse.api.constants import (
     EventContentFields,
     EventTypes,
     HistoryVisibility,
+    JoinRules,
     Membership,
     RoomTypes,
 )
@@ -89,14 +90,14 @@ class SpaceSummaryHandler:
         room_queue = deque((_RoomQueueEntry(room_id, ()),))
 
         # rooms we have already processed
-        processed_rooms = set()  # type: Set[str]
+        processed_rooms: Set[str] = set()
 
         # events we have already processed. We don't necessarily have their event ids,
         # so instead we key on (room id, state key)
-        processed_events = set()  # type: Set[Tuple[str, str]]
+        processed_events: Set[Tuple[str, str]] = set()
 
-        rooms_result = []  # type: List[JsonDict]
-        events_result = []  # type: List[JsonDict]
+        rooms_result: List[JsonDict] = []
+        events_result: List[JsonDict] = []
 
         while room_queue and len(rooms_result) < MAX_ROOMS:
             queue_entry = room_queue.popleft()
@@ -150,14 +151,21 @@ class SpaceSummaryHandler:
                     # The room should only be included in the summary if:
                     #     a. the user is in the room;
                     #     b. the room is world readable; or
-                    #     c. the user is in a space that has been granted access to
-                    #        the room.
+                    #     c. the user could join the room, e.g. the join rules
+                    #        are set to public or the user is in a space that
+                    #        has been granted access to the room.
                     #
                     # Note that we know the user is not in the root room (which is
                     # why the remote call was made in the first place), but the user
                     # could be in one of the children rooms and we just didn't know
                     # about the link.
-                    include_room = room.get("world_readable") is True
+
+                    # The API doesn't return the room version so assume that a
+                    # join rule of knock is valid.
+                    include_room = (
+                        room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK)
+                        or room.get("world_readable") is True
+                    )
 
                     # Check if the user is a member of any of the allowed spaces
                     # from the response.
@@ -264,10 +272,10 @@ class SpaceSummaryHandler:
         # the set of rooms that we should not walk further. Initialise it with the
         # excluded-rooms list; we will add other rooms as we process them so that
         # we do not loop.
-        processed_rooms = set(exclude_rooms)  # type: Set[str]
+        processed_rooms: Set[str] = set(exclude_rooms)
 
-        rooms_result = []  # type: List[JsonDict]
-        events_result = []  # type: List[JsonDict]
+        rooms_result: List[JsonDict] = []
+        events_result: List[JsonDict] = []
 
         while room_queue and len(rooms_result) < MAX_ROOMS:
             room_id = room_queue.popleft()
@@ -345,7 +353,7 @@ class SpaceSummaryHandler:
             max_children = MAX_ROOMS_PER_SPACE
 
         now = self._clock.time_msec()
-        events_result = []  # type: List[JsonDict]
+        events_result: List[JsonDict] = []
         for edge_event in itertools.islice(child_events, max_children):
             events_result.append(
                 await self._event_serializer.serialize_event(
@@ -420,9 +428,8 @@ class SpaceSummaryHandler:
 
         It should be included if:
 
-        * The requester is joined or invited to the room.
-        * The requester can join without an invite (per MSC3083).
-        * The origin server has any user that is joined or invited to the room.
+        * The requester is joined or can join the room (per MSC3173).
+        * The origin server has any user that is joined or can join the room.
         * The history visibility is set to world readable.
 
         Args:
@@ -441,13 +448,39 @@ class SpaceSummaryHandler:
 
         # If there's no state for the room, it isn't known.
         if not state_ids:
+            # The user might have a pending invite for the room.
+            if requester and await self._store.get_invite_for_local_user_in_room(
+                requester, room_id
+            ):
+                return True
+
             logger.info("room %s is unknown, omitting from summary", room_id)
             return False
 
         room_version = await self._store.get_room_version(room_id)
 
-        # if we have an authenticated requesting user, first check if they are able to view
-        # stripped state in the room.
+        # Include the room if it has join rules of public or knock.
+        join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""))
+        if join_rules_event_id:
+            join_rules_event = await self._store.get_event(join_rules_event_id)
+            join_rule = join_rules_event.content.get("join_rule")
+            if join_rule == JoinRules.PUBLIC or (
+                room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
+            ):
+                return True
+
+        # Include the room if it is peekable.
+        hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, ""))
+        if hist_vis_event_id:
+            hist_vis_ev = await self._store.get_event(hist_vis_event_id)
+            hist_vis = hist_vis_ev.content.get("history_visibility")
+            if hist_vis == HistoryVisibility.WORLD_READABLE:
+                return True
+
+        # Otherwise we need to check information specific to the user or server.
+
+        # If we have an authenticated requesting user, check if they are a member
+        # of the room (or can join the room).
         if requester:
             member_event_id = state_ids.get((EventTypes.Member, requester), None)
 
@@ -470,9 +503,11 @@ class SpaceSummaryHandler:
                     return True
 
         # If this is a request over federation, check if the host is in the room or
-        # is in one of the spaces specified via the join rules.
+        # has a user who could join the room.
         elif origin:
-            if await self._event_auth_handler.check_host_in_room(room_id, origin):
+            if await self._event_auth_handler.check_host_in_room(
+                room_id, origin
+            ) or await self._store.is_host_invited(room_id, origin):
                 return True
 
             # Alternately, if the host has a user in any of the spaces specified
@@ -490,18 +525,10 @@ class SpaceSummaryHandler:
                     ):
                         return True
 
-        # otherwise, check if the room is peekable
-        hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, ""), None)
-        if hist_vis_event_id:
-            hist_vis_ev = await self._store.get_event(hist_vis_event_id)
-            hist_vis = hist_vis_ev.content.get("history_visibility")
-            if hist_vis == HistoryVisibility.WORLD_READABLE:
-                return True
-
         logger.info(
-            "room %s is unpeekable and user %s is not a member / not allowed to join, omitting from summary",
+            "room %s is unpeekable and requester %s is not a member / not allowed to join, omitting from summary",
             room_id,
-            requester,
+            requester or origin,
         )
         return False
 
@@ -535,6 +562,7 @@ class SpaceSummaryHandler:
             "canonical_alias": stats["canonical_alias"],
             "num_joined_members": stats["joined_members"],
             "avatar_url": stats["avatar"],
+            "join_rules": stats["join_rules"],
             "world_readable": (
                 stats["history_visibility"] == HistoryVisibility.WORLD_READABLE
             ),
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 0b297e54c4..1b855a685c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -202,10 +202,10 @@ class SsoHandler:
         self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
 
         # a map from session id to session data
-        self._username_mapping_sessions = {}  # type: Dict[str, UsernameMappingSession]
+        self._username_mapping_sessions: Dict[str, UsernameMappingSession] = {}
 
         # map from idp_id to SsoIdentityProvider
-        self._identity_providers = {}  # type: Dict[str, SsoIdentityProvider]
+        self._identity_providers: Dict[str, SsoIdentityProvider] = {}
 
         self._consent_at_registration = hs.config.consent.user_consent_at_registration
 
@@ -296,7 +296,7 @@ class SsoHandler:
             )
 
         # if the client chose an IdP, use that
-        idp = None  # type: Optional[SsoIdentityProvider]
+        idp: Optional[SsoIdentityProvider] = None
         if idp_id:
             idp = self._identity_providers.get(idp_id)
             if not idp:
@@ -669,9 +669,9 @@ class SsoHandler:
             remote_user_id,
         )
 
-        user_id_to_verify = await self._auth_handler.get_session_data(
+        user_id_to_verify: str = await self._auth_handler.get_session_data(
             ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
-        )  # type: str
+        )
 
         if not user_id:
             logger.warning(
@@ -793,7 +793,7 @@ class SsoHandler:
         session.use_display_name = use_display_name
 
         emails_from_idp = set(session.emails)
-        filtered_emails = set()  # type: Set[str]
+        filtered_emails: Set[str] = set()
 
         # we iterate through the list rather than just building a set conjunction, so
         # that we can log attempts to use unknown addresses
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 4e45d1da57..3fd89af2a4 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -45,12 +45,11 @@ class StatsHandler:
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
         self.is_mine_id = hs.is_mine_id
-        self.stats_bucket_size = hs.config.stats_bucket_size
 
         self.stats_enabled = hs.config.stats_enabled
 
         # The current position in the current_state_delta stream
-        self.pos = None  # type: Optional[int]
+        self.pos: Optional[int] = None
 
         # Guard to ensure we only process deltas one at a time
         self._is_processing = False
@@ -106,20 +105,6 @@ class StatsHandler:
                 room_deltas = {}
                 user_deltas = {}
 
-            # Then count deltas for total_events and total_event_bytes.
-            (
-                room_count,
-                user_count,
-            ) = await self.store.get_changes_room_total_events_and_bytes(
-                self.pos, max_pos
-            )
-
-            for room_id, fields in room_count.items():
-                room_deltas.setdefault(room_id, Counter()).update(fields)
-
-            for user_id, fields in user_count.items():
-                user_deltas.setdefault(user_id, Counter()).update(fields)
-
             logger.debug("room_deltas: %s", room_deltas)
             logger.debug("user_deltas: %s", user_deltas)
 
@@ -146,10 +131,10 @@ class StatsHandler:
             mapping from room/user ID to changes in the various fields.
         """
 
-        room_to_stats_deltas = {}  # type: Dict[str, CounterType[str]]
-        user_to_stats_deltas = {}  # type: Dict[str, CounterType[str]]
+        room_to_stats_deltas: Dict[str, CounterType[str]] = {}
+        user_to_stats_deltas: Dict[str, CounterType[str]] = {}
 
-        room_to_state_updates = {}  # type: Dict[str, Dict[str, Any]]
+        room_to_state_updates: Dict[str, Dict[str, Any]] = {}
 
         for delta in deltas:
             typ = delta["type"]
@@ -179,14 +164,12 @@ class StatsHandler:
                 )
                 continue
 
-            event_content = {}  # type: JsonDict
+            event_content: JsonDict = {}
 
-            sender = None
             if event_id is not None:
                 event = await self.store.get_event(event_id, allow_none=True)
                 if event:
                     event_content = event.content or {}
-                    sender = event.sender
 
             # All the values in this dict are deltas (RELATIVE changes)
             room_stats_delta = room_to_stats_deltas.setdefault(room_id, Counter())
@@ -244,12 +227,6 @@ class StatsHandler:
                     room_stats_delta["joined_members"] += 1
                 elif membership == Membership.INVITE:
                     room_stats_delta["invited_members"] += 1
-
-                    if sender and self.is_mine_id(sender):
-                        user_to_stats_deltas.setdefault(sender, Counter())[
-                            "invites_sent"
-                        ] += 1
-
                 elif membership == Membership.LEAVE:
                     room_stats_delta["left_members"] += 1
                 elif membership == Membership.BAN:
@@ -279,10 +256,6 @@ class StatsHandler:
                 room_state["is_federatable"] = (
                     event_content.get("m.federate", True) is True
                 )
-                if sender and self.is_mine_id(sender):
-                    user_to_stats_deltas.setdefault(sender, Counter())[
-                        "rooms_created"
-                    ] += 1
             elif typ == EventTypes.JoinRules:
                 room_state["join_rules"] = event_content.get("join_rule")
             elif typ == EventTypes.RoomHistoryVisibility:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b9a0361059..f30bfcc93c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -278,12 +278,14 @@ class SyncHandler:
         self.state_store = self.storage.state
 
         # ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
-        self.lazy_loaded_members_cache = ExpiringCache(
+        self.lazy_loaded_members_cache: ExpiringCache[
+            Tuple[str, Optional[str]], LruCache[str, str]
+        ] = ExpiringCache(
             "lazy_loaded_members_cache",
             self.clock,
             max_len=0,
             expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
-        )  # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]
+        )
 
     async def wait_for_sync_for_user(
         self,
@@ -440,7 +442,7 @@ class SyncHandler:
             )
             now_token = now_token.copy_and_replace("typing_key", typing_key)
 
-            ephemeral_by_room = {}  # type: JsonDict
+            ephemeral_by_room: JsonDict = {}
 
             for event in typing:
                 # we want to exclude the room_id from the event, but modifying the
@@ -502,7 +504,7 @@ class SyncHandler:
                 # We check if there are any state events, if there are then we pass
                 # all current state events to the filter_events function. This is to
                 # ensure that we always include current state in the timeline
-                current_state_ids = frozenset()  # type: FrozenSet[str]
+                current_state_ids: FrozenSet[str] = frozenset()
                 if any(e.is_state() for e in recents):
                     current_state_ids_map = await self.store.get_current_state_ids(
                         room_id
@@ -783,9 +785,9 @@ class SyncHandler:
     def get_lazy_loaded_members_cache(
         self, cache_key: Tuple[str, Optional[str]]
     ) -> LruCache[str, str]:
-        cache = self.lazy_loaded_members_cache.get(
+        cache: Optional[LruCache[str, str]] = self.lazy_loaded_members_cache.get(
             cache_key
-        )  # type: Optional[LruCache[str, str]]
+        )
         if cache is None:
             logger.debug("creating LruCache for %r", cache_key)
             cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
@@ -984,7 +986,7 @@ class SyncHandler:
                     if t[0] == EventTypes.Member:
                         cache.set(t[1], event_id)
 
-        state = {}  # type: Dict[str, EventBase]
+        state: Dict[str, EventBase] = {}
         if state_ids:
             state = await self.store.get_events(list(state_ids.values()))
 
@@ -1088,9 +1090,13 @@ class SyncHandler:
 
         logger.debug("Fetching OTK data")
         device_id = sync_config.device_id
-        one_time_key_counts = {}  # type: JsonDict
-        unused_fallback_key_types = []  # type: List[str]
+        one_time_key_counts: JsonDict = {}
+        unused_fallback_key_types: List[str] = []
         if device_id:
+            # TODO: We should have a way to let clients differentiate between the states of:
+            #   * no change in OTK count since the provided since token
+            #   * the server has zero OTKs left for this device
+            #  Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
             one_time_key_counts = await self.store.count_e2e_one_time_keys(
                 user_id, device_id
             )
@@ -1437,7 +1443,7 @@ class SyncHandler:
         )
 
         if block_all_room_ephemeral:
-            ephemeral_by_room = {}  # type: Dict[str, List[JsonDict]]
+            ephemeral_by_room: Dict[str, List[JsonDict]] = {}
         else:
             now_token, ephemeral_by_room = await self.ephemeral_by_room(
                 sync_result_builder,
@@ -1468,7 +1474,7 @@ class SyncHandler:
 
         # If there is ignored users account data and it matches the proper type,
         # then use it.
-        ignored_users = frozenset()  # type: FrozenSet[str]
+        ignored_users: FrozenSet[str] = frozenset()
         if ignored_account_data:
             ignored_users_data = ignored_account_data.get("ignored_users", {})
             if isinstance(ignored_users_data, dict):
@@ -1586,7 +1592,7 @@ class SyncHandler:
             user_id, since_token.room_key, now_token.room_key
         )
 
-        mem_change_events_by_room_id = {}  # type: Dict[str, List[EventBase]]
+        mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
         for event in rooms_changed:
             mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
 
@@ -1599,7 +1605,7 @@ class SyncHandler:
             logger.debug(
                 "Membership changes in %s: [%s]",
                 room_id,
-                ", ".join(("%s (%s)" % (e.event_id, e.membership) for e in events)),
+                ", ".join("%s (%s)" % (e.event_id, e.membership) for e in events),
             )
 
             non_joins = [e for e in events if e.membership != Membership.JOIN]
@@ -1722,7 +1728,7 @@ class SyncHandler:
                 # This is all screaming out for a refactor, as the logic here is
                 # subtle and the moving parts numerous.
                 if leave_event.internal_metadata.is_out_of_band_membership():
-                    batch_events = [leave_event]  # type: Optional[List[EventBase]]
+                    batch_events: Optional[List[EventBase]] = [leave_event]
                 else:
                     batch_events = None
 
@@ -1971,7 +1977,7 @@ class SyncHandler:
             room_id, batch, sync_config, since_token, now_token, full_state=full_state
         )
 
-        summary = {}  # type: Optional[JsonDict]
+        summary: Optional[JsonDict] = {}
 
         # we include a summary in room responses when we're lazy loading
         # members (as the client otherwise doesn't have enough info to form
@@ -1995,7 +2001,7 @@ class SyncHandler:
             )
 
         if room_builder.rtype == "joined":
-            unread_notifications = {}  # type: Dict[str, int]
+            unread_notifications: Dict[str, int] = {}
             room_sync = JoinedSyncResult(
                 room_id=room_id,
                 timeline=batch,
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index e22393adc4..0cb651a400 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -68,11 +68,11 @@ class FollowerTypingHandler:
             )
 
         # map room IDs to serial numbers
-        self._room_serials = {}  # type: Dict[str, int]
+        self._room_serials: Dict[str, int] = {}
         # map room IDs to sets of users currently typing
-        self._room_typing = {}  # type: Dict[str, Set[str]]
+        self._room_typing: Dict[str, Set[str]] = {}
 
-        self._member_last_federation_poke = {}  # type: Dict[RoomMember, int]
+        self._member_last_federation_poke: Dict[RoomMember, int] = {}
         self.wheel_timer = WheelTimer(bucket_size=5000)
         self._latest_room_serial = 0
 
@@ -208,6 +208,7 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         self.auth = hs.get_auth()
         self.notifier = hs.get_notifier()
+        self.event_auth_handler = hs.get_event_auth_handler()
 
         self.hs = hs
 
@@ -216,7 +217,7 @@ class TypingWriterHandler(FollowerTypingHandler):
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
         # clock time we expect to stop
-        self._member_typing_until = {}  # type: Dict[RoomMember, int]
+        self._member_typing_until: Dict[RoomMember, int] = {}
 
         # caches which room_ids changed at which serials
         self._typing_stream_change_cache = StreamChangeCache(
@@ -326,6 +327,19 @@ class TypingWriterHandler(FollowerTypingHandler):
         room_id = content["room_id"]
         user_id = content["user_id"]
 
+        # If we're not in the room just ditch the event entirely. This is
+        # probably an old server that has come back and thinks we're still in
+        # the room (or we've been rejoined to the room by a state reset).
+        is_in_room = await self.event_auth_handler.check_host_in_room(
+            room_id, self.server_name
+        )
+        if not is_in_room:
+            logger.info(
+                "Ignoring typing update from %s as we're not in the room",
+                origin,
+            )
+            return
+
         member = RoomMember(user_id=user_id, room_id=room_id)
 
         # Check that the string is a valid user id
@@ -391,9 +405,9 @@ class TypingWriterHandler(FollowerTypingHandler):
         if last_id == current_id:
             return [], current_id, False
 
-        changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
-            last_id
-        )  # type: Optional[Iterable[str]]
+        changed_rooms: Optional[
+            Iterable[str]
+        ] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
 
         if changed_rooms is None:
             changed_rooms = self._room_serials
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index dacc4f3076..6edb1da50a 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -52,7 +52,7 @@ class UserDirectoryHandler(StateDeltasHandler):
         self.search_all_users = hs.config.user_directory_search_all_users
         self.spam_checker = hs.get_spam_checker()
         # The current position in the current_state_delta stream
-        self.pos = None  # type: Optional[int]
+        self.pos: Optional[int] = None
 
         # Guard to ensure we only process deltas one at a time
         self._is_processing = False
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index ed4671b7de..578fc48ef4 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -69,7 +69,7 @@ def _get_requested_host(request: IRequest) -> bytes:
         return hostname
 
     # no Host header, use the address/port that the request arrived on
-    host = request.getHost()  # type: Union[address.IPv4Address, address.IPv6Address]
+    host: Union[address.IPv4Address, address.IPv6Address] = request.getHost()
 
     hostname = host.host.encode("ascii")
 
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 1ca6624fd5..2ac76b15c2 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -160,7 +160,7 @@ class _IPBlacklistingResolver:
     def resolveHostName(
         self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
     ) -> IResolutionReceiver:
-        addresses = []  # type: List[IAddress]
+        addresses: List[IAddress] = []
 
         def _callback() -> None:
             has_bad_ip = False
@@ -333,9 +333,9 @@ class SimpleHttpClient:
         if self._ip_blacklist:
             # If we have an IP blacklist, we need to use a DNS resolver which
             # filters out blacklisted IP addresses, to prevent DNS rebinding.
-            self.reactor = BlacklistingReactorWrapper(
+            self.reactor: ISynapseReactor = BlacklistingReactorWrapper(
                 hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
-            )  # type: ISynapseReactor
+            )
         else:
             self.reactor = hs.get_reactor()
 
@@ -349,14 +349,14 @@ class SimpleHttpClient:
         pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5))
         pool.cachedConnectionTimeout = 2 * 60
 
-        self.agent = ProxyAgent(
+        self.agent: IAgent = ProxyAgent(
             self.reactor,
             hs.get_reactor(),
             connectTimeout=15,
             contextFactory=self.hs.get_http_client_context_factory(),
             pool=pool,
             use_proxy=use_proxy,
-        )  # type: IAgent
+        )
 
         if self._ip_blacklist:
             # If we have an IP blacklist, we then install the blacklisting Agent
@@ -411,7 +411,7 @@ class SimpleHttpClient:
                         cooperator=self._cooperator,
                     )
 
-                request_deferred = treq.request(
+                request_deferred: defer.Deferred = treq.request(
                     method,
                     uri,
                     agent=self.agent,
@@ -421,7 +421,7 @@ class SimpleHttpClient:
                     # response bodies.
                     unbuffered=True,
                     **self._extra_treq_args,
-                )  # type: defer.Deferred
+                )
 
                 # we use our own timeout mechanism rather than treq's as a workaround
                 # for https://twistedmatrix.com/trac/ticket/9534.
@@ -772,7 +772,7 @@ class BodyExceededMaxSize(Exception):
 class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
     """A protocol which immediately errors upon receiving data."""
 
-    transport = None  # type: Optional[ITCPTransport]
+    transport: Optional[ITCPTransport] = None
 
     def __init__(self, deferred: defer.Deferred):
         self.deferred = deferred
@@ -798,7 +798,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
 class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
     """A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
 
-    transport = None  # type: Optional[ITCPTransport]
+    transport: Optional[ITCPTransport] = None
 
     def __init__(
         self, stream: ByteWriteable, deferred: defer.Deferred, max_size: Optional[int]
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 20d39a4ea6..43f2140429 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -70,10 +70,8 @@ WELL_KNOWN_RETRY_ATTEMPTS = 3
 logger = logging.getLogger(__name__)
 
 
-_well_known_cache = TTLCache("well-known")  # type: TTLCache[bytes, Optional[bytes]]
-_had_valid_well_known_cache = TTLCache(
-    "had-valid-well-known"
-)  # type: TTLCache[bytes, bool]
+_well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
+_had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")
 
 
 @attr.s(slots=True, frozen=True)
@@ -130,9 +128,10 @@ class WellKnownResolver:
         # requests for the same server in parallel?
         try:
             with Measure(self._clock, "get_well_known"):
-                result, cache_period = await self._fetch_well_known(
-                    server_name
-                )  # type: Optional[bytes], float
+                result: Optional[bytes]
+                cache_period: float
+
+                result, cache_period = await self._fetch_well_known(server_name)
 
         except _FetchWellKnownFailure as e:
             if prev_result and e.temporary:
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index b8849c0150..2efa15bf04 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -43,6 +43,7 @@ from twisted.internet import defer
 from twisted.internet.error import DNSLookupError
 from twisted.internet.interfaces import IReactorTime
 from twisted.internet.task import _EPSILON, Cooperator
+from twisted.web.client import ResponseFailed
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IBodyProducer, IResponse
 
@@ -105,7 +106,7 @@ class ByteParser(ByteWriteable, Generic[T], abc.ABC):
     the parsed data.
     """
 
-    CONTENT_TYPE = abc.abstractproperty()  # type: str  # type: ignore
+    CONTENT_TYPE: str = abc.abstractproperty()  # type: ignore
     """The expected content type of the response, e.g. `application/json`. If
     the content type doesn't match we fail the request.
     """
@@ -262,6 +263,15 @@ async def _handle_response(
             request.uri.decode("ascii"),
         )
         raise RequestSendFailed(e, can_retry=True) from e
+    except ResponseFailed as e:
+        logger.warning(
+            "{%s} [%s] Failed to read response - %s %s",
+            request.txn_id,
+            request.destination,
+            request.method,
+            request.uri.decode("ascii"),
+        )
+        raise RequestSendFailed(e, can_retry=True) from e
     except Exception as e:
         logger.warning(
             "{%s} [%s] Error reading response %s %s: %s",
@@ -317,11 +327,11 @@ class MatrixFederationHttpClient:
 
         # We need to use a DNS resolver which filters out blacklisted IP
         # addresses, to prevent DNS rebinding.
-        self.reactor = BlacklistingReactorWrapper(
+        self.reactor: ISynapseReactor = BlacklistingReactorWrapper(
             hs.get_reactor(),
             hs.config.federation_ip_range_whitelist,
             hs.config.federation_ip_range_blacklist,
-        )  # type: ISynapseReactor
+        )
 
         user_agent = hs.version_string
         if hs.config.user_agent_suffix:
@@ -494,7 +504,7 @@ class MatrixFederationHttpClient:
         )
 
         # Inject the span into the headers
-        headers_dict = {}  # type: Dict[bytes, List[bytes]]
+        headers_dict: Dict[bytes, List[bytes]] = {}
         opentracing.inject_header_dict(headers_dict, request.destination)
 
         headers_dict[b"User-Agent"] = [self.version_string_bytes]
@@ -523,9 +533,9 @@ class MatrixFederationHttpClient:
                             destination_bytes, method_bytes, url_to_sign_bytes, json
                         )
                         data = encode_canonical_json(json)
-                        producer = QuieterFileBodyProducer(
+                        producer: Optional[IBodyProducer] = QuieterFileBodyProducer(
                             BytesIO(data), cooperator=self._cooperator
-                        )  # type: Optional[IBodyProducer]
+                        )
                     else:
                         producer = None
                         auth_headers = self.build_auth_headers(
@@ -1137,6 +1147,24 @@ class MatrixFederationHttpClient:
                 msg,
             )
             raise SynapseError(502, msg, Codes.TOO_LARGE)
+        except defer.TimeoutError as e:
+            logger.warning(
+                "{%s} [%s] Timed out reading response - %s %s",
+                request.txn_id,
+                request.destination,
+                request.method,
+                request.uri.decode("ascii"),
+            )
+            raise RequestSendFailed(e, can_retry=True) from e
+        except ResponseFailed as e:
+            logger.warning(
+                "{%s} [%s] Failed to read response - %s %s",
+                request.txn_id,
+                request.destination,
+                request.method,
+                request.uri.decode("ascii"),
+            )
+            raise RequestSendFailed(e, can_retry=True) from e
         except Exception as e:
             logger.warning(
                 "{%s} [%s] Error reading response: %s",
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index 7dfae8b786..f7193e60bd 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -117,7 +117,8 @@ class ProxyAgent(_AgentBase):
             https_proxy = proxies["https"].encode() if "https" in proxies else None
             no_proxy = proxies["no"] if "no" in proxies else None
 
-        # Parse credentials from https proxy connection string if present
+        # Parse credentials from http and https proxy connection string if present
+        self.http_proxy_creds, http_proxy = parse_username_password(http_proxy)
         self.https_proxy_creds, https_proxy = parse_username_password(https_proxy)
 
         self.http_proxy_endpoint = _http_proxy_endpoint(
@@ -171,7 +172,7 @@ class ProxyAgent(_AgentBase):
         """
         uri = uri.strip()
         if not _VALID_URI.match(uri):
-            raise ValueError("Invalid URI {!r}".format(uri))
+            raise ValueError(f"Invalid URI {uri!r}")
 
         parsed_uri = URI.fromBytes(uri)
         pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
@@ -189,6 +190,15 @@ class ProxyAgent(_AgentBase):
             and self.http_proxy_endpoint
             and not should_skip_proxy
         ):
+            # Determine whether we need to set Proxy-Authorization headers
+            if self.http_proxy_creds:
+                # Set a Proxy-Authorization header
+                if headers is None:
+                    headers = Headers()
+                headers.addRawHeader(
+                    b"Proxy-Authorization",
+                    self.http_proxy_creds.as_proxy_authorization_value(),
+                )
             # Cache *all* connections under the same key, since we are only
             # connecting to a single destination, the proxy:
             pool_key = ("http-proxy", self.http_proxy_endpoint)
diff --git a/synapse/http/server.py b/synapse/http/server.py
index efbc6d5b25..b79fa722e9 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -81,7 +81,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
 
     if f.check(SynapseError):
         # mypy doesn't understand that f.check asserts the type.
-        exc = f.value  # type: SynapseError  # type: ignore
+        exc: SynapseError = f.value  # type: ignore
         error_code = exc.code
         error_dict = exc.error_dict()
 
@@ -132,7 +132,7 @@ def return_html_error(
     """
     if f.check(CodeMessageException):
         # mypy doesn't understand that f.check asserts the type.
-        cme = f.value  # type: CodeMessageException  # type: ignore
+        cme: CodeMessageException = f.value  # type: ignore
         code = cme.code
         msg = cme.msg
 
@@ -404,7 +404,7 @@ class JsonResource(DirectServeJsonResource):
             key word arguments to pass to the callback
         """
         # At this point the path must be bytes.
-        request_path_bytes = request.path  # type: bytes  # type: ignore
+        request_path_bytes: bytes = request.path  # type: ignore
         request_path = request_path_bytes.decode("ascii")
         # Treat HEAD requests as GET requests.
         request_method = request.method
@@ -557,7 +557,7 @@ class _ByteProducer:
         request: Request,
         iterator: Iterator[bytes],
     ):
-        self._request = request  # type: Optional[Request]
+        self._request: Optional[Request] = request
         self._iterator = iterator
         self._paused = False
 
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 6ba2ce1e53..04560fb589 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -205,7 +205,7 @@ def parse_string(
             parameter is present, must be one of a list of allowed values and
             is not one of those allowed values.
     """
-    args = request.args  # type: Dict[bytes, List[bytes]]  # type: ignore
+    args: Dict[bytes, List[bytes]] = request.args  # type: ignore
     return parse_string_from_args(
         args,
         name,
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 40754b7bea..190084e8aa 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -64,16 +64,16 @@ class SynapseRequest(Request):
     def __init__(self, channel, *args, max_request_body_size=1024, **kw):
         Request.__init__(self, channel, *args, **kw)
         self._max_request_body_size = max_request_body_size
-        self.site = channel.site  # type: SynapseSite
+        self.site: SynapseSite = channel.site
         self._channel = channel  # this is used by the tests
         self.start_time = 0.0
 
         # The requester, if authenticated. For federation requests this is the
         # server name, for client requests this is the Requester object.
-        self._requester = None  # type: Optional[Union[Requester, str]]
+        self._requester: Optional[Union[Requester, str]] = None
 
         # we can't yet create the logcontext, as we don't know the method.
-        self.logcontext = None  # type: Optional[LoggingContext]
+        self.logcontext: Optional[LoggingContext] = None
 
         global _next_request_seq
         self.request_seq = _next_request_seq
@@ -152,7 +152,7 @@ class SynapseRequest(Request):
         Returns:
             The redacted URI as a string.
         """
-        uri = self.uri  # type: Union[bytes, str]
+        uri: Union[bytes, str] = self.uri
         if isinstance(uri, bytes):
             uri = uri.decode("ascii", errors="replace")
         return redact_uri(uri)
@@ -167,7 +167,7 @@ class SynapseRequest(Request):
         Returns:
             The request method as a string.
         """
-        method = self.method  # type: Union[bytes, str]
+        method: Union[bytes, str] = self.method
         if isinstance(method, bytes):
             return self.method.decode("ascii")
         return method
@@ -384,7 +384,7 @@ class SynapseRequest(Request):
         # authenticated (e.g. and admin is puppetting a user) then we log both.
         requester, authenticated_entity = self.get_authenticated_entity()
         if authenticated_entity:
-            requester = "{}.{}".format(authenticated_entity, requester)
+            requester = f"{authenticated_entity}.{requester}"
 
         self.site.access_logger.log(
             log_level,
@@ -434,8 +434,8 @@ class XForwardedForRequest(SynapseRequest):
     """
 
     # the client IP and ssl flag, as extracted from the headers.
-    _forwarded_for = None  # type: Optional[_XForwardedForAddress]
-    _forwarded_https = False  # type: bool
+    _forwarded_for: "Optional[_XForwardedForAddress]" = None
+    _forwarded_https: bool = False
 
     def requestReceived(self, command, path, version):
         # this method is called by the Channel once the full request has been
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index c515690b38..8202d0494d 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -110,9 +110,9 @@ class RemoteHandler(logging.Handler):
         self.port = port
         self.maximum_buffer = maximum_buffer
 
-        self._buffer = deque()  # type: Deque[logging.LogRecord]
-        self._connection_waiter = None  # type: Optional[Deferred]
-        self._producer = None  # type: Optional[LogProducer]
+        self._buffer: Deque[logging.LogRecord] = deque()
+        self._connection_waiter: Optional[Deferred] = None
+        self._producer: Optional[LogProducer] = None
 
         # Connect without DNS lookups if it's a direct IP.
         if _reactor is None:
@@ -123,9 +123,9 @@ class RemoteHandler(logging.Handler):
         try:
             ip = ip_address(self.host)
             if isinstance(ip, IPv4Address):
-                endpoint = TCP4ClientEndpoint(
+                endpoint: IStreamClientEndpoint = TCP4ClientEndpoint(
                     _reactor, self.host, self.port
-                )  # type: IStreamClientEndpoint
+                )
             elif isinstance(ip, IPv6Address):
                 endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
             else:
@@ -165,7 +165,7 @@ class RemoteHandler(logging.Handler):
         def writer(result: Protocol) -> None:
             # Force recognising transport as a Connection and not the more
             # generic ITransport.
-            transport = result.transport  # type: Connection  # type: ignore
+            transport: Connection = result.transport  # type: ignore
 
             # We have a connection. If we already have a producer, and its
             # transport is the same, just trigger a resumeProducing.
@@ -188,7 +188,7 @@ class RemoteHandler(logging.Handler):
             self._producer.resumeProducing()
             self._connection_waiter = None
 
-        deferred = self._service.whenConnected(failAfterFailures=1)  # type: Deferred
+        deferred: Deferred = self._service.whenConnected(failAfterFailures=1)
         deferred.addCallbacks(writer, fail)
         self._connection_waiter = deferred
 
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index c7a971a9d6..b9933a1528 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -63,7 +63,7 @@ def parse_drain_configs(
             DrainType.CONSOLE_JSON,
             DrainType.FILE_JSON,
         ):
-            formatter = "json"  # type: Optional[str]
+            formatter: Optional[str] = "json"
         elif logging_type in (
             DrainType.CONSOLE_JSON_TERSE,
             DrainType.NETWORK_JSON_TERSE,
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 7fc11a9ac2..18ac507802 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -113,13 +113,13 @@ class ContextResourceUsage:
             self.reset()
         else:
             # FIXME: mypy can't infer the types set via reset() above, so specify explicitly for now
-            self.ru_utime = copy_from.ru_utime  # type: float
-            self.ru_stime = copy_from.ru_stime  # type: float
-            self.db_txn_count = copy_from.db_txn_count  # type: int
+            self.ru_utime: float = copy_from.ru_utime
+            self.ru_stime: float = copy_from.ru_stime
+            self.db_txn_count: int = copy_from.db_txn_count
 
-            self.db_txn_duration_sec = copy_from.db_txn_duration_sec  # type: float
-            self.db_sched_duration_sec = copy_from.db_sched_duration_sec  # type: float
-            self.evt_db_fetch_count = copy_from.evt_db_fetch_count  # type: int
+            self.db_txn_duration_sec: float = copy_from.db_txn_duration_sec
+            self.db_sched_duration_sec: float = copy_from.db_sched_duration_sec
+            self.evt_db_fetch_count: int = copy_from.evt_db_fetch_count
 
     def copy(self) -> "ContextResourceUsage":
         return ContextResourceUsage(copy_from=self)
@@ -289,12 +289,12 @@ class LoggingContext:
 
         # The thread resource usage when the logcontext became active. None
         # if the context is not currently active.
-        self.usage_start = None  # type: Optional[resource._RUsage]
+        self.usage_start: Optional[resource._RUsage] = None
 
         self.main_thread = get_thread_id()
         self.request = None
         self.tag = ""
-        self.scope = None  # type: Optional[_LogContextScope]
+        self.scope: Optional["_LogContextScope"] = None
 
         # keep track of whether we have hit the __exit__ block for this context
         # (suggesting that the the thing that created the context thinks it should
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 140ed711e3..ecd51f1b4a 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -251,7 +251,7 @@ try:
             except Exception:
                 logger.exception("Failed to report span")
 
-    RustReporter = _WrappedRustReporter  # type: Optional[Type[_WrappedRustReporter]]
+    RustReporter: Optional[Type[_WrappedRustReporter]] = _WrappedRustReporter
 except ImportError:
     RustReporter = None
 
@@ -286,7 +286,7 @@ class SynapseBaggage:
 # Block everything by default
 # A regex which matches the server_names to expose traces for.
 # None means 'block everything'.
-_homeserver_whitelist = None  # type: Optional[Pattern[str]]
+_homeserver_whitelist: Optional[Pattern[str]] = None
 
 # Util methods
 
@@ -374,7 +374,7 @@ def init_tracer(hs: "HomeServer"):
 
     config = JaegerConfig(
         config=hs.config.jaeger_config,
-        service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()),
+        service_name=f"{hs.config.server_name} {hs.get_instance_name()}",
         scope_manager=LogContextScopeManager(hs.config),
         metrics_factory=PrometheusMetricsFactory(),
     )
@@ -662,7 +662,7 @@ def inject_header_dict(
 
     span = opentracing.tracer.active_span
 
-    carrier = {}  # type: Dict[str, str]
+    carrier: Dict[str, str] = {}
     opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
 
     for key, value in carrier.items():
@@ -704,7 +704,7 @@ def get_active_span_text_map(destination=None):
     if destination and not whitelisted_homeserver(destination):
         return {}
 
-    carrier = {}  # type: Dict[str, str]
+    carrier: Dict[str, str] = {}
     opentracing.tracer.inject(
         opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
     )
@@ -718,7 +718,7 @@ def active_span_context_as_string():
     Returns:
         The active span context encoded as a string.
     """
-    carrier = {}  # type: Dict[str, str]
+    carrier: Dict[str, str] = {}
     if opentracing:
         opentracing.tracer.inject(
             opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index fef2846669..f237b8a236 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -46,7 +46,7 @@ logger = logging.getLogger(__name__)
 METRICS_PREFIX = "/_synapse/metrics"
 
 running_on_pypy = platform.python_implementation() == "PyPy"
-all_gauges = {}  # type: Dict[str, Union[LaterGauge, InFlightGauge]]
+all_gauges: "Dict[str, Union[LaterGauge, InFlightGauge]]" = {}
 
 HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
 
@@ -130,7 +130,7 @@ class InFlightGauge:
         )
 
         # Counts number of in flight blocks for a given set of label values
-        self._registrations = {}  # type: Dict
+        self._registrations: Dict = {}
 
         # Protects access to _registrations
         self._lock = threading.Lock()
@@ -248,7 +248,7 @@ class GaugeBucketCollector:
 
         # We initially set this to None. We won't report metrics until
         # this has been initialised after a successful data update
-        self._metric = None  # type: Optional[GaugeHistogramMetricFamily]
+        self._metric: Optional[GaugeHistogramMetricFamily] = None
 
         registry.register(self)
 
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py
index 8002be56e0..bb9bcb5592 100644
--- a/synapse/metrics/_exposition.py
+++ b/synapse/metrics/_exposition.py
@@ -34,7 +34,7 @@ from twisted.web.resource import Resource
 
 from synapse.util import caches
 
-CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8")
+CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8"
 
 
 INF = float("inf")
@@ -55,8 +55,8 @@ def floatToGoString(d):
         # Go switches to exponents sooner than Python.
         # We only need to care about positive values for le/quantile.
         if d > 0 and dot > 6:
-            mantissa = "{0}.{1}{2}".format(s[0], s[1:dot], s[dot + 1 :]).rstrip("0.")
-            return "{0}e+0{1}".format(mantissa, dot - 1)
+            mantissa = f"{s[0]}.{s[1:dot]}{s[dot + 1 :]}".rstrip("0.")
+            return f"{mantissa}e+0{dot - 1}"
         return s
 
 
@@ -65,7 +65,7 @@ def sample_line(line, name):
         labelstr = "{{{0}}}".format(
             ",".join(
                 [
-                    '{0}="{1}"'.format(
+                    '{}="{}"'.format(
                         k,
                         v.replace("\\", r"\\").replace("\n", r"\n").replace('"', r"\""),
                     )
@@ -78,10 +78,8 @@ def sample_line(line, name):
     timestamp = ""
     if line.timestamp is not None:
         # Convert to milliseconds.
-        timestamp = " {0:d}".format(int(float(line.timestamp) * 1000))
-    return "{0}{1} {2}{3}\n".format(
-        name, labelstr, floatToGoString(line.value), timestamp
-    )
+        timestamp = f" {int(float(line.timestamp) * 1000):d}"
+    return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp)
 
 
 def generate_latest(registry, emit_help=False):
@@ -118,14 +116,14 @@ def generate_latest(registry, emit_help=False):
         # Output in the old format for compatibility.
         if emit_help:
             output.append(
-                "# HELP {0} {1}\n".format(
+                "# HELP {} {}\n".format(
                     mname,
                     metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
                 )
             )
-        output.append("# TYPE {0} {1}\n".format(mname, mtype))
+        output.append(f"# TYPE {mname} {mtype}\n")
 
-        om_samples = {}  # type: Dict[str, List[str]]
+        om_samples: Dict[str, List[str]] = {}
         for s in metric.samples:
             for suffix in ["_created", "_gsum", "_gcount"]:
                 if s.name == metric.name + suffix:
@@ -143,13 +141,13 @@ def generate_latest(registry, emit_help=False):
         for suffix, lines in sorted(om_samples.items()):
             if emit_help:
                 output.append(
-                    "# HELP {0}{1} {2}\n".format(
+                    "# HELP {}{} {}\n".format(
                         metric.name,
                         suffix,
                         metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
                     )
                 )
-            output.append("# TYPE {0}{1} gauge\n".format(metric.name, suffix))
+            output.append(f"# TYPE {metric.name}{suffix} gauge\n")
             output.extend(lines)
 
         # Get rid of the weird colon things while we're at it
@@ -163,12 +161,12 @@ def generate_latest(registry, emit_help=False):
         # Also output in the new format, if it's different.
         if emit_help:
             output.append(
-                "# HELP {0} {1}\n".format(
+                "# HELP {} {}\n".format(
                     mnewname,
                     metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
                 )
             )
-        output.append("# TYPE {0} {1}\n".format(mnewname, mtype))
+        output.append(f"# TYPE {mnewname} {mtype}\n")
 
         for s in metric.samples:
             # Get rid of the OpenMetrics specific samples (we should already have
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index de96ca0821..3a14260752 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -93,7 +93,7 @@ _background_process_db_sched_duration = Counter(
 # map from description to a counter, so that we can name our logcontexts
 # incrementally. (It actually duplicates _background_process_start_count, but
 # it's much simpler to do so than to try to combine them.)
-_background_process_counts = {}  # type: Dict[str, int]
+_background_process_counts: Dict[str, int] = {}
 
 # Set of all running background processes that became active active since the
 # last time metrics were scraped (i.e. background processes that performed some
@@ -103,7 +103,7 @@ _background_process_counts = {}  # type: Dict[str, int]
 # background processes stacking up behind a lock or linearizer, where we then
 # only need to iterate over and update metrics for the process that have
 # actually been active and can ignore the idle ones.
-_background_processes_active_since_last_scrape = set()  # type: Set[_BackgroundProcess]
+_background_processes_active_since_last_scrape: "Set[_BackgroundProcess]" = set()
 
 # A lock that covers the above set and dict
 _bg_metrics_lock = threading.Lock()
@@ -137,8 +137,7 @@ class _Collector:
             _background_process_db_txn_duration,
             _background_process_db_sched_duration,
         ):
-            for r in m.collect():
-                yield r
+            yield from m.collect()
 
 
 REGISTRY.register(_Collector())
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 721c45abac..1259fc2d90 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -12,18 +12,42 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import email.utils
 import logging
-from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    Generator,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+)
+
+import jinja2
 
 from twisted.internet import defer
 from twisted.web.resource import IResource
 
 from synapse.events import EventBase
 from synapse.http.client import SimpleHttpClient
+from synapse.http.server import (
+    DirectServeHtmlResource,
+    DirectServeJsonResource,
+    respond_with_html,
+)
+from synapse.http.servlet import parse_json_object_from_request
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.roommember import ProfileInfo
 from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, UserID, create_requester
+from synapse.types import JsonDict, Requester, UserID, create_requester
+from synapse.util import Clock
+from synapse.util.caches.descriptors import cached
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -33,7 +57,20 @@ This package defines the 'stable' API which can be used by extension modules whi
 are loaded into Synapse.
 """
 
-__all__ = ["errors", "make_deferred_yieldable", "run_in_background", "ModuleApi"]
+__all__ = [
+    "errors",
+    "make_deferred_yieldable",
+    "parse_json_object_from_request",
+    "respond_with_html",
+    "run_in_background",
+    "cached",
+    "UserID",
+    "DatabasePool",
+    "LoggingTransaction",
+    "DirectServeHtmlResource",
+    "DirectServeJsonResource",
+    "ModuleApi",
+]
 
 logger = logging.getLogger(__name__)
 
@@ -52,12 +89,28 @@ class ModuleApi:
         self._server_name = hs.hostname
         self._presence_stream = hs.get_event_sources().sources["presence"]
         self._state = hs.get_state_handler()
+        self._clock: Clock = hs.get_clock()
+        self._send_email_handler = hs.get_send_email_handler()
+
+        try:
+            app_name = self._hs.config.email_app_name
+
+            self._from_string = self._hs.config.email_notif_from % {"app": app_name}
+        except (KeyError, TypeError):
+            # If substitution failed (which can happen if the string contains
+            # placeholders other than just "app", or if the type of the placeholder is
+            # not a string), fall back to the bare strings.
+            self._from_string = self._hs.config.email_notif_from
+
+        self._raw_from = email.utils.parseaddr(self._from_string)[1]
 
         # We expose these as properties below in order to attach a helpful docstring.
-        self._http_client = hs.get_simple_http_client()  # type: SimpleHttpClient
+        self._http_client: SimpleHttpClient = hs.get_simple_http_client()
         self._public_room_list_manager = PublicRoomListManager(hs)
 
         self._spam_checker = hs.get_spam_checker()
+        self._account_validity_handler = hs.get_account_validity_handler()
+        self._third_party_event_rules = hs.get_third_party_event_rules()
 
     #################################################################################
     # The following methods should only be called during the module's initialisation.
@@ -67,6 +120,16 @@ class ModuleApi:
         """Registers callbacks for spam checking capabilities."""
         return self._spam_checker.register_callbacks
 
+    @property
+    def register_account_validity_callbacks(self):
+        """Registers callbacks for account validity capabilities."""
+        return self._account_validity_handler.register_account_validity_callbacks
+
+    @property
+    def register_third_party_rules_callbacks(self):
+        """Registers callbacks for third party event rules capabilities."""
+        return self._third_party_event_rules.register_third_party_rules_callbacks
+
     def register_web_resource(self, path: str, resource: IResource):
         """Registers a web resource to be served at the given path.
 
@@ -101,22 +164,56 @@ class ModuleApi:
         """
         return self._public_room_list_manager
 
-    def get_user_by_req(self, req, allow_guest=False):
+    @property
+    def public_baseurl(self) -> str:
+        """The configured public base URL for this homeserver."""
+        return self._hs.config.public_baseurl
+
+    @property
+    def email_app_name(self) -> str:
+        """The application name configured in the homeserver's configuration."""
+        return self._hs.config.email.email_app_name
+
+    async def get_user_by_req(
+        self,
+        req: SynapseRequest,
+        allow_guest: bool = False,
+        allow_expired: bool = False,
+    ) -> Requester:
         """Check the access_token provided for a request
 
         Args:
-            req (twisted.web.server.Request): Incoming HTTP request
-            allow_guest (bool): True if guest users should be allowed. If this
+            req: Incoming HTTP request
+            allow_guest: True if guest users should be allowed. If this
                 is False, and the access token is for a guest user, an
                 AuthError will be thrown
+            allow_expired: True if expired users should be allowed. If this
+                is False, and the access token is for an expired user, an
+                AuthError will be thrown
+
         Returns:
-            twisted.internet.defer.Deferred[synapse.types.Requester]:
-                the requester for this request
+            The requester for this request
+
         Raises:
-            synapse.api.errors.AuthError: if no user by that token exists,
+            InvalidClientCredentialsError: if no user by that token exists,
                 or the token is invalid.
         """
-        return self._auth.get_user_by_req(req, allow_guest)
+        return await self._auth.get_user_by_req(
+            req,
+            allow_guest,
+            allow_expired=allow_expired,
+        )
+
+    async def is_user_admin(self, user_id: str) -> bool:
+        """Checks if a user is a server admin.
+
+        Args:
+            user_id: The Matrix ID of the user to check.
+
+        Returns:
+            True if the user is a server admin, False otherwise.
+        """
+        return await self._store.is_server_admin(UserID.from_string(user_id))
 
     def get_qualified_user_id(self, username):
         """Qualify a user id, if necessary
@@ -134,6 +231,32 @@ class ModuleApi:
             return username
         return UserID(username, self._hs.hostname).to_string()
 
+    async def get_profile_for_user(self, localpart: str) -> ProfileInfo:
+        """Look up the profile info for the user with the given localpart.
+
+        Args:
+            localpart: The localpart to look up profile information for.
+
+        Returns:
+            The profile information (i.e. display name and avatar URL).
+        """
+        return await self._store.get_profileinfo(localpart)
+
+    async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]:
+        """Look up the threepids (email addresses and phone numbers) associated with the
+        given Matrix user ID.
+
+        Args:
+            user_id: The Matrix user ID to look up threepids for.
+
+        Returns:
+            A list of threepids, each threepid being represented by a dictionary
+            containing a "medium" key which value is "email" for email addresses and
+            "msisdn" for phone numbers, and an "address" key which value is the
+            threepid's address.
+        """
+        return await self._store.user_get_threepids(user_id)
+
     def check_user_exists(self, user_id):
         """Check if user exists.
 
@@ -464,6 +587,88 @@ class ModuleApi:
                 presence_events, destination
             )
 
+    def looping_background_call(
+        self,
+        f: Callable,
+        msec: float,
+        *args,
+        desc: Optional[str] = None,
+        **kwargs,
+    ):
+        """Wraps a function as a background process and calls it repeatedly.
+
+        Waits `msec` initially before calling `f` for the first time.
+
+        Args:
+            f: The function to call repeatedly. f can be either synchronous or
+                asynchronous, and must follow Synapse's logcontext rules.
+                More info about logcontexts is available at
+                https://matrix-org.github.io/synapse/latest/log_contexts.html
+            msec: How long to wait between calls in milliseconds.
+            *args: Positional arguments to pass to function.
+            desc: The background task's description. Default to the function's name.
+            **kwargs: Key arguments to pass to function.
+        """
+        if desc is None:
+            desc = f.__name__
+
+        if self._hs.config.run_background_tasks:
+            self._clock.looping_call(
+                run_as_background_process,
+                msec,
+                desc,
+                f,
+                *args,
+                **kwargs,
+            )
+        else:
+            logger.warning(
+                "Not running looping call %s as the configuration forbids it",
+                f,
+            )
+
+    async def send_mail(
+        self,
+        recipient: str,
+        subject: str,
+        html: str,
+        text: str,
+    ):
+        """Send an email on behalf of the homeserver.
+
+        Args:
+            recipient: The email address for the recipient.
+            subject: The email's subject.
+            html: The email's HTML content.
+            text: The email's text content.
+        """
+        await self._send_email_handler.send_email(
+            email_address=recipient,
+            subject=subject,
+            app_name=self.email_app_name,
+            html=html,
+            text=text,
+        )
+
+    def read_templates(
+        self,
+        filenames: List[str],
+        custom_template_directory: Optional[str] = None,
+    ) -> List[jinja2.Template]:
+        """Read and load the content of the template files at the given location.
+        By default, Synapse will look for these templates in its configured template
+        directory, but another directory to search in can be provided.
+
+        Args:
+            filenames: The name of the template files to look for.
+            custom_template_directory: An additional directory to look for the files in.
+
+        Returns:
+            A list containing the loaded templates, with the orders matching the one of
+            the filenames parameter.
+        """
+        return self._hs.config.read_templates(filenames, custom_template_directory)
+
 
 class PublicRoomListManager:
     """Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py
index 02bbb0be39..98ea911a81 100644
--- a/synapse/module_api/errors.py
+++ b/synapse/module_api/errors.py
@@ -14,5 +14,9 @@
 
 """Exception types which are exposed as part of the stable module API"""
 
-from synapse.api.errors import RedirectException, SynapseError  # noqa: F401
+from synapse.api.errors import (  # noqa: F401
+    InvalidClientCredentialsError,
+    RedirectException,
+    SynapseError,
+)
 from synapse.config._base import ConfigError  # noqa: F401
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 3c3cc47631..c5fbebc17d 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -203,21 +203,21 @@ class Notifier:
     UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
 
     def __init__(self, hs: "synapse.server.HomeServer"):
-        self.user_to_user_stream = {}  # type: Dict[str, _NotifierUserStream]
-        self.room_to_user_streams = {}  # type: Dict[str, Set[_NotifierUserStream]]
+        self.user_to_user_stream: Dict[str, _NotifierUserStream] = {}
+        self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {}
 
         self.hs = hs
         self.storage = hs.get_storage()
         self.event_sources = hs.get_event_sources()
         self.store = hs.get_datastore()
-        self.pending_new_room_events = []  # type: List[_PendingRoomEventEntry]
+        self.pending_new_room_events: List[_PendingRoomEventEntry] = []
 
         # Called when there are new things to stream over replication
-        self.replication_callbacks = []  # type: List[Callable[[], None]]
+        self.replication_callbacks: List[Callable[[], None]] = []
 
         # Called when remote servers have come back online after having been
         # down.
-        self.remote_server_up_callbacks = []  # type: List[Callable[[str], None]]
+        self.remote_server_up_callbacks: List[Callable[[str], None]] = []
 
         self.clock = hs.get_clock()
         self.appservice_handler = hs.get_application_service_handler()
@@ -237,7 +237,7 @@ class Notifier:
         # when rendering the metrics page, which is likely once per minute at
         # most when scraping it.
         def count_listeners():
-            all_user_streams = set()  # type: Set[_NotifierUserStream]
+            all_user_streams: Set[_NotifierUserStream] = set()
 
             for streams in list(self.room_to_user_streams.values()):
                 all_user_streams |= streams
@@ -329,8 +329,8 @@ class Notifier:
         pending = self.pending_new_room_events
         self.pending_new_room_events = []
 
-        users = set()  # type: Set[UserID]
-        rooms = set()  # type: Set[str]
+        users: Set[UserID] = set()
+        rooms: Set[str] = set()
 
         for entry in pending:
             if entry.event_pos.persisted_after(max_room_stream_token):
@@ -580,7 +580,7 @@ class Notifier:
             if after_token == before_token:
                 return EventStreamResult([], (from_token, from_token))
 
-            events = []  # type: List[EventBase]
+            events: List[EventBase] = []
             end_token = from_token
 
             for name, source in self.event_sources.sources.items():
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 669ea462e2..c337e530d3 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -194,7 +194,7 @@ class BulkPushRuleEvaluator:
         count_as_unread = _should_count_as_unread(event, context)
 
         rules_by_user = await self._get_rules_for_event(event, context)
-        actions_by_user = {}  # type: Dict[str, List[Union[dict, str]]]
+        actions_by_user: Dict[str, List[Union[dict, str]]] = {}
 
         room_members = await self.store.get_joined_users_from_context(event, context)
 
@@ -207,7 +207,7 @@ class BulkPushRuleEvaluator:
             event, len(room_members), sender_power_level, power_levels
         )
 
-        condition_cache = {}  # type: Dict[str, bool]
+        condition_cache: Dict[str, bool] = {}
 
         # If the event is not a state event check if any users ignore the sender.
         if not event.is_state():
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index 2ee0ccd58a..1fc9716a34 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -26,10 +26,10 @@ def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, l
     # We're going to be mutating this a lot, so do a deep copy
     ruleslist = copy.deepcopy(ruleslist)
 
-    rules = {
+    rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = {
         "global": {},
         "device": {},
-    }  # type: Dict[str, Dict[str, List[Dict[str, Any]]]]
+    }
 
     rules["global"] = _add_empty_priority_class_arrays(rules["global"])
 
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 99a18874d1..e08e125cb8 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -66,8 +66,8 @@ class EmailPusher(Pusher):
 
         self.store = self.hs.get_datastore()
         self.email = pusher_config.pushkey
-        self.timed_call = None  # type: Optional[IDelayedCall]
-        self.throttle_params = {}  # type: Dict[str, ThrottleParams]
+        self.timed_call: Optional[IDelayedCall] = None
+        self.throttle_params: Dict[str, ThrottleParams] = {}
         self._inited = False
 
         self._is_processing = False
@@ -168,7 +168,7 @@ class EmailPusher(Pusher):
             )
         )
 
-        soonest_due_at = None  # type: Optional[int]
+        soonest_due_at: Optional[int] = None
 
         if not unprocessed:
             await self.save_last_stream_ordering_and_success(self.max_stream_ordering)
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 06bf5f8ada..36aabd8422 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -71,7 +71,7 @@ class HttpPusher(Pusher):
         self.data = pusher_config.data
         self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
         self.failing_since = pusher_config.failing_since
-        self.timed_call = None  # type: Optional[IDelayedCall]
+        self.timed_call: Optional[IDelayedCall] = None
         self._is_processing = False
         self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
         self._pusherpool = hs.get_pusherpool()
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 5f9ea5003a..7be5fe1e9b 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -110,7 +110,7 @@ class Mailer:
         self.state_handler = self.hs.get_state_handler()
         self.storage = hs.get_storage()
         self.app_name = app_name
-        self.email_subjects = hs.config.email_subjects  # type: EmailSubjectConfig
+        self.email_subjects: EmailSubjectConfig = hs.config.email_subjects
 
         logger.info("Created Mailer for app_name %s" % app_name)
 
@@ -230,7 +230,7 @@ class Mailer:
             [pa["event_id"] for pa in push_actions]
         )
 
-        notifs_by_room = {}  # type: Dict[str, List[Dict[str, Any]]]
+        notifs_by_room: Dict[str, List[Dict[str, Any]]] = {}
         for pa in push_actions:
             notifs_by_room.setdefault(pa["room_id"], []).append(pa)
 
@@ -356,13 +356,13 @@ class Mailer:
 
         room_name = await calculate_room_name(self.store, room_state_ids, user_id)
 
-        room_vars = {
+        room_vars: Dict[str, Any] = {
             "title": room_name,
             "hash": string_ordinal_total(room_id),  # See sender avatar hash
             "notifs": [],
             "invite": is_invite,
             "link": self._make_room_link(room_id),
-        }  # type: Dict[str, Any]
+        }
 
         if not is_invite:
             for n in notifs:
@@ -460,9 +460,9 @@ class Mailer:
         type_state_key = ("m.room.member", event.sender)
         sender_state_event_id = room_state_ids.get(type_state_key)
         if sender_state_event_id:
-            sender_state_event = await self.store.get_event(
+            sender_state_event: Optional[EventBase] = await self.store.get_event(
                 sender_state_event_id
-            )  # type: Optional[EventBase]
+            )
         else:
             # Attempt to check the historical state for the room.
             historical_state = await self.state_store.get_state_for_event(
diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index 412941393f..0510c1cbd5 100644
--- a/synapse/push/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -199,7 +199,7 @@ def name_from_member_event(member_event: EventBase) -> str:
 
 
 def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]:
-    ret = {}  # type: Dict[str, Dict[str, str]]
+    ret: Dict[str, Dict[str, str]] = {}
     for k, v in state.items():
         ret.setdefault(k[0], {})[k[1]] = v
     return ret
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 98b90a4f51..7a8dc63976 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -195,9 +195,9 @@ class PushRuleEvaluatorForEvent:
 
 
 # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
-regex_cache = LruCache(
+regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache(
     50000, "regex_push_cache"
-)  # type: LruCache[Tuple[str, bool, bool], Pattern]
+)
 
 
 def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index c51938b8cf..021275437c 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -31,13 +31,13 @@ class PusherFactory:
         self.hs = hs
         self.config = hs.config
 
-        self.pusher_types = {
+        self.pusher_types: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]] = {
             "http": HttpPusher
-        }  # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]]
+        }
 
         logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
         if hs.config.email_enable_notifs:
-            self.mailers = {}  # type: Dict[str, Mailer]
+            self.mailers: Dict[str, Mailer] = {}
 
             self._notif_template_html = hs.config.email_notif_template_html
             self._notif_template_text = hs.config.email_notif_template_text
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 579fcdf472..85621f33ef 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -62,10 +62,6 @@ class PusherPool:
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
 
-        self._account_validity_enabled = (
-            hs.config.account_validity.account_validity_enabled
-        )
-
         # We shard the handling of push notifications by user ID.
         self._pusher_shard_config = hs.config.push.pusher_shard_config
         self._instance_name = hs.get_instance_name()
@@ -87,7 +83,9 @@ class PusherPool:
         self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering()
 
         # map from user id to app_id:pushkey to pusher
-        self.pushers = {}  # type: Dict[str, Dict[str, Pusher]]
+        self.pushers: Dict[str, Dict[str, Pusher]] = {}
+
+        self._account_validity_handler = hs.get_account_validity_handler()
 
     def start(self) -> None:
         """Starts the pushers off in a background process."""
@@ -238,12 +236,9 @@ class PusherPool:
 
             for u in users_affected:
                 # Don't push if the user account has expired
-                if self._account_validity_enabled:
-                    expired = await self.store.is_account_expired(
-                        u, self.clock.time_msec()
-                    )
-                    if expired:
-                        continue
+                expired = await self._account_validity_handler.is_user_expired(u)
+                if expired:
+                    continue
 
                 if u in self.pushers:
                     for p in self.pushers[u].values():
@@ -268,12 +263,9 @@ class PusherPool:
 
             for u in users_affected:
                 # Don't push if the user account has expired
-                if self._account_validity_enabled:
-                    expired = await self.store.is_account_expired(
-                        u, self.clock.time_msec()
-                    )
-                    if expired:
-                        continue
+                expired = await self._account_validity_handler.is_user_expired(u)
+                if expired:
+                    continue
 
                 if u in self.pushers:
                     for p in self.pushers[u].values():
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 271c17c226..cdcbdd772b 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -115,7 +115,7 @@ CONDITIONAL_REQUIREMENTS = {
     "cache_memory": ["pympler"],
 }
 
-ALL_OPTIONAL_REQUIREMENTS = set()  # type: Set[str]
+ALL_OPTIONAL_REQUIREMENTS: Set[str] = set()
 
 for name, optional_deps in CONDITIONAL_REQUIREMENTS.items():
     # Exclude systemd as it's a system-based requirement.
@@ -193,7 +193,7 @@ def check_requirements(for_feature=None):
     if not for_feature:
         # Check the optional dependencies are up to date. We allow them to not be
         # installed.
-        OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), [])  # type: List[str]
+        OPTS: List[str] = sum(CONDITIONAL_REQUIREMENTS.values(), [])
 
         for dependency in OPTS:
             try:
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index f13a7c23b4..25589b0042 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -85,17 +85,17 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
             is received.
     """
 
-    NAME = abc.abstractproperty()  # type: str  # type: ignore
-    PATH_ARGS = abc.abstractproperty()  # type: Tuple[str, ...]  # type: ignore
+    NAME: str = abc.abstractproperty()  # type: ignore
+    PATH_ARGS: Tuple[str, ...] = abc.abstractproperty()  # type: ignore
     METHOD = "POST"
     CACHE = True
     RETRY_ON_TIMEOUT = True
 
     def __init__(self, hs: "HomeServer"):
         if self.CACHE:
-            self.response_cache = ResponseCache(
+            self.response_cache: ResponseCache[str] = ResponseCache(
                 hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000
-            )  # type: ResponseCache[str]
+            )
 
         # We reserve `instance_name` as a parameter to sending requests, so we
         # assert here that sub classes don't try and use the name.
@@ -232,7 +232,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
                 # have a good idea that the request has either succeeded or failed on
                 # the master, and so whether we should clean up or not.
                 while True:
-                    headers = {}  # type: Dict[bytes, List[bytes]]
+                    headers: Dict[bytes, List[bytes]] = {}
                     # Add an authorization header, if configured.
                     if replication_secret:
                         headers[b"Authorization"] = [b"Bearer " + replication_secret]
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index faa99387a7..e460dd85cd 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -27,7 +27,9 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
-            self._cache_id_gen = MultiWriterIdGenerator(
+            self._cache_id_gen: Optional[
+                MultiWriterIdGenerator
+            ] = MultiWriterIdGenerator(
                 db_conn,
                 database,
                 stream_name="caches",
@@ -41,7 +43,7 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
                 ],
                 sequence_name="cache_invalidation_stream_seq",
                 writers=[],
-            )  # type: Optional[MultiWriterIdGenerator]
+            )
         else:
             self._cache_id_gen = None
 
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 13ed87adc4..436d39c320 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -23,9 +23,9 @@ class SlavedClientIpStore(BaseSlavedStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
 
-        self.client_ip_last_seen = LruCache(
+        self.client_ip_last_seen: LruCache[tuple, int] = LruCache(
             cache_name="client_ip_last_seen", max_size=50000
-        )  # type: LruCache[tuple, int]
+        )
 
     async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
         now = int(self._clock.time_msec())
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 62d7809175..9d4859798b 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -121,13 +121,13 @@ class ReplicationDataHandler:
         self._pusher_pool = hs.get_pusherpool()
         self._presence_handler = hs.get_presence_handler()
 
-        self.send_handler = None  # type: Optional[FederationSenderHandler]
+        self.send_handler: Optional[FederationSenderHandler] = None
         if hs.should_send_federation():
             self.send_handler = FederationSenderHandler(hs)
 
         # Map from stream to list of deferreds waiting for the stream to
         # arrive at a particular position. The lists are sorted by stream position.
-        self._streams_to_waiters = {}  # type: Dict[str, List[Tuple[int, Deferred]]]
+        self._streams_to_waiters: Dict[str, List[Tuple[int, Deferred]]] = {}
 
     async def on_rdata(
         self, stream_name: str, instance_name: str, token: int, rows: list
@@ -173,7 +173,7 @@ class ReplicationDataHandler:
             if entities:
                 self.notifier.on_new_event("to_device_key", token, users=entities)
         elif stream_name == DeviceListsStream.NAME:
-            all_room_ids = set()  # type: Set[str]
+            all_room_ids: Set[str] = set()
             for row in rows:
                 if row.entity.startswith("@"):
                     room_ids = await self.store.get_rooms_for_user(row.entity)
@@ -201,7 +201,7 @@ class ReplicationDataHandler:
                 if row.data.rejected:
                     continue
 
-                extra_users = ()  # type: Tuple[UserID, ...]
+                extra_users: Tuple[UserID, ...] = ()
                 if row.data.type == EventTypes.Member and row.data.state_key:
                     extra_users = (UserID.from_string(row.data.state_key),)
 
@@ -348,7 +348,7 @@ class FederationSenderHandler:
 
         # Stores the latest position in the federation stream we've gotten up
         # to. This is always set before we use it.
-        self.federation_position = None  # type: Optional[int]
+        self.federation_position: Optional[int] = None
 
         self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
 
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 505d450e19..1311b013da 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -34,7 +34,7 @@ class Command(metaclass=abc.ABCMeta):
     A full command line on the wire is constructed from `NAME + " " + to_line()`
     """
 
-    NAME = None  # type: str
+    NAME: str
 
     @classmethod
     @abc.abstractmethod
@@ -380,7 +380,7 @@ class RemoteServerUpCommand(_SimpleCommand):
     NAME = "REMOTE_SERVER_UP"
 
 
-_COMMANDS = (
+_COMMANDS: Tuple[Type[Command], ...] = (
     ServerCommand,
     RdataCommand,
     PositionCommand,
@@ -393,7 +393,7 @@ _COMMANDS = (
     UserIpCommand,
     RemoteServerUpCommand,
     ClearUserSyncsCommand,
-)  # type: Tuple[Type[Command], ...]
+)
 
 # Map of command name to command type.
 COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS}
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 2ad7a200bb..eae4515363 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -105,12 +105,12 @@ class ReplicationCommandHandler:
             hs.get_instance_name() in hs.config.worker.writers.presence
         )
 
-        self._streams = {
+        self._streams: Dict[str, Stream] = {
             stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
-        }  # type: Dict[str, Stream]
+        }
 
         # List of streams that this instance is the source of
-        self._streams_to_replicate = []  # type: List[Stream]
+        self._streams_to_replicate: List[Stream] = []
 
         for stream in self._streams.values():
             if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME:
@@ -180,14 +180,14 @@ class ReplicationCommandHandler:
 
         # Map of stream name to batched updates. See RdataCommand for info on
         # how batching works.
-        self._pending_batches = {}  # type: Dict[str, List[Any]]
+        self._pending_batches: Dict[str, List[Any]] = {}
 
         # The factory used to create connections.
-        self._factory = None  # type: Optional[ReconnectingClientFactory]
+        self._factory: Optional[ReconnectingClientFactory] = None
 
         # The currently connected connections. (The list of places we need to send
         # outgoing replication commands to.)
-        self._connections = []  # type: List[IReplicationConnection]
+        self._connections: List[IReplicationConnection] = []
 
         LaterGauge(
             "synapse_replication_tcp_resource_total_connections",
@@ -200,7 +200,7 @@ class ReplicationCommandHandler:
         # them in order in a separate background process.
 
         # the streams which are currently being processed by _unsafe_process_queue
-        self._processing_streams = set()  # type: Set[str]
+        self._processing_streams: Set[str] = set()
 
         # for each stream, a queue of commands that are awaiting processing, and the
         # connection that they arrived on.
@@ -210,7 +210,7 @@ class ReplicationCommandHandler:
 
         # For each connection, the incoming stream names that have received a POSITION
         # from that connection.
-        self._streams_by_connection = {}  # type: Dict[IReplicationConnection, Set[str]]
+        self._streams_by_connection: Dict[IReplicationConnection, Set[str]] = {}
 
         LaterGauge(
             "synapse_replication_tcp_command_queue",
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 6e3705364f..8c80153ab6 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -102,7 +102,7 @@ tcp_outbound_commands_counter = Counter(
 
 # A list of all connected protocols. This allows us to send metrics about the
 # connections.
-connected_connections = []  # type: List[BaseReplicationStreamProtocol]
+connected_connections: "List[BaseReplicationStreamProtocol]" = []
 
 
 logger = logging.getLogger(__name__)
@@ -146,15 +146,15 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
     # The transport is going to be an ITCPTransport, but that doesn't have the
     # (un)registerProducer methods, those are only on the implementation.
-    transport = None  # type: Connection
+    transport: Connection
 
     delimiter = b"\n"
 
     # Valid commands we expect to receive
-    VALID_INBOUND_COMMANDS = []  # type: Collection[str]
+    VALID_INBOUND_COMMANDS: Collection[str] = []
 
     # Valid commands we can send
-    VALID_OUTBOUND_COMMANDS = []  # type: Collection[str]
+    VALID_OUTBOUND_COMMANDS: Collection[str] = []
 
     max_line_buffer = 10000
 
@@ -165,7 +165,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.last_received_command = self.clock.time_msec()
         self.last_sent_command = 0
         # When we requested the connection be closed
-        self.time_we_closed = None  # type: Optional[int]
+        self.time_we_closed: Optional[int] = None
 
         self.received_ping = False  # Have we received a ping from the other side
 
@@ -175,10 +175,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.conn_id = random_string(5)  # To dedupe in case of name clashes.
 
         # List of pending commands to send once we've established the connection
-        self.pending_commands = []  # type: List[Command]
+        self.pending_commands: List[Command] = []
 
         # The LoopingCall for sending pings.
-        self._send_ping_loop = None  # type: Optional[task.LoopingCall]
+        self._send_ping_loop: Optional[task.LoopingCall] = None
 
         # a logcontext which we use for processing incoming commands. We declare it as a
         # background process so that the CPU stats get reported to prometheus.
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 6a2c2655e4..8c0df627c8 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -57,7 +57,7 @@ class ConstantProperty(Generic[T, V]):
     it.
     """
 
-    constant = attr.ib()  # type: V
+    constant: V = attr.ib()
 
     def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V:
         return self.constant
@@ -91,9 +91,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
             commands.
     """
 
-    synapse_handler = None  # type: ReplicationCommandHandler
-    synapse_stream_name = None  # type: str
-    synapse_outbound_redis_connection = None  # type: txredisapi.RedisProtocol
+    synapse_handler: "ReplicationCommandHandler"
+    synapse_stream_name: str
+    synapse_outbound_redis_connection: txredisapi.RedisProtocol
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b03824925a..3716c41bea 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -85,9 +85,9 @@ class Stream:
     time it was called.
     """
 
-    NAME = None  # type: str  # The name of the stream
+    NAME: str  # The name of the stream
     # The type of the row. Used by the default impl of parse_row.
-    ROW_TYPE = None  # type: Any
+    ROW_TYPE: Any = None
 
     @classmethod
     def parse_row(cls, row: StreamRow):
@@ -283,9 +283,7 @@ class PresenceStream(Stream):
 
             assert isinstance(presence_handler, PresenceHandler)
 
-            update_function = (
-                presence_handler.get_all_presence_updates
-            )  # type: UpdateFunction
+            update_function: UpdateFunction = presence_handler.get_all_presence_updates
         else:
             # Query presence writer process
             update_function = make_http_update_function(hs, self.NAME)
@@ -334,9 +332,9 @@ class TypingStream(Stream):
         if writer_instance == hs.get_instance_name():
             # On the writer, query the typing handler
             typing_writer_handler = hs.get_typing_writer_handler()
-            update_function = (
-                typing_writer_handler.get_all_typing_updates
-            )  # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]]
+            update_function: Callable[
+                [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
+            ] = typing_writer_handler.get_all_typing_updates
             current_token_function = typing_writer_handler.get_current_token
         else:
             # Query the typing writer process
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index e7e87bac92..a030e9299e 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -65,7 +65,7 @@ class BaseEventsStreamRow:
     """
 
     # Unique string that ids the type. Must be overridden in sub classes.
-    TypeId = None  # type: str
+    TypeId: str
 
     @classmethod
     def from_data(cls, data):
@@ -103,10 +103,10 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
     event_id = attr.ib()  # str, optional
 
 
-_EventRows = (
+_EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
     EventsStreamEventRow,
     EventsStreamCurrentStateRow,
-)  # type: Tuple[Type[BaseEventsStreamRow], ...]
+)
 
 TypeToRow = {Row.TypeId: Row for Row in _EventRows}
 
@@ -157,9 +157,9 @@ class EventsStream(Stream):
 
         # now we fetch up to that many rows from the events table
 
-        event_rows = await self._store.get_all_new_forward_event_rows(
+        event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows(
             instance_name, from_token, current_token, target_row_count
-        )  # type: List[Tuple]
+        )
 
         # we rely on get_all_new_forward_event_rows strictly honouring the limit, so
         # that we know it is safe to just take upper_limit = event_rows[-1][0].
@@ -172,7 +172,7 @@ class EventsStream(Stream):
 
         if len(event_rows) == target_row_count:
             limited = True
-            upper_limit = event_rows[-1][0]  # type: int
+            upper_limit: int = event_rows[-1][0]
         else:
             limited = False
             upper_limit = current_token
@@ -191,30 +191,30 @@ class EventsStream(Stream):
         # finally, fetch the ex-outliers rows. We assume there are few enough of these
         # not to bother with the limit.
 
-        ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
+        ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows(
             instance_name, from_token, upper_limit
-        )  # type: List[Tuple]
+        )
 
         # we now need to turn the raw database rows returned into tuples suitable
         # for the replication protocol (basically, we add an identifier to
         # distinguish the row type). At the same time, we can limit the event_rows
         # to the max stream_id from state_rows.
 
-        event_updates = (
+        event_updates: Iterable[Tuple[int, Tuple]] = (
             (stream_id, (EventsStreamEventRow.TypeId, rest))
             for (stream_id, *rest) in event_rows
             if stream_id <= upper_limit
-        )  # type: Iterable[Tuple[int, Tuple]]
+        )
 
-        state_updates = (
+        state_updates: Iterable[Tuple[int, Tuple]] = (
             (stream_id, (EventsStreamCurrentStateRow.TypeId, rest))
             for (stream_id, *rest) in state_rows
-        )  # type: Iterable[Tuple[int, Tuple]]
+        )
 
-        ex_outliers_updates = (
+        ex_outliers_updates: Iterable[Tuple[int, Tuple]] = (
             (stream_id, (EventsStreamEventRow.TypeId, rest))
             for (stream_id, *rest) in ex_outliers_rows
-        )  # type: Iterable[Tuple[int, Tuple]]
+        )
 
         # we need to return a sorted list, so merge them together.
         updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates))
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 096a85d363..c445af9bd9 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -51,9 +51,9 @@ class FederationStream(Stream):
             current_token = current_token_without_instance(
                 federation_sender.get_current_token
             )
-            update_function = (
-                federation_sender.get_replication_rows
-            )  # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]]
+            update_function: Callable[
+                [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
+            ] = federation_sender.get_replication_rows
 
         elif hs.should_send_federation():
             # federation sender: Query master process
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index f0cddd2d2c..40ee33646c 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -402,9 +402,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
 
         # Get the room ID from the identifier.
         try:
-            remote_room_hosts = [
+            remote_room_hosts: Optional[List[str]] = [
                 x.decode("ascii") for x in request.args[b"server_name"]
-            ]  # type: Optional[List[str]]
+            ]
         except Exception:
             remote_room_hosts = None
         room_id, remote_room_hosts = await self.resolve_room_id(
@@ -462,6 +462,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
         super().__init__(hs)
         self.hs = hs
         self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
         self.event_creation_handler = hs.get_event_creation_handler()
         self.state_handler = hs.get_state_handler()
         self.is_mine_id = hs.is_mine_id
@@ -500,7 +501,13 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
             admin_user_id = None
 
             for admin_user in reversed(admin_users):
-                if room_state.get((EventTypes.Member, admin_user)):
+                (
+                    current_membership_type,
+                    _,
+                ) = await self.store.get_local_current_membership_for_user_in_room(
+                    admin_user, room_id
+                )
+                if current_membership_type == "join":
                     admin_user_id = admin_user
                     break
 
@@ -652,9 +659,7 @@ class RoomEventContextServlet(RestServlet):
         filter_str = parse_string(request, "filter", encoding="utf-8")
         if filter_str:
             filter_json = urlparse.unquote(filter_str)
-            event_filter = Filter(
-                json_decoder.decode(filter_json)
-            )  # type: Optional[Filter]
+            event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
         else:
             event_filter = None
 
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 7d75564758..589e47fa47 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -357,7 +357,7 @@ class UserRegisterServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         self.auth_handler = hs.get_auth_handler()
         self.reactor = hs.get_reactor()
-        self.nonces = {}  # type: Dict[str, int]
+        self.nonces: Dict[str, int] = {}
         self.hs = hs
 
     def _clear_old_nonces(self):
@@ -560,16 +560,24 @@ class AccountValidityRenewServlet(RestServlet):
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        body = parse_json_object_from_request(request)
+        if self.account_activity_handler.on_legacy_admin_request_callback:
+            expiration_ts = await (
+                self.account_activity_handler.on_legacy_admin_request_callback(request)
+            )
+        else:
+            body = parse_json_object_from_request(request)
 
-        if "user_id" not in body:
-            raise SynapseError(400, "Missing property 'user_id' in the request body")
+            if "user_id" not in body:
+                raise SynapseError(
+                    400,
+                    "Missing property 'user_id' in the request body",
+                )
 
-        expiration_ts = await self.account_activity_handler.renew_account_for_user(
-            body["user_id"],
-            body.get("expiration_ts"),
-            not body.get("enable_renewal_emails", True),
-        )
+            expiration_ts = await self.account_activity_handler.renew_account_for_user(
+                body["user_id"],
+                body.get("expiration_ts"),
+                not body.get("enable_renewal_emails", True),
+            )
 
         res = {"expiration_ts": expiration_ts}
         return 200, res
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index cbcb60fe31..11567bf32c 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -44,19 +44,14 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-LoginResponse = TypedDict(
-    "LoginResponse",
-    {
-        "user_id": str,
-        "access_token": str,
-        "home_server": str,
-        "expires_in_ms": Optional[int],
-        "refresh_token": Optional[str],
-        "device_id": str,
-        "well_known": Optional[Dict[str, Any]],
-    },
-    total=False,
-)
+class LoginResponse(TypedDict, total=False):
+    user_id: str
+    access_token: str
+    home_server: str
+    expires_in_ms: Optional[int]
+    refresh_token: Optional[str]
+    device_id: str
+    well_known: Optional[Dict[str, Any]]
 
 
 class LoginRestServlet(RestServlet):
@@ -121,7 +116,7 @@ class LoginRestServlet(RestServlet):
             flows.append({"type": LoginRestServlet.CAS_TYPE})
 
         if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
-            sso_flow = {
+            sso_flow: JsonDict = {
                 "type": LoginRestServlet.SSO_TYPE,
                 "identity_providers": [
                     _get_auth_flow_dict_for_idp(
@@ -129,7 +124,7 @@ class LoginRestServlet(RestServlet):
                     )
                     for idp in self._sso_handler.get_identity_providers().values()
                 ],
-            }  # type: JsonDict
+            }
 
             if self._msc2858_enabled:
                 # backwards-compatibility support for clients which don't
@@ -150,9 +145,7 @@ class LoginRestServlet(RestServlet):
             # login flow types returned.
             flows.append({"type": LoginRestServlet.TOKEN_TYPE})
 
-        flows.extend(
-            ({"type": t} for t in self.auth_handler.get_supported_login_types())
-        )
+        flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types())
 
         flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
 
@@ -447,7 +440,7 @@ def _get_auth_flow_dict_for_idp(
         use_unstable_brands: whether we should use brand identifiers suitable
            for the unstable API
     """
-    e = {"id": idp.idp_id, "name": idp.idp_name}  # type: JsonDict
+    e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name}
     if idp.idp_icon:
         e["icon"] = idp.idp_icon
     if idp.idp_brand:
@@ -561,7 +554,7 @@ class SsoRedirectServlet(RestServlet):
             finish_request(request)
             return
 
-        args = request.args  # type: Dict[bytes, List[bytes]]  # type: ignore
+        args: Dict[bytes, List[bytes]] = request.args  # type: ignore
         client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True)
         sso_url = await self._sso_handler.handle_redirect_request(
             request,
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 92ebe838fd..31a1193cd3 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -29,6 +29,7 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.api.filtering import Filter
+from synapse.appservice import ApplicationService
 from synapse.events.utils import format_event_for_client_v2
 from synapse.http.servlet import (
     RestServlet,
@@ -47,11 +48,13 @@ from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
 from synapse.types import (
     JsonDict,
+    Requester,
     RoomAlias,
     RoomID,
     StreamToken,
     ThirdPartyInstanceID,
     UserID,
+    create_requester,
 )
 from synapse.util import json_decoder
 from synapse.util.stringutils import parse_and_validate_server_name, random_string
@@ -309,7 +312,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
         self.room_member_handler = hs.get_room_member_handler()
         self.auth = hs.get_auth()
 
-    async def inherit_depth_from_prev_ids(self, prev_event_ids) -> int:
+    async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int:
         (
             most_recent_prev_event_id,
             most_recent_prev_event_depth,
@@ -349,6 +352,54 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
 
         return depth
 
+    def _create_insertion_event_dict(
+        self, sender: str, room_id: str, origin_server_ts: int
+    ):
+        """Creates an event dict for an "insertion" event with the proper fields
+        and a random chunk ID.
+
+        Args:
+            sender: The event author MXID
+            room_id: The room ID that the event belongs to
+            origin_server_ts: Timestamp when the event was sent
+
+        Returns:
+            Tuple of event ID and stream ordering position
+        """
+
+        next_chunk_id = random_string(8)
+        insertion_event = {
+            "type": EventTypes.MSC2716_INSERTION,
+            "sender": sender,
+            "room_id": room_id,
+            "content": {
+                EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id,
+                EventContentFields.MSC2716_HISTORICAL: True,
+            },
+            "origin_server_ts": origin_server_ts,
+        }
+
+        return insertion_event
+
+    async def _create_requester_for_user_id_from_app_service(
+        self, user_id: str, app_service: ApplicationService
+    ) -> Requester:
+        """Creates a new requester for the given user_id
+        and validates that the app service is allowed to control
+        the given user.
+
+        Args:
+            user_id: The author MXID that the app service is controlling
+            app_service: The app service that controls the user
+
+        Returns:
+            Requester object
+        """
+
+        await self.auth.validate_appservice_can_control_user_id(app_service, user_id)
+
+        return create_requester(user_id, app_service=app_service)
+
     async def on_POST(self, request, room_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=False)
 
@@ -414,7 +465,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
             if event_dict["type"] == EventTypes.Member:
                 membership = event_dict["content"].get("membership", None)
                 event_id, _ = await self.room_member_handler.update_membership(
-                    requester,
+                    await self._create_requester_for_user_id_from_app_service(
+                        state_event["sender"], requester.app_service
+                    ),
                     target=UserID.from_string(event_dict["state_key"]),
                     room_id=room_id,
                     action=membership,
@@ -434,7 +487,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
                     event,
                     _,
                 ) = await self.event_creation_handler.create_and_send_nonmember_event(
-                    requester,
+                    await self._create_requester_for_user_id_from_app_service(
+                        state_event["sender"], requester.app_service
+                    ),
                     event_dict,
                     outlier=True,
                     prev_event_ids=[fake_prev_event_id],
@@ -449,37 +504,73 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
 
         events_to_create = body["events"]
 
-        # If provided, connect the chunk to the last insertion point
-        # The chunk ID passed in comes from the chunk_id in the
-        # "insertion" event from the previous chunk.
+        prev_event_ids = prev_events_from_query
+        inherited_depth = await self._inherit_depth_from_prev_ids(
+            prev_events_from_query
+        )
+
+        # Figure out which chunk to connect to. If they passed in
+        # chunk_id_from_query let's use it. The chunk ID passed in comes
+        # from the chunk_id in the "insertion" event from the previous chunk.
+        last_event_in_chunk = events_to_create[-1]
+        chunk_id_to_connect_to = chunk_id_from_query
+        base_insertion_event = None
         if chunk_id_from_query:
-            last_event_in_chunk = events_to_create[-1]
-            last_event_in_chunk["content"][
-                EventContentFields.MSC2716_CHUNK_ID
-            ] = chunk_id_from_query
+            # TODO: Verify the chunk_id_from_query corresponds to an insertion event
+            pass
+        # Otherwise, create an insertion event to act as a starting point.
+        #
+        # We don't always have an insertion event to start hanging more history
+        # off of (ideally there would be one in the main DAG, but that's not the
+        # case if we're wanting to add history to e.g. existing rooms without
+        # an insertion event), in which case we just create a new insertion event
+        # that can then get pointed to by a "marker" event later.
+        else:
+            base_insertion_event_dict = self._create_insertion_event_dict(
+                sender=requester.user.to_string(),
+                room_id=room_id,
+                origin_server_ts=last_event_in_chunk["origin_server_ts"],
+            )
+            base_insertion_event_dict["prev_events"] = prev_event_ids.copy()
+
+            (
+                base_insertion_event,
+                _,
+            ) = await self.event_creation_handler.create_and_send_nonmember_event(
+                await self._create_requester_for_user_id_from_app_service(
+                    base_insertion_event_dict["sender"],
+                    requester.app_service,
+                ),
+                base_insertion_event_dict,
+                prev_event_ids=base_insertion_event_dict.get("prev_events"),
+                auth_event_ids=auth_event_ids,
+                historical=True,
+                depth=inherited_depth,
+            )
+
+            chunk_id_to_connect_to = base_insertion_event["content"][
+                EventContentFields.MSC2716_NEXT_CHUNK_ID
+            ]
 
-        # Add an "insertion" event to the start of each chunk (next to the oldest
+        # Connect this current chunk to the insertion event from the previous chunk
+        last_event_in_chunk["content"][
+            EventContentFields.MSC2716_CHUNK_ID
+        ] = chunk_id_to_connect_to
+
+        # Add an "insertion" event to the start of each chunk (next to the oldest-in-time
         # event in the chunk) so the next chunk can be connected to this one.
-        next_chunk_id = random_string(64)
-        insertion_event = {
-            "type": EventTypes.MSC2716_INSERTION,
-            "sender": requester.user.to_string(),
-            "content": {
-                EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id,
-                EventContentFields.MSC2716_HISTORICAL: True,
-            },
+        insertion_event = self._create_insertion_event_dict(
+            sender=requester.user.to_string(),
+            room_id=room_id,
             # Since the insertion event is put at the start of the chunk,
-            # where the oldest event is, copy the origin_server_ts from
+            # where the oldest-in-time event is, copy the origin_server_ts from
             # the first event we're inserting
-            "origin_server_ts": events_to_create[0]["origin_server_ts"],
-        }
+            origin_server_ts=events_to_create[0]["origin_server_ts"],
+        )
         # Prepend the insertion event to the start of the chunk
         events_to_create = [insertion_event] + events_to_create
 
-        inherited_depth = await self.inherit_depth_from_prev_ids(prev_events_from_query)
-
         event_ids = []
-        prev_event_ids = prev_events_from_query
         events_to_persist = []
         for ev in events_to_create:
             assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"])
@@ -498,7 +589,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
             }
 
             event, context = await self.event_creation_handler.create_event(
-                requester,
+                await self._create_requester_for_user_id_from_app_service(
+                    ev["sender"], requester.app_service
+                ),
                 event_dict,
                 prev_event_ids=event_dict.get("prev_events"),
                 auth_event_ids=auth_event_ids,
@@ -528,15 +621,23 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
         # where topological_ordering is just depth.
         for (event, context) in reversed(events_to_persist):
             ev = await self.event_creation_handler.handle_new_client_event(
-                requester=requester,
+                await self._create_requester_for_user_id_from_app_service(
+                    event["sender"], requester.app_service
+                ),
                 event=event,
                 context=context,
             )
 
+        # Add the base_insertion_event to the bottom of the list we return
+        if base_insertion_event is not None:
+            event_ids.append(base_insertion_event.event_id)
+
         return 200, {
             "state_events": auth_event_ids,
             "events": event_ids,
-            "next_chunk_id": next_chunk_id,
+            "next_chunk_id": insertion_event["content"][
+                EventContentFields.MSC2716_NEXT_CHUNK_ID
+            ],
         }
 
     def on_GET(self, request, room_id):
@@ -682,7 +783,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
         server = parse_string(request, "server", default=None)
         content = parse_json_object_from_request(request)
 
-        limit = int(content.get("limit", 100))  # type: Optional[int]
+        limit: Optional[int] = int(content.get("limit", 100))
         since_token = content.get("since", None)
         search_filter = content.get("filter", None)
 
@@ -828,9 +929,7 @@ class RoomMessageListRestServlet(RestServlet):
         filter_str = parse_string(request, "filter", encoding="utf-8")
         if filter_str:
             filter_json = urlparse.unquote(filter_str)
-            event_filter = Filter(
-                json_decoder.decode(filter_json)
-            )  # type: Optional[Filter]
+            event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
             if (
                 event_filter
                 and event_filter.filter_json.get("event_format", "client")
@@ -943,9 +1042,7 @@ class RoomEventContextServlet(RestServlet):
         filter_str = parse_string(request, "filter", encoding="utf-8")
         if filter_str:
             filter_json = urlparse.unquote(filter_str)
-            event_filter = Filter(
-                json_decoder.decode(filter_json)
-            )  # type: Optional[Filter]
+            event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
         else:
             event_filter = None
 
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index 2d1ad3d3fb..3ebe401861 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -14,7 +14,7 @@
 
 import logging
 
-from synapse.api.errors import AuthError, SynapseError
+from synapse.api.errors import SynapseError
 from synapse.http.server import respond_with_html
 from synapse.http.servlet import RestServlet
 
@@ -92,11 +92,6 @@ class AccountValiditySendMailServlet(RestServlet):
         )
 
     async def on_POST(self, request):
-        if not self.account_validity_renew_by_email_enabled:
-            raise AuthError(
-                403, "Account renewal via email is disabled on this server."
-            )
-
         requester = await self.auth.get_user_by_req(request, allow_expired=True)
         user_id = requester.user.to_string()
         await self.account_activity_handler.send_renewal_email_to_user(user_id)
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index f8dcee603c..d537d811d8 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -59,7 +59,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
             requester, message_type, content["messages"]
         )
 
-        response = (200, {})  # type: Tuple[int, dict]
+        response: Tuple[int, dict] = (200, {})
         return response
 
 
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index e52570cd8e..4282e2b228 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -117,7 +117,7 @@ class ConsentResource(DirectServeHtmlResource):
         has_consented = False
         public_version = username == ""
         if not public_version:
-            args = request.args  # type: Dict[bytes, List[bytes]]
+            args: Dict[bytes, List[bytes]] = request.args
             userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
 
             self._check_hash(username, userhmac_bytes)
@@ -154,7 +154,7 @@ class ConsentResource(DirectServeHtmlResource):
         """
         version = parse_string(request, "v", required=True)
         username = parse_string(request, "u", required=True)
-        args = request.args  # type: Dict[bytes, List[bytes]]
+        args: Dict[bytes, List[bytes]] = request.args
         userhmac = parse_bytes_from_args(args, "h", required=True)
 
         self._check_hash(username, userhmac)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index d56a1ae482..63a40b1852 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -97,7 +97,7 @@ class RemoteKey(DirectServeJsonResource):
     async def _async_render_GET(self, request):
         if len(request.postpath) == 1:
             (server,) = request.postpath
-            query = {server.decode("ascii"): {}}  # type: dict
+            query: dict = {server.decode("ascii"): {}}
         elif len(request.postpath) == 2:
             server, key_id = request.postpath
             minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
@@ -141,7 +141,7 @@ class RemoteKey(DirectServeJsonResource):
         time_now_ms = self.clock.time_msec()
 
         # Note that the value is unused.
-        cache_misses = {}  # type: Dict[str, Dict[str, int]]
+        cache_misses: Dict[str, Dict[str, int]] = {}
         for (server_name, key_id, _), results in cached.items():
             results = [(result["ts_added_ms"], result) for result in results]
 
diff --git a/synapse/rest/media/v1/__init__.py b/synapse/rest/media/v1/__init__.py
index d20186bbd0..3dd16d4bb5 100644
--- a/synapse/rest/media/v1/__init__.py
+++ b/synapse/rest/media/v1/__init__.py
@@ -17,7 +17,7 @@ import PIL.Image
 # check for JPEG support.
 try:
     PIL.Image._getdecoder("rgb", "jpeg", None)
-except IOError as e:
+except OSError as e:
     if str(e).startswith("decoder jpeg not available"):
         raise Exception(
             "FATAL: jpeg codec not supported. Install pillow correctly! "
@@ -32,7 +32,7 @@ except Exception:
 # check for PNG support.
 try:
     PIL.Image._getdecoder("rgb", "zip", None)
-except IOError as e:
+except OSError as e:
     if str(e).startswith("decoder zip not available"):
         raise Exception(
             "FATAL: zip codec not supported. Install pillow correctly! "
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 0fb4cd81f1..90364ebcf7 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -49,7 +49,7 @@ TEXT_CONTENT_TYPES = [
 def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
     try:
         # The type on postpath seems incorrect in Twisted 21.2.0.
-        postpath = request.postpath  # type: List[bytes]  # type: ignore
+        postpath: List[bytes] = request.postpath  # type: ignore
         assert postpath
 
         # This allows users to append e.g. /test.png to the URL. Useful for
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 21c43c340c..4f702f890c 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -78,16 +78,16 @@ class MediaRepository:
 
         Thumbnailer.set_limits(self.max_image_pixels)
 
-        self.primary_base_path = hs.config.media_store_path  # type: str
-        self.filepaths = MediaFilePaths(self.primary_base_path)  # type: MediaFilePaths
+        self.primary_base_path: str = hs.config.media_store_path
+        self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path)
 
         self.dynamic_thumbnails = hs.config.dynamic_thumbnails
         self.thumbnail_requirements = hs.config.thumbnail_requirements
 
         self.remote_media_linearizer = Linearizer(name="media_remote")
 
-        self.recently_accessed_remotes = set()  # type: Set[Tuple[str, str]]
-        self.recently_accessed_locals = set()  # type: Set[str]
+        self.recently_accessed_remotes: Set[Tuple[str, str]] = set()
+        self.recently_accessed_locals: Set[str] = set()
 
         self.federation_domain_whitelist = hs.config.federation_domain_whitelist
 
@@ -711,7 +711,7 @@ class MediaRepository:
 
         # We deduplicate the thumbnail sizes by ignoring the cropped versions if
         # they have the same dimensions of a scaled one.
-        thumbnails = {}  # type: Dict[Tuple[int, int, str], str]
+        thumbnails: Dict[Tuple[int, int, str], str] = {}
         for r_width, r_height, r_method, r_type in requirements:
             if r_method == "crop":
                 thumbnails.setdefault((r_width, r_height, r_type), r_method)
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index c7fd97c46c..56cdc1b4ed 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -191,7 +191,7 @@ class MediaStorage:
 
         for provider in self.storage_providers:
             for path in paths:
-                res = await provider.fetch(path, file_info)  # type: Any
+                res: Any = await provider.fetch(path, file_info)
                 if res:
                     logger.debug("Streaming %s from %s", path, provider)
                     return res
@@ -233,7 +233,7 @@ class MediaStorage:
             os.makedirs(dirname)
 
         for provider in self.storage_providers:
-            res = await provider.fetch(path, file_info)  # type: Any
+            res: Any = await provider.fetch(path, file_info)
             if res:
                 with res:
                     consumer = BackgroundFileConsumer(
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 0adfb1a70f..8e7fead3a2 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -169,12 +169,12 @@ class PreviewUrlResource(DirectServeJsonResource):
 
         # memory cache mapping urls to an ObservableDeferred returning
         # JSON-encoded OG metadata
-        self._cache = ExpiringCache(
+        self._cache: ExpiringCache[str, ObservableDeferred] = ExpiringCache(
             cache_name="url_previews",
             clock=self.clock,
             # don't spider URLs more often than once an hour
             expiry_ms=ONE_HOUR,
-        )  # type: ExpiringCache[str, ObservableDeferred]
+        )
 
         if self._worker_run_media_background_jobs:
             self._cleaner_loop = self.clock.looping_call(
@@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
 
         # If this URL can be accessed via oEmbed, use that instead.
-        url_to_download = url  # type: Optional[str]
+        url_to_download: Optional[str] = url
         oembed_url = self._get_oembed_url(url)
         if oembed_url:
             # The result might be a new URL to download, or it might be HTML content.
@@ -788,7 +788,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
     # "og:video:height" : "720",
     # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
 
-    og = {}  # type: Dict[str, Optional[str]]
+    og: Dict[str, Optional[str]] = {}
     for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
         if "content" in tag.attrib:
             # if we've got more than 50 tags, someone is taking the piss
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 62dc4aae2d..146adca8f1 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -61,11 +61,11 @@ class UploadResource(DirectServeJsonResource):
                 errcode=Codes.TOO_LARGE,
             )
 
-        args = request.args  # type: Dict[bytes, List[bytes]]  # type: ignore
+        args: Dict[bytes, List[bytes]] = request.args  # type: ignore
         upload_name_bytes = parse_bytes_from_args(args, "filename")
         if upload_name_bytes:
             try:
-                upload_name = upload_name_bytes.decode("utf8")  # type: Optional[str]
+                upload_name: Optional[str] = upload_name_bytes.decode("utf8")
             except UnicodeDecodeError:
                 raise SynapseError(
                     msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400
@@ -89,7 +89,7 @@ class UploadResource(DirectServeJsonResource):
         # TODO(markjh): parse content-dispostion
 
         try:
-            content = request.content  # type: IO  # type: ignore
+            content: IO = request.content  # type: ignore
             content_uri = await self.media_repo.create_content(
                 media_type, upload_name, content, content_length, requester.user
             )
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
index 9b002cc15e..ab24ec0a8e 100644
--- a/synapse/rest/synapse/client/pick_username.py
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -118,9 +118,9 @@ class AccountDetailsResource(DirectServeHtmlResource):
             use_display_name = parse_boolean(request, "use_display_name", default=False)
 
             try:
-                emails_to_use = [
+                emails_to_use: List[str] = [
                     val.decode("utf-8") for val in request.args.get(b"use_email", [])
-                ]  # type: List[str]
+                ]
             except ValueError:
                 raise SynapseError(400, "Query parameter use_email must be utf-8")
         except SynapseError as e:
diff --git a/synapse/server.py b/synapse/server.py
index 2c27d2a7e8..095dba9ad0 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -247,15 +247,15 @@ class HomeServer(metaclass=abc.ABCMeta):
         # the key we use to sign events and requests
         self.signing_key = config.key.signing_key[0]
         self.config = config
-        self._listening_services = []  # type: List[twisted.internet.tcp.Port]
-        self.start_time = None  # type: Optional[int]
+        self._listening_services: List[twisted.internet.tcp.Port] = []
+        self.start_time: Optional[int] = None
 
         self._instance_id = random_string(5)
         self._instance_name = config.worker.instance_name
 
         self.version_string = version_string
 
-        self.datastores = None  # type: Optional[Databases]
+        self.datastores: Optional[Databases] = None
 
         self._module_web_resources: Dict[str, IResource] = {}
         self._module_web_resources_consumed = False
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index e65f6f88fe..4e0f814035 100644
--- a/synapse/server_notices/consent_server_notices.py
+++ b/synapse/server_notices/consent_server_notices.py
@@ -34,7 +34,7 @@ class ConsentServerNotices:
         self._server_notices_manager = hs.get_server_notices_manager()
         self._store = hs.get_datastore()
 
-        self._users_in_progress = set()  # type: Set[str]
+        self._users_in_progress: Set[str] = set()
 
         self._current_consent_version = hs.config.user_consent_version
         self._server_notice_content = hs.config.user_consent_server_notice_content
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index e4b0bc5c72..073b0d754f 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -205,7 +205,7 @@ class ResourceLimitsServerNotices:
             # The user has yet to join the server notices room
             pass
 
-        referenced_events = []  # type: List[str]
+        referenced_events: List[str] = []
         if pinned_state_event is not None:
             referenced_events = list(pinned_state_event.content.get("pinned", []))
 
diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py
index c875b15b32..cdf0973d05 100644
--- a/synapse/server_notices/server_notices_sender.py
+++ b/synapse/server_notices/server_notices_sender.py
@@ -32,10 +32,12 @@ class ServerNoticesSender(WorkerServerNoticesSender):
 
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
-        self._server_notices = (
+        self._server_notices: Iterable[
+            Union[ConsentServerNotices, ResourceLimitsServerNotices]
+        ] = (
             ConsentServerNotices(hs),
             ResourceLimitsServerNotices(hs),
-        )  # type: Iterable[Union[ConsentServerNotices, ResourceLimitsServerNotices]]
+        )
 
     async def on_user_syncing(self, user_id: str) -> None:
         """Called when the user performs a sync operation.
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index a1770f620e..6223daf522 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -309,9 +309,9 @@ class StateHandler:
 
         if old_state:
             # if we're given the state before the event, then we use that
-            state_ids_before_event = {
+            state_ids_before_event: StateMap[str] = {
                 (s.type, s.state_key): s.event_id for s in old_state
-            }  # type: StateMap[str]
+            }
             state_group_before_event = None
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
@@ -513,23 +513,25 @@ class StateResolutionHandler:
         self.resolve_linearizer = Linearizer(name="state_resolve_lock")
 
         # dict of set of event_ids -> _StateCacheEntry.
-        self._state_cache = ExpiringCache(
+        self._state_cache: ExpiringCache[
+            FrozenSet[int], _StateCacheEntry
+        ] = ExpiringCache(
             cache_name="state_cache",
             clock=self.clock,
             max_len=100000,
             expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
             iterable=True,
             reset_expiry_on_get=True,
-        )  # type: ExpiringCache[FrozenSet[int], _StateCacheEntry]
+        )
 
         #
         # stuff for tracking time spent on state-res by room
         #
 
         # tracks the amount of work done on state res per room
-        self._state_res_metrics = defaultdict(
+        self._state_res_metrics: DefaultDict[str, _StateResMetrics] = defaultdict(
             _StateResMetrics
-        )  # type: DefaultDict[str, _StateResMetrics]
+        )
 
         self.clock.looping_call(self._report_metrics, 120 * 1000)
 
@@ -700,9 +702,9 @@ class StateResolutionHandler:
         items = self._state_res_metrics.items()
 
         # log the N biggest rooms
-        biggest = heapq.nlargest(
+        biggest: List[Tuple[str, _StateResMetrics]] = heapq.nlargest(
             n_to_log, items, key=lambda i: extract_key(i[1])
-        )  # type: List[Tuple[str, _StateResMetrics]]
+        )
         metrics_logger.debug(
             "%i biggest rooms for state-res by %s: %s",
             len(biggest),
@@ -754,7 +756,7 @@ def _make_state_cache_entry(
 
     # failing that, look for the closest match.
     prev_group = None
-    delta_ids = None  # type: Optional[StateMap[str]]
+    delta_ids: Optional[StateMap[str]] = None
 
     for old_group, old_state in state_groups_ids.items():
         n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 318e998813..267193cedf 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -159,7 +159,7 @@ def _seperate(
     """
     state_set_iterator = iter(state_sets)
     unconflicted_state = dict(next(state_set_iterator))
-    conflicted_state = {}  # type: MutableStateMap[Set[str]]
+    conflicted_state: MutableStateMap[Set[str]] = {}
 
     for state_set in state_set_iterator:
         for key, value in state_set.items():
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 008644cd98..e66e6571c8 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -276,7 +276,7 @@ async def _get_auth_chain_difference(
     # event IDs if they appear in the `event_map`. This is the intersection of
     # the event's auth chain with the events in the `event_map` *plus* their
     # auth event IDs.
-    events_to_auth_chain = {}  # type: Dict[str, Set[str]]
+    events_to_auth_chain: Dict[str, Set[str]] = {}
     for event in event_map.values():
         chain = {event.event_id}
         events_to_auth_chain[event.event_id] = chain
@@ -301,17 +301,17 @@ async def _get_auth_chain_difference(
         # ((type, state_key)->event_id) mappings; and (b) we have stripped out
         # unpersisted events and replaced them with the persisted events in
         # their auth chain.
-        state_sets_ids = []  # type: List[Set[str]]
+        state_sets_ids: List[Set[str]] = []
 
         # For each state set, the unpersisted event IDs reachable (by their auth
         # chain) from the events in that set.
-        unpersisted_set_ids = []  # type: List[Set[str]]
+        unpersisted_set_ids: List[Set[str]] = []
 
         for state_set in state_sets:
-            set_ids = set()  # type: Set[str]
+            set_ids: Set[str] = set()
             state_sets_ids.append(set_ids)
 
-            unpersisted_ids = set()  # type: Set[str]
+            unpersisted_ids: Set[str] = set()
             unpersisted_set_ids.append(unpersisted_ids)
 
             for event_id in state_set.values():
@@ -334,7 +334,7 @@ async def _get_auth_chain_difference(
         union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
         intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
 
-        difference_from_event_map = union - intersection  # type: Collection[str]
+        difference_from_event_map: Collection[str] = union - intersection
     else:
         difference_from_event_map = ()
         state_sets_ids = [set(state_set.values()) for state_set in state_sets]
@@ -458,7 +458,7 @@ async def _reverse_topological_power_sort(
         The sorted list
     """
 
-    graph = {}  # type: Dict[str, Set[str]]
+    graph: Dict[str, Set[str]] = {}
     for idx, event_id in enumerate(event_ids, start=1):
         await _add_event_and_auth_chain_to_graph(
             graph, room_id, event_id, event_map, state_res_store, auth_diff
@@ -657,7 +657,7 @@ async def _get_mainline_depth_for_event(
     """
 
     room_id = event.room_id
-    tmp_event = event  # type: Optional[EventBase]
+    tmp_event: Optional[EventBase] = event
 
     # We do an iterative search, replacing `event with the power level in its
     # auth events (if any)
@@ -767,7 +767,7 @@ def lexicographical_topological_sort(
     # outgoing edges, c.f.
     # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
     outdegree_map = graph
-    reverse_graph = {}  # type: Dict[str, Set[str]]
+    reverse_graph: Dict[str, Set[str]] = {}
 
     # Lists of nodes with zero out degree. Is actually a tuple of
     # `(key(node), node)` so that sorting does the right thing
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 142787fdfd..82b31d24f1 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -92,14 +92,12 @@ class BackgroundUpdater:
         self.db_pool = database
 
         # if a background update is currently running, its name.
-        self._current_background_update = None  # type: Optional[str]
-
-        self._background_update_performance = (
-            {}
-        )  # type: Dict[str, BackgroundUpdatePerformance]
-        self._background_update_handlers = (
-            {}
-        )  # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]]
+        self._current_background_update: Optional[str] = None
+
+        self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
+        self._background_update_handlers: Dict[
+            str, Callable[[JsonDict, int], Awaitable[int]]
+        ] = {}
         self._all_done = False
 
     def start_doing_background_updates(self) -> None:
@@ -411,7 +409,7 @@ class BackgroundUpdater:
             c.execute(sql)
 
         if isinstance(self.db_pool.engine, engines.PostgresEngine):
-            runner = create_index_psql  # type: Optional[Callable[[Connection], None]]
+            runner: Optional[Callable[[Connection], None]] = create_index_psql
         elif psql_only:
             runner = None
         else:
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 33c42cf95a..ccf9ac51ef 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -670,8 +670,8 @@ class DatabasePool:
         Returns:
             The result of func
         """
-        after_callbacks = []  # type: List[_CallbackListEntry]
-        exception_callbacks = []  # type: List[_CallbackListEntry]
+        after_callbacks: List[_CallbackListEntry] = []
+        exception_callbacks: List[_CallbackListEntry] = []
 
         if not current_context():
             logger.warning("Starting db txn '%s' from sentinel context", desc)
@@ -907,7 +907,7 @@ class DatabasePool:
         # The sort is to ensure that we don't rely on dictionary iteration
         # order.
         keys, vals = zip(
-            *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
+            *(zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i)
         )
 
         for k in keys:
@@ -1090,7 +1090,7 @@ class DatabasePool:
                 return False
 
         # We didn't find any existing rows, so insert a new one
-        allvalues = {}  # type: Dict[str, Any]
+        allvalues: Dict[str, Any] = {}
         allvalues.update(keyvalues)
         allvalues.update(values)
         allvalues.update(insertion_values)
@@ -1121,7 +1121,7 @@ class DatabasePool:
             values: The nonunique columns and their new values
             insertion_values: additional key/values to use only when inserting
         """
-        allvalues = {}  # type: Dict[str, Any]
+        allvalues: Dict[str, Any] = {}
         allvalues.update(keyvalues)
         allvalues.update(insertion_values or {})
 
@@ -1257,7 +1257,7 @@ class DatabasePool:
             value_values: A list of each row's value column values.
                 Ignored if value_names is empty.
         """
-        allnames = []  # type: List[str]
+        allnames: List[str] = []
         allnames.extend(key_names)
         allnames.extend(value_names)
 
@@ -1566,7 +1566,7 @@ class DatabasePool:
         """
         keyvalues = keyvalues or {}
 
-        results = []  # type: List[Dict[str, Any]]
+        results: List[Dict[str, Any]] = []
 
         if not iterable:
             return results
@@ -1978,7 +1978,7 @@ class DatabasePool:
             raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
 
         where_clause = "WHERE " if filters or keyvalues or exclude_keyvalues else ""
-        arg_list = []  # type: List[Any]
+        arg_list: List[Any] = []
         if filters:
             where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
             arg_list += list(filters.values())
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 9f182c2a89..e2d1b758bd 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -48,9 +48,7 @@ def _make_exclusive_regex(
     ]
     if exclusive_user_regexes:
         exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
-        exclusive_user_pattern = re.compile(
-            exclusive_user_regex
-        )  # type: Optional[Pattern]
+        exclusive_user_pattern: Optional[Pattern] = re.compile(exclusive_user_regex)
     else:
         # We handle this case specially otherwise the constructed regex
         # will always match
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 50e7ddd735..c55508867d 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -203,9 +203,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             "delete_messages_for_device", delete_messages_for_device_txn
         )
 
-        log_kv(
-            {"message": "deleted {} messages for device".format(count), "count": count}
-        )
+        log_kv({"message": f"deleted {count} messages for device", "count": count})
 
         # Update the cache, ensuring that we only ever increase the value
         last_deleted_stream_id = self._last_device_delete_cache.get(
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 0e3dd4e9ca..1edc96042b 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -21,6 +21,7 @@ from canonicaljson import encode_canonical_json
 
 from twisted.enterprise.adbapi import Connection
 
+from synapse.api.constants import DeviceKeyAlgorithms
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool, make_in_list_sql_clause
@@ -247,7 +248,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
         txn.execute(sql, query_params)
 
-        result = {}  # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
+        result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
         for (user_id, device_id, display_name, key_json) in txn:
             if include_deleted_devices:
                 deleted_devices.remove((user_id, device_id))
@@ -381,9 +382,15 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
                 " GROUP BY algorithm"
             )
             txn.execute(sql, (user_id, device_id))
-            result = {}
+
+            # Initially set the key count to 0. This ensures that the client will always
+            # receive *some count*, even if it's 0.
+            result = {DeviceKeyAlgorithms.SIGNED_CURVE25519: 0}
+
+            # Override entries with the count of any keys we pulled from the database
             for algorithm, key_count in txn:
                 result[algorithm] = key_count
+
             return result
 
         return await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index c4474df975..d39368c20e 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -62,9 +62,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             )
 
         # Cache of event ID to list of auth event IDs and their depths.
-        self._event_auth_cache = LruCache(
+        self._event_auth_cache: LruCache[str, List[Tuple[str, int]]] = LruCache(
             500000, "_event_auth_cache", size_callback=len
-        )  # type: LruCache[str, List[Tuple[str, int]]]
+        )
 
         self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
 
@@ -137,10 +137,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         initial_events = set(event_ids)
 
         # All the events that we've found that are reachable from the events.
-        seen_events = set()  # type: Set[str]
+        seen_events: Set[str] = set()
 
         # A map from chain ID to max sequence number of the given events.
-        event_chains = {}  # type: Dict[int, int]
+        event_chains: Dict[int, int] = {}
 
         sql = """
             SELECT event_id, chain_id, sequence_number
@@ -182,7 +182,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         """
 
         # A map from chain ID to max sequence number *reachable* from any event ID.
-        chains = {}  # type: Dict[int, int]
+        chains: Dict[int, int] = {}
 
         # Add all linked chains reachable from initial set of chains.
         for batch in batch_iter(event_chains, 1000):
@@ -353,14 +353,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         initial_events = set(state_sets[0]).union(*state_sets[1:])
 
         # Map from event_id -> (chain ID, seq no)
-        chain_info = {}  # type: Dict[str, Tuple[int, int]]
+        chain_info: Dict[str, Tuple[int, int]] = {}
 
         # Map from chain ID -> seq no -> event Id
-        chain_to_event = {}  # type: Dict[int, Dict[int, str]]
+        chain_to_event: Dict[int, Dict[int, str]] = {}
 
         # All the chains that we've found that are reachable from the state
         # sets.
-        seen_chains = set()  # type: Set[int]
+        seen_chains: Set[int] = set()
 
         sql = """
             SELECT event_id, chain_id, sequence_number
@@ -392,9 +392,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         # Corresponds to `state_sets`, except as a map from chain ID to max
         # sequence number reachable from the state set.
-        set_to_chain = []  # type: List[Dict[int, int]]
+        set_to_chain: List[Dict[int, int]] = []
         for state_set in state_sets:
-            chains = {}  # type: Dict[int, int]
+            chains: Dict[int, int] = {}
             set_to_chain.append(chains)
 
             for event_id in state_set:
@@ -446,7 +446,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         # Mapping from chain ID to the range of sequence numbers that should be
         # pulled from the database.
-        chain_to_gap = {}  # type: Dict[int, Tuple[int, int]]
+        chain_to_gap: Dict[int, Tuple[int, int]] = {}
 
         for chain_id in seen_chains:
             min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
@@ -555,7 +555,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         }
 
         # The sorted list of events whose auth chains we should walk.
-        search = []  # type: List[Tuple[int, str]]
+        search: List[Tuple[int, str]] = []
 
         # We need to get the depth of the initial events for sorting purposes.
         sql = """
@@ -578,7 +578,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         search.sort()
 
         # Map from event to its auth events
-        event_to_auth_events = {}  # type: Dict[str, Set[str]]
+        event_to_auth_events: Dict[str, Set[str]] = {}
 
         base_sql = """
             SELECT a.event_id, auth_id, depth
@@ -1230,7 +1230,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
                 "SELECT coalesce(min(received_ts), 0) FROM federation_inbound_events_staging"
             )
 
-            (age,) = txn.fetchone()
+            (received_ts,) = txn.fetchone()
+
+            age = self._clock.time_msec() - received_ts
 
             return count, age
 
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index d1237c65cc..55caa6bbe7 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -759,7 +759,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # object because we might not have the same amount of rows in each of them. To do
         # this, we use a dict indexed on the user ID and room ID to make it easier to
         # populate.
-        summaries = {}  # type: Dict[Tuple[str, str], _EventPushSummary]
+        summaries: Dict[Tuple[str, str], _EventPushSummary] = {}
         for row in txn:
             summaries[(row[0], row[1])] = _EventPushSummary(
                 unread_count=row[2],
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 897fa06639..a396a201d4 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -109,10 +109,8 @@ class PersistEventsStore:
 
         # Ideally we'd move these ID gens here, unfortunately some other ID
         # generators are chained off them so doing so is a bit of a PITA.
-        self._backfill_id_gen = (
-            self.store._backfill_id_gen
-        )  # type: MultiWriterIdGenerator
-        self._stream_id_gen = self.store._stream_id_gen  # type: MultiWriterIdGenerator
+        self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
+        self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
 
         # This should only exist on instances that are configured to write
         assert (
@@ -221,7 +219,7 @@ class PersistEventsStore:
         Returns:
             Filtered event ids
         """
-        results = []  # type: List[str]
+        results: List[str] = []
 
         def _get_events_which_are_prevs_txn(txn, batch):
             sql = """
@@ -508,7 +506,7 @@ class PersistEventsStore:
         """
 
         # Map from event ID to chain ID/sequence number.
-        chain_map = {}  # type: Dict[str, Tuple[int, int]]
+        chain_map: Dict[str, Tuple[int, int]] = {}
 
         # Set of event IDs to calculate chain ID/seq numbers for.
         events_to_calc_chain_id_for = set(event_to_room_id)
@@ -817,8 +815,8 @@ class PersistEventsStore:
         #      new chain if the sequence number has already been allocated.
         #
 
-        existing_chains = set()  # type: Set[int]
-        tree = []  # type: List[Tuple[str, Optional[str]]]
+        existing_chains: Set[int] = set()
+        tree: List[Tuple[str, Optional[str]]] = []
 
         # We need to do this in a topologically sorted order as we want to
         # generate chain IDs/sequence numbers of an event's auth events before
@@ -848,7 +846,7 @@ class PersistEventsStore:
         )
         txn.execute(sql % (clause,), args)
 
-        chain_to_max_seq_no = {row[0]: row[1] for row in txn}  # type: Dict[Any, int]
+        chain_to_max_seq_no: Dict[Any, int] = {row[0]: row[1] for row in txn}
 
         # Allocate the new events chain ID/sequence numbers.
         #
@@ -858,8 +856,8 @@ class PersistEventsStore:
         # number of new chain IDs in one call, replacing all temporary
         # objects with real allocated chain IDs.
 
-        unallocated_chain_ids = set()  # type: Set[object]
-        new_chain_tuples = {}  # type: Dict[str, Tuple[Any, int]]
+        unallocated_chain_ids: Set[object] = set()
+        new_chain_tuples: Dict[str, Tuple[Any, int]] = {}
         for event_id, auth_event_id in tree:
             # If we reference an auth_event_id we fetch the allocated chain ID,
             # either from the existing `chain_map` or the newly generated
@@ -870,7 +868,7 @@ class PersistEventsStore:
                 if not existing_chain_id:
                     existing_chain_id = chain_map[auth_event_id]
 
-            new_chain_tuple = None  # type: Optional[Tuple[Any, int]]
+            new_chain_tuple: Optional[Tuple[Any, int]] = None
             if existing_chain_id:
                 # We found a chain ID/sequence number candidate, check its
                 # not already taken.
@@ -897,9 +895,9 @@ class PersistEventsStore:
         )
 
         # Map from potentially temporary chain ID to real chain ID
-        chain_id_to_allocated_map = dict(
+        chain_id_to_allocated_map: Dict[Any, int] = dict(
             zip(unallocated_chain_ids, newly_allocated_chain_ids)
-        )  # type: Dict[Any, int]
+        )
         chain_id_to_allocated_map.update((c, c) for c in existing_chains)
 
         return {
@@ -1175,9 +1173,9 @@ class PersistEventsStore:
         Returns:
             list[(EventBase, EventContext)]: filtered list
         """
-        new_events_and_contexts = (
-            OrderedDict()
-        )  # type: OrderedDict[str, Tuple[EventBase, EventContext]]
+        new_events_and_contexts: OrderedDict[
+            str, Tuple[EventBase, EventContext]
+        ] = OrderedDict()
         for event, context in events_and_contexts:
             prev_event_context = new_events_and_contexts.get(event.event_id)
             if prev_event_context:
@@ -1205,7 +1203,7 @@ class PersistEventsStore:
                 we are persisting
             backfilled (bool): True if the events were backfilled
         """
-        depth_updates = {}  # type: Dict[str, int]
+        depth_updates: Dict[str, int] = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
             txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@@ -1580,11 +1578,11 @@ class PersistEventsStore:
         # invalidate the cache for the redacted event
         txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
 
-        self.db_pool.simple_insert_txn(
+        self.db_pool.simple_upsert_txn(
             txn,
             table="redactions",
+            keyvalues={"event_id": event.event_id},
             values={
-                "event_id": event.event_id,
                 "redacts": event.redacts,
                 "received_ts": self._clock.time_msec(),
             },
@@ -1885,7 +1883,7 @@ class PersistEventsStore:
                 ),
             )
 
-            room_to_event_ids = {}  # type: Dict[str, List[str]]
+            room_to_event_ids: Dict[str, List[str]] = {}
             for e, _ in events_and_contexts:
                 room_to_event_ids.setdefault(e.room_id, []).append(e.event_id)
 
@@ -2012,10 +2010,6 @@ class PersistEventsStore:
 
         Forward extremities are handled when we first start persisting the events.
         """
-        events_by_room = {}  # type: Dict[str, List[EventBase]]
-        for ev in events:
-            events_by_room.setdefault(ev.room_id, []).append(ev)
-
         query = (
             "INSERT INTO event_backward_extremities (event_id, room_id)"
             " SELECT ?, ? WHERE NOT EXISTS ("
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 29f33bac55..6fcb2b8353 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -960,9 +960,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         event_to_types = {row[0]: (row[1], row[2]) for row in rows}
 
         # Calculate the new last position we've processed up to.
-        new_last_depth = rows[-1][3] if rows else last_depth  # type: int
-        new_last_stream = rows[-1][4] if rows else last_stream  # type: int
-        new_last_room_id = rows[-1][5] if rows else ""  # type: str
+        new_last_depth: int = rows[-1][3] if rows else last_depth
+        new_last_stream: int = rows[-1][4] if rows else last_stream
+        new_last_room_id: str = rows[-1][5] if rows else ""
 
         # Map from room_id to last depth/stream_ordering processed for the room,
         # excluding the last room (which we're likely still processing). We also
@@ -989,7 +989,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             retcols=("event_id", "auth_id"),
         )
 
-        event_to_auth_chain = {}  # type: Dict[str, List[str]]
+        event_to_auth_chain: Dict[str, List[str]] = {}
         for row in auth_events:
             event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 403a5ddaba..3c86adab56 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1365,10 +1365,10 @@ class EventsWorkerStore(SQLBaseStore):
         # we need to make sure that, for every stream id in the results, we get *all*
         # the rows with that stream id.
 
-        rows = await self.db_pool.runInteraction(
+        rows: List[Tuple] = await self.db_pool.runInteraction(
             "get_all_updated_current_state_deltas",
             get_all_updated_current_state_deltas_txn,
-        )  # type: List[Tuple]
+        )
 
         # if we've got fewer rows than the limit, we're good
         if len(rows) < target_row_count:
@@ -1469,7 +1469,7 @@ class EventsWorkerStore(SQLBaseStore):
         """
 
         mapping = {}
-        txn_id_to_event = {}  # type: Dict[Tuple[str, int, str], str]
+        txn_id_to_event: Dict[Tuple[str, int, str], str] = {}
 
         for event in events:
             token_id = getattr(event.internal_metadata, "token_id", None)
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 66ad363bfb..e70d3649ff 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -27,8 +27,11 @@ from synapse.util import json_encoder
 _DEFAULT_CATEGORY_ID = ""
 _DEFAULT_ROLE_ID = ""
 
+
 # A room in a group.
-_RoomInGroup = TypedDict("_RoomInGroup", {"room_id": str, "is_public": bool})
+class _RoomInGroup(TypedDict):
+    room_id: str
+    is_public: bool
 
 
 class GroupServerWorkerStore(SQLBaseStore):
@@ -92,6 +95,7 @@ class GroupServerWorkerStore(SQLBaseStore):
               "is_public": False                    # Whether this is a public room or not
             }
         """
+
         # TODO: Pagination
 
         def _get_rooms_in_group_txn(txn):
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index 774861074c..3d1dff660b 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -78,7 +78,11 @@ class LockStore(SQLBaseStore):
         """Called when the server is shutting down"""
         logger.info("Dropping held locks due to shutdown")
 
-        for (lock_name, lock_key), token in self._live_tokens.items():
+        # We need to take a copy of the tokens dict as dropping the locks will
+        # cause the dictionary to change.
+        tokens = dict(self._live_tokens)
+
+        for (lock_name, lock_key), token in tokens.items():
             await self._drop_lock(lock_name, lock_key, token)
 
         logger.info("Dropped locks due to shutdown")
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index c3f551d377..dc0bbc56ac 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -316,11 +316,140 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
         return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
 
+    async def count_r30v2_users(self) -> Dict[str, int]:
+        """
+        Counts the number of 30 day retained users, defined as users that:
+         - Appear more than once in the past 60 days
+         - Have more than 30 days between the most and least recent appearances that
+           occurred in the past 60 days.
+
+        (This is the second version of this metric, hence R30'v2')
+
+        Returns:
+             A mapping from client type to the number of 30-day retained users for that client.
+
+             The dict keys are:
+              - "all" (a combined number of users across any and all clients)
+              - "android" (Element Android)
+              - "ios" (Element iOS)
+              - "electron" (Element Desktop)
+              - "web" (any web application -- it's not possible to distinguish Element Web here)
+        """
+
+        def _count_r30v2_users(txn):
+            thirty_days_in_secs = 86400 * 30
+            now = int(self._clock.time())
+            sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
+            one_day_from_now_in_secs = now + 86400
+
+            # This is the 'per-platform' count.
+            sql = """
+                SELECT
+                    client_type,
+                    count(client_type)
+                FROM
+                    (
+                        SELECT
+                            user_id,
+                            CASE
+                                WHEN
+                                    LOWER(user_agent) LIKE '%%riot%%' OR
+                                    LOWER(user_agent) LIKE '%%element%%'
+                                    THEN CASE
+                                        WHEN
+                                            LOWER(user_agent) LIKE '%%electron%%'
+                                            THEN 'electron'
+                                        WHEN
+                                            LOWER(user_agent) LIKE '%%android%%'
+                                            THEN 'android'
+                                        WHEN
+                                            LOWER(user_agent) LIKE '%%ios%%'
+                                            THEN 'ios'
+                                        ELSE 'unknown'
+                                    END
+                                WHEN
+                                    LOWER(user_agent) LIKE '%%mozilla%%' OR
+                                    LOWER(user_agent) LIKE '%%gecko%%'
+                                    THEN 'web'
+                                ELSE 'unknown'
+                            END as client_type
+                        FROM
+                            user_daily_visits
+                        WHERE
+                            timestamp > ?
+                            AND
+                            timestamp < ?
+                        GROUP BY
+                            user_id,
+                            client_type
+                        HAVING
+                            max(timestamp) - min(timestamp) > ?
+                    ) AS temp
+                GROUP BY
+                    client_type
+                ;
+            """
+
+            # We initialise all the client types to zero, so we get an explicit
+            # zero if they don't appear in the query results
+            results = {"ios": 0, "android": 0, "web": 0, "electron": 0}
+            txn.execute(
+                sql,
+                (
+                    sixty_days_ago_in_secs * 1000,
+                    one_day_from_now_in_secs * 1000,
+                    thirty_days_in_secs * 1000,
+                ),
+            )
+
+            for row in txn:
+                if row[0] == "unknown":
+                    continue
+                results[row[0]] = row[1]
+
+            # This is the 'all users' count.
+            sql = """
+                SELECT COUNT(*) FROM (
+                    SELECT
+                        1
+                    FROM
+                        user_daily_visits
+                    WHERE
+                        timestamp > ?
+                        AND
+                        timestamp < ?
+                    GROUP BY
+                        user_id
+                    HAVING
+                        max(timestamp) - min(timestamp) > ?
+                ) AS r30_users
+            """
+
+            txn.execute(
+                sql,
+                (
+                    sixty_days_ago_in_secs * 1000,
+                    one_day_from_now_in_secs * 1000,
+                    thirty_days_in_secs * 1000,
+                ),
+            )
+            row = txn.fetchone()
+            if row is None:
+                results["all"] = 0
+            else:
+                results["all"] = row[0]
+
+            return results
+
+        return await self.db_pool.runInteraction(
+            "count_r30v2_users", _count_r30v2_users
+        )
+
     def _get_start_of_day(self):
         """
         Returns millisecond unixtime for start of UTC day.
         """
-        now = time.gmtime()
+        now = time.gmtime(self._clock.time())
         today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
         return today_start * 1000
 
@@ -352,7 +481,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
                     ) udv
                     ON u.user_id = udv.user_id AND u.device_id=udv.device_id
                     INNER JOIN users ON users.name=u.user_id
-                    WHERE last_seen > ? AND last_seen <= ?
+                    WHERE ? <= last_seen AND last_seen < ?
                     AND udv.timestamp IS NULL AND users.is_guest=0
                     AND users.appservice_id IS NULL
                     GROUP BY u.user_id, u.device_id
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 7fb7780d0f..664c65dac5 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -115,7 +115,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         logger.info("[purge] looking for events to delete")
 
         should_delete_expr = "state_key IS NULL"
-        should_delete_params = ()  # type: Tuple[Any, ...]
+        should_delete_params: Tuple[Any, ...] = ()
         if not delete_local_events:
             should_delete_expr += " AND event_id NOT LIKE ?"
 
@@ -215,6 +215,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "event_relations",
             "event_search",
             "rejections",
+            "redactions",
         ):
             logger.info("[purge] removing events from %s", table)
 
@@ -392,7 +393,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "room_memberships",
             "room_stats_state",
             "room_stats_current",
-            "room_stats_historical",
             "room_stats_earliest_token",
             "rooms",
             "stream_ordering_to_exterm",
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index db52176337..a7fb8cd848 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -79,9 +79,9 @@ class PushRulesWorkerStore(
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
-            self._push_rules_stream_id_gen = StreamIdGenerator(
-                db_conn, "push_rules_stream", "stream_id"
-            )  # type: Union[StreamIdGenerator, SlavedIdTracker]
+            self._push_rules_stream_id_gen: Union[
+                StreamIdGenerator, SlavedIdTracker
+            ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
         else:
             self._push_rules_stream_id_gen = SlavedIdTracker(
                 db_conn, "push_rules_stream", "stream_id"
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e31c5864ac..6ad1a0cf7f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1744,7 +1744,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
             items = keyvalues.items()
             where_clause = " AND ".join(k + " = ?" for k, _ in items)
-            values = [v for _, v in items]  # type: List[Union[str, int]]
+            values: List[Union[str, int]] = [v for _, v in items]
             # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
             # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
             # clause and values before we handle that. This seems to be only used in the "set password" handler.
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 9f0d64a325..6ddafe5434 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -25,6 +25,7 @@ from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.search import SearchStore
+from synapse.storage.types import Cursor
 from synapse.types import JsonDict, ThirdPartyInstanceID
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -1022,10 +1023,22 @@ class RoomWorkerStore(SQLBaseStore):
         )
 
 
-class RoomBackgroundUpdateStore(SQLBaseStore):
+class _BackgroundUpdates:
     REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
     ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column"
+    POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2"
+    REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth"
+
+
+_REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
+    "DROP TRIGGER populate_min_depth2_trigger ON room_depth",
+    "DROP FUNCTION populate_min_depth2()",
+    "ALTER TABLE room_depth DROP COLUMN min_depth",
+    "ALTER TABLE room_depth RENAME COLUMN min_depth2 TO min_depth",
+)
+
 
+class RoomBackgroundUpdateStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
 
@@ -1037,15 +1050,25 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         )
 
         self.db_pool.updates.register_background_update_handler(
-            self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE,
+            _BackgroundUpdates.REMOVE_TOMESTONED_ROOMS_BG_UPDATE,
             self._remove_tombstoned_rooms_from_directory,
         )
 
         self.db_pool.updates.register_background_update_handler(
-            self.ADD_ROOMS_ROOM_VERSION_COLUMN,
+            _BackgroundUpdates.ADD_ROOMS_ROOM_VERSION_COLUMN,
             self._background_add_rooms_room_version_column,
         )
 
+        # BG updates to change the type of room_depth.min_depth
+        self.db_pool.updates.register_background_update_handler(
+            _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2,
+            self._background_populate_room_depth_min_depth2,
+        )
+        self.db_pool.updates.register_background_update_handler(
+            _BackgroundUpdates.REPLACE_ROOM_DEPTH_MIN_DEPTH,
+            self._background_replace_room_depth_min_depth,
+        )
+
     async def _background_insert_retention(self, progress, batch_size):
         """Retrieves a list of all rooms within a range and inserts an entry for each of
         them into the room_retention table.
@@ -1164,7 +1187,9 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
                 new_last_room_id = room_id
 
             self.db_pool.updates._background_update_progress_txn(
-                txn, self.ADD_ROOMS_ROOM_VERSION_COLUMN, {"room_id": new_last_room_id}
+                txn,
+                _BackgroundUpdates.ADD_ROOMS_ROOM_VERSION_COLUMN,
+                {"room_id": new_last_room_id},
             )
 
             return False
@@ -1176,7 +1201,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
         if end:
             await self.db_pool.updates._end_background_update(
-                self.ADD_ROOMS_ROOM_VERSION_COLUMN
+                _BackgroundUpdates.ADD_ROOMS_ROOM_VERSION_COLUMN
             )
 
         return batch_size
@@ -1215,7 +1240,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
         if not rooms:
             await self.db_pool.updates._end_background_update(
-                self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE
+                _BackgroundUpdates.REMOVE_TOMESTONED_ROOMS_BG_UPDATE
             )
             return 0
 
@@ -1224,7 +1249,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
             await self.set_room_is_public(room_id, False)
 
         await self.db_pool.updates._background_update_progress(
-            self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]}
+            _BackgroundUpdates.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]}
         )
 
         return len(rooms)
@@ -1268,6 +1293,71 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
         return max_ordering is None
 
+    async def _background_populate_room_depth_min_depth2(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """Populate room_depth.min_depth2
+
+        This is to deal with the fact that min_depth was initially created as a
+        32-bit integer field.
+        """
+
+        def process(txn: Cursor) -> int:
+            last_room = progress.get("last_room", "")
+            txn.execute(
+                """
+                UPDATE room_depth SET min_depth2=min_depth
+                WHERE room_id IN (
+                   SELECT room_id FROM room_depth WHERE room_id > ?
+                   ORDER BY room_id LIMIT ?
+                )
+                RETURNING room_id;
+                """,
+                (last_room, batch_size),
+            )
+            row_count = txn.rowcount
+            if row_count == 0:
+                return 0
+            last_room = max(row[0] for row in txn)
+            logger.info("populated room_depth up to %s", last_room)
+
+            self.db_pool.updates._background_update_progress_txn(
+                txn,
+                _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2,
+                {"last_room": last_room},
+            )
+            return row_count
+
+        result = await self.db_pool.runInteraction(
+            "_background_populate_min_depth2", process
+        )
+
+        if result != 0:
+            return result
+
+        await self.db_pool.updates._end_background_update(
+            _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2
+        )
+        return 0
+
+    async def _background_replace_room_depth_min_depth(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """Drop the old 'min_depth' column and rename 'min_depth2' into its place."""
+
+        def process(txn: Cursor) -> None:
+            for sql in _REPLACE_ROOM_DEPTH_SQL_COMMANDS:
+                logger.info("completing room_depth migration: %s", sql)
+                txn.execute(sql)
+
+        await self.db_pool.runInteraction("_background_replace_room_depth", process)
+
+        await self.db_pool.updates._end_background_update(
+            _BackgroundUpdates.REPLACE_ROOM_DEPTH_MIN_DEPTH,
+        )
+
+        return 0
+
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 2796354a1f..68f1b40ea6 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -649,7 +649,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             event_to_memberships = await self._get_joined_profiles_from_event_ids(
                 missing_member_event_ids
             )
-            users_in_room.update((row for row in event_to_memberships.values() if row))
+            users_in_room.update(row for row in event_to_memberships.values() if row)
 
         if event is not None and event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
@@ -703,13 +703,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     @cached(max_entries=10000)
     async def is_host_joined(self, room_id: str, host: str) -> bool:
+        return await self._check_host_room_membership(room_id, host, Membership.JOIN)
+
+    @cached(max_entries=10000)
+    async def is_host_invited(self, room_id: str, host: str) -> bool:
+        return await self._check_host_room_membership(room_id, host, Membership.INVITE)
+
+    async def _check_host_room_membership(
+        self, room_id: str, host: str, membership: str
+    ) -> bool:
         if "%" in host or "_" in host:
             raise Exception("Invalid host name")
 
         sql = """
             SELECT state_key FROM current_state_events AS c
             INNER JOIN room_memberships AS m USING (event_id)
-            WHERE m.membership = 'join'
+            WHERE m.membership = ?
                 AND type = 'm.room.member'
                 AND c.room_id = ?
                 AND state_key LIKE ?
@@ -722,7 +731,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         like_clause = "%:" + host
 
         rows = await self.db_pool.execute(
-            "is_host_joined", None, sql, room_id, like_clause
+            "is_host_joined", None, sql, membership, room_id, like_clause
         )
 
         if not rows:
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 82a1833509..59d67c255b 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -26,7 +26,6 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import StoreError
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.state_deltas import StateDeltasStore
-from synapse.storage.engines import PostgresEngine
 from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
 
@@ -49,14 +48,6 @@ ABSOLUTE_STATS_FIELDS = {
     "user": ("joined_rooms",),
 }
 
-# these fields are per-timeslice and so should be reset to 0 upon a new slice
-# You can draw these stats on a histogram.
-# Example: number of events sent locally during a time slice
-PER_SLICE_FIELDS = {
-    "room": ("total_events", "total_event_bytes"),
-    "user": ("invites_sent", "rooms_created", "total_events", "total_event_bytes"),
-}
-
 TYPE_TO_TABLE = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")}
 
 # these are the tables (& ID columns) which contain our actual subjects
@@ -106,7 +97,6 @@ class StatsStore(StateDeltasStore):
         self.server_name = hs.hostname
         self.clock = self.hs.get_clock()
         self.stats_enabled = hs.config.stats_enabled
-        self.stats_bucket_size = hs.config.stats_bucket_size
 
         self.stats_delta_processing_lock = DeferredLock()
 
@@ -122,22 +112,6 @@ class StatsStore(StateDeltasStore):
         self.db_pool.updates.register_noop_background_update("populate_stats_cleanup")
         self.db_pool.updates.register_noop_background_update("populate_stats_prepare")
 
-    def quantise_stats_time(self, ts):
-        """
-        Quantises a timestamp to be a multiple of the bucket size.
-
-        Args:
-            ts (int): the timestamp to quantise, in milliseconds since the Unix
-                Epoch
-
-        Returns:
-            int: a timestamp which
-              - is divisible by the bucket size;
-              - is no later than `ts`; and
-              - is the largest such timestamp.
-        """
-        return (ts // self.stats_bucket_size) * self.stats_bucket_size
-
     async def _populate_stats_process_users(self, progress, batch_size):
         """
         This is a background update which regenerates statistics for users.
@@ -288,56 +262,6 @@ class StatsStore(StateDeltasStore):
             desc="update_room_state",
         )
 
-    async def get_statistics_for_subject(
-        self, stats_type: str, stats_id: str, start: str, size: int = 100
-    ) -> List[dict]:
-        """
-        Get statistics for a given subject.
-
-        Args:
-            stats_type: The type of subject
-            stats_id: The ID of the subject (e.g. room_id or user_id)
-            start: Pagination start. Number of entries, not timestamp.
-            size: How many entries to return.
-
-        Returns:
-            A list of dicts, where the dict has the keys of
-            ABSOLUTE_STATS_FIELDS[stats_type],  and "bucket_size" and "end_ts".
-        """
-        return await self.db_pool.runInteraction(
-            "get_statistics_for_subject",
-            self._get_statistics_for_subject_txn,
-            stats_type,
-            stats_id,
-            start,
-            size,
-        )
-
-    def _get_statistics_for_subject_txn(
-        self, txn, stats_type, stats_id, start, size=100
-    ):
-        """
-        Transaction-bound version of L{get_statistics_for_subject}.
-        """
-
-        table, id_col = TYPE_TO_TABLE[stats_type]
-        selected_columns = list(
-            ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type]
-        )
-
-        slice_list = self.db_pool.simple_select_list_paginate_txn(
-            txn,
-            table + "_historical",
-            "end_ts",
-            start,
-            size,
-            retcols=selected_columns + ["bucket_size", "end_ts"],
-            keyvalues={id_col: stats_id},
-            order_direction="DESC",
-        )
-
-        return slice_list
-
     @cached()
     async def get_earliest_token_for_stats(
         self, stats_type: str, id: str
@@ -451,14 +375,10 @@ class StatsStore(StateDeltasStore):
 
         table, id_col = TYPE_TO_TABLE[stats_type]
 
-        quantised_ts = self.quantise_stats_time(int(ts))
-        end_ts = quantised_ts + self.stats_bucket_size
-
         # Lets be paranoid and check that all the given field names are known
         abs_field_names = ABSOLUTE_STATS_FIELDS[stats_type]
-        slice_field_names = PER_SLICE_FIELDS[stats_type]
         for field in chain(fields.keys(), absolute_field_overrides.keys()):
-            if field not in abs_field_names and field not in slice_field_names:
+            if field not in abs_field_names:
                 # guard against potential SQL injection dodginess
                 raise ValueError(
                     "%s is not a recognised field"
@@ -491,20 +411,6 @@ class StatsStore(StateDeltasStore):
             additive_relatives=deltas_of_absolute_fields,
         )
 
-        per_slice_additive_relatives = {
-            key: fields.get(key, 0) for key in slice_field_names
-        }
-        self._upsert_copy_from_table_with_additive_relatives_txn(
-            txn=txn,
-            into_table=table + "_historical",
-            keyvalues={id_col: stats_id},
-            extra_dst_insvalues={"bucket_size": self.stats_bucket_size},
-            extra_dst_keyvalues={"end_ts": end_ts},
-            additive_relatives=per_slice_additive_relatives,
-            src_table=table + "_current",
-            copy_columns=abs_field_names,
-        )
-
     def _upsert_with_additive_relatives_txn(
         self, txn, table, keyvalues, absolutes, additive_relatives
     ):
@@ -528,7 +434,7 @@ class StatsStore(StateDeltasStore):
             ]
 
             relative_updates = [
-                "%(field)s = EXCLUDED.%(field)s + %(table)s.%(field)s"
+                "%(field)s = EXCLUDED.%(field)s + COALESCE(%(table)s.%(field)s, 0)"
                 % {"table": table, "field": field}
                 for field in additive_relatives.keys()
             ]
@@ -568,205 +474,13 @@ class StatsStore(StateDeltasStore):
                 self.db_pool.simple_insert_txn(txn, table, merged_dict)
             else:
                 for (key, val) in additive_relatives.items():
-                    current_row[key] += val
+                    if current_row[key] is None:
+                        current_row[key] = val
+                    else:
+                        current_row[key] += val
                 current_row.update(absolutes)
                 self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row)
 
-    def _upsert_copy_from_table_with_additive_relatives_txn(
-        self,
-        txn,
-        into_table,
-        keyvalues,
-        extra_dst_keyvalues,
-        extra_dst_insvalues,
-        additive_relatives,
-        src_table,
-        copy_columns,
-    ):
-        """Updates the historic stats table with latest updates.
-
-        This involves copying "absolute" fields from the `_current` table, and
-        adding relative fields to any existing values.
-
-        Args:
-             txn: Transaction
-             into_table (str): The destination table to UPSERT the row into
-             keyvalues (dict[str, any]): Row-identifying key values
-             extra_dst_keyvalues (dict[str, any]): Additional keyvalues
-                for `into_table`.
-             extra_dst_insvalues (dict[str, any]): Additional values to insert
-                on new row creation for `into_table`.
-             additive_relatives (dict[str, any]): Fields that will be added onto
-                if existing row present. (Must be disjoint from copy_columns.)
-             src_table (str): The source table to copy from
-             copy_columns (iterable[str]): The list of columns to copy
-        """
-        if self.database_engine.can_native_upsert:
-            ins_columns = chain(
-                keyvalues,
-                copy_columns,
-                additive_relatives,
-                extra_dst_keyvalues,
-                extra_dst_insvalues,
-            )
-            sel_exprs = chain(
-                keyvalues,
-                copy_columns,
-                (
-                    "?"
-                    for _ in chain(
-                        additive_relatives, extra_dst_keyvalues, extra_dst_insvalues
-                    )
-                ),
-            )
-            keyvalues_where = ("%s = ?" % f for f in keyvalues)
-
-            sets_cc = ("%s = EXCLUDED.%s" % (f, f) for f in copy_columns)
-            sets_ar = (
-                "%s = EXCLUDED.%s + %s.%s" % (f, f, into_table, f)
-                for f in additive_relatives
-            )
-
-            sql = """
-                INSERT INTO %(into_table)s (%(ins_columns)s)
-                SELECT %(sel_exprs)s
-                FROM %(src_table)s
-                WHERE %(keyvalues_where)s
-                ON CONFLICT (%(keyvalues)s)
-                DO UPDATE SET %(sets)s
-            """ % {
-                "into_table": into_table,
-                "ins_columns": ", ".join(ins_columns),
-                "sel_exprs": ", ".join(sel_exprs),
-                "keyvalues_where": " AND ".join(keyvalues_where),
-                "src_table": src_table,
-                "keyvalues": ", ".join(
-                    chain(keyvalues.keys(), extra_dst_keyvalues.keys())
-                ),
-                "sets": ", ".join(chain(sets_cc, sets_ar)),
-            }
-
-            qargs = list(
-                chain(
-                    additive_relatives.values(),
-                    extra_dst_keyvalues.values(),
-                    extra_dst_insvalues.values(),
-                    keyvalues.values(),
-                )
-            )
-            txn.execute(sql, qargs)
-        else:
-            self.database_engine.lock_table(txn, into_table)
-            src_row = self.db_pool.simple_select_one_txn(
-                txn, src_table, keyvalues, copy_columns
-            )
-            all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues}
-            dest_current_row = self.db_pool.simple_select_one_txn(
-                txn,
-                into_table,
-                keyvalues=all_dest_keyvalues,
-                retcols=list(chain(additive_relatives.keys(), copy_columns)),
-                allow_none=True,
-            )
-
-            if dest_current_row is None:
-                merged_dict = {
-                    **keyvalues,
-                    **extra_dst_keyvalues,
-                    **extra_dst_insvalues,
-                    **src_row,
-                    **additive_relatives,
-                }
-                self.db_pool.simple_insert_txn(txn, into_table, merged_dict)
-            else:
-                for (key, val) in additive_relatives.items():
-                    src_row[key] = dest_current_row[key] + val
-                self.db_pool.simple_update_txn(
-                    txn, into_table, all_dest_keyvalues, src_row
-                )
-
-    async def get_changes_room_total_events_and_bytes(
-        self, min_pos: int, max_pos: int
-    ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
-        """Fetches the counts of events in the given range of stream IDs.
-
-        Args:
-            min_pos
-            max_pos
-
-        Returns:
-            Mapping of room ID to field changes.
-        """
-
-        return await self.db_pool.runInteraction(
-            "stats_incremental_total_events_and_bytes",
-            self.get_changes_room_total_events_and_bytes_txn,
-            min_pos,
-            max_pos,
-        )
-
-    def get_changes_room_total_events_and_bytes_txn(
-        self, txn, low_pos: int, high_pos: int
-    ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
-        """Gets the total_events and total_event_bytes counts for rooms and
-        senders, in a range of stream_orderings (including backfilled events).
-
-        Args:
-            txn
-            low_pos: Low stream ordering
-            high_pos: High stream ordering
-
-        Returns:
-            The room and user deltas for total_events/total_event_bytes in the
-            format of `stats_id` -> fields
-        """
-
-        if low_pos >= high_pos:
-            # nothing to do here.
-            return {}, {}
-
-        if isinstance(self.database_engine, PostgresEngine):
-            new_bytes_expression = "OCTET_LENGTH(json)"
-        else:
-            new_bytes_expression = "LENGTH(CAST(json AS BLOB))"
-
-        sql = """
-            SELECT events.room_id, COUNT(*) AS new_events, SUM(%s) AS new_bytes
-            FROM events INNER JOIN event_json USING (event_id)
-            WHERE (? < stream_ordering AND stream_ordering <= ?)
-                OR (? <= stream_ordering AND stream_ordering <= ?)
-            GROUP BY events.room_id
-        """ % (
-            new_bytes_expression,
-        )
-
-        txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos))
-
-        room_deltas = {
-            room_id: {"total_events": new_events, "total_event_bytes": new_bytes}
-            for room_id, new_events, new_bytes in txn
-        }
-
-        sql = """
-            SELECT events.sender, COUNT(*) AS new_events, SUM(%s) AS new_bytes
-            FROM events INNER JOIN event_json USING (event_id)
-            WHERE (? < stream_ordering AND stream_ordering <= ?)
-                OR (? <= stream_ordering AND stream_ordering <= ?)
-            GROUP BY events.sender
-        """ % (
-            new_bytes_expression,
-        )
-
-        txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos))
-
-        user_deltas = {
-            user_id: {"total_events": new_events, "total_event_bytes": new_bytes}
-            for user_id, new_events, new_bytes in txn
-            if self.hs.is_mine_id(user_id)
-        }
-
-        return room_deltas, user_deltas
-
     async def _calculate_and_set_initial_state_for_room(
         self, room_id: str
     ) -> Tuple[dict, dict, int]:
@@ -893,6 +607,7 @@ class StatsStore(StateDeltasStore):
                 "invited_members": membership_counts.get(Membership.INVITE, 0),
                 "left_members": membership_counts.get(Membership.LEAVE, 0),
                 "banned_members": membership_counts.get(Membership.BAN, 0),
+                "knocked_members": membership_counts.get(Membership.KNOCK, 0),
                 "local_users_in_room": len(local_users_in_room),
             },
         )
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 7581c7d3ff..959f13de47 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1085,9 +1085,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
         # then filtering the results.
         if from_token.topological is not None:
-            from_bound = (
-                from_token.as_historical_tuple()
-            )  # type: Tuple[Optional[int], int]
+            from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple()
         elif direction == "b":
             from_bound = (
                 None,
@@ -1099,7 +1097,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
                 from_token.stream,
             )
 
-        to_bound = None  # type: Optional[Tuple[Optional[int], int]]
+        to_bound: Optional[Tuple[Optional[int], int]] = None
         if to_token:
             if to_token.topological is not None:
                 to_bound = to_token.as_historical_tuple()
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 1d62c6140f..f93ff0a545 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -42,7 +42,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
         )
 
-        tags_by_room = {}  # type: Dict[str, Dict[str, JsonDict]]
+        tags_by_room: Dict[str, Dict[str, JsonDict]] = {}
         for row in rows:
             room_tags = tags_by_room.setdefault(row["room_id"], {})
             room_tags[row["tag"]] = db_to_json(row["content"])
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 22c05cdde7..38bfdf5dad 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -224,12 +224,12 @@ class UIAuthWorkerStore(SQLBaseStore):
         self, txn: LoggingTransaction, session_id: str, key: str, value: Any
     ):
         # Get the current value.
-        result = self.db_pool.simple_select_one_txn(
+        result: Dict[str, Any] = self.db_pool.simple_select_one_txn(  # type: ignore
             txn,
             table="ui_auth_sessions",
             keyvalues={"session_id": session_id},
             retcols=("serverdict",),
-        )  # type: Dict[str, Any]  # type: ignore
+        )
 
         # Update it and add it back to the database.
         serverdict = db_to_json(result["serverdict"])
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 051095fea9..a39877f0d5 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -307,7 +307,7 @@ class EventsPersistenceStorage:
             matched the transcation ID; the existing event is returned in such
             a case.
         """
-        partitioned = {}  # type: Dict[str, List[Tuple[EventBase, EventContext]]]
+        partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
         for event, ctx in events_and_contexts:
             partitioned.setdefault(event.room_id, []).append((event, ctx))
 
@@ -384,7 +384,7 @@ class EventsPersistenceStorage:
             A dictionary of event ID to event ID we didn't persist as we already
             had another event persisted with the same TXN ID.
         """
-        replaced_events = {}  # type: Dict[str, str]
+        replaced_events: Dict[str, str] = {}
         if not events_and_contexts:
             return replaced_events
 
@@ -440,16 +440,14 @@ class EventsPersistenceStorage:
             # Set of remote users which were in rooms the server has left. We
             # should check if we still share any rooms and if not we mark their
             # device lists as stale.
-            potentially_left_users = set()  # type: Set[str]
+            potentially_left_users: Set[str] = set()
 
             if not backfilled:
                 with Measure(self._clock, "_calculate_state_and_extrem"):
                     # Work out the new "current state" for each room.
                     # We do this by working out what the new extremities are and then
                     # calculating the state from that.
-                    events_by_room = (
-                        {}
-                    )  # type: Dict[str, List[Tuple[EventBase, EventContext]]]
+                    events_by_room: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
                     for event, context in chunk:
                         events_by_room.setdefault(event.room_id, []).append(
                             (event, context)
@@ -622,9 +620,9 @@ class EventsPersistenceStorage:
         )
 
         # Remove any events which are prev_events of any existing events.
-        existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
-            result
-        )  # type: Collection[str]
+        existing_prevs: Collection[
+            str
+        ] = await self.persist_events_store._get_events_which_are_prevs(result)
         result.difference_update(existing_prevs)
 
         # Finally handle the case where the new events have soft-failed prev
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 683e5e3b90..61392b9639 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -256,7 +256,7 @@ def _setup_new_database(
         for database in databases
     )
 
-    directory_entries = []  # type: List[_DirectoryListing]
+    directory_entries: List[_DirectoryListing] = []
     for directory in directories:
         directory_entries.extend(
             _DirectoryListing(file_name, os.path.join(directory, file_name))
@@ -424,10 +424,10 @@ def _upgrade_existing_database(
             directories.append(os.path.join(schema_path, database, "delta", str(v)))
 
         # Used to check if we have any duplicate file names
-        file_name_counter = Counter()  # type: CounterType[str]
+        file_name_counter: CounterType[str] = Counter()
 
         # Now find which directories have anything of interest.
-        directory_entries = []  # type: List[_DirectoryListing]
+        directory_entries: List[_DirectoryListing] = []
         for directory in directories:
             logger.debug("Looking for schema deltas in %s", directory)
             try:
@@ -639,7 +639,7 @@ def get_statements(f: Iterable[str]) -> Generator[str, None, None]:
 
 
 def executescript(txn: Cursor, schema_path: str) -> None:
-    with open(schema_path, "r") as f:
+    with open(schema_path) as f:
         execute_statements_from_stream(txn, f)
 
 
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 0a53b73ccc..36340a652a 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-SCHEMA_VERSION = 60
+SCHEMA_VERSION = 61
 """Represents the expectations made by the codebase about the database schema
 
 This should be incremented whenever the codebase changes its requirements on the
@@ -21,6 +21,10 @@ older versions of Synapse).
 
 See `README.md <synapse/storage/schema/README.md>`_  for more information on how this
 works.
+
+Changes in SCHEMA_VERSION = 61:
+    - The `user_stats_historical` and `room_stats_historical` tables are not written and
+      are not read (previously, they were written but not read).
 """
 
 
diff --git a/synapse/storage/schema/main/delta/61/01change_appservices_txns.sql.postgres b/synapse/storage/schema/main/delta/61/01change_appservices_txns.sql.postgres
new file mode 100644
index 0000000000..c8aec78e60
--- /dev/null
+++ b/synapse/storage/schema/main/delta/61/01change_appservices_txns.sql.postgres
@@ -0,0 +1,23 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- we use bigint elsewhere in the database for appservice txn ids (notably,
+-- application_services_state.last_txn), and generally we use bigints everywhere else
+-- we have monotonic counters, so let's bring this one in line.
+--
+-- assuming there aren't thousands of rows for decommisioned/non-functional ASes, this
+-- table should be pretty small, so safe to do a synchronous ALTER TABLE.
+
+ALTER TABLE application_services_txns ALTER COLUMN txn_id SET DATA TYPE BIGINT;
diff --git a/synapse/storage/schema/main/delta/61/02drop_redundant_room_depth_index.sql b/synapse/storage/schema/main/delta/61/02drop_redundant_room_depth_index.sql
new file mode 100644
index 0000000000..35ca7a40c0
--- /dev/null
+++ b/synapse/storage/schema/main/delta/61/02drop_redundant_room_depth_index.sql
@@ -0,0 +1,18 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this index is redundant; there is another UNIQUE index on this table.
+DROP INDEX IF EXISTS room_depth_room;
+
diff --git a/synapse/storage/schema/main/delta/61/03recreate_min_depth.py b/synapse/storage/schema/main/delta/61/03recreate_min_depth.py
new file mode 100644
index 0000000000..f8d7db9f2e
--- /dev/null
+++ b/synapse/storage/schema/main/delta/61/03recreate_min_depth.py
@@ -0,0 +1,70 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This migration handles the process of changing the type of `room_depth.min_depth` to
+a BIGINT.
+"""
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.types import Cursor
+
+
+def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
+    if not isinstance(database_engine, PostgresEngine):
+        # this only applies to postgres - sqlite does not distinguish between big and
+        # little ints.
+        return
+
+    # First add a new column to contain the bigger min_depth
+    cur.execute("ALTER TABLE room_depth ADD COLUMN min_depth2 BIGINT")
+
+    # Create a trigger which will keep it populated.
+    cur.execute(
+        """
+        CREATE OR REPLACE FUNCTION populate_min_depth2() RETURNS trigger AS $BODY$
+            BEGIN
+                new.min_depth2 := new.min_depth;
+                RETURN NEW;
+            END;
+        $BODY$ LANGUAGE plpgsql
+        """
+    )
+
+    cur.execute(
+        """
+        CREATE TRIGGER populate_min_depth2_trigger BEFORE INSERT OR UPDATE ON room_depth
+        FOR EACH ROW
+        EXECUTE PROCEDURE populate_min_depth2()
+        """
+    )
+
+    # Start a bg process to populate it for old rooms
+    cur.execute(
+        """
+       INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+            (6103, 'populate_room_depth_min_depth2', '{}')
+       """
+    )
+
+    # and another to switch them over once it completes.
+    cur.execute(
+        """
+        INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+            (6103, 'replace_room_depth_min_depth', '{}', 'populate_room_depth2')
+        """
+    )
+
+
+def run_upgrade(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
+    pass
diff --git a/synapse/storage/schema/state/delta/61/02state_groups_state_n_distinct.sql.postgres b/synapse/storage/schema/state/delta/61/02state_groups_state_n_distinct.sql.postgres
new file mode 100644
index 0000000000..35a153da7b
--- /dev/null
+++ b/synapse/storage/schema/state/delta/61/02state_groups_state_n_distinct.sql.postgres
@@ -0,0 +1,34 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- By default the postgres statistics collector massively underestimates the
+-- number of distinct state groups are in the `state_groups_state`, which can
+-- cause postgres to use table scans for queries for multiple state groups.
+--
+-- To work around this we can manually tell postgres the number of distinct state
+-- groups there are by setting `n_distinct` (a negative value here is the number
+-- of distinct values divided by the number of rows, so -0.02 means on average
+-- there are 50 rows per distinct value). We don't need a particularly
+-- accurate number here, as a) we just want it to always use index scans and b)
+-- our estimate is going to be better than the one made by the statistics
+-- collector.
+
+ALTER TABLE state_groups_state ALTER COLUMN state_group SET (n_distinct = -0.02);
+
+-- Ideally we'd do an `ANALYZE state_groups_state (state_group)` here so that
+-- the above gets picked up immediately, but that can take a bit of time so we
+-- rely on the autovacuum eventually getting run and doing that in the
+-- background for us.
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index c9dce726cb..f8fbba9d38 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -91,7 +91,7 @@ class StateFilter:
         Returns:
             The new state filter.
         """
-        type_dict = {}  # type: Dict[str, Optional[Set[str]]]
+        type_dict: Dict[str, Optional[Set[str]]] = {}
         for typ, s in types:
             if typ in type_dict:
                 if type_dict[typ] is None:
@@ -194,7 +194,7 @@ class StateFilter:
         """
 
         where_clause = ""
-        where_args = []  # type: List[str]
+        where_args: List[str] = []
 
         if self.is_full():
             return where_clause, where_args
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f1e62f9e85..c768fdea56 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -112,7 +112,7 @@ class StreamIdGenerator:
         # insertion ordering will ensure its in the correct ordering.
         #
         # The key and values are the same, but we never look at the values.
-        self._unfinished_ids = OrderedDict()  # type: OrderedDict[int, int]
+        self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
 
     def get_next(self):
         """
@@ -236,15 +236,15 @@ class MultiWriterIdGenerator:
         # Note: If we are a negative stream then we still store all the IDs as
         # positive to make life easier for us, and simply negate the IDs when we
         # return them.
-        self._current_positions = {}  # type: Dict[str, int]
+        self._current_positions: Dict[str, int] = {}
 
         # Set of local IDs that we're still processing. The current position
         # should be less than the minimum of this set (if not empty).
-        self._unfinished_ids = set()  # type: Set[int]
+        self._unfinished_ids: Set[int] = set()
 
         # Set of local IDs that we've processed that are larger than the current
         # position, due to there being smaller unpersisted IDs.
-        self._finished_ids = set()  # type: Set[int]
+        self._finished_ids: Set[int] = set()
 
         # We track the max position where we know everything before has been
         # persisted. This is done by a) looking at the min across all instances
@@ -265,7 +265,7 @@ class MultiWriterIdGenerator:
         self._persisted_upto_position = (
             min(self._current_positions.values()) if self._current_positions else 1
         )
-        self._known_persisted_positions = []  # type: List[int]
+        self._known_persisted_positions: List[int] = []
 
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
@@ -465,7 +465,7 @@ class MultiWriterIdGenerator:
             self._unfinished_ids.discard(next_id)
             self._finished_ids.add(next_id)
 
-            new_cur = None  # type: Optional[int]
+            new_cur: Optional[int] = None
 
             if self._unfinished_ids:
                 # If there are unfinished IDs then the new position will be the
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 30b6b8e0ca..bb33e04fb1 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -208,10 +208,10 @@ class LocalSequenceGenerator(SequenceGenerator):
                  get_next_id_txn; should return the curreent maximum id
         """
         # the callback. this is cleared after it is called, so that it can be GCed.
-        self._callback = get_first_callback  # type: Optional[GetFirstCallbackType]
+        self._callback: Optional[GetFirstCallbackType] = get_first_callback
 
         # The current max value, or None if we haven't looked in the DB yet.
-        self._current_max_id = None  # type: Optional[int]
+        self._current_max_id: Optional[int] = None
         self._lock = threading.Lock()
 
     def get_next_id_txn(self, txn: Cursor) -> int:
@@ -274,7 +274,7 @@ def build_sequence_generator(
             `check_consistency` details.
     """
     if isinstance(database_engine, PostgresEngine):
-        seq = PostgresSequenceGenerator(sequence_name)  # type: SequenceGenerator
+        seq: SequenceGenerator = PostgresSequenceGenerator(sequence_name)
     else:
         seq = LocalSequenceGenerator(get_first_callback)
 
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 20fceaa935..99b0aac2fb 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -32,9 +32,9 @@ class EventSources:
     }
 
     def __init__(self, hs):
-        self.sources = {
+        self.sources: Dict[str, Any] = {
             name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items()
-        }  # type: Dict[str, Any]
+        }
         self.store = hs.get_datastore()
 
     def get_current_token(self) -> StreamToken:
diff --git a/synapse/types.py b/synapse/types.py
index 8d2fa00f71..429bb013d2 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -182,14 +182,14 @@ def create_requester(
     )
 
 
-def get_domain_from_id(string):
+def get_domain_from_id(string: str) -> str:
     idx = string.find(":")
     if idx == -1:
         raise SynapseError(400, "Invalid ID: %r" % (string,))
     return string[idx + 1 :]
 
 
-def get_localpart_from_id(string):
+def get_localpart_from_id(string: str) -> str:
     idx = string.find(":")
     if idx == -1:
         raise SynapseError(400, "Invalid ID: %r" % (string,))
@@ -210,7 +210,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
         'domain' : The domain part of the name
     """
 
-    SIGIL = abc.abstractproperty()  # type: str  # type: ignore
+    SIGIL: str = abc.abstractproperty()  # type: ignore
 
     localpart = attr.ib(type=str)
     domain = attr.ib(type=str)
@@ -304,7 +304,7 @@ class GroupID(DomainSpecificString):
 
     @classmethod
     def from_string(cls: Type[DS], s: str) -> DS:
-        group_id = super().from_string(s)  # type: DS # type: ignore
+        group_id: DS = super().from_string(s)  # type: ignore
 
         if not group_id.localpart:
             raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
@@ -577,10 +577,10 @@ class RoomStreamToken:
             entries = []
             for name, pos in self.instance_map.items():
                 instance_id = await store.get_id_for_instance(name)
-                entries.append("{}.{}".format(instance_id, pos))
+                entries.append(f"{instance_id}.{pos}")
 
             encoded_map = "~".join(entries)
-            return "m{}~{}".format(self.stream, encoded_map)
+            return f"m{self.stream}~{encoded_map}"
         else:
             return "s%d" % (self.stream,)
 
@@ -600,7 +600,7 @@ class StreamToken:
     groups_key = attr.ib(type=int)
 
     _SEPARATOR = "_"
-    START = None  # type: StreamToken
+    START: "StreamToken"
 
     @classmethod
     async def from_string(cls, store: "DataStore", string: str) -> "StreamToken":
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 061102c3c8..014db1355b 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -257,7 +257,7 @@ class Linearizer:
             max_count: The maximum number of concurrent accesses
         """
         if name is None:
-            self.name = id(self)  # type: Union[str, int]
+            self.name: Union[str, int] = id(self)
         else:
             self.name = name
 
@@ -269,7 +269,7 @@ class Linearizer:
         self.max_count = max_count
 
         # key_to_defer is a map from the key to a _LinearizerEntry.
-        self.key_to_defer = {}  # type: Dict[Hashable, _LinearizerEntry]
+        self.key_to_defer: Dict[Hashable, _LinearizerEntry] = {}
 
     def is_queued(self, key: Hashable) -> bool:
         """Checks whether there is a process queued up waiting"""
@@ -409,10 +409,10 @@ class ReadWriteLock:
 
     def __init__(self):
         # Latest readers queued
-        self.key_to_current_readers = {}  # type: Dict[str, Set[defer.Deferred]]
+        self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {}
 
         # Latest writer queued
-        self.key_to_current_writer = {}  # type: Dict[str, defer.Deferred]
+        self.key_to_current_writer: Dict[str, defer.Deferred] = {}
 
     async def read(self, key: str) -> ContextManager:
         new_defer = defer.Deferred()
diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py
index 8fd5bfb69b..274cea7eb7 100644
--- a/synapse/util/batching_queue.py
+++ b/synapse/util/batching_queue.py
@@ -93,11 +93,11 @@ class BatchingQueue(Generic[V, R]):
         self._clock = clock
 
         # The set of keys currently being processed.
-        self._processing_keys = set()  # type: Set[Hashable]
+        self._processing_keys: Set[Hashable] = set()
 
         # The currently pending batch of values by key, with a Deferred to call
         # with the result of the corresponding `_process_batch_callback` call.
-        self._next_values = {}  # type: Dict[Hashable, List[Tuple[V, defer.Deferred]]]
+        self._next_values: Dict[Hashable, List[Tuple[V, defer.Deferred]]] = {}
 
         # The function to call with batches of values.
         self._process_batch_callback = process_batch_callback
@@ -108,9 +108,7 @@ class BatchingQueue(Generic[V, R]):
 
         number_of_keys.labels(self._name).set_function(lambda: len(self._next_values))
 
-        self._number_in_flight_metric = number_in_flight.labels(
-            self._name
-        )  # type: Gauge
+        self._number_in_flight_metric: Gauge = number_in_flight.labels(self._name)
 
     async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
         """Adds the value to the queue with the given key, returning the result
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index ca36f07c20..9012034b7a 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -29,8 +29,8 @@ logger = logging.getLogger(__name__)
 TRACK_MEMORY_USAGE = False
 
 
-caches_by_name = {}  # type: Dict[str, Sized]
-collectors_by_name = {}  # type: Dict[str, CacheMetric]
+caches_by_name: Dict[str, Sized] = {}
+collectors_by_name: Dict[str, "CacheMetric"] = {}
 
 cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
 cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py
index a301c9e89b..891bee0b33 100644
--- a/synapse/util/caches/cached_call.py
+++ b/synapse/util/caches/cached_call.py
@@ -63,9 +63,9 @@ class CachedCall(Generic[TV]):
             f: The underlying function. Only one call to this function will be alive
                 at once (per instance of CachedCall)
         """
-        self._callable = f  # type: Optional[Callable[[], Awaitable[TV]]]
-        self._deferred = None  # type: Optional[Deferred]
-        self._result = None  # type: Union[None, Failure, TV]
+        self._callable: Optional[Callable[[], Awaitable[TV]]] = f
+        self._deferred: Optional[Deferred] = None
+        self._result: Union[None, Failure, TV] = None
 
     async def get(self) -> TV:
         """Kick off the call if necessary, and return the result"""
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 1044139119..8c6fafc677 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -80,25 +80,25 @@ class DeferredCache(Generic[KT, VT]):
         cache_type = TreeCache if tree else dict
 
         # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
-        self._pending_deferred_cache = (
-            cache_type()
-        )  # type: Union[TreeCache, MutableMapping[KT, CacheEntry]]
+        self._pending_deferred_cache: Union[
+            TreeCache, "MutableMapping[KT, CacheEntry]"
+        ] = cache_type()
 
         def metrics_cb():
             cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
 
         # cache is used for completed results and maps to the result itself, rather than
         # a Deferred.
-        self.cache = LruCache(
+        self.cache: LruCache[KT, VT] = LruCache(
             max_size=max_entries,
             cache_name=name,
             cache_type=cache_type,
             size_callback=(lambda d: len(d) or 1) if iterable else None,
             metrics_collection_callback=metrics_cb,
             apply_cache_factor_from_config=apply_cache_factor_from_config,
-        )  # type: LruCache[KT, VT]
+        )
 
-        self.thread = None  # type: Optional[threading.Thread]
+        self.thread: Optional[threading.Thread] = None
 
     @property
     def max_entries(self):
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index d77e8edeea..1e8e6b1d01 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -46,17 +46,17 @@ F = TypeVar("F", bound=Callable[..., Any])
 
 
 class _CachedFunction(Generic[F]):
-    invalidate = None  # type: Any
-    invalidate_all = None  # type: Any
-    prefill = None  # type: Any
-    cache = None  # type: Any
-    num_args = None  # type: Any
+    invalidate: Any = None
+    invalidate_all: Any = None
+    prefill: Any = None
+    cache: Any = None
+    num_args: Any = None
 
-    __name__ = None  # type: str
+    __name__: str
 
     # Note: This function signature is actually fiddled with by the synapse mypy
     # plugin to a) make it a bound method, and b) remove any `cache_context` arg.
-    __call__ = None  # type: F
+    __call__: F
 
 
 class _CacheDescriptorBase:
@@ -115,8 +115,8 @@ class _CacheDescriptorBase:
 
 
 class _LruCachedFunction(Generic[F]):
-    cache = None  # type: LruCache[CacheKey, Any]
-    __call__ = None  # type: F
+    cache: LruCache[CacheKey, Any]
+    __call__: F
 
 
 def lru_cache(
@@ -180,10 +180,10 @@ class LruCacheDescriptor(_CacheDescriptorBase):
         self.max_entries = max_entries
 
     def __get__(self, obj, owner):
-        cache = LruCache(
+        cache: LruCache[CacheKey, Any] = LruCache(
             cache_name=self.orig.__name__,
             max_size=self.max_entries,
-        )  # type: LruCache[CacheKey, Any]
+        )
 
         get_cache_key = self.cache_key_builder
         sentinel = LruCacheDescriptor._Sentinel.sentinel
@@ -271,12 +271,12 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         self.iterable = iterable
 
     def __get__(self, obj, owner):
-        cache = DeferredCache(
+        cache: DeferredCache[CacheKey, Any] = DeferredCache(
             name=self.orig.__name__,
             max_entries=self.max_entries,
             tree=self.tree,
             iterable=self.iterable,
-        )  # type: DeferredCache[CacheKey, Any]
+        )
 
         get_cache_key = self.cache_key_builder
 
@@ -359,7 +359,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj, objtype=None):
         cached_method = getattr(obj, self.cached_method_name)
-        cache = cached_method.cache  # type: DeferredCache[CacheKey, Any]
+        cache: DeferredCache[CacheKey, Any] = cached_method.cache
         num_args = cached_method.num_args
 
         @functools.wraps(self.orig)
@@ -472,15 +472,15 @@ class _CacheContext:
 
     Cache = Union[DeferredCache, LruCache]
 
-    _cache_context_objects = (
-        WeakValueDictionary()
-    )  # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext]
+    _cache_context_objects: """WeakValueDictionary[
+        Tuple["_CacheContext.Cache", CacheKey], "_CacheContext"
+    ]""" = WeakValueDictionary()
 
     def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
         self._cache = cache
         self._cache_key = cache_key
 
-    def invalidate(self):  # type: () -> None
+    def invalidate(self) -> None:
         """Invalidates the cache entry referred to by the context."""
         self._cache.invalidate(self._cache_key)
 
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 56d94d96ce..3f852edd7f 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -62,13 +62,13 @@ class DictionaryCache(Generic[KT, DKT]):
     """
 
     def __init__(self, name: str, max_entries: int = 1000):
-        self.cache = LruCache(
+        self.cache: LruCache[KT, DictionaryEntry] = LruCache(
             max_size=max_entries, cache_name=name, size_callback=len
-        )  # type: LruCache[KT, DictionaryEntry]
+        )
 
         self.name = name
         self.sequence = 0
-        self.thread = None  # type: Optional[threading.Thread]
+        self.thread: Optional[threading.Thread] = None
 
     def check_thread(self) -> None:
         expected_thread = self.thread
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index ac47a31cd7..bde16b8577 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -27,7 +27,7 @@ from synapse.util.caches import register_cache
 logger = logging.getLogger(__name__)
 
 
-SENTINEL = object()  # type: Any
+SENTINEL: Any = object()
 
 
 T = TypeVar("T")
@@ -71,7 +71,7 @@ class ExpiringCache(Generic[KT, VT]):
         self._expiry_ms = expiry_ms
         self._reset_expiry_on_get = reset_expiry_on_get
 
-        self._cache = OrderedDict()  # type: OrderedDict[KT, _CacheEntry]
+        self._cache: OrderedDict[KT, _CacheEntry] = OrderedDict()
 
         self.iterable = iterable
 
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 4b9d0433ff..5c65d187b6 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -90,8 +90,7 @@ def enumerate_leaves(node, depth):
         yield node
     else:
         for n in node.values():
-            for m in enumerate_leaves(n, depth - 1):
-                yield m
+            yield from enumerate_leaves(n, depth - 1)
 
 
 P = TypeVar("P")
@@ -226,7 +225,7 @@ class _Node:
         # footprint down. Storing `None` is free as its a singleton, while empty
         # lists are 56 bytes (and empty sets are 216 bytes, if we did the naive
         # thing and used sets).
-        self.callbacks = None  # type: Optional[List[Callable[[], None]]]
+        self.callbacks: Optional[List[Callable[[], None]]] = None
 
         self.add_callbacks(callbacks)
 
@@ -362,15 +361,15 @@ class LruCache(Generic[KT, VT]):
 
         # register_cache might call our "set_cache_factor" callback; there's nothing to
         # do yet when we get resized.
-        self._on_resize = None  # type: Optional[Callable[[],None]]
+        self._on_resize: Optional[Callable[[], None]] = None
 
         if cache_name is not None:
-            metrics = register_cache(
+            metrics: Optional[CacheMetric] = register_cache(
                 "lru_cache",
                 cache_name,
                 self,
                 collect_callback=metrics_collection_callback,
-            )  # type: Optional[CacheMetric]
+            )
         else:
             metrics = None
 
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 34c662c4db..ed7204336f 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -66,7 +66,7 @@ class ResponseCache(Generic[KV]):
         # This is poorly-named: it includes both complete and incomplete results.
         # We keep complete results rather than switching to absolute values because
         # that makes it easier to cache Failure results.
-        self.pending_result_cache = {}  # type: Dict[KV, ObservableDeferred]
+        self.pending_result_cache: Dict[KV, ObservableDeferred] = {}
 
         self.clock = clock
         self.timeout_sec = timeout_ms / 1000.0
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index e81e468899..3a41a8baa6 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -45,10 +45,10 @@ class StreamChangeCache:
     ):
         self._original_max_size = max_size
         self._max_size = math.floor(max_size)
-        self._entity_to_key = {}  # type: Dict[EntityType, int]
+        self._entity_to_key: Dict[EntityType, int] = {}
 
         # map from stream id to the a set of entities which changed at that stream id.
-        self._cache = SortedDict()  # type: SortedDict[int, Set[EntityType]]
+        self._cache: SortedDict[int, Set[EntityType]] = SortedDict()
 
         # the earliest stream_pos for which we can reliably answer
         # get_all_entities_changed. In other words, one less than the earliest
@@ -155,7 +155,7 @@ class StreamChangeCache:
         if stream_pos < self._earliest_known_stream_pos:
             return None
 
-        changed_entities = []  # type: List[EntityType]
+        changed_entities: List[EntityType] = []
 
         for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
             changed_entities.extend(self._cache[k])
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index a6df81ebff..4138931e7b 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -138,7 +138,6 @@ def iterate_tree_cache_entry(d):
     """
     if isinstance(d, TreeCacheNode):
         for value_d in d.values():
-            for value in iterate_tree_cache_entry(value_d):
-                yield value
+            yield from iterate_tree_cache_entry(value_d)
     else:
         yield d
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index c276107d56..46afe3f934 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -23,7 +23,7 @@ from synapse.util.caches import register_cache
 
 logger = logging.getLogger(__name__)
 
-SENTINEL = object()  # type: Any
+SENTINEL: Any = object()
 
 T = TypeVar("T")
 KT = TypeVar("KT")
@@ -35,10 +35,10 @@ class TTLCache(Generic[KT, VT]):
 
     def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
         # map from key to _CacheEntry
-        self._data = {}  # type: Dict[KT, _CacheEntry]
+        self._data: Dict[KT, _CacheEntry] = {}
 
         # the _CacheEntries, sorted by expiry time
-        self._expiry_list = SortedList()  # type: SortedList[_CacheEntry]
+        self._expiry_list: SortedList[_CacheEntry] = SortedList()
 
         self._timer = timer
 
diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py
index 31b24dd188..d8532411c2 100644
--- a/synapse/util/daemonize.py
+++ b/synapse/util/daemonize.py
@@ -31,13 +31,13 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
     # If pidfile already exists, we should read pid from there; to overwrite it, if
     # locking will fail, because locking attempt somehow purges the file contents.
     if os.path.isfile(pid_file):
-        with open(pid_file, "r") as pid_fh:
+        with open(pid_file) as pid_fh:
             old_pid = pid_fh.read()
 
     # Create a lockfile so that only one instance of this daemon is running at any time.
     try:
         lock_fh = open(pid_file, "w")
-    except IOError:
+    except OSError:
         print("Unable to create the pidfile.")
         sys.exit(1)
 
@@ -45,7 +45,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
         # Try to get an exclusive lock on the file. This will fail if another process
         # has the file locked.
         fcntl.flock(lock_fh, fcntl.LOCK_EX | fcntl.LOCK_NB)
-    except IOError:
+    except OSError:
         print("Unable to lock on the pidfile.")
         # We need to overwrite the pidfile if we got here.
         #
@@ -113,7 +113,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
     try:
         lock_fh.write("%s" % (os.getpid()))
         lock_fh.flush()
-    except IOError:
+    except OSError:
         logger.error("Unable to write pid to the pidfile.")
         print("Unable to write pid to the pidfile.")
         sys.exit(1)
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index 886afa9d19..8ac3eab2f5 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -68,7 +68,7 @@ def sorted_topologically(
     # This is implemented by Kahn's algorithm.
 
     degree_map = {node: 0 for node in nodes}
-    reverse_graph = {}  # type: Dict[T, Set[T]]
+    reverse_graph: Dict[T, Set[T]] = {}
 
     for node, edges in graph.items():
         if node not in degree_map:
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
index f6ebfd7e7d..d1f76e3dc5 100644
--- a/synapse/util/macaroons.py
+++ b/synapse/util/macaroons.py
@@ -39,7 +39,7 @@ def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
              caveat in the macaroon, or if the caveat was not found in the macaroon.
     """
     prefix = key + " = "
-    result = None  # type: Optional[str]
+    result: Optional[str] = None
     for caveat in macaroon.caveats:
         if not caveat.caveat_id.startswith(prefix):
             continue
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 45353d41c5..1b82dca81b 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -124,7 +124,7 @@ class Measure:
             assert isinstance(curr_context, LoggingContext)
             parent_context = curr_context
         self._logging_context = LoggingContext(str(curr_context), parent_context)
-        self.start = None  # type: Optional[int]
+        self.start: Optional[int] = None
 
     def __enter__(self) -> "Measure":
         if self.start is not None:
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index eed0291cae..99f01e325c 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -41,7 +41,7 @@ def do_patch():
         @functools.wraps(f)
         def wrapped(*args, **kwargs):
             start_context = current_context()
-            changes = []  # type: List[str]
+            changes: List[str] = []
             orig = orig_inline_callbacks(_check_yield_points(f, changes))
 
             try:
@@ -131,7 +131,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
         gen = f(*args, **kwargs)
 
         last_yield_line_no = gen.gi_frame.f_lineno
-        result = None  # type: Any
+        result: Any = None
         while True:
             expected_context = current_context()
 
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 490fb26e81..17532059e9 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -90,13 +90,13 @@ async def filter_events_for_client(
         AccountDataTypes.IGNORED_USER_LIST, user_id
     )
 
-    ignore_list = frozenset()  # type: FrozenSet[str]
+    ignore_list: FrozenSet[str] = frozenset()
     if ignore_dict_content:
         ignored_users_dict = ignore_dict_content.get("ignored_users", {})
         if isinstance(ignored_users_dict, dict):
             ignore_list = frozenset(ignored_users_dict.keys())
 
-    erased_senders = await storage.main.are_users_erased((e.sender for e in events))
+    erased_senders = await storage.main.are_users_erased(e.sender for e in events)
 
     if filter_send_to_client:
         room_ids = {e.room_id for e in events}
@@ -353,7 +353,7 @@ async def filter_events_for_server(
         )
 
     if not check_history_visibility_only:
-        erased_senders = await storage.main.are_users_erased((e.sender for e in events))
+        erased_senders = await storage.main.are_users_erased(e.sender for e in events)
     else:
         # We don't want to check whether users are erased, which is equivalent
         # to no users having been erased.