summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/_base.py8
-rw-r--r--synapse/handlers/account_validity.py128
-rw-r--r--synapse/handlers/admin.py11
-rw-r--r--synapse/handlers/appservice.py6
-rw-r--r--synapse/handlers/auth.py16
-rw-r--r--synapse/handlers/cas.py4
-rw-r--r--synapse/handlers/device.py14
-rw-r--r--synapse/handlers/devicemessage.py2
-rw-r--r--synapse/handlers/directory.py11
-rw-r--r--synapse/handlers/e2e_keys.py40
-rw-r--r--synapse/handlers/event_auth.py62
-rw-r--r--synapse/handlers/events.py6
-rw-r--r--synapse/handlers/federation.py85
-rw-r--r--synapse/handlers/groups_local.py4
-rw-r--r--synapse/handlers/identity.py4
-rw-r--r--synapse/handlers/initial_sync.py14
-rw-r--r--synapse/handlers/message.py52
-rw-r--r--synapse/handlers/oidc.py56
-rw-r--r--synapse/handlers/pagination.py4
-rw-r--r--synapse/handlers/presence.py28
-rw-r--r--synapse/handlers/profile.py4
-rw-r--r--synapse/handlers/receipts.py19
-rw-r--r--synapse/handlers/register.py20
-rw-r--r--synapse/handlers/room.py19
-rw-r--r--synapse/handlers/room_list.py44
-rw-r--r--synapse/handlers/saml.py8
-rw-r--r--synapse/handlers/search.py8
-rw-r--r--synapse/handlers/space_summary.py88
-rw-r--r--synapse/handlers/sso.py12
-rw-r--r--synapse/handlers/stats.py37
-rw-r--r--synapse/handlers/sync.py34
-rw-r--r--synapse/handlers/typing.py28
-rw-r--r--synapse/handlers/user_directory.py2
33 files changed, 577 insertions, 301 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index d800e16912..525f3d39b1 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -38,10 +38,10 @@ class BaseHandler:
     """
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()  # type: synapse.storage.DataStore
+        self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.notifier = hs.get_notifier()
-        self.state_handler = hs.get_state_handler()  # type: synapse.state.StateHandler
+        self.state_handler = hs.get_state_handler()
         self.distributor = hs.get_distributor()
         self.clock = hs.get_clock()
         self.hs = hs
@@ -55,12 +55,12 @@ class BaseHandler:
         # Check whether ratelimiting room admin message redaction is enabled
         # by the presence of rate limits in the config
         if self.hs.config.rc_admin_redaction:
-            self.admin_redaction_ratelimiter = Ratelimiter(
+            self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
                 store=self.store,
                 clock=self.clock,
                 rate_hz=self.hs.config.rc_admin_redaction.per_second,
                 burst_count=self.hs.config.rc_admin_redaction.burst_count,
-            )  # type: Optional[Ratelimiter]
+            )
         else:
             self.admin_redaction_ratelimiter = None
 
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index d752cf34f0..078accd634 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -15,9 +15,11 @@
 import email.mime.multipart
 import email.utils
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
 
-from synapse.api.errors import StoreError, SynapseError
+from twisted.web.http import Request
+
+from synapse.api.errors import AuthError, StoreError, SynapseError
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.types import UserID
 from synapse.util import stringutils
@@ -27,6 +29,15 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+# Types for callbacks to be registered via the module api
+IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]]
+ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable]
+# Temporary hooks to allow for a transition from `/_matrix/client` endpoints
+# to `/_synapse/client/account_validity`. See `register_account_validity_callbacks`.
+ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable]
+ON_LEGACY_RENEW_CALLBACK = Callable[[str], Awaitable[Tuple[bool, bool, int]]]
+ON_LEGACY_ADMIN_REQUEST = Callable[[Request], Awaitable]
+
 
 class AccountValidityHandler:
     def __init__(self, hs: "HomeServer"):
@@ -70,6 +81,99 @@ class AccountValidityHandler:
             if hs.config.run_background_tasks:
                 self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
 
+        self._is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = []
+        self._on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = []
+        self._on_legacy_send_mail_callback: Optional[
+            ON_LEGACY_SEND_MAIL_CALLBACK
+        ] = None
+        self._on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None
+
+        # The legacy admin requests callback isn't a protected attribute because we need
+        # to access it from the admin servlet, which is outside of this handler.
+        self.on_legacy_admin_request_callback: Optional[ON_LEGACY_ADMIN_REQUEST] = None
+
+    def register_account_validity_callbacks(
+        self,
+        is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
+        on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
+        on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
+        on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
+        on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
+    ):
+        """Register callbacks from module for each hook."""
+        if is_user_expired is not None:
+            self._is_user_expired_callbacks.append(is_user_expired)
+
+        if on_user_registration is not None:
+            self._on_user_registration_callbacks.append(on_user_registration)
+
+        # The builtin account validity feature exposes 3 endpoints (send_mail, renew, and
+        # an admin one). As part of moving the feature into a module, we need to change
+        # the path from /_matrix/client/unstable/account_validity/... to
+        # /_synapse/client/account_validity, because:
+        #
+        #   * the feature isn't part of the Matrix spec thus shouldn't live under /_matrix
+        #   * the way we register servlets means that modules can't register resources
+        #     under /_matrix/client
+        #
+        # We need to allow for a transition period between the old and new endpoints
+        # in order to allow for clients to update (and for emails to be processed).
+        #
+        # Once the email-account-validity module is loaded, it will take control of account
+        # validity by moving the rows from our `account_validity` table into its own table.
+        #
+        # Therefore, we need to allow modules (in practice just the one implementing the
+        # email-based account validity) to temporarily hook into the legacy endpoints so we
+        # can route the traffic coming into the old endpoints into the module, which is
+        # why we have the following three temporary hooks.
+        if on_legacy_send_mail is not None:
+            if self._on_legacy_send_mail_callback is not None:
+                raise RuntimeError("Tried to register on_legacy_send_mail twice")
+
+            self._on_legacy_send_mail_callback = on_legacy_send_mail
+
+        if on_legacy_renew is not None:
+            if self._on_legacy_renew_callback is not None:
+                raise RuntimeError("Tried to register on_legacy_renew twice")
+
+            self._on_legacy_renew_callback = on_legacy_renew
+
+        if on_legacy_admin_request is not None:
+            if self.on_legacy_admin_request_callback is not None:
+                raise RuntimeError("Tried to register on_legacy_admin_request twice")
+
+            self.on_legacy_admin_request_callback = on_legacy_admin_request
+
+    async def is_user_expired(self, user_id: str) -> bool:
+        """Checks if a user has expired against third-party modules.
+
+        Args:
+            user_id: The user to check the expiry of.
+
+        Returns:
+            Whether the user has expired.
+        """
+        for callback in self._is_user_expired_callbacks:
+            expired = await callback(user_id)
+            if expired is not None:
+                return expired
+
+        if self._account_validity_enabled:
+            # If no module could determine whether the user has expired and the legacy
+            # configuration is enabled, fall back to it.
+            return await self.store.is_account_expired(user_id, self.clock.time_msec())
+
+        return False
+
+    async def on_user_registration(self, user_id: str):
+        """Tell third-party modules about a user's registration.
+
+        Args:
+            user_id: The ID of the newly registered user.
+        """
+        for callback in self._on_user_registration_callbacks:
+            await callback(user_id)
+
     @wrap_as_background_process("send_renewals")
     async def _send_renewal_emails(self) -> None:
         """Gets the list of users whose account is expiring in the amount of time
@@ -95,6 +199,17 @@ class AccountValidityHandler:
         Raises:
             SynapseError if the user is not set to renew.
         """
+        # If a module supports sending a renewal email from here, do that, otherwise do
+        # the legacy dance.
+        if self._on_legacy_send_mail_callback is not None:
+            await self._on_legacy_send_mail_callback(user_id)
+            return
+
+        if not self._account_validity_renew_by_email_enabled:
+            raise AuthError(
+                403, "Account renewal via email is disabled on this server."
+            )
+
         expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
 
         # If this user isn't set to be expired, raise an error.
@@ -209,6 +324,10 @@ class AccountValidityHandler:
         token is considered stale. A token is stale if the 'token_used_ts_ms' db column
         is non-null.
 
+        This method exists to support handling the legacy account validity /renew
+        endpoint. If a module implements the on_legacy_renew callback, then this process
+        is delegated to the module instead.
+
         Args:
             renewal_token: Token sent with the renewal request.
         Returns:
@@ -218,6 +337,11 @@ class AccountValidityHandler:
               * An int representing the user's expiry timestamp as milliseconds since the
                 epoch, or 0 if the token was invalid.
         """
+        # If a module supports triggering a renew from here, do that, otherwise do the
+        # legacy dance.
+        if self._on_legacy_renew_callback is not None:
+            return await self._on_legacy_renew_callback(renewal_token)
+
         try:
             (
                 user_id,
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index f72ded038e..bfa7f2c545 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -62,9 +62,16 @@ class AdminHandler(BaseHandler):
         if ret:
             profile = await self.store.get_profileinfo(user.localpart)
             threepids = await self.store.user_get_threepids(user.to_string())
+            external_ids = [
+                ({"auth_provider": auth_provider, "external_id": external_id})
+                for auth_provider, external_id in await self.store.get_external_ids_by_user(
+                    user.to_string()
+                )
+            ]
             ret["displayname"] = profile.display_name
             ret["avatar_url"] = profile.avatar_url
             ret["threepids"] = threepids
+            ret["external_ids"] = external_ids
         return ret
 
     async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
@@ -132,7 +139,7 @@ class AdminHandler(BaseHandler):
             to_key = RoomStreamToken(None, stream_ordering)
 
             # Events that we've processed in this room
-            written_events = set()  # type: Set[str]
+            written_events: Set[str] = set()
 
             # We need to track gaps in the events stream so that we can then
             # write out the state at those events. We do this by keeping track
@@ -145,7 +152,7 @@ class AdminHandler(BaseHandler):
             # The reverse mapping to above, i.e. map from unseen event to events
             # that have the unseen event in their prev_events, i.e. the unseen
             # events "children".
-            unseen_to_child_events = {}  # type: Dict[str, Set[str]]
+            unseen_to_child_events: Dict[str, Set[str]] = {}
 
             # We fetch events in the room the user could see by fetching *all*
             # events that we have and then filtering, this isn't the most
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 862638cc4f..21a17cd2e8 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -96,7 +96,7 @@ class ApplicationServicesHandler:
                         self.current_max, limit
                     )
 
-                    events_by_room = {}  # type: Dict[str, List[EventBase]]
+                    events_by_room: Dict[str, List[EventBase]] = {}
                     for event in events:
                         events_by_room.setdefault(event.room_id, []).append(event)
 
@@ -275,7 +275,7 @@ class ApplicationServicesHandler:
     async def _handle_presence(
         self, service: ApplicationService, users: Collection[Union[str, UserID]]
     ) -> List[JsonDict]:
-        events = []  # type: List[JsonDict]
+        events: List[JsonDict] = []
         presence_source = self.event_sources.sources["presence"]
         from_key = await self.store.get_type_stream_id_for_appservice(
             service, "presence"
@@ -375,7 +375,7 @@ class ApplicationServicesHandler:
         self, only_protocol: Optional[str] = None
     ) -> Dict[str, JsonDict]:
         services = self.store.get_app_services()
-        protocols = {}  # type: Dict[str, List[JsonDict]]
+        protocols: Dict[str, List[JsonDict]] = {}
 
         # Collect up all the individual protocol responses out of the ASes
         for s in services:
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index e2ac595a62..22a8552241 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -191,7 +191,7 @@ class AuthHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.checkers = {}  # type: Dict[str, UserInteractiveAuthChecker]
+        self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
         for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
             inst = auth_checker_class(hs)
             if inst.is_enabled():
@@ -296,7 +296,7 @@ class AuthHandler(BaseHandler):
 
         # A mapping of user ID to extra attributes to include in the login
         # response.
-        self._extra_attributes = {}  # type: Dict[str, SsoLoginExtraAttributes]
+        self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {}
 
     async def validate_user_via_ui_auth(
         self,
@@ -500,7 +500,7 @@ class AuthHandler(BaseHandler):
                 all the stages in any of the permitted flows.
         """
 
-        sid = None  # type: Optional[str]
+        sid: Optional[str] = None
         authdict = clientdict.pop("auth", {})
         if "session" in authdict:
             sid = authdict["session"]
@@ -588,9 +588,9 @@ class AuthHandler(BaseHandler):
             )
 
         # check auth type currently being presented
-        errordict = {}  # type: Dict[str, Any]
+        errordict: Dict[str, Any] = {}
         if "type" in authdict:
-            login_type = authdict["type"]  # type: str
+            login_type: str = authdict["type"]
             try:
                 result = await self._check_auth_dict(authdict, clientip)
                 if result:
@@ -766,7 +766,7 @@ class AuthHandler(BaseHandler):
             LoginType.TERMS: self._get_params_terms,
         }
 
-        params = {}  # type: Dict[str, Any]
+        params: Dict[str, Any] = {}
 
         for f in public_flows:
             for stage in f:
@@ -1530,9 +1530,9 @@ class AuthHandler(BaseHandler):
         except StoreError:
             raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
 
-        user_id_to_verify = await self.get_session_data(
+        user_id_to_verify: str = await self.get_session_data(
             session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
-        )  # type: str
+        )
 
         idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
             user_id_to_verify
diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py
index 7346ccfe93..0325f86e20 100644
--- a/synapse/handlers/cas.py
+++ b/synapse/handlers/cas.py
@@ -40,7 +40,7 @@ class CasError(Exception):
 
     def __str__(self):
         if self.error_description:
-            return "{}: {}".format(self.error, self.error_description)
+            return f"{self.error}: {self.error_description}"
         return self.error
 
 
@@ -171,7 +171,7 @@ class CasHandler:
 
         # Iterate through the nodes and pull out the user and any extra attributes.
         user = None
-        attributes = {}  # type: Dict[str, List[Optional[str]]]
+        attributes: Dict[str, List[Optional[str]]] = {}
         for child in root[0]:
             if child.tag.endswith("user"):
                 user = child.text
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 95bdc5902a..46ee834407 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -452,7 +452,7 @@ class DeviceHandler(DeviceWorkerHandler):
             user_id
         )
 
-        hosts = set()  # type: Set[str]
+        hosts: Set[str] = set()
         if self.hs.is_mine_id(user_id):
             hosts.update(get_domain_from_id(u) for u in users_who_share_room)
             hosts.discard(self.server_name)
@@ -613,20 +613,20 @@ class DeviceListUpdater:
         self._remote_edu_linearizer = Linearizer(name="remote_device_list")
 
         # user_id -> list of updates waiting to be handled.
-        self._pending_updates = (
-            {}
-        )  # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]]
+        self._pending_updates: Dict[
+            str, List[Tuple[str, str, Iterable[str], JsonDict]]
+        ] = {}
 
         # Recently seen stream ids. We don't bother keeping these in the DB,
         # but they're useful to have them about to reduce the number of spurious
         # resyncs.
-        self._seen_updates = ExpiringCache(
+        self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache(
             cache_name="device_update_edu",
             clock=self.clock,
             max_len=10000,
             expiry_ms=30 * 60 * 1000,
             iterable=True,
-        )  # type: ExpiringCache[str, Set[str]]
+        )
 
         # Attempt to resync out of sync device lists every 30s.
         self._resync_retry_in_progress = False
@@ -755,7 +755,7 @@ class DeviceListUpdater:
         """Given a list of updates for a user figure out if we need to do a full
         resync, or whether we have enough data that we can just apply the delta.
         """
-        seen_updates = self._seen_updates.get(user_id, set())  # type: Set[str]
+        seen_updates: Set[str] = self._seen_updates.get(user_id, set())
 
         extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
 
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 580b941595..679b47f081 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -203,7 +203,7 @@ class DeviceMessageHandler:
         log_kv({"number_of_to_device_messages": len(messages)})
         set_tag("sender", sender_user_id)
         local_messages = {}
-        remote_messages = {}  # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
+        remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
         for user_id, by_device in messages.items():
             # Ratelimit local cross-user key requests by the sending device.
             if (
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 4064a2b859..d487fee627 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -22,6 +22,7 @@ from synapse.api.errors import (
     CodeMessageException,
     Codes,
     NotFoundError,
+    RequestSendFailed,
     ShadowBanError,
     StoreError,
     SynapseError,
@@ -236,9 +237,9 @@ class DirectoryHandler(BaseHandler):
     async def get_association(self, room_alias: RoomAlias) -> JsonDict:
         room_id = None
         if self.hs.is_mine(room_alias):
-            result = await self.get_association_from_room_alias(
-                room_alias
-            )  # type: Optional[RoomAliasMapping]
+            result: Optional[
+                RoomAliasMapping
+            ] = await self.get_association_from_room_alias(room_alias)
 
             if result:
                 room_id = result.room_id
@@ -252,12 +253,14 @@ class DirectoryHandler(BaseHandler):
                     retry_on_dns_fail=False,
                     ignore_backoff=True,
                 )
+            except RequestSendFailed:
+                raise SynapseError(502, "Failed to fetch alias")
             except CodeMessageException as e:
                 logging.warning("Error retrieving alias")
                 if e.code == 404:
                     fed_result = None
                 else:
-                    raise
+                    raise SynapseError(502, "Failed to fetch alias")
 
             if fed_result and "room_id" in fed_result and "servers" in fed_result:
                 room_id = fed_result["room_id"]
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 3972849d4d..d92370859f 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -115,9 +115,9 @@ class E2eKeysHandler:
                 the number of in-flight queries at a time.
         """
         with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
-            device_keys_query = query_body.get(
+            device_keys_query: Dict[str, Iterable[str]] = query_body.get(
                 "device_keys", {}
-            )  # type: Dict[str, Iterable[str]]
+            )
 
             # separate users by domain.
             # make a map from domain to user_id to device_ids
@@ -136,7 +136,7 @@ class E2eKeysHandler:
 
             # First get local devices.
             # A map of destination -> failure response.
-            failures = {}  # type: Dict[str, JsonDict]
+            failures: Dict[str, JsonDict] = {}
             results = {}
             if local_query:
                 local_result = await self.query_local_devices(local_query)
@@ -151,11 +151,9 @@ class E2eKeysHandler:
 
             # Now attempt to get any remote devices from our local cache.
             # A map of destination -> user ID -> device IDs.
-            remote_queries_not_in_cache = (
-                {}
-            )  # type: Dict[str, Dict[str, Iterable[str]]]
+            remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
             if remote_queries:
-                query_list = []  # type: List[Tuple[str, Optional[str]]]
+                query_list: List[Tuple[str, Optional[str]]] = []
                 for user_id, device_ids in remote_queries.items():
                     if device_ids:
                         query_list.extend(
@@ -362,9 +360,9 @@ class E2eKeysHandler:
             A map from user_id -> device_id -> device details
         """
         set_tag("local_query", query)
-        local_query = []  # type: List[Tuple[str, Optional[str]]]
+        local_query: List[Tuple[str, Optional[str]]] = []
 
-        result_dict = {}  # type: Dict[str, Dict[str, dict]]
+        result_dict: Dict[str, Dict[str, dict]] = {}
         for user_id, device_ids in query.items():
             # we use UserID.from_string to catch invalid user ids
             if not self.is_mine(UserID.from_string(user_id)):
@@ -402,9 +400,9 @@ class E2eKeysHandler:
         self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
     ) -> JsonDict:
         """Handle a device key query from a federated server"""
-        device_keys_query = query_body.get(
+        device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
             "device_keys", {}
-        )  # type: Dict[str, Optional[List[str]]]
+        )
         res = await self.query_local_devices(device_keys_query)
         ret = {"device_keys": res}
 
@@ -421,8 +419,8 @@ class E2eKeysHandler:
     async def claim_one_time_keys(
         self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
     ) -> JsonDict:
-        local_query = []  # type: List[Tuple[str, str, str]]
-        remote_queries = {}  # type: Dict[str, Dict[str, Dict[str, str]]]
+        local_query: List[Tuple[str, str, str]] = []
+        remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
 
         for user_id, one_time_keys in query.get("one_time_keys", {}).items():
             # we use UserID.from_string to catch invalid user ids
@@ -439,8 +437,8 @@ class E2eKeysHandler:
         results = await self.store.claim_e2e_one_time_keys(local_query)
 
         # A map of user ID -> device ID -> key ID -> key.
-        json_result = {}  # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
-        failures = {}  # type: Dict[str, JsonDict]
+        json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+        failures: Dict[str, JsonDict] = {}
         for user_id, device_keys in results.items():
             for device_id, keys in device_keys.items():
                 for key_id, json_str in keys.items():
@@ -768,8 +766,8 @@ class E2eKeysHandler:
         Raises:
             SynapseError: if the input is malformed
         """
-        signature_list = []  # type: List[SignatureListItem]
-        failures = {}  # type: Dict[str, Dict[str, JsonDict]]
+        signature_list: List["SignatureListItem"] = []
+        failures: Dict[str, Dict[str, JsonDict]] = {}
         if not signatures:
             return signature_list, failures
 
@@ -930,8 +928,8 @@ class E2eKeysHandler:
         Raises:
             SynapseError: if the input is malformed
         """
-        signature_list = []  # type: List[SignatureListItem]
-        failures = {}  # type: Dict[str, Dict[str, JsonDict]]
+        signature_list: List["SignatureListItem"] = []
+        failures: Dict[str, Dict[str, JsonDict]] = {}
         if not signatures:
             return signature_list, failures
 
@@ -1300,7 +1298,7 @@ class SigningKeyEduUpdater:
         self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
 
         # user_id -> list of updates waiting to be handled.
-        self._pending_updates = {}  # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
+        self._pending_updates: Dict[str, List[Tuple[JsonDict, JsonDict]]] = {}
 
     async def incoming_signing_key_update(
         self, origin: str, edu_content: JsonDict
@@ -1349,7 +1347,7 @@ class SigningKeyEduUpdater:
                 # This can happen since we batch updates
                 return
 
-            device_ids = []  # type: List[str]
+            device_ids: List[str] = []
 
             logger.info("pending updates: %r", pending_updates)
 
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 989996b628..41dbdfd0a1 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -11,8 +11,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import TYPE_CHECKING, Collection, Optional
+from typing import TYPE_CHECKING, Collection, List, Optional, Union
 
+from synapse import event_auth
 from synapse.api.constants import (
     EventTypes,
     JoinRules,
@@ -20,9 +21,11 @@ from synapse.api.constants import (
     RestrictedJoinRuleTypes,
 )
 from synapse.api.errors import AuthError
-from synapse.api.room_versions import RoomVersion
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase
+from synapse.events.builder import EventBuilder
 from synapse.types import StateMap
+from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -34,8 +37,63 @@ class EventAuthHandler:
     """
 
     def __init__(self, hs: "HomeServer"):
+        self._clock = hs.get_clock()
         self._store = hs.get_datastore()
 
+    async def check_from_context(
+        self, room_version: str, event, context, do_sig_check=True
+    ) -> None:
+        auth_event_ids = event.auth_event_ids()
+        auth_events_by_id = await self._store.get_events(auth_event_ids)
+        auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
+
+        room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+        event_auth.check(
+            room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
+        )
+
+    def compute_auth_events(
+        self,
+        event: Union[EventBase, EventBuilder],
+        current_state_ids: StateMap[str],
+        for_verification: bool = False,
+    ) -> List[str]:
+        """Given an event and current state return the list of event IDs used
+        to auth an event.
+
+        If `for_verification` is False then only return auth events that
+        should be added to the event's `auth_events`.
+
+        Returns:
+            List of event IDs.
+        """
+
+        if event.type == EventTypes.Create:
+            return []
+
+        # Currently we ignore the `for_verification` flag even though there are
+        # some situations where we can drop particular auth events when adding
+        # to the event's `auth_events` (e.g. joins pointing to previous joins
+        # when room is publicly joinable). Dropping event IDs has the
+        # advantage that the auth chain for the room grows slower, but we use
+        # the auth chain in state resolution v2 to order events, which means
+        # care must be taken if dropping events to ensure that it doesn't
+        # introduce undesirable "state reset" behaviour.
+        #
+        # All of which sounds a bit tricky so we don't bother for now.
+
+        auth_ids = []
+        for etype, state_key in event_auth.auth_types_for_event(event):
+            auth_ev_id = current_state_ids.get((etype, state_key))
+            if auth_ev_id:
+                auth_ids.append(auth_ev_id)
+
+        return auth_ids
+
+    async def check_host_in_room(self, room_id: str, host: str) -> bool:
+        with Measure(self._clock, "check_host_in_room"):
+            return await self._store.is_host_joined(room_id, host)
+
     async def check_restricted_join_rules(
         self,
         state_ids: StateMap[str],
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index f134f1e234..4b3f037072 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -93,7 +93,7 @@ class EventStreamHandler(BaseHandler):
 
             # When the user joins a new room, or another user joins a currently
             # joined room, we need to send down presence for those users.
-            to_add = []  # type: List[JsonDict]
+            to_add: List[JsonDict] = []
             for event in events:
                 if not isinstance(event, EventBase):
                     continue
@@ -103,9 +103,9 @@ class EventStreamHandler(BaseHandler):
                     # Send down presence.
                     if event.state_key == auth_user_id:
                         # Send down presence for everyone in the room.
-                        users = await self.store.get_users_in_room(
+                        users: Iterable[str] = await self.store.get_users_in_room(
                             event.room_id
-                        )  # type: Iterable[str]
+                        )
                     else:
                         users = [event.state_key]
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index d929c65131..cf389be3e4 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -181,7 +181,7 @@ class FederationHandler(BaseHandler):
 
         # When joining a room we need to queue any events for that room up.
         # For each room, a list of (pdu, origin) tuples.
-        self.room_queues = {}  # type: Dict[str, List[Tuple[EventBase, str]]]
+        self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {}
         self._room_pdu_linearizer = Linearizer("fed_room_pdu")
 
         self._room_backfill = Linearizer("room_backfill")
@@ -250,7 +250,9 @@ class FederationHandler(BaseHandler):
         #
         # Note that if we were never in the room then we would have already
         # dropped the event, since we wouldn't know the room version.
-        is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
+        is_in_room = await self._event_auth_handler.check_host_in_room(
+            room_id, self.server_name
+        )
         if not is_in_room:
             logger.info(
                 "Ignoring PDU from %s as we're not in the room",
@@ -366,7 +368,7 @@ class FederationHandler(BaseHandler):
                     ours = await self.state_store.get_state_groups_ids(room_id, seen)
 
                     # state_maps is a list of mappings from (type, state_key) to event_id
-                    state_maps = list(ours.values())  # type: List[StateMap[str]]
+                    state_maps: List[StateMap[str]] = list(ours.values())
 
                     # we don't need this any more, let's delete it.
                     del ours
@@ -733,7 +735,7 @@ class FederationHandler(BaseHandler):
         # we need to make sure we re-load from the database to get the rejected
         # state correct.
         fetched_events.update(
-            (await self.store.get_events(missing_desired_events, allow_rejected=True))
+            await self.store.get_events(missing_desired_events, allow_rejected=True)
         )
 
         # check for events which were in the wrong room.
@@ -843,7 +845,7 @@ class FederationHandler(BaseHandler):
                 # exact key to expect. Otherwise check it matches any key we
                 # have for that device.
 
-                current_keys = []  # type: Container[str]
+                current_keys: Container[str] = []
 
                 if device:
                     keys = device.get("keys", {}).get("keys", {})
@@ -1183,7 +1185,7 @@ class FederationHandler(BaseHandler):
                 if e_type == EventTypes.Member and event.membership == Membership.JOIN
             ]
 
-            joined_domains = {}  # type: Dict[str, int]
+            joined_domains: Dict[str, int] = {}
             for u, d in joined_users:
                 try:
                     dom = get_domain_from_id(u)
@@ -1312,7 +1314,7 @@ class FederationHandler(BaseHandler):
 
         room_version = await self.store.get_room_version(room_id)
 
-        event_map = {}  # type: Dict[str, EventBase]
+        event_map: Dict[str, EventBase] = {}
 
         async def get_event(event_id: str):
             with nested_logging_context(event_id):
@@ -1412,12 +1414,15 @@ class FederationHandler(BaseHandler):
 
         Invites must be signed by the invitee's server before distribution.
         """
-        pdu = await self.federation_client.send_invite(
-            destination=target_host,
-            room_id=event.room_id,
-            event_id=event.event_id,
-            pdu=event,
-        )
+        try:
+            pdu = await self.federation_client.send_invite(
+                destination=target_host,
+                room_id=event.room_id,
+                event_id=event.event_id,
+                pdu=event,
+            )
+        except RequestSendFailed:
+            raise SynapseError(502, f"Can't connect to server {target_host}")
 
         return pdu
 
@@ -1591,7 +1596,7 @@ class FederationHandler(BaseHandler):
 
         # Ask the remote server to create a valid knock event for us. Once received,
         # we sign the event
-        params = {"ver": supported_room_versions}  # type: Dict[str, Iterable[str]]
+        params: Dict[str, Iterable[str]] = {"ver": supported_room_versions}
         origin, event, event_format_version = await self._make_and_verify_event(
             target_hosts, room_id, knockee, Membership.KNOCK, content, params=params
         )
@@ -1674,7 +1679,9 @@ class FederationHandler(BaseHandler):
         room_version = await self.store.get_room_version_id(room_id)
 
         # now check that we are *still* in the room
-        is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
+        is_in_room = await self._event_auth_handler.check_host_in_room(
+            room_id, self.server_name
+        )
         if not is_in_room:
             logger.info(
                 "Got /make_join request for room %s we are no longer in",
@@ -1705,7 +1712,7 @@ class FederationHandler(BaseHandler):
 
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
-        await self.auth.check_from_context(
+        await self._event_auth_handler.check_from_context(
             room_version, event, context, do_sig_check=False
         )
 
@@ -1877,7 +1884,7 @@ class FederationHandler(BaseHandler):
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_leave_request`
-            await self.auth.check_from_context(
+            await self._event_auth_handler.check_from_context(
                 room_version, event, context, do_sig_check=False
             )
         except AuthError as e:
@@ -1939,7 +1946,7 @@ class FederationHandler(BaseHandler):
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_knock_request`
-            await self.auth.check_from_context(
+            await self._event_auth_handler.check_from_context(
                 room_version, event, context, do_sig_check=False
             )
         except AuthError as e:
@@ -2111,7 +2118,7 @@ class FederationHandler(BaseHandler):
     async def on_backfill_request(
         self, origin: str, room_id: str, pdu_list: List[str], limit: int
     ) -> List[EventBase]:
-        in_room = await self.auth.check_host_in_room(room_id, origin)
+        in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
@@ -2146,7 +2153,9 @@ class FederationHandler(BaseHandler):
         )
 
         if event:
-            in_room = await self.auth.check_host_in_room(event.room_id, origin)
+            in_room = await self._event_auth_handler.check_host_in_room(
+                event.room_id, origin
+            )
             if not in_room:
                 raise AuthError(403, "Host not in room.")
 
@@ -2444,14 +2453,14 @@ class FederationHandler(BaseHandler):
             state_sets_d = await self.state_store.get_state_groups(
                 event.room_id, extrem_ids
             )
-            state_sets = list(state_sets_d.values())  # type: List[Iterable[EventBase]]
+            state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
             state_sets.append(state)
             current_states = await self.state_handler.resolve_events(
                 room_version, state_sets, event
             )
-            current_state_ids = {
+            current_state_ids: StateMap[str] = {
                 k: e.event_id for k, e in current_states.items()
-            }  # type: StateMap[str]
+            }
         else:
             current_state_ids = await self.state_handler.get_current_state_ids(
                 event.room_id, latest_event_ids=extrem_ids
@@ -2499,7 +2508,7 @@ class FederationHandler(BaseHandler):
         latest_events: List[str],
         limit: int,
     ) -> List[EventBase]:
-        in_room = await self.auth.check_host_in_room(room_id, origin)
+        in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
@@ -2562,7 +2571,7 @@ class FederationHandler(BaseHandler):
 
         if not auth_events:
             prev_state_ids = await context.get_prev_state_ids()
-            auth_events_ids = self.auth.compute_auth_events(
+            auth_events_ids = self._event_auth_handler.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
             auth_events_x = await self.store.get_events(auth_events_ids)
@@ -2808,7 +2817,7 @@ class FederationHandler(BaseHandler):
         """
         # exclude the state key of the new event from the current_state in the context.
         if event.is_state():
-            event_key = (event.type, event.state_key)  # type: Optional[Tuple[str, str]]
+            event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
         else:
             event_key = None
         state_updates = {
@@ -2991,7 +3000,7 @@ class FederationHandler(BaseHandler):
             "state_key": target_user_id,
         }
 
-        if await self.auth.check_host_in_room(room_id, self.hs.hostname):
+        if await self._event_auth_handler.check_host_in_room(room_id, self.hs.hostname):
             room_version = await self.store.get_room_version_id(room_id)
             builder = self.event_builder_factory.new(room_version, event_dict)
 
@@ -3011,7 +3020,9 @@ class FederationHandler(BaseHandler):
             event.internal_metadata.send_on_behalf_of = self.hs.hostname
 
             try:
-                await self.auth.check_from_context(room_version, event, context)
+                await self._event_auth_handler.check_from_context(
+                    room_version, event, context
+                )
             except AuthError as e:
                 logger.warning("Denying new third party invite %r because %s", event, e)
                 raise e
@@ -3023,9 +3034,13 @@ class FederationHandler(BaseHandler):
             await member_handler.send_membership_event(None, event, context)
         else:
             destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
-            await self.federation_client.forward_third_party_invite(
-                destinations, room_id, event_dict
-            )
+
+            try:
+                await self.federation_client.forward_third_party_invite(
+                    destinations, room_id, event_dict
+                )
+            except (RequestSendFailed, HttpResponseException):
+                raise SynapseError(502, "Failed to forward third party invite")
 
     async def on_exchange_third_party_invite_request(
         self, event_dict: JsonDict
@@ -3054,7 +3069,9 @@ class FederationHandler(BaseHandler):
         )
 
         try:
-            await self.auth.check_from_context(room_version, event, context)
+            await self._event_auth_handler.check_from_context(
+                room_version, event, context
+            )
         except AuthError as e:
             logger.warning("Denying third party invite %r because %s", event, e)
             raise e
@@ -3139,10 +3156,10 @@ class FederationHandler(BaseHandler):
 
         logger.debug("Checking auth on event %r", event.content)
 
-        last_exception = None  # type: Optional[Exception]
+        last_exception: Optional[Exception] = None
 
         # for each public key in the 3pid invite event
-        for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
+        for public_key_object in event_auth.get_public_keys(invite_event):
             try:
                 # for each sig on the third_party_invite block of the actual invite
                 for server, signature_block in signed["signatures"].items():
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 157f2ff218..1a6c5c64a2 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -214,7 +214,7 @@ class GroupsLocalWorkerHandler:
     async def bulk_get_publicised_groups(
         self, user_ids: Iterable[str], proxy: bool = True
     ) -> JsonDict:
-        destinations = {}  # type: Dict[str, Set[str]]
+        destinations: Dict[str, Set[str]] = {}
         local_users = set()
 
         for user_id in user_ids:
@@ -227,7 +227,7 @@ class GroupsLocalWorkerHandler:
             raise SynapseError(400, "Some user_ids are not local")
 
         results = {}
-        failed_results = []  # type: List[str]
+        failed_results: List[str] = []
         for destination, dest_user_ids in destinations.items():
             try:
                 r = await self.transport_client.bulk_get_publicised_groups(
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 33d16fbf9c..0961dec5ab 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -302,7 +302,7 @@ class IdentityHandler(BaseHandler):
             )
 
         url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
-        url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
+        url_bytes = b"/_matrix/identity/api/v1/3pid/unbind"
 
         content = {
             "mxid": mxid,
@@ -695,7 +695,7 @@ class IdentityHandler(BaseHandler):
                 return data["mxid"]
         except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
-        except IOError as e:
+        except OSError as e:
             logger.warning("Error from v1 identity server lookup: %s" % (e,))
 
         return None
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 76242865ae..5d49640760 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -46,9 +46,17 @@ class InitialSyncHandler(BaseHandler):
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
         self.validator = EventValidator()
-        self.snapshot_cache = ResponseCache(
-            hs.get_clock(), "initial_sync_cache"
-        )  # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
+        self.snapshot_cache: ResponseCache[
+            Tuple[
+                str,
+                Optional[StreamToken],
+                Optional[StreamToken],
+                str,
+                Optional[int],
+                bool,
+                bool,
+            ]
+        ] = ResponseCache(hs.get_clock(), "initial_sync_cache")
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index db12abd59d..c7fe4ff89e 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -81,7 +81,7 @@ class MessageHandler:
 
         # The scheduled call to self._expire_event. None if no call is currently
         # scheduled.
-        self._scheduled_expiry = None  # type: Optional[IDelayedCall]
+        self._scheduled_expiry: Optional[IDelayedCall] = None
 
         if not hs.config.worker_app:
             run_as_background_process(
@@ -196,9 +196,7 @@ class MessageHandler:
                 room_state_events = await self.state_store.get_state_for_events(
                     [event.event_id], state_filter=state_filter
                 )
-                room_state = room_state_events[
-                    event.event_id
-                ]  # type: Mapping[Any, EventBase]
+                room_state: Mapping[Any, EventBase] = room_state_events[event.event_id]
             else:
                 raise AuthError(
                     403,
@@ -385,6 +383,7 @@ class EventCreationHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
+        self._event_auth_handler = hs.get_event_auth_handler()
         self.store = hs.get_datastore()
         self.storage = hs.get_storage()
         self.state = hs.get_state_handler()
@@ -420,9 +419,9 @@ class EventCreationHandler:
         self.action_generator = hs.get_action_generator()
 
         self.spam_checker = hs.get_spam_checker()
-        self.third_party_event_rules = (
+        self.third_party_event_rules: "ThirdPartyEventRules" = (
             self.hs.get_third_party_event_rules()
-        )  # type: ThirdPartyEventRules
+        )
 
         self._block_events_without_consent_error = (
             self.config.block_events_without_consent_error
@@ -439,7 +438,7 @@ class EventCreationHandler:
         #
         # map from room id to time-of-last-attempt.
         #
-        self._rooms_to_exclude_from_dummy_event_insertion = {}  # type: Dict[str, int]
+        self._rooms_to_exclude_from_dummy_event_insertion: Dict[str, int] = {}
         # The number of forward extremeities before a dummy event is sent.
         self._dummy_events_threshold = hs.config.dummy_events_threshold
 
@@ -464,9 +463,7 @@ class EventCreationHandler:
         # Stores the state groups we've recently added to the joined hosts
         # external cache. Note that the timeout must be significantly less than
         # the TTL on the external cache.
-        self._external_cache_joined_hosts_updates = (
-            None
-        )  # type: Optional[ExpiringCache]
+        self._external_cache_joined_hosts_updates: Optional[ExpiringCache] = None
         if self._external_cache.is_enabled():
             self._external_cache_joined_hosts_updates = ExpiringCache(
                 "_external_cache_joined_hosts_updates",
@@ -509,12 +506,17 @@ class EventCreationHandler:
                 Should normally be left as None, which will cause them to be calculated
                 based on the room state at the prev_events.
 
+                If non-None, prev_event_ids must also be provided.
+
             require_consent: Whether to check if the requester has
                 consented to the privacy policy.
 
             outlier: Indicates whether the event is an `outlier`, i.e. if
                 it's from an arbitrary point and floating in the DAG as
                 opposed to being inline with the current DAG.
+            historical: Indicates whether the message is being inserted
+                back in time around some existing events. This is used to skip
+                a few checks and mark the event as backfilled.
             depth: Override the depth used to order the event in the DAG.
                 Should normally be set to None, which will cause the depth to be calculated
                 based on the prev_events.
@@ -581,6 +583,9 @@ class EventCreationHandler:
         # Strip down the auth_event_ids to only what we need to auth the event.
         # For example, we don't need extra m.room.member that don't match event.sender
         if auth_event_ids is not None:
+            # If auth events are provided, prev events must be also.
+            assert prev_event_ids is not None
+
             temp_event = await builder.build(
                 prev_event_ids=prev_event_ids,
                 auth_event_ids=auth_event_ids,
@@ -592,7 +597,7 @@ class EventCreationHandler:
                 (e.type, e.state_key): e.event_id for e in auth_events
             }
             # Actually strip down and use the necessary auth events
-            auth_event_ids = self.auth.compute_auth_events(
+            auth_event_ids = self._event_auth_handler.compute_auth_events(
                 event=temp_event,
                 current_state_ids=auth_event_state_map,
                 for_verification=False,
@@ -766,6 +771,7 @@ class EventCreationHandler:
         txn_id: Optional[str] = None,
         ignore_shadow_ban: bool = False,
         outlier: bool = False,
+        historical: bool = False,
         depth: Optional[int] = None,
     ) -> Tuple[EventBase, int]:
         """
@@ -784,6 +790,8 @@ class EventCreationHandler:
                 The event ids to use as the auth_events for the new event.
                 Should normally be left as None, which will cause them to be calculated
                 based on the room state at the prev_events.
+
+                If non-None, prev_event_ids must also be provided.
             ratelimit: Whether to rate limit this send.
             txn_id: The transaction ID.
             ignore_shadow_ban: True if shadow-banned users should be allowed to
@@ -791,6 +799,9 @@ class EventCreationHandler:
             outlier: Indicates whether the event is an `outlier`, i.e. if
                 it's from an arbitrary point and floating in the DAG as
                 opposed to being inline with the current DAG.
+            historical: Indicates whether the message is being inserted
+                back in time around some existing events. This is used to skip
+                a few checks and mark the event as backfilled.
             depth: Override the depth used to order the event in the DAG.
                 Should normally be set to None, which will cause the depth to be calculated
                 based on the prev_events.
@@ -839,6 +850,7 @@ class EventCreationHandler:
                 prev_event_ids=prev_event_ids,
                 auth_event_ids=auth_event_ids,
                 outlier=outlier,
+                historical=historical,
                 depth=depth,
             )
 
@@ -1049,7 +1061,9 @@ class EventCreationHandler:
             assert event.content["membership"] == Membership.LEAVE
         else:
             try:
-                await self.auth.check_from_context(room_version, event, context)
+                await self._event_auth_handler.check_from_context(
+                    room_version, event, context
+                )
             except AuthError as err:
                 logger.warning("Denying new event %r because %s", event, err)
                 raise err
@@ -1281,7 +1295,7 @@ class EventCreationHandler:
             # Validate a newly added alias or newly added alt_aliases.
 
             original_alias = None
-            original_alt_aliases = []  # type: List[str]
+            original_alt_aliases: List[str] = []
 
             original_event_id = event.unsigned.get("replaces_state")
             if original_event_id:
@@ -1374,7 +1388,7 @@ class EventCreationHandler:
                     raise AuthError(403, "Redacting server ACL events is not permitted")
 
             prev_state_ids = await context.get_prev_state_ids()
-            auth_events_ids = self.auth.compute_auth_events(
+            auth_events_ids = self._event_auth_handler.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
             auth_events_map = await self.store.get_events(auth_events_ids)
@@ -1584,11 +1598,13 @@ class EventCreationHandler:
         for k, v in original_event.internal_metadata.get_dict().items():
             setattr(builder.internal_metadata, k, v)
 
-        # the event type hasn't changed, so there's no point in re-calculating the
-        # auth events.
+        # modules can send new state events, so we re-calculate the auth events just in
+        # case.
+        prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
+
         event = await builder.build(
-            prev_event_ids=original_event.prev_event_ids(),
-            auth_event_ids=original_event.auth_event_ids(),
+            prev_event_ids=prev_event_ids,
+            auth_event_ids=None,
         )
 
         # we rebuild the event context, to be on the safe side. If nothing else,
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index ee6e41c0e4..eca8f16040 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -72,26 +72,26 @@ _SESSION_COOKIES = [
     (b"oidc_session_no_samesite", b"HttpOnly"),
 ]
 
+
 #: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
 #: OpenID.Core sec 3.1.3.3.
-Token = TypedDict(
-    "Token",
-    {
-        "access_token": str,
-        "token_type": str,
-        "id_token": Optional[str],
-        "refresh_token": Optional[str],
-        "expires_in": int,
-        "scope": Optional[str],
-    },
-)
+class Token(TypedDict):
+    access_token: str
+    token_type: str
+    id_token: Optional[str]
+    refresh_token: Optional[str]
+    expires_in: int
+    scope: Optional[str]
+
 
 #: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but
 #: there is no real point of doing this in our case.
 JWK = Dict[str, str]
 
+
 #: A JWK Set, as per RFC7517 sec 5.
-JWKS = TypedDict("JWKS", {"keys": List[JWK]})
+class JWKS(TypedDict):
+    keys: List[JWK]
 
 
 class OidcHandler:
@@ -105,9 +105,9 @@ class OidcHandler:
         assert provider_confs
 
         self._token_generator = OidcSessionTokenGenerator(hs)
-        self._providers = {
+        self._providers: Dict[str, "OidcProvider"] = {
             p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
-        }  # type: Dict[str, OidcProvider]
+        }
 
     async def load_metadata(self) -> None:
         """Validate the config and load the metadata from the remote endpoint.
@@ -178,7 +178,7 @@ class OidcHandler:
         # are two.
 
         for cookie_name, _ in _SESSION_COOKIES:
-            session = request.getCookie(cookie_name)  # type: Optional[bytes]
+            session: Optional[bytes] = request.getCookie(cookie_name)
             if session is not None:
                 break
         else:
@@ -255,7 +255,7 @@ class OidcError(Exception):
 
     def __str__(self):
         if self.error_description:
-            return "{}: {}".format(self.error, self.error_description)
+            return f"{self.error}: {self.error_description}"
         return self.error
 
 
@@ -277,7 +277,7 @@ class OidcProvider:
         self._token_generator = token_generator
 
         self._config = provider
-        self._callback_url = hs.config.oidc_callback_url  # type: str
+        self._callback_url: str = hs.config.oidc_callback_url
 
         # Calculate the prefix for OIDC callback paths based on the public_baseurl.
         # We'll insert this into the Path= parameter of any session cookies we set.
@@ -290,7 +290,7 @@ class OidcProvider:
         self._scopes = provider.scopes
         self._user_profile_method = provider.user_profile_method
 
-        client_secret = None  # type: Union[None, str, JwtClientSecret]
+        client_secret: Optional[Union[str, JwtClientSecret]] = None
         if provider.client_secret:
             client_secret = provider.client_secret
         elif provider.client_secret_jwt_key:
@@ -305,7 +305,7 @@ class OidcProvider:
             provider.client_id,
             client_secret,
             provider.client_auth_method,
-        )  # type: ClientAuth
+        )
         self._client_auth_method = provider.client_auth_method
 
         # cache of metadata for the identity provider (endpoint uris, mostly). This is
@@ -324,7 +324,7 @@ class OidcProvider:
         self._allow_existing_users = provider.allow_existing_users
 
         self._http_client = hs.get_proxied_http_client()
-        self._server_name = hs.config.server_name  # type: str
+        self._server_name: str = hs.config.server_name
 
         # identifier for the external_ids table
         self.idp_id = provider.idp_id
@@ -639,7 +639,7 @@ class OidcProvider:
             )
             logger.warning(description)
             # Body was still valid JSON. Might be useful to log it for debugging.
-            logger.warning("Code exchange response: {resp!r}".format(resp=resp))
+            logger.warning("Code exchange response: %r", resp)
             raise OidcError("server_error", description)
 
         return resp
@@ -1217,10 +1217,12 @@ class OidcSessionData:
     ui_auth_session_id = attr.ib(type=str)
 
 
-UserAttributeDict = TypedDict(
-    "UserAttributeDict",
-    {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]},
-)
+class UserAttributeDict(TypedDict):
+    localpart: Optional[str]
+    display_name: Optional[str]
+    emails: List[str]
+
+
 C = TypeVar("C")
 
 
@@ -1381,7 +1383,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
         if display_name == "":
             display_name = None
 
-        emails = []  # type: List[str]
+        emails: List[str] = []
         email = render_template_field(self._config.email_template)
         if email:
             emails.append(email)
@@ -1391,7 +1393,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
         )
 
     async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
-        extras = {}  # type: Dict[str, str]
+        extras: Dict[str, str] = {}
         for key, template in self._config.extra_attributes.items():
             try:
                 extras[key] = template.render(user=userinfo).strip()
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 1e1186c29e..1dbafd253d 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -81,9 +81,9 @@ class PaginationHandler:
         self._server_name = hs.hostname
 
         self.pagination_lock = ReadWriteLock()
-        self._purges_in_progress_by_room = set()  # type: Set[str]
+        self._purges_in_progress_by_room: Set[str] = set()
         # map from purge id to PurgeStatus
-        self._purges_by_id = {}  # type: Dict[str, PurgeStatus]
+        self._purges_by_id: Dict[str, PurgeStatus] = {}
         self._event_serializer = hs.get_event_client_serializer()
 
         self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 44ed7a0712..016c5df2ca 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -378,14 +378,14 @@ class WorkerPresenceHandler(BasePresenceHandler):
 
         # The number of ongoing syncs on this process, by user id.
         # Empty if _presence_enabled is false.
-        self._user_to_num_current_syncs = {}  # type: Dict[str, int]
+        self._user_to_num_current_syncs: Dict[str, int] = {}
 
         self.notifier = hs.get_notifier()
         self.instance_id = hs.get_instance_id()
 
         # user_id -> last_sync_ms. Lists the users that have stopped syncing but
         # we haven't notified the presence writer of that yet
-        self.users_going_offline = {}  # type: Dict[str, int]
+        self.users_going_offline: Dict[str, int] = {}
 
         self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
         self._set_state_client = ReplicationPresenceSetState.make_client(hs)
@@ -650,7 +650,7 @@ class PresenceHandler(BasePresenceHandler):
 
         # Set of users who have presence in the `user_to_current_state` that
         # have not yet been persisted
-        self.unpersisted_users_changes = set()  # type: Set[str]
+        self.unpersisted_users_changes: Set[str] = set()
 
         hs.get_reactor().addSystemEventTrigger(
             "before",
@@ -664,7 +664,7 @@ class PresenceHandler(BasePresenceHandler):
 
         # Keeps track of the number of *ongoing* syncs on this process. While
         # this is non zero a user will never go offline.
-        self.user_to_num_current_syncs = {}  # type: Dict[str, int]
+        self.user_to_num_current_syncs: Dict[str, int] = {}
 
         # Keeps track of the number of *ongoing* syncs on other processes.
         # While any sync is ongoing on another process the user will never
@@ -674,8 +674,8 @@ class PresenceHandler(BasePresenceHandler):
         # we assume that all the sync requests on that process have stopped.
         # Stored as a dict from process_id to set of user_id, and a dict of
         # process_id to millisecond timestamp last updated.
-        self.external_process_to_current_syncs = {}  # type: Dict[str, Set[str]]
-        self.external_process_last_updated_ms = {}  # type: Dict[str, int]
+        self.external_process_to_current_syncs: Dict[str, Set[str]] = {}
+        self.external_process_last_updated_ms: Dict[str, int] = {}
 
         self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
 
@@ -1581,9 +1581,7 @@ class PresenceEventSource:
 
             # The set of users that we're interested in and that have had a presence update.
             # We'll actually pull the presence updates for these users at the end.
-            interested_and_updated_users = (
-                set()
-            )  # type: Union[Set[str], FrozenSet[str]]
+            interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
 
             if from_key:
                 # First get all users that have had a presence update
@@ -1950,8 +1948,8 @@ async def get_interested_parties(
         A 2-tuple of `(room_ids_to_states, users_to_states)`,
         with each item being a dict of `entity_name` -> `[UserPresenceState]`
     """
-    room_ids_to_states = {}  # type: Dict[str, List[UserPresenceState]]
-    users_to_states = {}  # type: Dict[str, List[UserPresenceState]]
+    room_ids_to_states: Dict[str, List[UserPresenceState]] = {}
+    users_to_states: Dict[str, List[UserPresenceState]] = {}
     for state in states:
         room_ids = await store.get_rooms_for_user(state.user_id)
         for room_id in room_ids:
@@ -2063,12 +2061,12 @@ class PresenceFederationQueue:
         # stream_id, destinations, user_ids)`. We don't store the full states
         # for efficiency, and remote workers will already have the full states
         # cached.
-        self._queue = []  # type: List[Tuple[int, int, Collection[str], Set[str]]]
+        self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = []
 
         self._next_id = 1
 
         # Map from instance name to current token
-        self._current_tokens = {}  # type: Dict[str, int]
+        self._current_tokens: Dict[str, int] = {}
 
         if self._queue_presence_updates:
             self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
@@ -2168,7 +2166,7 @@ class PresenceFederationQueue:
         # handle the case where `from_token` stream ID has already been dropped.
         start_idx = max(from_token + 1 - self._next_id, -len(self._queue))
 
-        to_send = []  # type: List[Tuple[int, Tuple[str, str]]]
+        to_send: List[Tuple[int, Tuple[str, str]]] = []
         limited = False
         new_id = upto_token
         for _, stream_id, destinations, user_ids in self._queue[start_idx:]:
@@ -2216,7 +2214,7 @@ class PresenceFederationQueue:
         if not self._federation:
             return
 
-        hosts_to_users = {}  # type: Dict[str, Set[str]]
+        hosts_to_users: Dict[str, Set[str]] = {}
         for row in rows:
             hosts_to_users.setdefault(row.destination, set()).add(row.user_id)
 
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 05b4a97b59..20a033d0ba 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -197,7 +197,7 @@ class ProfileHandler(BaseHandler):
                 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
             )
 
-        displayname_to_set = new_displayname  # type: Optional[str]
+        displayname_to_set: Optional[str] = new_displayname
         if new_displayname == "":
             displayname_to_set = None
 
@@ -286,7 +286,7 @@ class ProfileHandler(BaseHandler):
                 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
             )
 
-        avatar_url_to_set = new_avatar_url  # type: Optional[str]
+        avatar_url_to_set: Optional[str] = new_avatar_url
         if new_avatar_url == "":
             avatar_url_to_set = None
 
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index f782d9db32..283483fc2c 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -30,6 +30,8 @@ class ReceiptsHandler(BaseHandler):
 
         self.server_name = hs.config.server_name
         self.store = hs.get_datastore()
+        self.event_auth_handler = hs.get_event_auth_handler()
+
         self.hs = hs
 
         # We only need to poke the federation sender explicitly if its on the
@@ -59,6 +61,19 @@ class ReceiptsHandler(BaseHandler):
         """Called when we receive an EDU of type m.receipt from a remote HS."""
         receipts = []
         for room_id, room_values in content.items():
+            # If we're not in the room just ditch the event entirely. This is
+            # probably an old server that has come back and thinks we're still in
+            # the room (or we've been rejoined to the room by a state reset).
+            is_in_room = await self.event_auth_handler.check_host_in_room(
+                room_id, self.server_name
+            )
+            if not is_in_room:
+                logger.info(
+                    "Ignoring receipt from %s as we're not in the room",
+                    origin,
+                )
+                continue
+
             for receipt_type, users in room_values.items():
                 for user_id, user_values in users.items():
                     if get_domain_from_id(user_id) != origin:
@@ -83,8 +98,8 @@ class ReceiptsHandler(BaseHandler):
 
     async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
         """Takes a list of receipts, stores them and informs the notifier."""
-        min_batch_id = None  # type: Optional[int]
-        max_batch_id = None  # type: Optional[int]
+        min_batch_id: Optional[int] = None
+        max_batch_id: Optional[int] = None
 
         for receipt in receipts:
             res = await self.store.insert_receipt(
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 26ef016179..8cf614136e 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -55,15 +55,12 @@ login_counter = Counter(
     ["guest", "auth_provider"],
 )
 
-LoginDict = TypedDict(
-    "LoginDict",
-    {
-        "device_id": str,
-        "access_token": str,
-        "valid_until_ms": Optional[int],
-        "refresh_token": Optional[str],
-    },
-)
+
+class LoginDict(TypedDict):
+    device_id: str
+    access_token: str
+    valid_until_ms: Optional[int]
+    refresh_token: Optional[str]
 
 
 class RegistrationHandler(BaseHandler):
@@ -77,6 +74,7 @@ class RegistrationHandler(BaseHandler):
         self.identity_handler = self.hs.get_identity_handler()
         self.ratelimiter = hs.get_registration_ratelimiter()
         self.macaroon_gen = hs.get_macaroon_generator()
+        self._account_validity_handler = hs.get_account_validity_handler()
         self._server_notices_mxid = hs.config.server_notices_mxid
         self._server_name = hs.hostname
 
@@ -700,6 +698,10 @@ class RegistrationHandler(BaseHandler):
                 shadow_banned=shadow_banned,
             )
 
+            # Only call the account validity module(s) on the main process, to avoid
+            # repeating e.g. database writes on all of the workers.
+            await self._account_validity_handler.on_user_registration(user_id)
+
     async def register_device(
         self,
         user_id: str,
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 835d874cee..64656fda22 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -83,10 +83,11 @@ class RoomCreationHandler(BaseHandler):
         self.spam_checker = hs.get_spam_checker()
         self.event_creation_handler = hs.get_event_creation_handler()
         self.room_member_handler = hs.get_room_member_handler()
+        self._event_auth_handler = hs.get_event_auth_handler()
         self.config = hs.config
 
         # Room state based off defined presets
-        self._presets_dict = {
+        self._presets_dict: Dict[str, Dict[str, Any]] = {
             RoomCreationPreset.PRIVATE_CHAT: {
                 "join_rules": JoinRules.INVITE,
                 "history_visibility": HistoryVisibility.SHARED,
@@ -108,7 +109,7 @@ class RoomCreationHandler(BaseHandler):
                 "guest_can_join": False,
                 "power_level_content_override": {},
             },
-        }  # type: Dict[str, Dict[str, Any]]
+        }
 
         # Modify presets to selectively enable encryption by default per homeserver config
         for preset_name, preset_config in self._presets_dict.items():
@@ -126,9 +127,9 @@ class RoomCreationHandler(BaseHandler):
         # If a user tries to update the same room multiple times in quick
         # succession, only process the first attempt and return its result to
         # subsequent requests
-        self._upgrade_response_cache = ResponseCache(
+        self._upgrade_response_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
             hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
-        )  # type: ResponseCache[Tuple[str, str]]
+        )
         self._server_notices_mxid = hs.config.server_notices_mxid
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
@@ -226,7 +227,7 @@ class RoomCreationHandler(BaseHandler):
             },
         )
         old_room_version = await self.store.get_room_version_id(old_room_id)
-        await self.auth.check_from_context(
+        await self._event_auth_handler.check_from_context(
             old_room_version, tombstone_event, tombstone_context
         )
 
@@ -376,10 +377,10 @@ class RoomCreationHandler(BaseHandler):
         if not await self.spam_checker.user_may_create_room(user_id):
             raise SynapseError(403, "You are not permitted to create rooms")
 
-        creation_content = {
+        creation_content: JsonDict = {
             "room_version": new_room_version.identifier,
             "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
-        }  # type: JsonDict
+        }
 
         # Check if old room was non-federatable
 
@@ -935,7 +936,7 @@ class RoomCreationHandler(BaseHandler):
                 etype=EventTypes.PowerLevels, content=pl_content
             )
         else:
-            power_level_content = {
+            power_level_content: JsonDict = {
                 "users": {creator_id: 100},
                 "users_default": 0,
                 "events": {
@@ -954,7 +955,7 @@ class RoomCreationHandler(BaseHandler):
                 "kick": 50,
                 "redact": 50,
                 "invite": 50,
-            }  # type: JsonDict
+            }
 
             if config["original_invitees_have_ops"]:
                 for invitee in invite_list:
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 5e3ef7ce3a..6284bcdfbc 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -20,7 +20,12 @@ import msgpack
 from unpaddedbase64 import decode_base64, encode_base64
 
 from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
-from synapse.api.errors import Codes, HttpResponseException
+from synapse.api.errors import (
+    Codes,
+    HttpResponseException,
+    RequestSendFailed,
+    SynapseError,
+)
 from synapse.types import JsonDict, ThirdPartyInstanceID
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.response_cache import ResponseCache
@@ -42,12 +47,12 @@ class RoomListHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.enable_room_list_search = hs.config.enable_room_list_search
-        self.response_cache = ResponseCache(
-            hs.get_clock(), "room_list"
-        )  # type: ResponseCache[Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]]
-        self.remote_response_cache = ResponseCache(
-            hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
-        )  # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
+        self.response_cache: ResponseCache[
+            Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]
+        ] = ResponseCache(hs.get_clock(), "room_list")
+        self.remote_response_cache: ResponseCache[
+            Tuple[str, Optional[int], Optional[str], bool, Optional[str]]
+        ] = ResponseCache(hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000)
 
     async def get_local_public_room_list(
         self,
@@ -134,10 +139,10 @@ class RoomListHandler(BaseHandler):
         if since_token:
             batch_token = RoomListNextBatch.from_token(since_token)
 
-            bounds = (
+            bounds: Optional[Tuple[int, str]] = (
                 batch_token.last_joined_members,
                 batch_token.last_room_id,
-            )  # type: Optional[Tuple[int, str]]
+            )
             forwards = batch_token.direction_is_forward
             has_batch_token = True
         else:
@@ -177,7 +182,7 @@ class RoomListHandler(BaseHandler):
 
         results = [build_room_entry(r) for r in results]
 
-        response = {}  # type: JsonDict
+        response: JsonDict = {}
         num_results = len(results)
         if limit is not None:
             more_to_come = num_results == probing_limit
@@ -417,14 +422,17 @@ class RoomListHandler(BaseHandler):
         repl_layer = self.hs.get_federation_client()
         if search_filter:
             # We can't cache when asking for search
-            return await repl_layer.get_public_rooms(
-                server_name,
-                limit=limit,
-                since_token=since_token,
-                search_filter=search_filter,
-                include_all_networks=include_all_networks,
-                third_party_instance_id=third_party_instance_id,
-            )
+            try:
+                return await repl_layer.get_public_rooms(
+                    server_name,
+                    limit=limit,
+                    since_token=since_token,
+                    search_filter=search_filter,
+                    include_all_networks=include_all_networks,
+                    third_party_instance_id=third_party_instance_id,
+                )
+            except (RequestSendFailed, HttpResponseException):
+                raise SynapseError(502, "Failed to fetch room list")
 
         key = (
             server_name,
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
index 80ba65b9e0..e6e71e9729 100644
--- a/synapse/handlers/saml.py
+++ b/synapse/handlers/saml.py
@@ -83,7 +83,7 @@ class SamlHandler(BaseHandler):
         self.unstable_idp_brand = None
 
         # a map from saml session id to Saml2SessionData object
-        self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
+        self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {}
 
         self._sso_handler = hs.get_sso_handler()
         self._sso_handler.register_identity_provider(self)
@@ -372,7 +372,7 @@ class SamlHandler(BaseHandler):
 
 
 DOT_REPLACE_PATTERN = re.compile(
-    ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+    "[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)
 )
 
 
@@ -386,10 +386,10 @@ def dot_replace_for_mxid(username: str) -> str:
     return username
 
 
-MXID_MAPPER_MAP = {
+MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = {
     "hexencode": map_username_to_mxid_localpart,
     "dotreplace": dot_replace_for_mxid,
-}  # type: Dict[str, Callable[[str], str]]
+}
 
 
 @attr.s
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 4e718d3f63..8226d6f5a1 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -192,7 +192,7 @@ class SearchHandler(BaseHandler):
         # If doing a subset of all rooms seearch, check if any of the rooms
         # are from an upgraded room, and search their contents as well
         if search_filter.rooms:
-            historical_room_ids = []  # type: List[str]
+            historical_room_ids: List[str] = []
             for room_id in search_filter.rooms:
                 # Add any previous rooms to the search if they exist
                 ids = await self.get_old_rooms_from_upgraded_room(room_id)
@@ -216,9 +216,9 @@ class SearchHandler(BaseHandler):
         rank_map = {}  # event_id -> rank of event
         allowed_events = []
         # Holds result of grouping by room, if applicable
-        room_groups = {}  # type: Dict[str, JsonDict]
+        room_groups: Dict[str, JsonDict] = {}
         # Holds result of grouping by sender, if applicable
-        sender_group = {}  # type: Dict[str, JsonDict]
+        sender_group: Dict[str, JsonDict] = {}
 
         # Holds the next_batch for the entire result set if one of those exists
         global_next_batch = None
@@ -262,7 +262,7 @@ class SearchHandler(BaseHandler):
                 s["results"].append(e.event_id)
 
         elif order_by == "recent":
-            room_events = []  # type: List[EventBase]
+            room_events: List[EventBase] = []
             i = 0
 
             pagination_token = batch_token
diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
index 266f369883..5f7d4602bd 100644
--- a/synapse/handlers/space_summary.py
+++ b/synapse/handlers/space_summary.py
@@ -24,6 +24,7 @@ from synapse.api.constants import (
     EventContentFields,
     EventTypes,
     HistoryVisibility,
+    JoinRules,
     Membership,
     RoomTypes,
 )
@@ -89,14 +90,14 @@ class SpaceSummaryHandler:
         room_queue = deque((_RoomQueueEntry(room_id, ()),))
 
         # rooms we have already processed
-        processed_rooms = set()  # type: Set[str]
+        processed_rooms: Set[str] = set()
 
         # events we have already processed. We don't necessarily have their event ids,
         # so instead we key on (room id, state key)
-        processed_events = set()  # type: Set[Tuple[str, str]]
+        processed_events: Set[Tuple[str, str]] = set()
 
-        rooms_result = []  # type: List[JsonDict]
-        events_result = []  # type: List[JsonDict]
+        rooms_result: List[JsonDict] = []
+        events_result: List[JsonDict] = []
 
         while room_queue and len(rooms_result) < MAX_ROOMS:
             queue_entry = room_queue.popleft()
@@ -150,14 +151,21 @@ class SpaceSummaryHandler:
                     # The room should only be included in the summary if:
                     #     a. the user is in the room;
                     #     b. the room is world readable; or
-                    #     c. the user is in a space that has been granted access to
-                    #        the room.
+                    #     c. the user could join the room, e.g. the join rules
+                    #        are set to public or the user is in a space that
+                    #        has been granted access to the room.
                     #
                     # Note that we know the user is not in the root room (which is
                     # why the remote call was made in the first place), but the user
                     # could be in one of the children rooms and we just didn't know
                     # about the link.
-                    include_room = room.get("world_readable") is True
+
+                    # The API doesn't return the room version so assume that a
+                    # join rule of knock is valid.
+                    include_room = (
+                        room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK)
+                        or room.get("world_readable") is True
+                    )
 
                     # Check if the user is a member of any of the allowed spaces
                     # from the response.
@@ -264,10 +272,10 @@ class SpaceSummaryHandler:
         # the set of rooms that we should not walk further. Initialise it with the
         # excluded-rooms list; we will add other rooms as we process them so that
         # we do not loop.
-        processed_rooms = set(exclude_rooms)  # type: Set[str]
+        processed_rooms: Set[str] = set(exclude_rooms)
 
-        rooms_result = []  # type: List[JsonDict]
-        events_result = []  # type: List[JsonDict]
+        rooms_result: List[JsonDict] = []
+        events_result: List[JsonDict] = []
 
         while room_queue and len(rooms_result) < MAX_ROOMS:
             room_id = room_queue.popleft()
@@ -345,7 +353,7 @@ class SpaceSummaryHandler:
             max_children = MAX_ROOMS_PER_SPACE
 
         now = self._clock.time_msec()
-        events_result = []  # type: List[JsonDict]
+        events_result: List[JsonDict] = []
         for edge_event in itertools.islice(child_events, max_children):
             events_result.append(
                 await self._event_serializer.serialize_event(
@@ -420,9 +428,8 @@ class SpaceSummaryHandler:
 
         It should be included if:
 
-        * The requester is joined or invited to the room.
-        * The requester can join without an invite (per MSC3083).
-        * The origin server has any user that is joined or invited to the room.
+        * The requester is joined or can join the room (per MSC3173).
+        * The origin server has any user that is joined or can join the room.
         * The history visibility is set to world readable.
 
         Args:
@@ -441,13 +448,39 @@ class SpaceSummaryHandler:
 
         # If there's no state for the room, it isn't known.
         if not state_ids:
+            # The user might have a pending invite for the room.
+            if requester and await self._store.get_invite_for_local_user_in_room(
+                requester, room_id
+            ):
+                return True
+
             logger.info("room %s is unknown, omitting from summary", room_id)
             return False
 
         room_version = await self._store.get_room_version(room_id)
 
-        # if we have an authenticated requesting user, first check if they are able to view
-        # stripped state in the room.
+        # Include the room if it has join rules of public or knock.
+        join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""))
+        if join_rules_event_id:
+            join_rules_event = await self._store.get_event(join_rules_event_id)
+            join_rule = join_rules_event.content.get("join_rule")
+            if join_rule == JoinRules.PUBLIC or (
+                room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
+            ):
+                return True
+
+        # Include the room if it is peekable.
+        hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, ""))
+        if hist_vis_event_id:
+            hist_vis_ev = await self._store.get_event(hist_vis_event_id)
+            hist_vis = hist_vis_ev.content.get("history_visibility")
+            if hist_vis == HistoryVisibility.WORLD_READABLE:
+                return True
+
+        # Otherwise we need to check information specific to the user or server.
+
+        # If we have an authenticated requesting user, check if they are a member
+        # of the room (or can join the room).
         if requester:
             member_event_id = state_ids.get((EventTypes.Member, requester), None)
 
@@ -470,9 +503,11 @@ class SpaceSummaryHandler:
                     return True
 
         # If this is a request over federation, check if the host is in the room or
-        # is in one of the spaces specified via the join rules.
+        # has a user who could join the room.
         elif origin:
-            if await self._auth.check_host_in_room(room_id, origin):
+            if await self._event_auth_handler.check_host_in_room(
+                room_id, origin
+            ) or await self._store.is_host_invited(room_id, origin):
                 return True
 
             # Alternately, if the host has a user in any of the spaces specified
@@ -485,21 +520,15 @@ class SpaceSummaryHandler:
                     await self._event_auth_handler.get_rooms_that_allow_join(state_ids)
                 )
                 for space_id in allowed_rooms:
-                    if await self._auth.check_host_in_room(space_id, origin):
+                    if await self._event_auth_handler.check_host_in_room(
+                        space_id, origin
+                    ):
                         return True
 
-        # otherwise, check if the room is peekable
-        hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, ""), None)
-        if hist_vis_event_id:
-            hist_vis_ev = await self._store.get_event(hist_vis_event_id)
-            hist_vis = hist_vis_ev.content.get("history_visibility")
-            if hist_vis == HistoryVisibility.WORLD_READABLE:
-                return True
-
         logger.info(
-            "room %s is unpeekable and user %s is not a member / not allowed to join, omitting from summary",
+            "room %s is unpeekable and requester %s is not a member / not allowed to join, omitting from summary",
             room_id,
-            requester,
+            requester or origin,
         )
         return False
 
@@ -533,6 +562,7 @@ class SpaceSummaryHandler:
             "canonical_alias": stats["canonical_alias"],
             "num_joined_members": stats["joined_members"],
             "avatar_url": stats["avatar"],
+            "join_rules": stats["join_rules"],
             "world_readable": (
                 stats["history_visibility"] == HistoryVisibility.WORLD_READABLE
             ),
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 0b297e54c4..1b855a685c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -202,10 +202,10 @@ class SsoHandler:
         self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
 
         # a map from session id to session data
-        self._username_mapping_sessions = {}  # type: Dict[str, UsernameMappingSession]
+        self._username_mapping_sessions: Dict[str, UsernameMappingSession] = {}
 
         # map from idp_id to SsoIdentityProvider
-        self._identity_providers = {}  # type: Dict[str, SsoIdentityProvider]
+        self._identity_providers: Dict[str, SsoIdentityProvider] = {}
 
         self._consent_at_registration = hs.config.consent.user_consent_at_registration
 
@@ -296,7 +296,7 @@ class SsoHandler:
             )
 
         # if the client chose an IdP, use that
-        idp = None  # type: Optional[SsoIdentityProvider]
+        idp: Optional[SsoIdentityProvider] = None
         if idp_id:
             idp = self._identity_providers.get(idp_id)
             if not idp:
@@ -669,9 +669,9 @@ class SsoHandler:
             remote_user_id,
         )
 
-        user_id_to_verify = await self._auth_handler.get_session_data(
+        user_id_to_verify: str = await self._auth_handler.get_session_data(
             ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
-        )  # type: str
+        )
 
         if not user_id:
             logger.warning(
@@ -793,7 +793,7 @@ class SsoHandler:
         session.use_display_name = use_display_name
 
         emails_from_idp = set(session.emails)
-        filtered_emails = set()  # type: Set[str]
+        filtered_emails: Set[str] = set()
 
         # we iterate through the list rather than just building a set conjunction, so
         # that we can log attempts to use unknown addresses
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 4e45d1da57..3fd89af2a4 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -45,12 +45,11 @@ class StatsHandler:
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
         self.is_mine_id = hs.is_mine_id
-        self.stats_bucket_size = hs.config.stats_bucket_size
 
         self.stats_enabled = hs.config.stats_enabled
 
         # The current position in the current_state_delta stream
-        self.pos = None  # type: Optional[int]
+        self.pos: Optional[int] = None
 
         # Guard to ensure we only process deltas one at a time
         self._is_processing = False
@@ -106,20 +105,6 @@ class StatsHandler:
                 room_deltas = {}
                 user_deltas = {}
 
-            # Then count deltas for total_events and total_event_bytes.
-            (
-                room_count,
-                user_count,
-            ) = await self.store.get_changes_room_total_events_and_bytes(
-                self.pos, max_pos
-            )
-
-            for room_id, fields in room_count.items():
-                room_deltas.setdefault(room_id, Counter()).update(fields)
-
-            for user_id, fields in user_count.items():
-                user_deltas.setdefault(user_id, Counter()).update(fields)
-
             logger.debug("room_deltas: %s", room_deltas)
             logger.debug("user_deltas: %s", user_deltas)
 
@@ -146,10 +131,10 @@ class StatsHandler:
             mapping from room/user ID to changes in the various fields.
         """
 
-        room_to_stats_deltas = {}  # type: Dict[str, CounterType[str]]
-        user_to_stats_deltas = {}  # type: Dict[str, CounterType[str]]
+        room_to_stats_deltas: Dict[str, CounterType[str]] = {}
+        user_to_stats_deltas: Dict[str, CounterType[str]] = {}
 
-        room_to_state_updates = {}  # type: Dict[str, Dict[str, Any]]
+        room_to_state_updates: Dict[str, Dict[str, Any]] = {}
 
         for delta in deltas:
             typ = delta["type"]
@@ -179,14 +164,12 @@ class StatsHandler:
                 )
                 continue
 
-            event_content = {}  # type: JsonDict
+            event_content: JsonDict = {}
 
-            sender = None
             if event_id is not None:
                 event = await self.store.get_event(event_id, allow_none=True)
                 if event:
                     event_content = event.content or {}
-                    sender = event.sender
 
             # All the values in this dict are deltas (RELATIVE changes)
             room_stats_delta = room_to_stats_deltas.setdefault(room_id, Counter())
@@ -244,12 +227,6 @@ class StatsHandler:
                     room_stats_delta["joined_members"] += 1
                 elif membership == Membership.INVITE:
                     room_stats_delta["invited_members"] += 1
-
-                    if sender and self.is_mine_id(sender):
-                        user_to_stats_deltas.setdefault(sender, Counter())[
-                            "invites_sent"
-                        ] += 1
-
                 elif membership == Membership.LEAVE:
                     room_stats_delta["left_members"] += 1
                 elif membership == Membership.BAN:
@@ -279,10 +256,6 @@ class StatsHandler:
                 room_state["is_federatable"] = (
                     event_content.get("m.federate", True) is True
                 )
-                if sender and self.is_mine_id(sender):
-                    user_to_stats_deltas.setdefault(sender, Counter())[
-                        "rooms_created"
-                    ] += 1
             elif typ == EventTypes.JoinRules:
                 room_state["join_rules"] = event_content.get("join_rule")
             elif typ == EventTypes.RoomHistoryVisibility:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b9a0361059..150a4f291e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -278,12 +278,14 @@ class SyncHandler:
         self.state_store = self.storage.state
 
         # ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
-        self.lazy_loaded_members_cache = ExpiringCache(
+        self.lazy_loaded_members_cache: ExpiringCache[
+            Tuple[str, Optional[str]], LruCache[str, str]
+        ] = ExpiringCache(
             "lazy_loaded_members_cache",
             self.clock,
             max_len=0,
             expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
-        )  # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]
+        )
 
     async def wait_for_sync_for_user(
         self,
@@ -440,7 +442,7 @@ class SyncHandler:
             )
             now_token = now_token.copy_and_replace("typing_key", typing_key)
 
-            ephemeral_by_room = {}  # type: JsonDict
+            ephemeral_by_room: JsonDict = {}
 
             for event in typing:
                 # we want to exclude the room_id from the event, but modifying the
@@ -502,7 +504,7 @@ class SyncHandler:
                 # We check if there are any state events, if there are then we pass
                 # all current state events to the filter_events function. This is to
                 # ensure that we always include current state in the timeline
-                current_state_ids = frozenset()  # type: FrozenSet[str]
+                current_state_ids: FrozenSet[str] = frozenset()
                 if any(e.is_state() for e in recents):
                     current_state_ids_map = await self.store.get_current_state_ids(
                         room_id
@@ -783,9 +785,9 @@ class SyncHandler:
     def get_lazy_loaded_members_cache(
         self, cache_key: Tuple[str, Optional[str]]
     ) -> LruCache[str, str]:
-        cache = self.lazy_loaded_members_cache.get(
+        cache: Optional[LruCache[str, str]] = self.lazy_loaded_members_cache.get(
             cache_key
-        )  # type: Optional[LruCache[str, str]]
+        )
         if cache is None:
             logger.debug("creating LruCache for %r", cache_key)
             cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
@@ -984,7 +986,7 @@ class SyncHandler:
                     if t[0] == EventTypes.Member:
                         cache.set(t[1], event_id)
 
-        state = {}  # type: Dict[str, EventBase]
+        state: Dict[str, EventBase] = {}
         if state_ids:
             state = await self.store.get_events(list(state_ids.values()))
 
@@ -1088,8 +1090,8 @@ class SyncHandler:
 
         logger.debug("Fetching OTK data")
         device_id = sync_config.device_id
-        one_time_key_counts = {}  # type: JsonDict
-        unused_fallback_key_types = []  # type: List[str]
+        one_time_key_counts: JsonDict = {}
+        unused_fallback_key_types: List[str] = []
         if device_id:
             one_time_key_counts = await self.store.count_e2e_one_time_keys(
                 user_id, device_id
@@ -1437,7 +1439,7 @@ class SyncHandler:
         )
 
         if block_all_room_ephemeral:
-            ephemeral_by_room = {}  # type: Dict[str, List[JsonDict]]
+            ephemeral_by_room: Dict[str, List[JsonDict]] = {}
         else:
             now_token, ephemeral_by_room = await self.ephemeral_by_room(
                 sync_result_builder,
@@ -1468,7 +1470,7 @@ class SyncHandler:
 
         # If there is ignored users account data and it matches the proper type,
         # then use it.
-        ignored_users = frozenset()  # type: FrozenSet[str]
+        ignored_users: FrozenSet[str] = frozenset()
         if ignored_account_data:
             ignored_users_data = ignored_account_data.get("ignored_users", {})
             if isinstance(ignored_users_data, dict):
@@ -1586,7 +1588,7 @@ class SyncHandler:
             user_id, since_token.room_key, now_token.room_key
         )
 
-        mem_change_events_by_room_id = {}  # type: Dict[str, List[EventBase]]
+        mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
         for event in rooms_changed:
             mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
 
@@ -1599,7 +1601,7 @@ class SyncHandler:
             logger.debug(
                 "Membership changes in %s: [%s]",
                 room_id,
-                ", ".join(("%s (%s)" % (e.event_id, e.membership) for e in events)),
+                ", ".join("%s (%s)" % (e.event_id, e.membership) for e in events),
             )
 
             non_joins = [e for e in events if e.membership != Membership.JOIN]
@@ -1722,7 +1724,7 @@ class SyncHandler:
                 # This is all screaming out for a refactor, as the logic here is
                 # subtle and the moving parts numerous.
                 if leave_event.internal_metadata.is_out_of_band_membership():
-                    batch_events = [leave_event]  # type: Optional[List[EventBase]]
+                    batch_events: Optional[List[EventBase]] = [leave_event]
                 else:
                     batch_events = None
 
@@ -1971,7 +1973,7 @@ class SyncHandler:
             room_id, batch, sync_config, since_token, now_token, full_state=full_state
         )
 
-        summary = {}  # type: Optional[JsonDict]
+        summary: Optional[JsonDict] = {}
 
         # we include a summary in room responses when we're lazy loading
         # members (as the client otherwise doesn't have enough info to form
@@ -1995,7 +1997,7 @@ class SyncHandler:
             )
 
         if room_builder.rtype == "joined":
-            unread_notifications = {}  # type: Dict[str, int]
+            unread_notifications: Dict[str, int] = {}
             room_sync = JoinedSyncResult(
                 room_id=room_id,
                 timeline=batch,
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index e22393adc4..0cb651a400 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -68,11 +68,11 @@ class FollowerTypingHandler:
             )
 
         # map room IDs to serial numbers
-        self._room_serials = {}  # type: Dict[str, int]
+        self._room_serials: Dict[str, int] = {}
         # map room IDs to sets of users currently typing
-        self._room_typing = {}  # type: Dict[str, Set[str]]
+        self._room_typing: Dict[str, Set[str]] = {}
 
-        self._member_last_federation_poke = {}  # type: Dict[RoomMember, int]
+        self._member_last_federation_poke: Dict[RoomMember, int] = {}
         self.wheel_timer = WheelTimer(bucket_size=5000)
         self._latest_room_serial = 0
 
@@ -208,6 +208,7 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         self.auth = hs.get_auth()
         self.notifier = hs.get_notifier()
+        self.event_auth_handler = hs.get_event_auth_handler()
 
         self.hs = hs
 
@@ -216,7 +217,7 @@ class TypingWriterHandler(FollowerTypingHandler):
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
         # clock time we expect to stop
-        self._member_typing_until = {}  # type: Dict[RoomMember, int]
+        self._member_typing_until: Dict[RoomMember, int] = {}
 
         # caches which room_ids changed at which serials
         self._typing_stream_change_cache = StreamChangeCache(
@@ -326,6 +327,19 @@ class TypingWriterHandler(FollowerTypingHandler):
         room_id = content["room_id"]
         user_id = content["user_id"]
 
+        # If we're not in the room just ditch the event entirely. This is
+        # probably an old server that has come back and thinks we're still in
+        # the room (or we've been rejoined to the room by a state reset).
+        is_in_room = await self.event_auth_handler.check_host_in_room(
+            room_id, self.server_name
+        )
+        if not is_in_room:
+            logger.info(
+                "Ignoring typing update from %s as we're not in the room",
+                origin,
+            )
+            return
+
         member = RoomMember(user_id=user_id, room_id=room_id)
 
         # Check that the string is a valid user id
@@ -391,9 +405,9 @@ class TypingWriterHandler(FollowerTypingHandler):
         if last_id == current_id:
             return [], current_id, False
 
-        changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
-            last_id
-        )  # type: Optional[Iterable[str]]
+        changed_rooms: Optional[
+            Iterable[str]
+        ] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
 
         if changed_rooms is None:
             changed_rooms = self._room_serials
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index dacc4f3076..6edb1da50a 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -52,7 +52,7 @@ class UserDirectoryHandler(StateDeltasHandler):
         self.search_all_users = hs.config.user_directory_search_all_users
         self.spam_checker = hs.get_spam_checker()
         # The current position in the current_state_delta stream
-        self.pos = None  # type: Optional[int]
+        self.pos: Optional[int] = None
 
         # Guard to ensure we only process deltas one at a time
         self._is_processing = False