summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py14
-rw-r--r--synapse/api/errors.py69
-rw-r--r--synapse/api/filtering.py18
-rw-r--r--synapse/api/presence.py51
-rw-r--r--synapse/api/ratelimiting.py4
-rw-r--r--synapse/api/urls.py13
-rw-r--r--synapse/app/_base.py2
-rw-r--r--synapse/app/admin_cmd.py14
-rw-r--r--synapse/config/_base.pyi2
-rw-r--r--synapse/config/experimental.py3
-rw-r--r--synapse/config/homeserver.py2
-rw-r--r--synapse/config/password_auth_providers.py53
-rw-r--r--synapse/config/retention.py226
-rw-r--r--synapse/config/server.py201
-rw-r--r--synapse/event_auth.py33
-rw-r--r--synapse/events/__init__.py16
-rw-r--r--synapse/events/builder.py4
-rw-r--r--synapse/events/presence_router.py21
-rw-r--r--synapse/events/snapshot.py110
-rw-r--r--synapse/events/spamcheck.py25
-rw-r--r--synapse/events/third_party_rules.py25
-rw-r--r--synapse/events/utils.py119
-rw-r--r--synapse/events/validator.py18
-rw-r--r--synapse/handlers/auth.py528
-rw-r--r--synapse/handlers/device.py15
-rw-r--r--synapse/handlers/event_auth.py3
-rw-r--r--synapse/handlers/federation.py154
-rw-r--r--synapse/handlers/federation_event.py284
-rw-r--r--synapse/handlers/message.py57
-rw-r--r--synapse/handlers/pagination.py13
-rw-r--r--synapse/handlers/presence.py2
-rw-r--r--synapse/handlers/room.py22
-rw-r--r--synapse/handlers/room_batch.py40
-rw-r--r--synapse/handlers/user_directory.py159
-rw-r--r--synapse/module_api/__init__.py15
-rw-r--r--synapse/module_api/errors.py11
-rw-r--r--synapse/py.typed0
-rw-r--r--synapse/replication/tcp/protocol.py18
-rw-r--r--synapse/replication/tcp/redis.py18
-rw-r--r--synapse/rest/client/relations.py8
-rw-r--r--synapse/rest/client/room_batch.py15
-rw-r--r--synapse/rest/media/v1/filepath.py26
-rw-r--r--synapse/rest/media/v1/oembed.py13
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py110
-rw-r--r--synapse/server.py6
-rw-r--r--synapse/state/v1.py4
-rw-r--r--synapse/state/v2.py2
-rw-r--r--synapse/storage/databases/main/client_ips.py140
-rw-r--r--synapse/storage/databases/main/event_federation.py2
-rw-r--r--synapse/storage/databases/main/events.py10
-rw-r--r--synapse/storage/databases/main/events_worker.py142
-rw-r--r--synapse/storage/databases/main/registration.py8
-rw-r--r--synapse/storage/databases/main/room.py8
-rw-r--r--synapse/storage/databases/main/room_batch.py13
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/schema/__init__.py6
-rw-r--r--synapse/storage/schema/main/delta/65/01msc2716_insertion_event_edges.sql19
57 files changed, 1773 insertions, 1143 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index e6ca9232ee..44883c6663 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -245,7 +245,7 @@ class Auth:
 
     async def validate_appservice_can_control_user_id(
         self, app_service: ApplicationService, user_id: str
-    ):
+    ) -> None:
         """Validates that the app service is allowed to control
         the given user.
 
@@ -618,5 +618,13 @@ class Auth:
                 % (user_id, room_id),
             )
 
-    async def check_auth_blocking(self, *args, **kwargs) -> None:
-        await self._auth_blocking.check_auth_blocking(*args, **kwargs)
+    async def check_auth_blocking(
+        self,
+        user_id: Optional[str] = None,
+        threepid: Optional[dict] = None,
+        user_type: Optional[str] = None,
+        requester: Optional[Requester] = None,
+    ) -> None:
+        await self._auth_blocking.check_auth_blocking(
+            user_id=user_id, threepid=threepid, user_type=user_type, requester=requester
+        )
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 9480f448d7..685d1c25cf 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -18,7 +18,7 @@
 import logging
 import typing
 from http import HTTPStatus
-from typing import Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Union
 
 from twisted.web import http
 
@@ -143,7 +143,7 @@ class SynapseError(CodeMessageException):
         super().__init__(code, msg)
         self.errcode = errcode
 
-    def error_dict(self):
+    def error_dict(self) -> "JsonDict":
         return cs_error(self.msg, self.errcode)
 
 
@@ -175,7 +175,7 @@ class ProxiedRequestError(SynapseError):
         else:
             self._additional_fields = dict(additional_fields)
 
-    def error_dict(self):
+    def error_dict(self) -> "JsonDict":
         return cs_error(self.msg, self.errcode, **self._additional_fields)
 
 
@@ -196,7 +196,7 @@ class ConsentNotGivenError(SynapseError):
         )
         self._consent_uri = consent_uri
 
-    def error_dict(self):
+    def error_dict(self) -> "JsonDict":
         return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri)
 
 
@@ -262,14 +262,10 @@ class InteractiveAuthIncompleteError(Exception):
 class UnrecognizedRequestError(SynapseError):
     """An error indicating we don't understand the request you're trying to make"""
 
-    def __init__(self, *args, **kwargs):
-        if "errcode" not in kwargs:
-            kwargs["errcode"] = Codes.UNRECOGNIZED
-        if len(args) == 0:
-            message = "Unrecognized request"
-        else:
-            message = args[0]
-        super().__init__(400, message, **kwargs)
+    def __init__(
+        self, msg: str = "Unrecognized request", errcode: str = Codes.UNRECOGNIZED
+    ):
+        super().__init__(400, msg, errcode)
 
 
 class NotFoundError(SynapseError):
@@ -284,10 +280,8 @@ class AuthError(SynapseError):
     other poorly-defined times.
     """
 
-    def __init__(self, *args, **kwargs):
-        if "errcode" not in kwargs:
-            kwargs["errcode"] = Codes.FORBIDDEN
-        super().__init__(*args, **kwargs)
+    def __init__(self, code: int, msg: str, errcode: str = Codes.FORBIDDEN):
+        super().__init__(code, msg, errcode)
 
 
 class InvalidClientCredentialsError(SynapseError):
@@ -321,7 +315,7 @@ class InvalidClientTokenError(InvalidClientCredentialsError):
         super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN")
         self._soft_logout = soft_logout
 
-    def error_dict(self):
+    def error_dict(self) -> "JsonDict":
         d = super().error_dict()
         d["soft_logout"] = self._soft_logout
         return d
@@ -345,7 +339,7 @@ class ResourceLimitError(SynapseError):
         self.limit_type = limit_type
         super().__init__(code, msg, errcode=errcode)
 
-    def error_dict(self):
+    def error_dict(self) -> "JsonDict":
         return cs_error(
             self.msg,
             self.errcode,
@@ -357,32 +351,17 @@ class ResourceLimitError(SynapseError):
 class EventSizeError(SynapseError):
     """An error raised when an event is too big."""
 
-    def __init__(self, *args, **kwargs):
-        if "errcode" not in kwargs:
-            kwargs["errcode"] = Codes.TOO_LARGE
-        super().__init__(413, *args, **kwargs)
-
-
-class EventStreamError(SynapseError):
-    """An error raised when there a problem with the event stream."""
-
-    def __init__(self, *args, **kwargs):
-        if "errcode" not in kwargs:
-            kwargs["errcode"] = Codes.BAD_PAGINATION
-        super().__init__(*args, **kwargs)
+    def __init__(self, msg: str):
+        super().__init__(413, msg, Codes.TOO_LARGE)
 
 
 class LoginError(SynapseError):
     """An error raised when there was a problem logging in."""
 
-    pass
-
 
 class StoreError(SynapseError):
     """An error raised when there was a problem storing some data."""
 
-    pass
-
 
 class InvalidCaptchaError(SynapseError):
     def __init__(
@@ -395,7 +374,7 @@ class InvalidCaptchaError(SynapseError):
         super().__init__(code, msg, errcode)
         self.error_url = error_url
 
-    def error_dict(self):
+    def error_dict(self) -> "JsonDict":
         return cs_error(self.msg, self.errcode, error_url=self.error_url)
 
 
@@ -412,7 +391,7 @@ class LimitExceededError(SynapseError):
         super().__init__(code, msg, errcode)
         self.retry_after_ms = retry_after_ms
 
-    def error_dict(self):
+    def error_dict(self) -> "JsonDict":
         return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)
 
 
@@ -443,10 +422,8 @@ class UnsupportedRoomVersionError(SynapseError):
 class ThreepidValidationError(SynapseError):
     """An error raised when there was a problem authorising an event."""
 
-    def __init__(self, *args, **kwargs):
-        if "errcode" not in kwargs:
-            kwargs["errcode"] = Codes.FORBIDDEN
-        super().__init__(*args, **kwargs)
+    def __init__(self, msg: str, errcode: str = Codes.FORBIDDEN):
+        super().__init__(400, msg, errcode)
 
 
 class IncompatibleRoomVersionError(SynapseError):
@@ -466,7 +443,7 @@ class IncompatibleRoomVersionError(SynapseError):
 
         self._room_version = room_version
 
-    def error_dict(self):
+    def error_dict(self) -> "JsonDict":
         return cs_error(self.msg, self.errcode, room_version=self._room_version)
 
 
@@ -494,7 +471,7 @@ class RequestSendFailed(RuntimeError):
     errors (like programming errors).
     """
 
-    def __init__(self, inner_exception, can_retry):
+    def __init__(self, inner_exception: BaseException, can_retry: bool):
         super().__init__(
             "Failed to send request: %s: %s"
             % (type(inner_exception).__name__, inner_exception)
@@ -503,7 +480,7 @@ class RequestSendFailed(RuntimeError):
         self.can_retry = can_retry
 
 
-def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs):
+def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict":
     """Utility method for constructing an error response for client-server
     interactions.
 
@@ -551,7 +528,7 @@ class FederationError(RuntimeError):
         msg = "%s %s: %s" % (level, code, reason)
         super().__init__(msg)
 
-    def get_dict(self):
+    def get_dict(self) -> "JsonDict":
         return {
             "level": self.level,
             "code": self.code,
@@ -580,7 +557,7 @@ class HttpResponseException(CodeMessageException):
         super().__init__(code, msg)
         self.response = response
 
-    def to_synapse_error(self):
+    def to_synapse_error(self) -> SynapseError:
         """Make a SynapseError based on an HTTPResponseException
 
         This is useful when a proxied request has failed, and we need to
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 20e91a115d..bc550ae646 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -231,24 +231,24 @@ class FilterCollection:
     def include_redundant_members(self) -> bool:
         return self._room_state_filter.include_redundant_members()
 
-    def filter_presence(self, events):
+    def filter_presence(
+        self, events: Iterable[UserPresenceState]
+    ) -> List[UserPresenceState]:
         return self._presence_filter.filter(events)
 
-    def filter_account_data(self, events):
+    def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
         return self._account_data.filter(events)
 
-    def filter_room_state(self, events):
+    def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
         return self._room_state_filter.filter(self._room_filter.filter(events))
 
-    def filter_room_timeline(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
+    def filter_room_timeline(self, events: Iterable[EventBase]) -> List[EventBase]:
         return self._room_timeline_filter.filter(self._room_filter.filter(events))
 
-    def filter_room_ephemeral(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
+    def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
         return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
 
-    def filter_room_account_data(
-        self, events: Iterable[FilterEvent]
-    ) -> List[FilterEvent]:
+    def filter_room_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
         return self._room_account_data.filter(self._room_filter.filter(events))
 
     def blocks_all_presence(self) -> bool:
@@ -309,7 +309,7 @@ class Filter:
         # except for presence which actually gets passed around as its own
         # namedtuple type.
         if isinstance(event, UserPresenceState):
-            sender = event.user_id
+            sender: Optional[str] = event.user_id
             room_id = None
             ev_type = "m.presence"
             contains_url = False
diff --git a/synapse/api/presence.py b/synapse/api/presence.py
index a3bf0348d1..b80aa83cb3 100644
--- a/synapse/api/presence.py
+++ b/synapse/api/presence.py
@@ -12,49 +12,48 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from collections import namedtuple
+from typing import Any, Optional
+
+import attr
 
 from synapse.api.constants import PresenceState
+from synapse.types import JsonDict
 
 
-class UserPresenceState(
-    namedtuple(
-        "UserPresenceState",
-        (
-            "user_id",
-            "state",
-            "last_active_ts",
-            "last_federation_update_ts",
-            "last_user_sync_ts",
-            "status_msg",
-            "currently_active",
-        ),
-    )
-):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UserPresenceState:
     """Represents the current presence state of the user.
 
-    user_id (str)
-    last_active (int): Time in msec that the user last interacted with server.
-    last_federation_update (int): Time in msec since either a) we sent a presence
+    user_id
+    last_active: Time in msec that the user last interacted with server.
+    last_federation_update: Time in msec since either a) we sent a presence
         update to other servers or b) we received a presence update, depending
         on if is a local user or not.
-    last_user_sync (int): Time in msec that the user last *completed* a sync
+    last_user_sync: Time in msec that the user last *completed* a sync
         (or event stream).
-    status_msg (str): User set status message.
+    status_msg: User set status message.
     """
 
-    def as_dict(self):
-        return dict(self._asdict())
+    user_id: str
+    state: str
+    last_active_ts: int
+    last_federation_update_ts: int
+    last_user_sync_ts: int
+    status_msg: Optional[str]
+    currently_active: bool
+
+    def as_dict(self) -> JsonDict:
+        return attr.asdict(self)
 
     @staticmethod
-    def from_dict(d):
+    def from_dict(d: JsonDict) -> "UserPresenceState":
         return UserPresenceState(**d)
 
-    def copy_and_replace(self, **kwargs):
-        return self._replace(**kwargs)
+    def copy_and_replace(self, **kwargs: Any) -> "UserPresenceState":
+        return attr.evolve(self, **kwargs)
 
     @classmethod
-    def default(cls, user_id):
+    def default(cls, user_id: str) -> "UserPresenceState":
         """Returns a default presence state."""
         return cls(
             user_id=user_id,
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index e8964097d3..849c18ceda 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -161,7 +161,7 @@ class Ratelimiter:
 
         return allowed, time_allowed
 
-    def _prune_message_counts(self, time_now_s: float):
+    def _prune_message_counts(self, time_now_s: float) -> None:
         """Remove message count entries that have not exceeded their defined
         rate_hz limit
 
@@ -190,7 +190,7 @@ class Ratelimiter:
         update: bool = True,
         n_actions: int = 1,
         _time_now_s: Optional[float] = None,
-    ):
+    ) -> None:
         """Checks if an action can be performed. If not, raises a LimitExceededError
 
         Checks if the user has ratelimiting disabled in the database by looking
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index 032c69b210..6e84b1524f 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -19,6 +19,7 @@ from hashlib import sha256
 from urllib.parse import urlencode
 
 from synapse.config import ConfigError
+from synapse.config.homeserver import HomeServerConfig
 
 SYNAPSE_CLIENT_API_PREFIX = "/_synapse/client"
 CLIENT_API_PREFIX = "/_matrix/client"
@@ -34,11 +35,7 @@ LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
 
 
 class ConsentURIBuilder:
-    def __init__(self, hs_config):
-        """
-        Args:
-            hs_config (synapse.config.homeserver.HomeServerConfig):
-        """
+    def __init__(self, hs_config: HomeServerConfig):
         if hs_config.key.form_secret is None:
             raise ConfigError("form_secret not set in config")
         if hs_config.server.public_baseurl is None:
@@ -47,15 +44,15 @@ class ConsentURIBuilder:
         self._hmac_secret = hs_config.key.form_secret.encode("utf-8")
         self._public_baseurl = hs_config.server.public_baseurl
 
-    def build_user_consent_uri(self, user_id):
+    def build_user_consent_uri(self, user_id: str) -> str:
         """Build a URI which we can give to the user to do their privacy
         policy consent
 
         Args:
-            user_id (str): mxid or username of user
+            user_id: mxid or username of user
 
         Returns
-            (str) the URI where the user can do consent
+            The URI where the user can do consent
         """
         mac = hmac.new(
             key=self._hmac_secret, msg=user_id.encode("ascii"), digestmod=sha256
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 4a204a5823..bb4d53d778 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -42,6 +42,7 @@ from synapse.crypto import context_factory
 from synapse.events.presence_router import load_legacy_presence_router
 from synapse.events.spamcheck import load_legacy_spam_checkers
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
+from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.logging.context import PreserveLoggingContext
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.metrics.jemalloc import setup_jemalloc_stats
@@ -379,6 +380,7 @@ async def start(hs: "HomeServer"):
     load_legacy_spam_checkers(hs)
     load_legacy_third_party_event_rules(hs)
     load_legacy_presence_router(hs)
+    load_legacy_password_auth_providers(hs)
 
     # If we've configured an expiry time for caches, start the background job now.
     setup_expire_lru_cache_entries(hs)
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 13d20af457..b156b93bf3 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -39,6 +39,7 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
 from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
 from synapse.server import HomeServer
+from synapse.storage.databases.main.room import RoomWorkerStore
 from synapse.util.logcontext import LoggingContext
 from synapse.util.versionstring import get_version_string
 
@@ -58,6 +59,7 @@ class AdminCmdSlavedStore(
     SlavedEventStore,
     SlavedClientIpStore,
     BaseSlavedStore,
+    RoomWorkerStore,
 ):
     pass
 
@@ -185,11 +187,7 @@ def start(config_options):
     # a full worker config.
     config.worker.worker_app = "synapse.app.admin_cmd"
 
-    if (
-        not config.worker.worker_daemonize
-        and not config.worker.worker_log_file
-        and not config.worker.worker_log_config
-    ):
+    if not config.worker.worker_daemonize and not config.worker.worker_log_config:
         # Since we're meant to be run as a "command" let's not redirect stdio
         # unless we've actually set log config.
         config.logging.no_redirect_stdio = True
@@ -198,9 +196,9 @@ def start(config_options):
     config.server.update_user_directory = False
     config.worker.run_background_tasks = False
     config.worker.start_pushers = False
-    config.pusher_shard_config.instances = []
+    config.worker.pusher_shard_config.instances = []
     config.worker.send_federation = False
-    config.federation_shard_config.instances = []
+    config.worker.federation_shard_config.instances = []
 
     synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts
 
@@ -221,7 +219,7 @@ def start(config_options):
 
     async def run():
         with LoggingContext("command"):
-            _base.start(ss)
+            await _base.start(ss)
             await args.func(ss, args)
 
     _base.start_worker_reactor(
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 06fbd1166b..c1d9069798 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -26,6 +26,7 @@ from synapse.config import (
     redis,
     registration,
     repository,
+    retention,
     room_directory,
     saml2,
     server,
@@ -91,6 +92,7 @@ class RootConfig:
     modules: modules.ModulesConfig
     caches: cache.CacheConfig
     federation: federation.FederationConfig
+    retention: retention.RetentionConfig
 
     config_classes: List = ...
     def __init__(self) -> None: ...
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 7b0381c06a..b013a3918c 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -24,6 +24,9 @@ class ExperimentalConfig(Config):
     def read_config(self, config: JsonDict, **kwargs):
         experimental = config.get("experimental_features") or {}
 
+        # Whether to enable experimental MSC1849 (aka relations) support
+        self.msc1849_enabled = config.get("experimental_msc1849_support_enabled", True)
+
         # MSC3026 (busy presence state)
         self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)
 
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 442f1b9ac0..001605c265 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -38,6 +38,7 @@ from .ratelimiting import RatelimitConfig
 from .redis import RedisConfig
 from .registration import RegistrationConfig
 from .repository import ContentRepositoryConfig
+from .retention import RetentionConfig
 from .room import RoomConfig
 from .room_directory import RoomDirectoryConfig
 from .saml2 import SAML2Config
@@ -59,6 +60,7 @@ class HomeServerConfig(RootConfig):
     config_classes = [
         ModulesConfig,
         ServerConfig,
+        RetentionConfig,
         TlsConfig,
         FederationConfig,
         CacheConfig,
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index 83994df798..f980102b45 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -25,6 +25,29 @@ class PasswordAuthProviderConfig(Config):
     section = "authproviders"
 
     def read_config(self, config, **kwargs):
+        """Parses the old password auth providers config. The config format looks like this:
+
+        password_providers:
+           # Example config for an LDAP auth provider
+           - module: "ldap_auth_provider.LdapAuthProvider"
+             config:
+               enabled: true
+               uri: "ldap://ldap.example.com:389"
+               start_tls: true
+               base: "ou=users,dc=example,dc=com"
+               attributes:
+                  uid: "cn"
+                  mail: "email"
+                  name: "givenName"
+               #bind_dn:
+               #bind_password:
+               #filter: "(objectClass=posixAccount)"
+
+        We expect admins to use modules for this feature (which is why it doesn't appear
+        in the sample config file), but we want to keep support for it around for a bit
+        for backwards compatibility.
+        """
+
         self.password_providers: List[Tuple[Type, Any]] = []
         providers = []
 
@@ -49,33 +72,3 @@ class PasswordAuthProviderConfig(Config):
             )
 
             self.password_providers.append((provider_class, provider_config))
-
-    def generate_config_section(self, **kwargs):
-        return """\
-        # Password providers allow homeserver administrators to integrate
-        # their Synapse installation with existing authentication methods
-        # ex. LDAP, external tokens, etc.
-        #
-        # For more information and known implementations, please see
-        # https://matrix-org.github.io/synapse/latest/password_auth_providers.html
-        #
-        # Note: instances wishing to use SAML or CAS authentication should
-        # instead use the `saml2_config` or `cas_config` options,
-        # respectively.
-        #
-        password_providers:
-        #    # Example config for an LDAP auth provider
-        #    - module: "ldap_auth_provider.LdapAuthProvider"
-        #      config:
-        #        enabled: true
-        #        uri: "ldap://ldap.example.com:389"
-        #        start_tls: true
-        #        base: "ou=users,dc=example,dc=com"
-        #        attributes:
-        #           uid: "cn"
-        #           mail: "email"
-        #           name: "givenName"
-        #        #bind_dn:
-        #        #bind_password:
-        #        #filter: "(objectClass=posixAccount)"
-        """
diff --git a/synapse/config/retention.py b/synapse/config/retention.py
new file mode 100644
index 0000000000..aed9bf458f
--- /dev/null
+++ b/synapse/config/retention.py
@@ -0,0 +1,226 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+import logging
+from typing import List, Optional
+
+import attr
+
+from synapse.config._base import Config, ConfigError
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RetentionPurgeJob:
+    """Object describing the configuration of the manhole"""
+
+    interval: int
+    shortest_max_lifetime: Optional[int]
+    longest_max_lifetime: Optional[int]
+
+
+class RetentionConfig(Config):
+    section = "retention"
+
+    def read_config(self, config, **kwargs):
+        retention_config = config.get("retention")
+        if retention_config is None:
+            retention_config = {}
+
+        self.retention_enabled = retention_config.get("enabled", False)
+
+        retention_default_policy = retention_config.get("default_policy")
+
+        if retention_default_policy is not None:
+            self.retention_default_min_lifetime = retention_default_policy.get(
+                "min_lifetime"
+            )
+            if self.retention_default_min_lifetime is not None:
+                self.retention_default_min_lifetime = self.parse_duration(
+                    self.retention_default_min_lifetime
+                )
+
+            self.retention_default_max_lifetime = retention_default_policy.get(
+                "max_lifetime"
+            )
+            if self.retention_default_max_lifetime is not None:
+                self.retention_default_max_lifetime = self.parse_duration(
+                    self.retention_default_max_lifetime
+                )
+
+            if (
+                self.retention_default_min_lifetime is not None
+                and self.retention_default_max_lifetime is not None
+                and (
+                    self.retention_default_min_lifetime
+                    > self.retention_default_max_lifetime
+                )
+            ):
+                raise ConfigError(
+                    "The default retention policy's 'min_lifetime' can not be greater"
+                    " than its 'max_lifetime'"
+                )
+        else:
+            self.retention_default_min_lifetime = None
+            self.retention_default_max_lifetime = None
+
+        if self.retention_enabled:
+            logger.info(
+                "Message retention policies support enabled with the following default"
+                " policy: min_lifetime = %s ; max_lifetime = %s",
+                self.retention_default_min_lifetime,
+                self.retention_default_max_lifetime,
+            )
+
+        self.retention_allowed_lifetime_min = retention_config.get(
+            "allowed_lifetime_min"
+        )
+        if self.retention_allowed_lifetime_min is not None:
+            self.retention_allowed_lifetime_min = self.parse_duration(
+                self.retention_allowed_lifetime_min
+            )
+
+        self.retention_allowed_lifetime_max = retention_config.get(
+            "allowed_lifetime_max"
+        )
+        if self.retention_allowed_lifetime_max is not None:
+            self.retention_allowed_lifetime_max = self.parse_duration(
+                self.retention_allowed_lifetime_max
+            )
+
+        if (
+            self.retention_allowed_lifetime_min is not None
+            and self.retention_allowed_lifetime_max is not None
+            and self.retention_allowed_lifetime_min
+            > self.retention_allowed_lifetime_max
+        ):
+            raise ConfigError(
+                "Invalid retention policy limits: 'allowed_lifetime_min' can not be"
+                " greater than 'allowed_lifetime_max'"
+            )
+
+        self.retention_purge_jobs: List[RetentionPurgeJob] = []
+        for purge_job_config in retention_config.get("purge_jobs", []):
+            interval_config = purge_job_config.get("interval")
+
+            if interval_config is None:
+                raise ConfigError(
+                    "A retention policy's purge jobs configuration must have the"
+                    " 'interval' key set."
+                )
+
+            interval = self.parse_duration(interval_config)
+
+            shortest_max_lifetime = purge_job_config.get("shortest_max_lifetime")
+
+            if shortest_max_lifetime is not None:
+                shortest_max_lifetime = self.parse_duration(shortest_max_lifetime)
+
+            longest_max_lifetime = purge_job_config.get("longest_max_lifetime")
+
+            if longest_max_lifetime is not None:
+                longest_max_lifetime = self.parse_duration(longest_max_lifetime)
+
+            if (
+                shortest_max_lifetime is not None
+                and longest_max_lifetime is not None
+                and shortest_max_lifetime > longest_max_lifetime
+            ):
+                raise ConfigError(
+                    "A retention policy's purge jobs configuration's"
+                    " 'shortest_max_lifetime' value can not be greater than its"
+                    " 'longest_max_lifetime' value."
+                )
+
+            self.retention_purge_jobs.append(
+                RetentionPurgeJob(interval, shortest_max_lifetime, longest_max_lifetime)
+            )
+
+        if not self.retention_purge_jobs:
+            self.retention_purge_jobs = [
+                RetentionPurgeJob(self.parse_duration("1d"), None, None)
+            ]
+
+    def generate_config_section(self, config_dir_path, server_name, **kwargs):
+        return """\
+        # Message retention policy at the server level.
+        #
+        # Room admins and mods can define a retention period for their rooms using the
+        # 'm.room.retention' state event, and server admins can cap this period by setting
+        # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
+        #
+        # If this feature is enabled, Synapse will regularly look for and purge events
+        # which are older than the room's maximum retention period. Synapse will also
+        # filter events received over federation so that events that should have been
+        # purged are ignored and not stored again.
+        #
+        retention:
+          # The message retention policies feature is disabled by default. Uncomment the
+          # following line to enable it.
+          #
+          #enabled: true
+
+          # Default retention policy. If set, Synapse will apply it to rooms that lack the
+          # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
+          # matter much because Synapse doesn't take it into account yet.
+          #
+          #default_policy:
+          #  min_lifetime: 1d
+          #  max_lifetime: 1y
+
+          # Retention policy limits. If set, and the state of a room contains a
+          # 'm.room.retention' event in its state which contains a 'min_lifetime' or a
+          # 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy
+          # to these limits when running purge jobs.
+          #
+          #allowed_lifetime_min: 1d
+          #allowed_lifetime_max: 1y
+
+          # Server admins can define the settings of the background jobs purging the
+          # events which lifetime has expired under the 'purge_jobs' section.
+          #
+          # If no configuration is provided, a single job will be set up to delete expired
+          # events in every room daily.
+          #
+          # Each job's configuration defines which range of message lifetimes the job
+          # takes care of. For example, if 'shortest_max_lifetime' is '2d' and
+          # 'longest_max_lifetime' is '3d', the job will handle purging expired events in
+          # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
+          # lower than or equal to 3 days. Both the minimum and the maximum value of a
+          # range are optional, e.g. a job with no 'shortest_max_lifetime' and a
+          # 'longest_max_lifetime' of '3d' will handle every room with a retention policy
+          # which 'max_lifetime' is lower than or equal to three days.
+          #
+          # The rationale for this per-job configuration is that some rooms might have a
+          # retention policy with a low 'max_lifetime', where history needs to be purged
+          # of outdated messages on a more frequent basis than for the rest of the rooms
+          # (e.g. every 12h), but not want that purge to be performed by a job that's
+          # iterating over every room it knows, which could be heavy on the server.
+          #
+          # If any purge job is configured, it is strongly recommended to have at least
+          # a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime'
+          # set, or one job without 'shortest_max_lifetime' and one job without
+          # 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if
+          # 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a
+          # room's policy to these values is done after the policies are retrieved from
+          # Synapse's database (which is done using the range specified in a purge job's
+          # configuration).
+          #
+          #purge_jobs:
+          #  - longest_max_lifetime: 3d
+          #    interval: 12h
+          #  - shortest_max_lifetime: 3d
+          #    interval: 1d
+        """
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 818b806357..ed094bdc44 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -225,15 +225,6 @@ class ManholeConfig:
     pub_key: Optional[Key]
 
 
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class RetentionConfig:
-    """Object describing the configuration of the manhole"""
-
-    interval: int
-    shortest_max_lifetime: Optional[int]
-    longest_max_lifetime: Optional[int]
-
-
 @attr.s(frozen=True)
 class LimitRemoteRoomsConfig:
     enabled: bool = attr.ib(validator=attr.validators.instance_of(bool), default=False)
@@ -376,11 +367,6 @@ class ServerConfig(Config):
         # (other than those sent by local server admins)
         self.block_non_admin_invites = config.get("block_non_admin_invites", False)
 
-        # Whether to enable experimental MSC1849 (aka relations) support
-        self.experimental_msc1849_support_enabled = config.get(
-            "experimental_msc1849_support_enabled", True
-        )
-
         # Options to control access by tracking MAU
         self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
         self.max_mau_value = 0
@@ -466,124 +452,6 @@ class ServerConfig(Config):
         # events with profile information that differ from the target's global profile.
         self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
 
-        retention_config = config.get("retention")
-        if retention_config is None:
-            retention_config = {}
-
-        self.retention_enabled = retention_config.get("enabled", False)
-
-        retention_default_policy = retention_config.get("default_policy")
-
-        if retention_default_policy is not None:
-            self.retention_default_min_lifetime = retention_default_policy.get(
-                "min_lifetime"
-            )
-            if self.retention_default_min_lifetime is not None:
-                self.retention_default_min_lifetime = self.parse_duration(
-                    self.retention_default_min_lifetime
-                )
-
-            self.retention_default_max_lifetime = retention_default_policy.get(
-                "max_lifetime"
-            )
-            if self.retention_default_max_lifetime is not None:
-                self.retention_default_max_lifetime = self.parse_duration(
-                    self.retention_default_max_lifetime
-                )
-
-            if (
-                self.retention_default_min_lifetime is not None
-                and self.retention_default_max_lifetime is not None
-                and (
-                    self.retention_default_min_lifetime
-                    > self.retention_default_max_lifetime
-                )
-            ):
-                raise ConfigError(
-                    "The default retention policy's 'min_lifetime' can not be greater"
-                    " than its 'max_lifetime'"
-                )
-        else:
-            self.retention_default_min_lifetime = None
-            self.retention_default_max_lifetime = None
-
-        if self.retention_enabled:
-            logger.info(
-                "Message retention policies support enabled with the following default"
-                " policy: min_lifetime = %s ; max_lifetime = %s",
-                self.retention_default_min_lifetime,
-                self.retention_default_max_lifetime,
-            )
-
-        self.retention_allowed_lifetime_min = retention_config.get(
-            "allowed_lifetime_min"
-        )
-        if self.retention_allowed_lifetime_min is not None:
-            self.retention_allowed_lifetime_min = self.parse_duration(
-                self.retention_allowed_lifetime_min
-            )
-
-        self.retention_allowed_lifetime_max = retention_config.get(
-            "allowed_lifetime_max"
-        )
-        if self.retention_allowed_lifetime_max is not None:
-            self.retention_allowed_lifetime_max = self.parse_duration(
-                self.retention_allowed_lifetime_max
-            )
-
-        if (
-            self.retention_allowed_lifetime_min is not None
-            and self.retention_allowed_lifetime_max is not None
-            and self.retention_allowed_lifetime_min
-            > self.retention_allowed_lifetime_max
-        ):
-            raise ConfigError(
-                "Invalid retention policy limits: 'allowed_lifetime_min' can not be"
-                " greater than 'allowed_lifetime_max'"
-            )
-
-        self.retention_purge_jobs: List[RetentionConfig] = []
-        for purge_job_config in retention_config.get("purge_jobs", []):
-            interval_config = purge_job_config.get("interval")
-
-            if interval_config is None:
-                raise ConfigError(
-                    "A retention policy's purge jobs configuration must have the"
-                    " 'interval' key set."
-                )
-
-            interval = self.parse_duration(interval_config)
-
-            shortest_max_lifetime = purge_job_config.get("shortest_max_lifetime")
-
-            if shortest_max_lifetime is not None:
-                shortest_max_lifetime = self.parse_duration(shortest_max_lifetime)
-
-            longest_max_lifetime = purge_job_config.get("longest_max_lifetime")
-
-            if longest_max_lifetime is not None:
-                longest_max_lifetime = self.parse_duration(longest_max_lifetime)
-
-            if (
-                shortest_max_lifetime is not None
-                and longest_max_lifetime is not None
-                and shortest_max_lifetime > longest_max_lifetime
-            ):
-                raise ConfigError(
-                    "A retention policy's purge jobs configuration's"
-                    " 'shortest_max_lifetime' value can not be greater than its"
-                    " 'longest_max_lifetime' value."
-                )
-
-            self.retention_purge_jobs.append(
-                RetentionConfig(interval, shortest_max_lifetime, longest_max_lifetime)
-            )
-
-        if not self.retention_purge_jobs:
-            self.retention_purge_jobs = [
-                RetentionConfig(self.parse_duration("1d"), None, None)
-            ]
-
         self.listeners = [parse_listener_def(x) for x in config.get("listeners", [])]
 
         # no_tls is not really supported any more, but let's grandfather it in
@@ -1255,75 +1123,6 @@ class ServerConfig(Config):
         #
         #user_ips_max_age: 14d
 
-        # Message retention policy at the server level.
-        #
-        # Room admins and mods can define a retention period for their rooms using the
-        # 'm.room.retention' state event, and server admins can cap this period by setting
-        # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
-        #
-        # If this feature is enabled, Synapse will regularly look for and purge events
-        # which are older than the room's maximum retention period. Synapse will also
-        # filter events received over federation so that events that should have been
-        # purged are ignored and not stored again.
-        #
-        retention:
-          # The message retention policies feature is disabled by default. Uncomment the
-          # following line to enable it.
-          #
-          #enabled: true
-
-          # Default retention policy. If set, Synapse will apply it to rooms that lack the
-          # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
-          # matter much because Synapse doesn't take it into account yet.
-          #
-          #default_policy:
-          #  min_lifetime: 1d
-          #  max_lifetime: 1y
-
-          # Retention policy limits. If set, and the state of a room contains a
-          # 'm.room.retention' event in its state which contains a 'min_lifetime' or a
-          # 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy
-          # to these limits when running purge jobs.
-          #
-          #allowed_lifetime_min: 1d
-          #allowed_lifetime_max: 1y
-
-          # Server admins can define the settings of the background jobs purging the
-          # events which lifetime has expired under the 'purge_jobs' section.
-          #
-          # If no configuration is provided, a single job will be set up to delete expired
-          # events in every room daily.
-          #
-          # Each job's configuration defines which range of message lifetimes the job
-          # takes care of. For example, if 'shortest_max_lifetime' is '2d' and
-          # 'longest_max_lifetime' is '3d', the job will handle purging expired events in
-          # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
-          # lower than or equal to 3 days. Both the minimum and the maximum value of a
-          # range are optional, e.g. a job with no 'shortest_max_lifetime' and a
-          # 'longest_max_lifetime' of '3d' will handle every room with a retention policy
-          # which 'max_lifetime' is lower than or equal to three days.
-          #
-          # The rationale for this per-job configuration is that some rooms might have a
-          # retention policy with a low 'max_lifetime', where history needs to be purged
-          # of outdated messages on a more frequent basis than for the rest of the rooms
-          # (e.g. every 12h), but not want that purge to be performed by a job that's
-          # iterating over every room it knows, which could be heavy on the server.
-          #
-          # If any purge job is configured, it is strongly recommended to have at least
-          # a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime'
-          # set, or one job without 'shortest_max_lifetime' and one job without
-          # 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if
-          # 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a
-          # room's policy to these values is done after the policies are retrieved from
-          # Synapse's database (which is done using the range specified in a purge job's
-          # configuration).
-          #
-          #purge_jobs:
-          #  - longest_max_lifetime: 3d
-          #    interval: 12h
-          #  - shortest_max_lifetime: 3d
-          #    interval: 1d
-
         # Inhibits the /requestToken endpoints from returning an error that might leak
         # information about whether an e-mail address is in use or not on this
         # homeserver.
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index ca0293a3dc..e885961698 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, List, Optional, Set, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 from canonicaljson import encode_canonical_json
 from signedjson.key import decode_verify_key_bytes
@@ -113,7 +113,7 @@ def validate_event_for_room_version(
 
 
 def check_auth_rules_for_event(
-    room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase]
+    room_version_obj: RoomVersion, event: EventBase, auth_events: Iterable[EventBase]
 ) -> None:
     """Check that an event complies with the auth rules
 
@@ -137,8 +137,6 @@ def check_auth_rules_for_event(
     Raises:
         AuthError if the checks fail
     """
-    assert isinstance(auth_events, dict)
-
     # We need to ensure that the auth events are actually for the same room, to
     # stop people from using powers they've been granted in other rooms for
     # example.
@@ -147,7 +145,7 @@ def check_auth_rules_for_event(
     # the state res algorithm isn't silly enough to give us events from different rooms.
     # Still, it's easier to do it anyway.
     room_id = event.room_id
-    for auth_event in auth_events.values():
+    for auth_event in auth_events:
         if auth_event.room_id != room_id:
             raise AuthError(
                 403,
@@ -186,8 +184,10 @@ def check_auth_rules_for_event(
         logger.debug("Allowing! %s", event)
         return
 
+    auth_dict = {(e.type, e.state_key): e for e in auth_events}
+
     # 3. If event does not have a m.room.create in its auth_events, reject.
-    creation_event = auth_events.get((EventTypes.Create, ""), None)
+    creation_event = auth_dict.get((EventTypes.Create, ""), None)
     if not creation_event:
         raise AuthError(403, "No create event in auth events")
 
@@ -195,7 +195,7 @@ def check_auth_rules_for_event(
     creating_domain = get_domain_from_id(event.room_id)
     originating_domain = get_domain_from_id(event.sender)
     if creating_domain != originating_domain:
-        if not _can_federate(event, auth_events):
+        if not _can_federate(event, auth_dict):
             raise AuthError(403, "This room has been marked as unfederatable.")
 
     # 4. If type is m.room.aliases
@@ -217,23 +217,20 @@ def check_auth_rules_for_event(
         logger.debug("Allowing! %s", event)
         return
 
-    if logger.isEnabledFor(logging.DEBUG):
-        logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()])
-
     # 5. If type is m.room.membership
     if event.type == EventTypes.Member:
-        _is_membership_change_allowed(room_version_obj, event, auth_events)
+        _is_membership_change_allowed(room_version_obj, event, auth_dict)
         logger.debug("Allowing! %s", event)
         return
 
-    _check_event_sender_in_room(event, auth_events)
+    _check_event_sender_in_room(event, auth_dict)
 
     # Special case to allow m.room.third_party_invite events wherever
     # a user is allowed to issue invites.  Fixes
     # https://github.com/vector-im/vector-web/issues/1208 hopefully
     if event.type == EventTypes.ThirdPartyInvite:
-        user_level = get_user_power_level(event.user_id, auth_events)
-        invite_level = get_named_level(auth_events, "invite", 0)
+        user_level = get_user_power_level(event.user_id, auth_dict)
+        invite_level = get_named_level(auth_dict, "invite", 0)
 
         if user_level < invite_level:
             raise AuthError(403, "You don't have permission to invite users")
@@ -241,20 +238,20 @@ def check_auth_rules_for_event(
             logger.debug("Allowing! %s", event)
             return
 
-    _can_send_event(event, auth_events)
+    _can_send_event(event, auth_dict)
 
     if event.type == EventTypes.PowerLevels:
-        _check_power_levels(room_version_obj, event, auth_events)
+        _check_power_levels(room_version_obj, event, auth_dict)
 
     if event.type == EventTypes.Redaction:
-        check_redaction(room_version_obj, event, auth_events)
+        check_redaction(room_version_obj, event, auth_dict)
 
     if (
         event.type == EventTypes.MSC2716_INSERTION
         or event.type == EventTypes.MSC2716_BATCH
         or event.type == EventTypes.MSC2716_MARKER
     ):
-        check_historical(room_version_obj, event, auth_events)
+        check_historical(room_version_obj, event, auth_dict)
 
     logger.debug("Allowing! %s", event)
 
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 49190459c8..157669ea88 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -348,12 +348,16 @@ class EventBase(metaclass=abc.ABCMeta):
         return self.__repr__()
 
     def __repr__(self):
-        return "<%s event_id=%r, type=%r, state_key=%r, outlier=%s>" % (
-            self.__class__.__name__,
-            self.event_id,
-            self.get("type", None),
-            self.get("state_key", None),
-            self.internal_metadata.is_outlier(),
+        rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else ""
+
+        return (
+            f"<{self.__class__.__name__} "
+            f"{rejection}"
+            f"event_id={self.event_id}, "
+            f"type={self.get('type')}, "
+            f"state_key={self.get('state_key')}, "
+            f"outlier={self.internal_metadata.is_outlier()}"
+            ">"
         )
 
 
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 50f2a4c1f4..4f409f31e1 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -90,13 +90,13 @@ class EventBuilder:
     )
 
     @property
-    def state_key(self):
+    def state_key(self) -> str:
         if self._state_key is not None:
             return self._state_key
 
         raise AttributeError("state_key")
 
-    def is_state(self):
+    def is_state(self) -> bool:
         return self._state_key is not None
 
     async def build(
diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py
index 68b8b19024..a58f313e8b 100644
--- a/synapse/events/presence_router.py
+++ b/synapse/events/presence_router.py
@@ -14,6 +14,7 @@
 import logging
 from typing import (
     TYPE_CHECKING,
+    Any,
     Awaitable,
     Callable,
     Dict,
@@ -33,14 +34,13 @@ if TYPE_CHECKING:
 GET_USERS_FOR_STATES_CALLBACK = Callable[
     [Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
 ]
-GET_INTERESTED_USERS_CALLBACK = Callable[
-    [str], Awaitable[Union[Set[str], "PresenceRouter.ALL_USERS"]]
-]
+# This must either return a set of strings or the constant PresenceRouter.ALL_USERS.
+GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]]
 
 logger = logging.getLogger(__name__)
 
 
-def load_legacy_presence_router(hs: "HomeServer"):
+def load_legacy_presence_router(hs: "HomeServer") -> None:
     """Wrapper that loads a presence router module configured using the old
     configuration, and registers the hooks they implement.
     """
@@ -69,9 +69,10 @@ def load_legacy_presence_router(hs: "HomeServer"):
         if f is None:
             return None
 
-        def run(*args, **kwargs):
-            # mypy doesn't do well across function boundaries so we need to tell it
-            # f is definitely not None.
+        def run(*args: Any, **kwargs: Any) -> Awaitable:
+            # Assertion required because mypy can't prove we won't change `f`
+            # back to `None`. See
+            # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
             assert f is not None
 
             return maybe_awaitable(f(*args, **kwargs))
@@ -104,7 +105,7 @@ class PresenceRouter:
         self,
         get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
         get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
-    ):
+    ) -> None:
         # PresenceRouter modules are required to implement both of these methods
         # or neither of them as they are assumed to act in a complementary manner
         paired_methods = [get_users_for_states, get_interested_users]
@@ -142,7 +143,7 @@ class PresenceRouter:
             # Don't include any extra destinations for presence updates
             return {}
 
-        users_for_states = {}
+        users_for_states: Dict[str, Set[UserPresenceState]] = {}
         # run all the callbacks for get_users_for_states and combine the results
         for callback in self._get_users_for_states_callbacks:
             try:
@@ -171,7 +172,7 @@ class PresenceRouter:
 
         return users_for_states
 
-    async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]:
+    async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
         """
         Retrieve a list of users that `user_id` is interested in receiving the
         presence of. This will be in addition to those they share a room with.
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 5ba01eeef9..d7527008c4 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -11,17 +11,20 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import TYPE_CHECKING, Optional, Union
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
 
 import attr
 from frozendict import frozendict
 
+from twisted.internet.defer import Deferred
+
 from synapse.appservice import ApplicationService
 from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable, run_in_background
-from synapse.types import StateMap
+from synapse.types import JsonDict, StateMap
 
 if TYPE_CHECKING:
+    from synapse.storage import Storage
     from synapse.storage.databases.main import DataStore
 
 
@@ -112,13 +115,13 @@ class EventContext:
 
     @staticmethod
     def with_state(
-        state_group,
-        state_group_before_event,
-        current_state_ids,
-        prev_state_ids,
-        prev_group=None,
-        delta_ids=None,
-    ):
+        state_group: Optional[int],
+        state_group_before_event: Optional[int],
+        current_state_ids: Optional[StateMap[str]],
+        prev_state_ids: Optional[StateMap[str]],
+        prev_group: Optional[int] = None,
+        delta_ids: Optional[StateMap[str]] = None,
+    ) -> "EventContext":
         return EventContext(
             current_state_ids=current_state_ids,
             prev_state_ids=prev_state_ids,
@@ -129,22 +132,22 @@ class EventContext:
         )
 
     @staticmethod
-    def for_outlier():
+    def for_outlier() -> "EventContext":
         """Return an EventContext instance suitable for persisting an outlier event"""
         return EventContext(
             current_state_ids={},
             prev_state_ids={},
         )
 
-    async def serialize(self, event: EventBase, store: "DataStore") -> dict:
+    async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
         """Converts self to a type that can be serialized as JSON, and then
         deserialized by `deserialize`
 
         Args:
-            event (FrozenEvent): The event that this context relates to
+            event: The event that this context relates to
 
         Returns:
-            dict
+            The serialized event.
         """
 
         # We don't serialize the full state dicts, instead they get pulled out
@@ -170,17 +173,16 @@ class EventContext:
         }
 
     @staticmethod
-    def deserialize(storage, input):
+    def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
         """Converts a dict that was produced by `serialize` back into a
         EventContext.
 
         Args:
-            storage (Storage): Used to convert AS ID to AS object and fetch
-                state.
-            input (dict): A dict produced by `serialize`
+            storage: Used to convert AS ID to AS object and fetch state.
+            input: A dict produced by `serialize`
 
         Returns:
-            EventContext
+            The event context.
         """
         context = _AsyncEventContextImpl(
             # We use the state_group and prev_state_id stuff to pull the
@@ -241,22 +243,25 @@ class EventContext:
         await self._ensure_fetched()
         return self._current_state_ids
 
-    async def get_prev_state_ids(self):
+    async def get_prev_state_ids(self) -> StateMap[str]:
         """
         Gets the room state map, excluding this event.
 
         For a non-state event, this will be the same as get_current_state_ids().
 
         Returns:
-            dict[(str, str), str]|None: Returns None if state_group
-                is None, which happens when the associated event is an outlier.
-                Maps a (type, state_key) to the event ID of the state event matching
-                this tuple.
+            Returns {} if state_group is None, which happens when the associated
+            event is an outlier.
+
+            Maps a (type, state_key) to the event ID of the state event matching
+            this tuple.
         """
         await self._ensure_fetched()
+        # There *should* be previous state IDs now.
+        assert self._prev_state_ids is not None
         return self._prev_state_ids
 
-    def get_cached_current_state_ids(self):
+    def get_cached_current_state_ids(self) -> Optional[StateMap[str]]:
         """Gets the current state IDs if we have them already cached.
 
         It is an error to access this for a rejected event, since rejected state should
@@ -264,16 +269,17 @@ class EventContext:
         ``rejected`` is set.
 
         Returns:
-            dict[(str, str), str]|None: Returns None if we haven't cached the
-            state or if state_group is None, which happens when the associated
-            event is an outlier.
+            Returns None if we haven't cached the state or if state_group is None
+            (which happens when the associated event is an outlier).
+
+            Otherwise, returns the the current state IDs.
         """
         if self.rejected:
             raise RuntimeError("Attempt to access state_ids of rejected event")
 
         return self._current_state_ids
 
-    async def _ensure_fetched(self):
+    async def _ensure_fetched(self) -> None:
         return None
 
 
@@ -285,46 +291,46 @@ class _AsyncEventContextImpl(EventContext):
 
     Attributes:
 
-        _storage (Storage)
+        _storage
 
-        _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
-            been calculated. None if we haven't started calculating yet
+        _fetching_state_deferred: Resolves when *_state_ids have been calculated.
+            None if we haven't started calculating yet
 
-        _event_type (str): The type of the event the context is associated with.
+        _event_type: The type of the event the context is associated with.
 
-        _event_state_key (str): The state_key of the event the context is
-            associated with.
+        _event_state_key: The state_key of the event the context is associated with.
 
-        _prev_state_id (str|None): If the event associated with the context is
-            a state event, then `_prev_state_id` is the event_id of the state
-            that was replaced.
+        _prev_state_id: If the event associated with the context is a state event,
+            then `_prev_state_id` is the event_id of the state that was replaced.
     """
 
     # This needs to have a default as we're inheriting
-    _storage = attr.ib(default=None)
-    _prev_state_id = attr.ib(default=None)
-    _event_type = attr.ib(default=None)
-    _event_state_key = attr.ib(default=None)
-    _fetching_state_deferred = attr.ib(default=None)
+    _storage: "Storage" = attr.ib(default=None)
+    _prev_state_id: Optional[str] = attr.ib(default=None)
+    _event_type: str = attr.ib(default=None)
+    _event_state_key: Optional[str] = attr.ib(default=None)
+    _fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None)
 
-    async def _ensure_fetched(self):
+    async def _ensure_fetched(self) -> None:
         if not self._fetching_state_deferred:
             self._fetching_state_deferred = run_in_background(self._fill_out_state)
 
-        return await make_deferred_yieldable(self._fetching_state_deferred)
+        await make_deferred_yieldable(self._fetching_state_deferred)
 
-    async def _fill_out_state(self):
+    async def _fill_out_state(self) -> None:
         """Called to populate the _current_state_ids and _prev_state_ids
         attributes by loading from the database.
         """
         if self.state_group is None:
             return
 
-        self._current_state_ids = await self._storage.state.get_state_ids_for_group(
+        current_state_ids = await self._storage.state.get_state_ids_for_group(
             self.state_group
         )
+        # Set this separately so mypy knows current_state_ids is not None.
+        self._current_state_ids = current_state_ids
         if self._event_state_key is not None:
-            self._prev_state_ids = dict(self._current_state_ids)
+            self._prev_state_ids = dict(current_state_ids)
 
             key = (self._event_type, self._event_state_key)
             if self._prev_state_id:
@@ -332,10 +338,12 @@ class _AsyncEventContextImpl(EventContext):
             else:
                 self._prev_state_ids.pop(key, None)
         else:
-            self._prev_state_ids = self._current_state_ids
+            self._prev_state_ids = current_state_ids
 
 
-def _encode_state_dict(state_dict):
+def _encode_state_dict(
+    state_dict: Optional[StateMap[str]],
+) -> Optional[List[Tuple[str, str, str]]]:
     """Since dicts of (type, state_key) -> event_id cannot be serialized in
     JSON we need to convert them to a form that can.
     """
@@ -345,7 +353,9 @@ def _encode_state_dict(state_dict):
     return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()]
 
 
-def _decode_state_dict(input):
+def _decode_state_dict(
+    input: Optional[List[Tuple[str, str, str]]]
+) -> Optional[StateMap[str]]:
     """Decodes a state dict encoded using `_encode_state_dict` above"""
     if input is None:
         return None
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index ae4c8ab257..3134beb8d3 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -77,7 +77,7 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
 ]
 
 
-def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
+def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
     """Wrapper that loads spam checkers configured using the old configuration, and
     registers the spam checker hooks they implement.
     """
@@ -129,9 +129,9 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
                         request_info: Collection[Tuple[str, str]],
                         auth_provider_id: Optional[str],
                     ) -> Union[Awaitable[RegistrationBehaviour], RegistrationBehaviour]:
-                        # We've already made sure f is not None above, but mypy doesn't
-                        # do well across function boundaries so we need to tell it f is
-                        # definitely not None.
+                        # Assertion required because mypy can't prove we won't
+                        # change `f` back to `None`. See
+                        # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
                         assert f is not None
 
                         return f(
@@ -146,9 +146,10 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
                         "Bad signature for callback check_registration_for_spam",
                     )
 
-            def run(*args, **kwargs):
-                # mypy doesn't do well across function boundaries so we need to tell it
-                # wrapped_func is definitely not None.
+            def run(*args: Any, **kwargs: Any) -> Awaitable:
+                # Assertion required because mypy can't prove we won't change `f`
+                # back to `None`. See
+                # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
                 assert wrapped_func is not None
 
                 return maybe_awaitable(wrapped_func(*args, **kwargs))
@@ -165,7 +166,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
 
 
 class SpamChecker:
-    def __init__(self):
+    def __init__(self) -> None:
         self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
         self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
         self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
@@ -209,7 +210,7 @@ class SpamChecker:
             CHECK_REGISTRATION_FOR_SPAM_CALLBACK
         ] = None,
         check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
-    ):
+    ) -> None:
         """Register callbacks from module for each hook."""
         if check_event_for_spam is not None:
             self._check_event_for_spam_callbacks.append(check_event_for_spam)
@@ -275,7 +276,9 @@ class SpamChecker:
 
         return False
 
-    async def user_may_join_room(self, user_id: str, room_id: str, is_invited: bool):
+    async def user_may_join_room(
+        self, user_id: str, room_id: str, is_invited: bool
+    ) -> bool:
         """Checks if a given users is allowed to join a room.
         Not called when a user creates a room.
 
@@ -285,7 +288,7 @@ class SpamChecker:
             is_invited: Whether the user is invited into the room
 
         Returns:
-            bool: Whether the user may join the room
+            Whether the user may join the room
         """
         for callback in self._user_may_join_room_callbacks:
             if await callback(user_id, room_id, is_invited) is False:
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 976d9fa446..2a6dabdab6 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple
 
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
@@ -38,7 +38,7 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
 ]
 
 
-def load_legacy_third_party_event_rules(hs: "HomeServer"):
+def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
     """Wrapper that loads a third party event rules module configured using the old
     configuration, and registers the hooks they implement.
     """
@@ -77,9 +77,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
                 event: EventBase,
                 state_events: StateMap[EventBase],
             ) -> Tuple[bool, Optional[dict]]:
-                # We've already made sure f is not None above, but mypy doesn't do well
-                # across function boundaries so we need to tell it f is definitely not
-                # None.
+                # Assertion required because mypy can't prove we won't change
+                # `f` back to `None`. See
+                # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
                 assert f is not None
 
                 res = await f(event, state_events)
@@ -98,9 +98,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
             async def wrap_on_create_room(
                 requester: Requester, config: dict, is_requester_admin: bool
             ) -> None:
-                # We've already made sure f is not None above, but mypy doesn't do well
-                # across function boundaries so we need to tell it f is definitely not
-                # None.
+                # Assertion required because mypy can't prove we won't change
+                # `f` back to `None`. See
+                # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
                 assert f is not None
 
                 res = await f(requester, config, is_requester_admin)
@@ -112,9 +112,10 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
 
             return wrap_on_create_room
 
-        def run(*args, **kwargs):
-            # mypy doesn't do well across function boundaries so we need to tell it
-            # f is definitely not None.
+        def run(*args: Any, **kwargs: Any) -> Awaitable:
+            # Assertion required because mypy can't prove we won't change  `f`
+            # back to `None`. See
+            # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
             assert f is not None
 
             return maybe_awaitable(f(*args, **kwargs))
@@ -162,7 +163,7 @@ class ThirdPartyEventRules:
         check_visibility_can_be_modified: Optional[
             CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
         ] = None,
-    ):
+    ) -> None:
         """Register callbacks from modules for each hook."""
         if check_event_allowed is not None:
             self._check_event_allowed_callbacks.append(check_event_allowed)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 520edbbf61..3f3eba86a8 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -13,18 +13,32 @@
 # limitations under the License.
 import collections.abc
 import re
-from typing import Any, Mapping, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Union,
+)
 
 from frozendict import frozendict
 
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersion
+from synapse.types import JsonDict
 from synapse.util.async_helpers import yieldable_gather_results
 from synapse.util.frozenutils import unfreeze
 
 from . import EventBase
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
 # (?<!stuff) matches if the current position in the string is not preceded
 # by a match for 'stuff'.
@@ -65,7 +79,7 @@ def prune_event(event: EventBase) -> EventBase:
     return pruned_event
 
 
-def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
+def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict:
     """Redacts the event_dict in the same way as `prune_event`, except it
     operates on dicts rather than event objects
 
@@ -97,7 +111,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
 
     new_content = {}
 
-    def add_fields(*fields):
+    def add_fields(*fields: str) -> None:
         for field in fields:
             if field in event_dict["content"]:
                 new_content[field] = event_dict["content"][field]
@@ -151,7 +165,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
 
     allowed_fields["content"] = new_content
 
-    unsigned = {}
+    unsigned: JsonDict = {}
     allowed_fields["unsigned"] = unsigned
 
     event_unsigned = event_dict.get("unsigned", {})
@@ -164,16 +178,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
     return allowed_fields
 
 
-def _copy_field(src, dst, field):
+def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None:
     """Copy the field in 'src' to 'dst'.
 
     For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
     then dst={"foo":{"bar":5}}.
 
     Args:
-        src(dict): The dict to read from.
-        dst(dict): The dict to modify.
-        field(list<str>): List of keys to drill down to in 'src'.
+        src: The dict to read from.
+        dst: The dict to modify.
+        field: List of keys to drill down to in 'src'.
     """
     if len(field) == 0:  # this should be impossible
         return
@@ -205,7 +219,7 @@ def _copy_field(src, dst, field):
     sub_out_dict[key_to_move] = sub_dict[key_to_move]
 
 
-def only_fields(dictionary, fields):
+def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict:
     """Return a new dict with only the fields in 'dictionary' which are present
     in 'fields'.
 
@@ -215,11 +229,11 @@ def only_fields(dictionary, fields):
     A literal '.' character in a field name may be escaped using a '\'.
 
     Args:
-        dictionary(dict): The dictionary to read from.
-        fields(list<str>): A list of fields to copy over. Only shallow refs are
+        dictionary: The dictionary to read from.
+        fields: A list of fields to copy over. Only shallow refs are
         taken.
     Returns:
-        dict: A new dictionary with only the given fields. If fields was empty,
+        A new dictionary with only the given fields. If fields was empty,
         the same dictionary is returned.
     """
     if len(fields) == 0:
@@ -235,17 +249,17 @@ def only_fields(dictionary, fields):
         [f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
     ]
 
-    output = {}
+    output: JsonDict = {}
     for field_array in split_fields:
         _copy_field(dictionary, output, field_array)
     return output
 
 
-def format_event_raw(d):
+def format_event_raw(d: JsonDict) -> JsonDict:
     return d
 
 
-def format_event_for_client_v1(d):
+def format_event_for_client_v1(d: JsonDict) -> JsonDict:
     d = format_event_for_client_v2(d)
 
     sender = d.get("sender")
@@ -267,7 +281,7 @@ def format_event_for_client_v1(d):
     return d
 
 
-def format_event_for_client_v2(d):
+def format_event_for_client_v2(d: JsonDict) -> JsonDict:
     drop_keys = (
         "auth_events",
         "prev_events",
@@ -282,37 +296,37 @@ def format_event_for_client_v2(d):
     return d
 
 
-def format_event_for_client_v2_without_room_id(d):
+def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
     d = format_event_for_client_v2(d)
     d.pop("room_id", None)
     return d
 
 
 def serialize_event(
-    e,
-    time_now_ms,
-    as_client_event=True,
-    event_format=format_event_for_client_v1,
-    token_id=None,
-    only_event_fields=None,
-    include_stripped_room_state=False,
-):
+    e: Union[JsonDict, EventBase],
+    time_now_ms: int,
+    as_client_event: bool = True,
+    event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
+    token_id: Optional[str] = None,
+    only_event_fields: Optional[List[str]] = None,
+    include_stripped_room_state: bool = False,
+) -> JsonDict:
     """Serialize event for clients
 
     Args:
-        e (EventBase)
-        time_now_ms (int)
-        as_client_event (bool)
+        e
+        time_now_ms
+        as_client_event
         event_format
         token_id
         only_event_fields
-        include_stripped_room_state (bool): Some events can have stripped room state
+        include_stripped_room_state: Some events can have stripped room state
             stored in the `unsigned` field. This is required for invite and knock
             functionality. If this option is False, that state will be removed from the
             event before it is returned. Otherwise, it will be kept.
 
     Returns:
-        dict
+        The serialized event dictionary.
     """
 
     # FIXME(erikj): To handle the case of presence events and the like
@@ -369,25 +383,27 @@ class EventClientSerializer:
     clients.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
-        self.experimental_msc1849_support_enabled = (
-            hs.config.server.experimental_msc1849_support_enabled
-        )
+        self._msc1849_enabled = hs.config.experimental.msc1849_enabled
 
     async def serialize_event(
-        self, event, time_now, bundle_aggregations=True, **kwargs
-    ):
+        self,
+        event: Union[JsonDict, EventBase],
+        time_now: int,
+        bundle_aggregations: bool = True,
+        **kwargs: Any,
+    ) -> JsonDict:
         """Serializes a single event.
 
         Args:
-            event (EventBase)
-            time_now (int): The current time in milliseconds
-            bundle_aggregations (bool): Whether to bundle in related events
+            event
+            time_now: The current time in milliseconds
+            bundle_aggregations: Whether to bundle in related events
             **kwargs: Arguments to pass to `serialize_event`
 
         Returns:
-            dict: The serialized event
+            The serialized event
         """
         # To handle the case of presence events and the like
         if not isinstance(event, EventBase):
@@ -400,7 +416,7 @@ class EventClientSerializer:
         # we need to bundle in with the event.
         # Do not bundle relations if the event has been redacted
         if not event.internal_metadata.is_redacted() and (
-            self.experimental_msc1849_support_enabled and bundle_aggregations
+            self._msc1849_enabled and bundle_aggregations
         ):
             annotations = await self.store.get_aggregation_groups_for_event(event_id)
             references = await self.store.get_relations_for_event(
@@ -448,25 +464,27 @@ class EventClientSerializer:
 
         return serialized_event
 
-    def serialize_events(self, events, time_now, **kwargs):
+    async def serialize_events(
+        self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
+    ) -> List[JsonDict]:
         """Serializes multiple events.
 
         Args:
-            event (iter[EventBase])
-            time_now (int): The current time in milliseconds
+            event
+            time_now: The current time in milliseconds
             **kwargs: Arguments to pass to `serialize_event`
 
         Returns:
-            Deferred[list[dict]]: The list of serialized events
+            The list of serialized events
         """
-        return yieldable_gather_results(
+        return await yieldable_gather_results(
             self.serialize_event, events, time_now=time_now, **kwargs
         )
 
 
 def copy_power_levels_contents(
     old_power_levels: Mapping[str, Union[int, Mapping[str, int]]]
-):
+) -> Dict[str, Union[int, Dict[str, int]]]:
     """Copy the content of a power_levels event, unfreezing frozendicts along the way
 
     Raises:
@@ -475,7 +493,7 @@ def copy_power_levels_contents(
     if not isinstance(old_power_levels, collections.abc.Mapping):
         raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
 
-    power_levels = {}
+    power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
     for k, v in old_power_levels.items():
 
         if isinstance(v, int):
@@ -483,7 +501,8 @@ def copy_power_levels_contents(
             continue
 
         if isinstance(v, collections.abc.Mapping):
-            power_levels[k] = h = {}
+            h: Dict[str, int] = {}
+            power_levels[k] = h
             for k1, v1 in v.items():
                 # we should only have one level of nesting
                 if not isinstance(v1, int):
@@ -498,7 +517,7 @@ def copy_power_levels_contents(
     return power_levels
 
 
-def validate_canonicaljson(value: Any):
+def validate_canonicaljson(value: Any) -> None:
     """
     Ensure that the JSON object is valid according to the rules of canonical JSON.
 
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 6eb6544c4c..4d459c17f1 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import collections.abc
-from typing import Union
+from typing import Iterable, Union
 
 import jsonschema
 
@@ -28,11 +28,11 @@ from synapse.events.utils import (
     validate_canonicaljson,
 )
 from synapse.federation.federation_server import server_matches_acl_event
-from synapse.types import EventID, RoomID, UserID
+from synapse.types import EventID, JsonDict, RoomID, UserID
 
 
 class EventValidator:
-    def validate_new(self, event: EventBase, config: HomeServerConfig):
+    def validate_new(self, event: EventBase, config: HomeServerConfig) -> None:
         """Validates the event has roughly the right format
 
         Args:
@@ -116,7 +116,7 @@ class EventValidator:
                     errcode=Codes.BAD_JSON,
                 )
 
-    def _validate_retention(self, event: EventBase):
+    def _validate_retention(self, event: EventBase) -> None:
         """Checks that an event that defines the retention policy for a room respects the
         format enforced by the spec.
 
@@ -156,7 +156,7 @@ class EventValidator:
                 errcode=Codes.BAD_JSON,
             )
 
-    def validate_builder(self, event: Union[EventBase, EventBuilder]):
+    def validate_builder(self, event: Union[EventBase, EventBuilder]) -> None:
         """Validates that the builder/event has roughly the right format. Only
         checks values that we expect a proto event to have, rather than all the
         fields an event would have
@@ -204,14 +204,14 @@ class EventValidator:
 
             self._ensure_state_event(event)
 
-    def _ensure_strings(self, d, keys):
+    def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None:
         for s in keys:
             if s not in d:
                 raise SynapseError(400, "'%s' not in content" % (s,))
             if not isinstance(d[s], str):
                 raise SynapseError(400, "'%s' not a string type" % (s,))
 
-    def _ensure_state_event(self, event):
+    def _ensure_state_event(self, event: Union[EventBase, EventBuilder]) -> None:
         if not event.is_state():
             raise SynapseError(400, "'%s' must be state events" % (event.type,))
 
@@ -244,7 +244,9 @@ POWER_LEVELS_SCHEMA = {
 }
 
 
-def _create_power_level_validator():
+# This could return something newer than Draft 7, but that's the current "latest"
+# validator.
+def _create_power_level_validator() -> jsonschema.Draft7Validator:
     validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA)
 
     # by default jsonschema does not consider a frozendict to be an object so
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index f4612a5b92..ebe75a9e9b 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -200,46 +200,13 @@ class AuthHandler:
 
         self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
 
-        # we can't use hs.get_module_api() here, because to do so will create an
-        # import loop.
-        #
-        # TODO: refactor this class to separate the lower-level stuff that
-        #   ModuleApi can use from the higher-level stuff that uses ModuleApi, as
-        #   better way to break the loop
-        account_handler = ModuleApi(hs, self)
-
-        self.password_providers = [
-            PasswordProvider.load(module, config, account_handler)
-            for module, config in hs.config.authproviders.password_providers
-        ]
-
-        logger.info("Extra password_providers: %s", self.password_providers)
+        self.password_auth_provider = hs.get_password_auth_provider()
 
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
         self.macaroon_gen = hs.get_macaroon_generator()
         self._password_enabled = hs.config.auth.password_enabled
         self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
 
-        # start out by assuming PASSWORD is enabled; we will remove it later if not.
-        login_types = set()
-        if self._password_localdb_enabled:
-            login_types.add(LoginType.PASSWORD)
-
-        for provider in self.password_providers:
-            login_types.update(provider.get_supported_login_types().keys())
-
-        if not self._password_enabled:
-            login_types.discard(LoginType.PASSWORD)
-
-        # Some clients just pick the first type in the list. In this case, we want
-        # them to use PASSWORD (rather than token or whatever), so we want to make sure
-        # that comes first, where it's present.
-        self._supported_login_types = []
-        if LoginType.PASSWORD in login_types:
-            self._supported_login_types.append(LoginType.PASSWORD)
-            login_types.remove(LoginType.PASSWORD)
-        self._supported_login_types.extend(login_types)
-
         # Ratelimiter for failed auth during UIA. Uses same ratelimit config
         # as per `rc_login.failed_attempts`.
         self._failed_uia_attempts_ratelimiter = Ratelimiter(
@@ -427,11 +394,10 @@ class AuthHandler:
                     ui_auth_types.add(LoginType.PASSWORD)
 
         # also allow auth from password providers
-        for provider in self.password_providers:
-            for t in provider.get_supported_login_types().keys():
-                if t == LoginType.PASSWORD and not self._password_enabled:
-                    continue
-                ui_auth_types.add(t)
+        for t in self.password_auth_provider.get_supported_login_types().keys():
+            if t == LoginType.PASSWORD and not self._password_enabled:
+                continue
+            ui_auth_types.add(t)
 
         # if sso is enabled, allow the user to log in via SSO iff they have a mapping
         # from sso to mxid.
@@ -1038,7 +1004,25 @@ class AuthHandler:
         Returns:
             login types
         """
-        return self._supported_login_types
+        # Load any login types registered by modules
+        # This is stored in the password_auth_provider so this doesn't trigger
+        # any callbacks
+        types = list(self.password_auth_provider.get_supported_login_types().keys())
+
+        # This list should include PASSWORD if (either _password_localdb_enabled is
+        # true or if one of the modules registered it) AND _password_enabled is true
+        # Also:
+        # Some clients just pick the first type in the list. In this case, we want
+        # them to use PASSWORD (rather than token or whatever), so we want to make sure
+        # that comes first, where it's present.
+        if LoginType.PASSWORD in types:
+            types.remove(LoginType.PASSWORD)
+            if self._password_enabled:
+                types.insert(0, LoginType.PASSWORD)
+        elif self._password_localdb_enabled and self._password_enabled:
+            types.insert(0, LoginType.PASSWORD)
+
+        return types
 
     async def validate_login(
         self,
@@ -1217,15 +1201,20 @@ class AuthHandler:
 
         known_login_type = False
 
-        for provider in self.password_providers:
-            supported_login_types = provider.get_supported_login_types()
-            if login_type not in supported_login_types:
-                # this password provider doesn't understand this login type
-                continue
-
+        # Check if login_type matches a type registered by one of the modules
+        # We don't need to remove LoginType.PASSWORD from the list if password login is
+        # disabled, since if that were the case then by this point we know that the
+        # login_type is not LoginType.PASSWORD
+        supported_login_types = self.password_auth_provider.get_supported_login_types()
+        # check if the login type being used is supported by a module
+        if login_type in supported_login_types:
+            # Make a note that this login type is supported by the server
             known_login_type = True
+            # Get all the fields expected for this login types
             login_fields = supported_login_types[login_type]
 
+            # go through the login submission and keep track of which required fields are
+            # provided/not provided
             missing_fields = []
             login_dict = {}
             for f in login_fields:
@@ -1233,6 +1222,7 @@ class AuthHandler:
                     missing_fields.append(f)
                 else:
                     login_dict[f] = login_submission[f]
+            # raise an error if any of the expected fields for that login type weren't provided
             if missing_fields:
                 raise SynapseError(
                     400,
@@ -1240,10 +1230,15 @@ class AuthHandler:
                     % (login_type, missing_fields),
                 )
 
-            result = await provider.check_auth(username, login_type, login_dict)
+            # call all of the check_auth hooks for that login_type
+            # it will return a result once the first success is found (or None otherwise)
+            result = await self.password_auth_provider.check_auth(
+                username, login_type, login_dict
+            )
             if result:
                 return result
 
+        # if no module managed to authenticate the user, then fallback to built in password based auth
         if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
             known_login_type = True
 
@@ -1282,11 +1277,16 @@ class AuthHandler:
             completed login/registration, or `None`. If authentication was
             unsuccessful, `user_id` and `callback` are both `None`.
         """
-        for provider in self.password_providers:
-            result = await provider.check_3pid_auth(medium, address, password)
-            if result:
-                return result
+        # call all of the check_3pid_auth callbacks
+        # Result will be from the first callback that returns something other than None
+        # If all the callbacks return None, then result is also set to None
+        result = await self.password_auth_provider.check_3pid_auth(
+            medium, address, password
+        )
+        if result:
+            return result
 
+        # if result is None then return (None, None)
         return None, None
 
     async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
@@ -1365,13 +1365,12 @@ class AuthHandler:
         user_info = await self.auth.get_user_by_access_token(access_token)
         await self.store.delete_access_token(access_token)
 
-        # see if any of our auth providers want to know about this
-        for provider in self.password_providers:
-            await provider.on_logged_out(
-                user_id=user_info.user_id,
-                device_id=user_info.device_id,
-                access_token=access_token,
-            )
+        # see if any modules want to know about this
+        await self.password_auth_provider.on_logged_out(
+            user_id=user_info.user_id,
+            device_id=user_info.device_id,
+            access_token=access_token,
+        )
 
         # delete pushers associated with this access token
         if user_info.token_id is not None:
@@ -1398,12 +1397,11 @@ class AuthHandler:
             user_id, except_token_id=except_token_id, device_id=device_id
         )
 
-        # see if any of our auth providers want to know about this
-        for provider in self.password_providers:
-            for token, _, device_id in tokens_and_devices:
-                await provider.on_logged_out(
-                    user_id=user_id, device_id=device_id, access_token=token
-                )
+        # see if any modules want to know about this
+        for token, _, device_id in tokens_and_devices:
+            await self.password_auth_provider.on_logged_out(
+                user_id=user_id, device_id=device_id, access_token=token
+            )
 
         # delete pushers associated with the access tokens
         await self.hs.get_pusherpool().remove_pushers_by_access_token(
@@ -1811,40 +1809,228 @@ class MacaroonGenerator:
         return macaroon
 
 
-class PasswordProvider:
-    """Wrapper for a password auth provider module
+def load_legacy_password_auth_providers(hs: "HomeServer") -> None:
+    module_api = hs.get_module_api()
+    for module, config in hs.config.authproviders.password_providers:
+        load_single_legacy_password_auth_provider(
+            module=module, config=config, api=module_api
+        )
 
-    This class abstracts out all of the backwards-compatibility hacks for
-    password providers, to provide a consistent interface.
-    """
 
-    @classmethod
-    def load(
-        cls, module: Type, config: JsonDict, module_api: ModuleApi
-    ) -> "PasswordProvider":
-        try:
-            pp = module(config=config, account_handler=module_api)
-        except Exception as e:
-            logger.error("Error while initializing %r: %s", module, e)
-            raise
-        return cls(pp, module_api)
+def load_single_legacy_password_auth_provider(
+    module: Type, config: JsonDict, api: ModuleApi
+) -> None:
+    try:
+        provider = module(config=config, account_handler=api)
+    except Exception as e:
+        logger.error("Error while initializing %r: %s", module, e)
+        raise
+
+    # The known hooks. If a module implements a method who's name appears in this set
+    # we'll want to register it
+    password_auth_provider_methods = {
+        "check_3pid_auth",
+        "on_logged_out",
+    }
+
+    # All methods that the module provides should be async, but this wasn't enforced
+    # in the old module system, so we wrap them if needed
+    def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
+        # f might be None if the callback isn't implemented by the module. In this
+        # case we don't want to register a callback at all so we return None.
+        if f is None:
+            return None
+
+        # We need to wrap check_password because its old form would return a boolean
+        # but we now want it to behave just like check_auth() and return the matrix id of
+        # the user if authentication succeeded or None otherwise
+        if f.__name__ == "check_password":
+
+            async def wrapped_check_password(
+                username: str, login_type: str, login_dict: JsonDict
+            ) -> Optional[Tuple[str, Optional[Callable]]]:
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                matrix_user_id = api.get_qualified_user_id(username)
+                password = login_dict["password"]
+
+                is_valid = await f(matrix_user_id, password)
+
+                if is_valid:
+                    return matrix_user_id, None
+
+                return None
 
-    def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
-        self._pp = pp
-        self._module_api = module_api
+            return wrapped_check_password
+
+        # We need to wrap check_auth as in the old form it could return
+        # just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
+        if f.__name__ == "check_auth":
+
+            async def wrapped_check_auth(
+                username: str, login_type: str, login_dict: JsonDict
+            ) -> Optional[Tuple[str, Optional[Callable]]]:
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                result = await f(username, login_type, login_dict)
+
+                if isinstance(result, str):
+                    return result, None
+
+                return result
+
+            return wrapped_check_auth
+
+        # We need to wrap check_3pid_auth as in the old form it could return
+        # just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
+        if f.__name__ == "check_3pid_auth":
+
+            async def wrapped_check_3pid_auth(
+                medium: str, address: str, password: str
+            ) -> Optional[Tuple[str, Optional[Callable]]]:
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                result = await f(medium, address, password)
+
+                if isinstance(result, str):
+                    return result, None
+
+                return result
 
-        self._supported_login_types = {}
+            return wrapped_check_3pid_auth
 
-        # grandfather in check_password support
-        if hasattr(self._pp, "check_password"):
-            self._supported_login_types[LoginType.PASSWORD] = ("password",)
+        def run(*args: Tuple, **kwargs: Dict) -> Awaitable:
+            # mypy doesn't do well across function boundaries so we need to tell it
+            # f is definitely not None.
+            assert f is not None
 
-        g = getattr(self._pp, "get_supported_login_types", None)
-        if g:
-            self._supported_login_types.update(g())
+            return maybe_awaitable(f(*args, **kwargs))
 
-    def __str__(self) -> str:
-        return str(self._pp)
+        return run
+
+    # populate hooks with the implemented methods, wrapped with async_wrapper
+    hooks = {
+        hook: async_wrapper(getattr(provider, hook, None))
+        for hook in password_auth_provider_methods
+    }
+
+    supported_login_types = {}
+    # call get_supported_login_types and add that to the dict
+    g = getattr(provider, "get_supported_login_types", None)
+    if g is not None:
+        # Note the old module style also called get_supported_login_types at loading time
+        # and it is synchronous
+        supported_login_types.update(g())
+
+    auth_checkers = {}
+    # Legacy modules have a check_auth method which expects to be called with one of
+    # the keys returned by get_supported_login_types. New style modules register a
+    # dictionary of login_type->check_auth_method mappings
+    check_auth = async_wrapper(getattr(provider, "check_auth", None))
+    if check_auth is not None:
+        for login_type, fields in supported_login_types.items():
+            # need tuple(fields) since fields can be any Iterable type (so may not be hashable)
+            auth_checkers[(login_type, tuple(fields))] = check_auth
+
+    # if it has a "check_password" method then it should handle all auth checks
+    # with login type of LoginType.PASSWORD
+    check_password = async_wrapper(getattr(provider, "check_password", None))
+    if check_password is not None:
+        # need to use a tuple here for ("password",) not a list since lists aren't hashable
+        auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password
+
+    api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers)
+
+
+CHECK_3PID_AUTH_CALLBACK = Callable[
+    [str, str, str],
+    Awaitable[
+        Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
+    ],
+]
+ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
+CHECK_AUTH_CALLBACK = Callable[
+    [str, str, JsonDict],
+    Awaitable[
+        Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
+    ],
+]
+
+
+class PasswordAuthProvider:
+    """
+    A class that the AuthHandler calls when authenticating users
+    It allows modules to provide alternative methods for authentication
+    """
+
+    def __init__(self) -> None:
+        # lists of callbacks
+        self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
+        self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
+
+        # Mapping from login type to login parameters
+        self._supported_login_types: Dict[str, Iterable[str]] = {}
+
+        # Mapping from login type to auth checker callbacks
+        self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
+
+    def register_password_auth_provider_callbacks(
+        self,
+        check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
+        on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
+        auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None,
+    ) -> None:
+        # Register check_3pid_auth callback
+        if check_3pid_auth is not None:
+            self.check_3pid_auth_callbacks.append(check_3pid_auth)
+
+        # register on_logged_out callback
+        if on_logged_out is not None:
+            self.on_logged_out_callbacks.append(on_logged_out)
+
+        if auth_checkers is not None:
+            # register a new supported login_type
+            # Iterate through all of the types being registered
+            for (login_type, fields), callback in auth_checkers.items():
+                # Note: fields may be empty here. This would allow a modules auth checker to
+                # be called with just 'login_type' and no password or other secrets
+
+                # Need to check that all the field names are strings or may get nasty errors later
+                for f in fields:
+                    if not isinstance(f, str):
+                        raise RuntimeError(
+                            "A module tried to register support for login type: %s with parameters %s"
+                            " but all parameter names must be strings"
+                            % (login_type, fields)
+                        )
+
+                # 2 modules supporting the same login type must expect the same fields
+                # e.g. 1 can't expect "pass" if the other expects "password"
+                # so throw an exception if that happens
+                if login_type not in self._supported_login_types.get(login_type, []):
+                    self._supported_login_types[login_type] = fields
+                else:
+                    fields_currently_supported = self._supported_login_types.get(
+                        login_type
+                    )
+                    if fields_currently_supported != fields:
+                        raise RuntimeError(
+                            "A module tried to register support for login type: %s with parameters %s"
+                            " but another module had already registered support for that type with parameters %s"
+                            % (login_type, fields, fields_currently_supported)
+                        )
+
+                # Add the new method to the list of auth_checker_callbacks for this login type
+                self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
 
     def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
         """Get the login types supported by this password provider
@@ -1852,20 +2038,15 @@ class PasswordProvider:
         Returns a map from a login type identifier (such as m.login.password) to an
         iterable giving the fields which must be provided by the user in the submission
         to the /login API.
-
-        This wrapper adds m.login.password to the list if the underlying password
-        provider supports the check_password() api.
         """
+
         return self._supported_login_types
 
     async def check_auth(
         self, username: str, login_type: str, login_dict: JsonDict
-    ) -> Optional[Tuple[str, Optional[Callable]]]:
+    ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
         """Check if the user has presented valid login credentials
 
-        This wrapper also calls check_password() if the underlying password provider
-        supports the check_password() api and the login type is m.login.password.
-
         Args:
             username: user id presented by the client. Either an MXID or an unqualified
                 username.
@@ -1879,63 +2060,130 @@ class PasswordProvider:
             user, and `callback` is an optional callback which will be called with the
             result from the /login call (including access_token, device_id, etc.)
         """
-        # first grandfather in a call to check_password
-        if login_type == LoginType.PASSWORD:
-            check_password = getattr(self._pp, "check_password", None)
-            if check_password:
-                qualified_user_id = self._module_api.get_qualified_user_id(username)
-                is_valid = await check_password(
-                    qualified_user_id, login_dict["password"]
-                )
-                if is_valid:
-                    return qualified_user_id, None
 
-        check_auth = getattr(self._pp, "check_auth", None)
-        if not check_auth:
-            return None
-        result = await check_auth(username, login_type, login_dict)
+        # Go through all callbacks for the login type until one returns with a value
+        # other than None (i.e. until a callback returns a success)
+        for callback in self.auth_checker_callbacks[login_type]:
+            try:
+                result = await callback(username, login_type, login_dict)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue
 
-        # Check if the return value is a str or a tuple
-        if isinstance(result, str):
-            # If it's a str, set callback function to None
-            return result, None
+            if result is not None:
+                # Check that the callback returned a Tuple[str, Optional[Callable]]
+                # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
+                # result is always the right type, but as it is 3rd party code it might not be
+
+                if not isinstance(result, tuple) or len(result) != 2:
+                    logger.warning(
+                        "Wrong type returned by module API callback %s: %s, expected"
+                        " Optional[Tuple[str, Optional[Callable]]]",
+                        callback,
+                        result,
+                    )
+                    continue
 
-        return result
+                # pull out the two parts of the tuple so we can do type checking
+                str_result, callback_result = result
+
+                # the 1st item in the tuple should be a str
+                if not isinstance(str_result, str):
+                    logger.warning(  # type: ignore[unreachable]
+                        "Wrong type returned by module API callback %s: %s, expected"
+                        " Optional[Tuple[str, Optional[Callable]]]",
+                        callback,
+                        result,
+                    )
+                    continue
+
+                # the second should be Optional[Callable]
+                if callback_result is not None:
+                    if not callable(callback_result):
+                        logger.warning(  # type: ignore[unreachable]
+                            "Wrong type returned by module API callback %s: %s, expected"
+                            " Optional[Tuple[str, Optional[Callable]]]",
+                            callback,
+                            result,
+                        )
+                        continue
+
+                # The result is a (str, Optional[callback]) tuple so return the successful result
+                return result
+
+        # If this point has been reached then none of the callbacks successfully authenticated
+        # the user so return None
+        return None
 
     async def check_3pid_auth(
         self, medium: str, address: str, password: str
-    ) -> Optional[Tuple[str, Optional[Callable]]]:
-        g = getattr(self._pp, "check_3pid_auth", None)
-        if not g:
-            return None
-
+    ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
         # This function is able to return a deferred that either
         # resolves None, meaning authentication failure, or upon
         # success, to a str (which is the user_id) or a tuple of
         # (user_id, callback_func), where callback_func should be run
         # after we've finished everything else
-        result = await g(medium, address, password)
 
-        # Check if the return value is a str or a tuple
-        if isinstance(result, str):
-            # If it's a str, set callback function to None
-            return result, None
+        for callback in self.check_3pid_auth_callbacks:
+            try:
+                result = await callback(medium, address, password)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue
 
-        return result
+            if result is not None:
+                # Check that the callback returned a Tuple[str, Optional[Callable]]
+                # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
+                # result is always the right type, but as it is 3rd party code it might not be
+
+                if not isinstance(result, tuple) or len(result) != 2:
+                    logger.warning(
+                        "Wrong type returned by module API callback %s: %s, expected"
+                        " Optional[Tuple[str, Optional[Callable]]]",
+                        callback,
+                        result,
+                    )
+                    continue
+
+                # pull out the two parts of the tuple so we can do type checking
+                str_result, callback_result = result
+
+                # the 1st item in the tuple should be a str
+                if not isinstance(str_result, str):
+                    logger.warning(  # type: ignore[unreachable]
+                        "Wrong type returned by module API callback %s: %s, expected"
+                        " Optional[Tuple[str, Optional[Callable]]]",
+                        callback,
+                        result,
+                    )
+                    continue
+
+                # the second should be Optional[Callable]
+                if callback_result is not None:
+                    if not callable(callback_result):
+                        logger.warning(  # type: ignore[unreachable]
+                            "Wrong type returned by module API callback %s: %s, expected"
+                            " Optional[Tuple[str, Optional[Callable]]]",
+                            callback,
+                            result,
+                        )
+                        continue
+
+                # The result is a (str, Optional[callback]) tuple so return the successful result
+                return result
+
+        # If this point has been reached then none of the callbacks successfully authenticated
+        # the user so return None
+        return None
 
     async def on_logged_out(
         self, user_id: str, device_id: Optional[str], access_token: str
     ) -> None:
-        g = getattr(self._pp, "on_logged_out", None)
-        if not g:
-            return
 
-        # This might return an awaitable, if it does block the log out
-        # until it completes.
-        await maybe_awaitable(
-            g(
-                user_id=user_id,
-                device_id=device_id,
-                access_token=access_token,
-            )
-        )
+        # call all of the on_logged_out callbacks
+        for callback in self.on_logged_out_callbacks:
+            try:
+                callback(user_id, device_id, access_token)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 75e6019760..6eafbea25d 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -14,7 +14,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Set,
+    Tuple,
+)
 
 from synapse.api import errors
 from synapse.api.constants import EventTypes
@@ -595,7 +606,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
 
 def _update_device_from_client_ips(
-    device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict]
+    device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
 ) -> None:
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
     device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index d089c56286..365063ebdf 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -55,8 +55,7 @@ class EventAuthHandler:
         """Check an event passes the auth rules at its own auth events"""
         auth_event_ids = event.auth_event_ids()
         auth_events_by_id = await self._store.get_events(auth_event_ids)
-        auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
-        check_auth_rules_for_event(room_version_obj, event, auth_events)
+        check_auth_rules_for_event(room_version_obj, event, auth_events_by_id.values())
 
     def compute_auth_events(
         self,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 3e341bd287..3112cc88b1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -15,7 +15,6 @@
 
 """Contains handlers for federation events."""
 
-import itertools
 import logging
 from http import HTTPStatus
 from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
@@ -27,12 +26,7 @@ from unpaddedbase64 import decode_base64
 from twisted.internet import defer
 
 from synapse import event_auth
-from synapse.api.constants import (
-    EventContentFields,
-    EventTypes,
-    Membership,
-    RejectedReason,
-)
+from synapse.api.constants import EventContentFields, EventTypes, Membership
 from synapse.api.errors import (
     AuthError,
     CodeMessageException,
@@ -43,12 +37,9 @@ from synapse.api.errors import (
     RequestSendFailed,
     SynapseError,
 )
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.crypto.event_signing import compute_event_signature
-from synapse.event_auth import (
-    check_auth_rules_for_event,
-    validate_event_for_room_version,
-)
+from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
@@ -238,18 +229,10 @@ class FederationHandler:
             )
             return False
 
-        logger.debug(
-            "room_id: %s, backfill: current_depth: %s, max_depth: %s, extrems: %s",
-            room_id,
-            current_depth,
-            max_depth,
-            sorted_extremeties_tuple,
-        )
-
         # We ignore extremities that have a greater depth than our current depth
         # as:
         #    1. we don't really care about getting events that have happened
-        #       before our current position; and
+        #       after our current position; and
         #    2. we have likely previously tried and failed to backfill from that
         #       extremity, so to avoid getting "stuck" requesting the same
         #       backfill repeatedly we drop those extremities.
@@ -257,9 +240,19 @@ class FederationHandler:
             t for t in sorted_extremeties_tuple if int(t[1]) <= current_depth
         ]
 
+        logger.debug(
+            "room_id: %s, backfill: current_depth: %s, limit: %s, max_depth: %s, extrems: %s filtered_sorted_extremeties_tuple: %s",
+            room_id,
+            current_depth,
+            limit,
+            max_depth,
+            sorted_extremeties_tuple,
+            filtered_sorted_extremeties_tuple,
+        )
+
         # However, we need to check that the filtered extremities are non-empty.
         # If they are empty then either we can a) bail or b) still attempt to
-        # backill. We opt to try backfilling anyway just in case we do get
+        # backfill. We opt to try backfilling anyway just in case we do get
         # relevant events.
         if filtered_sorted_extremeties_tuple:
             sorted_extremeties_tuple = filtered_sorted_extremeties_tuple
@@ -389,7 +382,7 @@ class FederationHandler:
             for key, state_dict in states.items()
         }
 
-        for e_id, _ in sorted_extremeties_tuple:
+        for e_id in event_ids:
             likely_extremeties_domains = get_domains_from_state(states[e_id])
 
             success = await try_backfill(
@@ -517,7 +510,7 @@ class FederationHandler:
                 auth_events=auth_chain,
             )
 
-            max_stream_id = await self._persist_auth_tree(
+            max_stream_id = await self._federation_event_handler.process_remote_join(
                 origin, room_id, auth_chain, state, event, room_version_obj
             )
 
@@ -1093,119 +1086,6 @@ class FederationHandler:
         else:
             return None
 
-    async def _persist_auth_tree(
-        self,
-        origin: str,
-        room_id: str,
-        auth_events: List[EventBase],
-        state: List[EventBase],
-        event: EventBase,
-        room_version: RoomVersion,
-    ) -> int:
-        """Checks the auth chain is valid (and passes auth checks) for the
-        state and event. Then persists the auth chain and state atomically.
-        Persists the event separately. Notifies about the persisted events
-        where appropriate.
-
-        Will attempt to fetch missing auth events.
-
-        Args:
-            origin: Where the events came from
-            room_id,
-            auth_events
-            state
-            event
-            room_version: The room version we expect this room to have, and
-                will raise if it doesn't match the version in the create event.
-        """
-        events_to_context = {}
-        for e in itertools.chain(auth_events, state):
-            e.internal_metadata.outlier = True
-            events_to_context[e.event_id] = EventContext.for_outlier()
-
-        event_map = {
-            e.event_id: e for e in itertools.chain(auth_events, state, [event])
-        }
-
-        create_event = None
-        for e in auth_events:
-            if (e.type, e.state_key) == (EventTypes.Create, ""):
-                create_event = e
-                break
-
-        if create_event is None:
-            # If the state doesn't have a create event then the room is
-            # invalid, and it would fail auth checks anyway.
-            raise SynapseError(400, "No create event in state")
-
-        room_version_id = create_event.content.get(
-            "room_version", RoomVersions.V1.identifier
-        )
-
-        if room_version.identifier != room_version_id:
-            raise SynapseError(400, "Room version mismatch")
-
-        missing_auth_events = set()
-        for e in itertools.chain(auth_events, state, [event]):
-            for e_id in e.auth_event_ids():
-                if e_id not in event_map:
-                    missing_auth_events.add(e_id)
-
-        for e_id in missing_auth_events:
-            m_ev = await self.federation_client.get_pdu(
-                [origin],
-                e_id,
-                room_version=room_version,
-                outlier=True,
-                timeout=10000,
-            )
-            if m_ev and m_ev.event_id == e_id:
-                event_map[e_id] = m_ev
-            else:
-                logger.info("Failed to find auth event %r", e_id)
-
-        for e in itertools.chain(auth_events, state, [event]):
-            auth_for_e = {
-                (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
-                for e_id in e.auth_event_ids()
-                if e_id in event_map
-            }
-            if create_event:
-                auth_for_e[(EventTypes.Create, "")] = create_event
-
-            try:
-                validate_event_for_room_version(room_version, e)
-                check_auth_rules_for_event(room_version, e, auth_for_e)
-            except SynapseError as err:
-                # we may get SynapseErrors here as well as AuthErrors. For
-                # instance, there are a couple of (ancient) events in some
-                # rooms whose senders do not have the correct sigil; these
-                # cause SynapseErrors in auth.check. We don't want to give up
-                # the attempt to federate altogether in such cases.
-
-                logger.warning("Rejecting %s because %s", e.event_id, err.msg)
-
-                if e == event:
-                    raise
-                events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
-
-        if auth_events or state:
-            await self._federation_event_handler.persist_events_and_notify(
-                room_id,
-                [
-                    (e, events_to_context[e.event_id])
-                    for e in itertools.chain(auth_events, state)
-                ],
-            )
-
-        new_event_context = await self.state_handler.compute_event_context(
-            event, old_state=state
-        )
-
-        return await self._federation_event_handler.persist_events_and_notify(
-            room_id, [(event, new_event_context)]
-        )
-
     async def on_get_missing_events(
         self,
         origin: str,
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index f640b417b3..5a2f2e5ebb 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import itertools
 import logging
 from http import HTTPStatus
 from typing import (
@@ -45,7 +46,7 @@ from synapse.api.errors import (
     RequestSendFailed,
     SynapseError,
 )
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions
 from synapse.event_auth import (
     auth_types_for_event,
     check_auth_rules_for_event,
@@ -214,7 +215,7 @@ class FederationEventHandler:
 
         if missing_prevs:
             # We only backfill backwards to the min depth.
-            min_depth = await self.get_min_depth_for_context(pdu.room_id)
+            min_depth = await self._store.get_min_depth(pdu.room_id)
             logger.debug("min_depth: %d", min_depth)
 
             if min_depth is not None and pdu.depth > min_depth:
@@ -390,9 +391,122 @@ class FederationEventHandler:
             prev_member_event,
         )
 
+    async def process_remote_join(
+        self,
+        origin: str,
+        room_id: str,
+        auth_events: List[EventBase],
+        state: List[EventBase],
+        event: EventBase,
+        room_version: RoomVersion,
+    ) -> int:
+        """Persists the events returned by a send_join
+
+        Checks the auth chain is valid (and passes auth checks) for the
+        state and event. Then persists the auth chain and state atomically.
+        Persists the event separately. Notifies about the persisted events
+        where appropriate.
+
+        Will attempt to fetch missing auth events.
+
+        Args:
+            origin: Where the events came from
+            room_id,
+            auth_events
+            state
+            event
+            room_version: The room version we expect this room to have, and
+                will raise if it doesn't match the version in the create event.
+        """
+        events_to_context = {}
+        for e in itertools.chain(auth_events, state):
+            e.internal_metadata.outlier = True
+            events_to_context[e.event_id] = EventContext.for_outlier()
+
+        event_map = {
+            e.event_id: e for e in itertools.chain(auth_events, state, [event])
+        }
+
+        create_event = None
+        for e in auth_events:
+            if (e.type, e.state_key) == (EventTypes.Create, ""):
+                create_event = e
+                break
+
+        if create_event is None:
+            # If the state doesn't have a create event then the room is
+            # invalid, and it would fail auth checks anyway.
+            raise SynapseError(400, "No create event in state")
+
+        room_version_id = create_event.content.get(
+            "room_version", RoomVersions.V1.identifier
+        )
+
+        if room_version.identifier != room_version_id:
+            raise SynapseError(400, "Room version mismatch")
+
+        missing_auth_events = set()
+        for e in itertools.chain(auth_events, state, [event]):
+            for e_id in e.auth_event_ids():
+                if e_id not in event_map:
+                    missing_auth_events.add(e_id)
+
+        for e_id in missing_auth_events:
+            m_ev = await self._federation_client.get_pdu(
+                [origin],
+                e_id,
+                room_version=room_version,
+                outlier=True,
+                timeout=10000,
+            )
+            if m_ev and m_ev.event_id == e_id:
+                event_map[e_id] = m_ev
+            else:
+                logger.info("Failed to find auth event %r", e_id)
+
+        for e in itertools.chain(auth_events, state, [event]):
+            auth_for_e = [
+                event_map[e_id] for e_id in e.auth_event_ids() if e_id in event_map
+            ]
+            if create_event:
+                auth_for_e.append(create_event)
+
+            try:
+                validate_event_for_room_version(room_version, e)
+                check_auth_rules_for_event(room_version, e, auth_for_e)
+            except SynapseError as err:
+                # we may get SynapseErrors here as well as AuthErrors. For
+                # instance, there are a couple of (ancient) events in some
+                # rooms whose senders do not have the correct sigil; these
+                # cause SynapseErrors in auth.check. We don't want to give up
+                # the attempt to federate altogether in such cases.
+
+                logger.warning("Rejecting %s because %s", e.event_id, err.msg)
+
+                if e == event:
+                    raise
+                events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
+
+        if auth_events or state:
+            await self.persist_events_and_notify(
+                room_id,
+                [
+                    (e, events_to_context[e.event_id])
+                    for e in itertools.chain(auth_events, state)
+                ],
+            )
+
+        new_event_context = await self._state_handler.compute_event_context(
+            event, old_state=state
+        )
+
+        return await self.persist_events_and_notify(
+            room_id, [(event, new_event_context)]
+        )
+
     @log_function
     async def backfill(
-        self, dest: str, room_id: str, limit: int, extremities: List[str]
+        self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
     ) -> None:
         """Trigger a backfill request to `dest` for the given `room_id`
 
@@ -1116,14 +1230,12 @@ class FederationEventHandler:
 
         await concurrently_execute(get_event, event_ids, 5)
         logger.info("Fetched %i events of %i requested", len(events), len(event_ids))
-        await self._auth_and_persist_fetched_events(destination, room_id, events)
+        await self._auth_and_persist_outliers(room_id, events)
 
-    async def _auth_and_persist_fetched_events(
-        self, origin: str, room_id: str, events: Iterable[EventBase]
+    async def _auth_and_persist_outliers(
+        self, room_id: str, events: Iterable[EventBase]
     ) -> None:
-        """Persist the events fetched by _get_events_and_persist or _get_remote_auth_chain_for_event
-
-        The events to be persisted must be outliers.
+        """Persist a batch of outlier events fetched from remote servers.
 
         We first sort the events to make sure that we process each event's auth_events
         before the event itself, and then auth and persist them.
@@ -1131,7 +1243,6 @@ class FederationEventHandler:
         Notifies about the events where appropriate.
 
         Params:
-            origin: where the events came from
             room_id: the room that the events are meant to be in (though this has
                not yet been checked)
             events: the events that have been fetched
@@ -1167,15 +1278,15 @@ class FederationEventHandler:
                 shortstr(e.event_id for e in roots),
             )
 
-            await self._auth_and_persist_fetched_events_inner(origin, room_id, roots)
+            await self._auth_and_persist_outliers_inner(room_id, roots)
 
             for ev in roots:
                 del event_map[ev.event_id]
 
-    async def _auth_and_persist_fetched_events_inner(
-        self, origin: str, room_id: str, fetched_events: Collection[EventBase]
+    async def _auth_and_persist_outliers_inner(
+        self, room_id: str, fetched_events: Collection[EventBase]
     ) -> None:
-        """Helper for _auth_and_persist_fetched_events
+        """Helper for _auth_and_persist_outliers
 
         Persists a batch of events where we have (theoretically) already persisted all
         of their auth events.
@@ -1203,7 +1314,7 @@ class FederationEventHandler:
 
         def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
             with nested_logging_context(suffix=event.event_id):
-                auth = {}
+                auth = []
                 for auth_event_id in event.auth_event_ids():
                     ae = persisted_events.get(auth_event_id)
                     if not ae:
@@ -1216,7 +1327,7 @@ class FederationEventHandler:
                         # exist, which means it is premature to reject `event`. Instead we
                         # just ignore it for now.
                         return None
-                    auth[(ae.type, ae.state_key)] = ae
+                    auth.append(ae)
 
                 context = EventContext.for_outlier()
                 try:
@@ -1256,6 +1367,10 @@ class FederationEventHandler:
 
         Returns:
             The updated context object.
+
+        Raises:
+            AuthError if we were unable to find copies of the event's auth events.
+               (Most other failures just cause us to set `context.rejected`.)
         """
         # This method should only be used for non-outliers
         assert not event.internal_metadata.outlier
@@ -1272,7 +1387,26 @@ class FederationEventHandler:
             context.rejected = RejectedReason.AUTH_ERROR
             return context
 
-        # calculate what the auth events *should* be, to use as a basis for auth.
+        # next, check that we have all of the event's auth events.
+        #
+        # Note that this can raise AuthError, which we want to propagate to the
+        # caller rather than swallow with `context.rejected` (since we cannot be
+        # certain that there is a permanent problem with the event).
+        claimed_auth_events = await self._load_or_fetch_auth_events_for_event(
+            origin, event
+        )
+
+        # ... and check that the event passes auth at those auth events.
+        try:
+            check_auth_rules_for_event(room_version_obj, event, claimed_auth_events)
+        except AuthError as e:
+            logger.warning(
+                "While checking auth of %r against auth_events: %s", event, e
+            )
+            context.rejected = RejectedReason.AUTH_ERROR
+            return context
+
+        # now check auth against what we think the auth events *should* be.
         prev_state_ids = await context.get_prev_state_ids()
         auth_events_ids = self._event_auth_handler.compute_auth_events(
             event, prev_state_ids, for_verification=True
@@ -1305,7 +1439,9 @@ class FederationEventHandler:
             auth_events_for_auth = calculated_auth_event_map
 
         try:
-            check_auth_rules_for_event(room_version_obj, event, auth_events_for_auth)
+            check_auth_rules_for_event(
+                room_version_obj, event, auth_events_for_auth.values()
+            )
         except AuthError as e:
             logger.warning("Failed auth resolution for %r because %s", event, e)
             context.rejected = RejectedReason.AUTH_ERROR
@@ -1403,11 +1539,9 @@ class FederationEventHandler:
         current_state_ids_list = [
             e for k, e in current_state_ids.items() if k in auth_types
         ]
-
-        auth_events_map = await self._store.get_events(current_state_ids_list)
-        current_auth_events = {
-            (e.type, e.state_key): e for e in auth_events_map.values()
-        }
+        current_auth_events = await self._store.get_events_as_list(
+            current_state_ids_list
+        )
 
         try:
             check_auth_rules_for_event(room_version_obj, event, current_auth_events)
@@ -1472,6 +1606,9 @@ class FederationEventHandler:
         # if we have missing events, we need to fetch those events from somewhere.
         #
         # we start by checking if they are in the store, and then try calling /event_auth/.
+        #
+        # TODO: this code is now redundant, since it should be impossible for us to
+        #   get here without already having the auth events.
         if missing_auth:
             have_events = await self._store.have_seen_events(
                 event.room_id, missing_auth
@@ -1575,7 +1712,7 @@ class FederationEventHandler:
         logger.info(
             "After state res: updating auth_events with new state %s",
             {
-                (d.type, d.state_key): d.event_id
+                d
                 for d in new_state.values()
                 if auth_events.get((d.type, d.state_key)) != d
             },
@@ -1589,6 +1726,75 @@ class FederationEventHandler:
 
         return context, auth_events
 
+    async def _load_or_fetch_auth_events_for_event(
+        self, destination: str, event: EventBase
+    ) -> Collection[EventBase]:
+        """Fetch this event's auth_events, from database or remote
+
+        Loads any of the auth_events that we already have from the database/cache. If
+        there are any that are missing, calls /event_auth to get the complete auth
+        chain for the event (and then attempts to load the auth_events again).
+
+        If any of the auth_events cannot be found, raises an AuthError. This can happen
+        for a number of reasons; eg: the events don't exist, or we were unable to talk
+        to `destination`, or we couldn't validate the signature on the event (which
+        in turn has multiple potential causes).
+
+        Args:
+            destination: where to send the /event_auth request. Typically the server
+               that sent us `event` in the first place.
+            event: the event whose auth_events we want
+
+        Returns:
+            all of the events in `event.auth_events`, after deduplication
+
+        Raises:
+            AuthError if we were unable to fetch the auth_events for any reason.
+        """
+        event_auth_event_ids = set(event.auth_event_ids())
+        event_auth_events = await self._store.get_events(
+            event_auth_event_ids, allow_rejected=True
+        )
+        missing_auth_event_ids = event_auth_event_ids.difference(
+            event_auth_events.keys()
+        )
+        if not missing_auth_event_ids:
+            return event_auth_events.values()
+
+        logger.info(
+            "Event %s refers to unknown auth events %s: fetching auth chain",
+            event,
+            missing_auth_event_ids,
+        )
+        try:
+            await self._get_remote_auth_chain_for_event(
+                destination, event.room_id, event.event_id
+            )
+        except Exception as e:
+            logger.warning("Failed to get auth chain for %s: %s", event, e)
+            # in this case, it's very likely we still won't have all the auth
+            # events - but we pick that up below.
+
+        # try to fetch the auth events we missed list time.
+        extra_auth_events = await self._store.get_events(
+            missing_auth_event_ids, allow_rejected=True
+        )
+        missing_auth_event_ids.difference_update(extra_auth_events.keys())
+        event_auth_events.update(extra_auth_events)
+        if not missing_auth_event_ids:
+            return event_auth_events.values()
+
+        # we still don't have all the auth events.
+        logger.warning(
+            "Missing auth events for %s: %s",
+            event,
+            shortstr(missing_auth_event_ids),
+        )
+        # the fact we can't find the auth event doesn't mean it doesn't
+        # exist, which means it is premature to store `event` as rejected.
+        # instead we raise an AuthError, which will make the caller ignore it.
+        raise AuthError(code=HTTPStatus.FORBIDDEN, msg="Auth events could not be found")
+
     async def _get_remote_auth_chain_for_event(
         self, destination: str, room_id: str, event_id: str
     ) -> None:
@@ -1624,9 +1830,7 @@ class FederationEventHandler:
         for s in seen_remotes:
             remote_event_map.pop(s, None)
 
-        await self._auth_and_persist_fetched_events(
-            destination, room_id, remote_event_map.values()
-        )
+        await self._auth_and_persist_outliers(room_id, remote_event_map.values())
 
     async def _update_context_for_auth_events(
         self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
@@ -1696,16 +1900,27 @@ class FederationEventHandler:
         # persist_events_and_notify directly.)
         assert not event.internal_metadata.outlier
 
-        try:
-            if (
-                not backfilled
-                and not context.rejected
-                and (await self._store.get_min_depth(event.room_id)) <= event.depth
-            ):
+        if not backfilled and not context.rejected:
+            min_depth = await self._store.get_min_depth(event.room_id)
+            if min_depth is None or min_depth > event.depth:
+                # XXX richvdh 2021/10/07: I don't really understand what this
+                # condition is doing. I think it's trying not to send pushes
+                # for events that predate our join - but that's not really what
+                # min_depth means, and anyway ancient events are a more general
+                # problem.
+                #
+                # for now I'm just going to log about it.
+                logger.info(
+                    "Skipping push actions for old event with depth %s < %s",
+                    event.depth,
+                    min_depth,
+                )
+            else:
                 await self._action_generator.handle_push_actions_for_event(
                     event, context
                 )
 
+        try:
             await self.persist_events_and_notify(
                 event.room_id, [(event, context)], backfilled=backfilled
             )
@@ -1837,6 +2052,3 @@ class FederationEventHandler:
                 len(ev.auth_event_ids()),
             )
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
-
-    async def get_min_depth_for_context(self, context: str) -> int:
-        return await self._store.get_min_depth(context)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 4de9f4b828..2e024b551f 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -607,29 +607,6 @@ class EventCreationHandler:
 
         builder.internal_metadata.historical = historical
 
-        # Strip down the auth_event_ids to only what we need to auth the event.
-        # For example, we don't need extra m.room.member that don't match event.sender
-        if auth_event_ids is not None:
-            # If auth events are provided, prev events must be also.
-            assert prev_event_ids is not None
-
-            temp_event = await builder.build(
-                prev_event_ids=prev_event_ids,
-                auth_event_ids=auth_event_ids,
-                depth=depth,
-            )
-            auth_events = await self.store.get_events_as_list(auth_event_ids)
-            # Create a StateMap[str]
-            auth_event_state_map = {
-                (e.type, e.state_key): e.event_id for e in auth_events
-            }
-            # Actually strip down and use the necessary auth events
-            auth_event_ids = self._event_auth_handler.compute_auth_events(
-                event=temp_event,
-                current_state_ids=auth_event_state_map,
-                for_verification=False,
-            )
-
         event, context = await self.create_new_client_event(
             builder=builder,
             requester=requester,
@@ -936,6 +913,33 @@ class EventCreationHandler:
             Tuple of created event, context
         """
 
+        # Strip down the auth_event_ids to only what we need to auth the event.
+        # For example, we don't need extra m.room.member that don't match event.sender
+        full_state_ids_at_event = None
+        if auth_event_ids is not None:
+            # If auth events are provided, prev events must be also.
+            assert prev_event_ids is not None
+
+            # Copy the full auth state before it stripped down
+            full_state_ids_at_event = auth_event_ids.copy()
+
+            temp_event = await builder.build(
+                prev_event_ids=prev_event_ids,
+                auth_event_ids=auth_event_ids,
+                depth=depth,
+            )
+            auth_events = await self.store.get_events_as_list(auth_event_ids)
+            # Create a StateMap[str]
+            auth_event_state_map = {
+                (e.type, e.state_key): e.event_id for e in auth_events
+            }
+            # Actually strip down and use the necessary auth events
+            auth_event_ids = self._event_auth_handler.compute_auth_events(
+                event=temp_event,
+                current_state_ids=auth_event_state_map,
+                for_verification=False,
+            )
+
         if prev_event_ids is not None:
             assert (
                 len(prev_event_ids) <= 10
@@ -965,6 +969,13 @@ class EventCreationHandler:
         if builder.internal_metadata.outlier:
             event.internal_metadata.outlier = True
             context = EventContext.for_outlier()
+        elif (
+            event.type == EventTypes.MSC2716_INSERTION
+            and full_state_ids_at_event
+            and builder.internal_metadata.is_historical()
+        ):
+            old_state = await self.store.get_events_as_list(full_state_ids_at_event)
+            context = await self.state.compute_event_context(event, old_state=old_state)
         else:
             context = await self.state.compute_event_context(event)
 
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 176e4dfdd4..60ff896386 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -86,19 +86,22 @@ class PaginationHandler:
         self._event_serializer = hs.get_event_client_serializer()
 
         self._retention_default_max_lifetime = (
-            hs.config.server.retention_default_max_lifetime
+            hs.config.retention.retention_default_max_lifetime
         )
 
         self._retention_allowed_lifetime_min = (
-            hs.config.server.retention_allowed_lifetime_min
+            hs.config.retention.retention_allowed_lifetime_min
         )
         self._retention_allowed_lifetime_max = (
-            hs.config.server.retention_allowed_lifetime_max
+            hs.config.retention.retention_allowed_lifetime_max
         )
 
-        if hs.config.worker.run_background_tasks and hs.config.server.retention_enabled:
+        if (
+            hs.config.worker.run_background_tasks
+            and hs.config.retention.retention_enabled
+        ):
             # Run the purge jobs described in the configuration file.
-            for job in hs.config.server.retention_purge_jobs:
+            for job in hs.config.retention.retention_purge_jobs:
                 logger.info("Setting up purge job with config: %s", job)
 
                 self.clock.looping_call(
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 404afb9402..b5968e047b 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1489,7 +1489,7 @@ def format_user_presence_state(
     The "user_id" is optional so that this function can be used to format presence
     updates for client /sync responses and for federation /send requests.
     """
-    content = {"presence": state.state}
+    content: JsonDict = {"presence": state.state}
     if include_user_id:
         content["user_id"] = state.user_id
     if state.last_active_ts:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 7072bca1fc..6f39e9446f 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -465,17 +465,35 @@ class RoomCreationHandler:
         # the room has been created
         # Calculate the minimum power level needed to clone the room
         event_power_levels = power_levels.get("events", {})
+        if not isinstance(event_power_levels, dict):
+            event_power_levels = {}
         state_default = power_levels.get("state_default", 50)
+        try:
+            state_default_int = int(state_default)  # type: ignore[arg-type]
+        except (TypeError, ValueError):
+            state_default_int = 50
         ban = power_levels.get("ban", 50)
-        needed_power_level = max(state_default, ban, max(event_power_levels.values()))
+        try:
+            ban = int(ban)  # type: ignore[arg-type]
+        except (TypeError, ValueError):
+            ban = 50
+        needed_power_level = max(
+            state_default_int, ban, max(event_power_levels.values())
+        )
 
         # Get the user's current power level, this matches the logic in get_user_power_level,
         # but without the entire state map.
         user_power_levels = power_levels.setdefault("users", {})
+        if not isinstance(user_power_levels, dict):
+            user_power_levels = {}
         users_default = power_levels.get("users_default", 0)
         current_power_level = user_power_levels.get(user_id, users_default)
+        try:
+            current_power_level_int = int(current_power_level)  # type: ignore[arg-type]
+        except (TypeError, ValueError):
+            current_power_level_int = 0
         # Raise the requester's power level in the new room if necessary
-        if current_power_level < needed_power_level:
+        if current_power_level_int < needed_power_level:
             user_power_levels[user_id] = needed_power_level
 
         await self._send_events_for_new_room(
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 51dd4e7555..2f5a3e4d19 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -13,6 +13,10 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+def generate_fake_event_id() -> str:
+    return "$fake_" + random_string(43)
+
+
 class RoomBatchHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
@@ -177,6 +181,11 @@ class RoomBatchHandler:
 
         state_event_ids_at_start = []
         auth_event_ids = initial_auth_event_ids.copy()
+
+        # Make the state events float off on their own so we don't have a
+        # bunch of `@mxid joined the room` noise between each batch
+        prev_event_id_for_state_chain = generate_fake_event_id()
+
         for state_event in state_events_at_start:
             assert_params_in_dict(
                 state_event, ["type", "origin_server_ts", "content", "sender"]
@@ -200,10 +209,6 @@ class RoomBatchHandler:
             # Mark all events as historical
             event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
 
-            # Make the state events float off on their own so we don't have a
-            # bunch of `@mxid joined the room` noise between each batch
-            fake_prev_event_id = "$" + random_string(43)
-
             # TODO: This is pretty much the same as some other code to handle inserting state in this file
             if event_dict["type"] == EventTypes.Member:
                 membership = event_dict["content"].get("membership", None)
@@ -216,7 +221,7 @@ class RoomBatchHandler:
                     action=membership,
                     content=event_dict["content"],
                     outlier=True,
-                    prev_event_ids=[fake_prev_event_id],
+                    prev_event_ids=[prev_event_id_for_state_chain],
                     # Make sure to use a copy of this list because we modify it
                     # later in the loop here. Otherwise it will be the same
                     # reference and also update in the event when we append later.
@@ -235,7 +240,7 @@ class RoomBatchHandler:
                     ),
                     event_dict,
                     outlier=True,
-                    prev_event_ids=[fake_prev_event_id],
+                    prev_event_ids=[prev_event_id_for_state_chain],
                     # Make sure to use a copy of this list because we modify it
                     # later in the loop here. Otherwise it will be the same
                     # reference and also update in the event when we append later.
@@ -245,6 +250,8 @@ class RoomBatchHandler:
 
             state_event_ids_at_start.append(event_id)
             auth_event_ids.append(event_id)
+            # Connect all the state in a floating chain
+            prev_event_id_for_state_chain = event_id
 
         return state_event_ids_at_start
 
@@ -289,6 +296,10 @@ class RoomBatchHandler:
         for ev in events_to_create:
             assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"])
 
+            assert self.hs.is_mine_id(ev["sender"]), "User must be our own: %s" % (
+                ev["sender"],
+            )
+
             event_dict = {
                 "type": ev["type"],
                 "origin_server_ts": ev["origin_server_ts"],
@@ -311,6 +322,19 @@ class RoomBatchHandler:
                 historical=True,
                 depth=inherited_depth,
             )
+
+            assert context._state_group
+
+            # Normally this is done when persisting the event but we have to
+            # pre-emptively do it here because we create all the events first,
+            # then persist them in another pass below. And we want to share
+            # state_groups across the whole batch so this lookup needs to work
+            # for the next event in the batch in this loop.
+            await self.store.store_state_group_id_for_event_id(
+                event_id=event.event_id,
+                state_group_id=context._state_group,
+            )
+
             logger.debug(
                 "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s",
                 event,
@@ -318,10 +342,6 @@ class RoomBatchHandler:
                 auth_event_ids,
             )
 
-            assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
-                event.sender,
-            )
-
             events_to_persist.append((event, context))
             event_id = event.event_id
 
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 8810f048ba..991fee7e58 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -196,63 +196,12 @@ class UserDirectoryHandler(StateDeltasHandler):
                     room_id, prev_event_id, event_id, typ
                 )
             elif typ == EventTypes.Member:
-                change = await self._get_key_change(
+                await self._handle_room_membership_event(
+                    room_id,
                     prev_event_id,
                     event_id,
-                    key_name="membership",
-                    public_value=Membership.JOIN,
+                    state_key,
                 )
-
-                is_remote = not self.is_mine_id(state_key)
-                if change is MatchChange.now_false:
-                    # Need to check if the server left the room entirely, if so
-                    # we might need to remove all the users in that room
-                    is_in_room = await self.store.is_host_joined(
-                        room_id, self.server_name
-                    )
-                    if not is_in_room:
-                        logger.debug("Server left room: %r", room_id)
-                        # Fetch all the users that we marked as being in user
-                        # directory due to being in the room and then check if
-                        # need to remove those users or not
-                        user_ids = await self.store.get_users_in_dir_due_to_room(
-                            room_id
-                        )
-
-                        for user_id in user_ids:
-                            await self._handle_remove_user(room_id, user_id)
-                        continue
-                    else:
-                        logger.debug("Server is still in room: %r", room_id)
-
-                include_in_dir = (
-                    is_remote
-                    or await self.store.should_include_local_user_in_dir(state_key)
-                )
-                if include_in_dir:
-                    if change is MatchChange.no_change:
-                        # Handle any profile changes for remote users.
-                        # (For local users we are not forced to scan membership
-                        # events; instead the rest of the application calls
-                        # `handle_local_profile_change`.)
-                        if is_remote:
-                            await self._handle_profile_change(
-                                state_key, room_id, prev_event_id, event_id
-                            )
-                        continue
-
-                    if change is MatchChange.now_true:  # The user joined
-                        # This may be the first time we've seen a remote user. If
-                        # so, ensure we have a directory entry for them. (We don't
-                        # need to do this for local users: their directory entry
-                        # is created at the point of registration.
-                        if is_remote:
-                            await self._upsert_directory_entry_for_remote_user(
-                                state_key, event_id
-                            )
-                        await self._track_user_joined_room(room_id, state_key)
-                    else:  # The user left
-                        await self._handle_remove_user(room_id, state_key)
             else:
                 logger.debug("Ignoring irrelevant type: %r", typ)
 
@@ -317,14 +266,83 @@ class UserDirectoryHandler(StateDeltasHandler):
         for user_id in users_in_room:
             await self.store.remove_user_who_share_room(user_id, room_id)
 
-        # Then, re-add them to the tables.
+        # Then, re-add all remote users and some local users to the tables.
         # NOTE: this is not the most efficient method, as _track_user_joined_room sets
         # up local_user -> other_user and other_user_whos_local -> local_user,
         # which when ran over an entire room, will result in the same values
         # being added multiple times. The batching upserts shouldn't make this
         # too bad, though.
         for user_id in users_in_room:
-            await self._track_user_joined_room(room_id, user_id)
+            if not self.is_mine_id(
+                user_id
+            ) or await self.store.should_include_local_user_in_dir(user_id):
+                await self._track_user_joined_room(room_id, user_id)
+
+    async def _handle_room_membership_event(
+        self,
+        room_id: str,
+        prev_event_id: str,
+        event_id: str,
+        state_key: str,
+    ) -> None:
+        """Process a single room membershp event.
+
+        We have to do two things:
+
+        1. Update the room-sharing tables.
+           This applies to remote users and non-excluded local users.
+        2. Update the user_directory and user_directory_search tables.
+           This applies to remote users only, because we only become aware of
+           the (and any profile changes) by listening to these events.
+           The rest of the application knows exactly when local users are
+           created or their profile changed---it will directly call methods
+           on this class.
+        """
+        joined = await self._get_key_change(
+            prev_event_id,
+            event_id,
+            key_name="membership",
+            public_value=Membership.JOIN,
+        )
+
+        # Both cases ignore excluded local users, so start by discarding them.
+        is_remote = not self.is_mine_id(state_key)
+        if not is_remote and not await self.store.should_include_local_user_in_dir(
+            state_key
+        ):
+            return
+
+        if joined is MatchChange.now_false:
+            # Need to check if the server left the room entirely, if so
+            # we might need to remove all the users in that room
+            is_in_room = await self.store.is_host_joined(room_id, self.server_name)
+            if not is_in_room:
+                logger.debug("Server left room: %r", room_id)
+                # Fetch all the users that we marked as being in user
+                # directory due to being in the room and then check if
+                # need to remove those users or not
+                user_ids = await self.store.get_users_in_dir_due_to_room(room_id)
+
+                for user_id in user_ids:
+                    await self._handle_remove_user(room_id, user_id)
+            else:
+                logger.debug("Server is still in room: %r", room_id)
+                await self._handle_remove_user(room_id, state_key)
+        elif joined is MatchChange.no_change:
+            # Handle any profile changes for remote users.
+            # (For local users the rest of the application calls
+            # `handle_local_profile_change`.)
+            if is_remote:
+                await self._handle_possible_remote_profile_change(
+                    state_key, room_id, prev_event_id, event_id
+                )
+        elif joined is MatchChange.now_true:  # The user joined
+            # This may be the first time we've seen a remote user. If
+            # so, ensure we have a directory entry for them. (For local users,
+            # the rest of the application calls `handle_local_profile_change`.)
+            if is_remote:
+                await self._upsert_directory_entry_for_remote_user(state_key, event_id)
+            await self._track_user_joined_room(room_id, state_key)
 
     async def _upsert_directory_entry_for_remote_user(
         self, user_id: str, event_id: str
@@ -349,8 +367,8 @@ class UserDirectoryHandler(StateDeltasHandler):
         """Someone's just joined a room. Update `users_in_public_rooms` or
         `users_who_share_private_rooms` as appropriate.
 
-        The caller is responsible for ensuring that the given user is not excluded
-        from the user directory.
+        The caller is responsible for ensuring that the given user should be
+        included in the user directory.
         """
         is_public = await self.store.is_room_world_readable_or_publicly_joinable(
             room_id
@@ -386,24 +404,32 @@ class UserDirectoryHandler(StateDeltasHandler):
                 await self.store.add_users_who_share_private_room(room_id, to_insert)
 
     async def _handle_remove_user(self, room_id: str, user_id: str) -> None:
-        """Called when we might need to remove user from directory
+        """Called when when someone leaves a room. The user may be local or remote.
+
+        (If the person who left was the last local user in this room, the server
+        is no longer in the room. We call this function to forget that the remaining
+        remote users are in the room, even though they haven't left. So the name is
+        a little misleading!)
 
         Args:
             room_id: The room ID that user left or stopped being public that
             user_id
         """
-        logger.debug("Removing user %r", user_id)
+        logger.debug("Removing user %r from room %r", user_id, room_id)
 
         # Remove user from sharing tables
         await self.store.remove_user_who_share_room(user_id, room_id)
 
-        # Are they still in any rooms? If not, remove them entirely.
-        rooms_user_is_in = await self.store.get_user_dir_rooms_user_is_in(user_id)
+        # Additionally, if they're a remote user and we're no longer joined
+        # to any rooms they're in, remove them from the user directory.
+        if not self.is_mine_id(user_id):
+            rooms_user_is_in = await self.store.get_user_dir_rooms_user_is_in(user_id)
 
-        if len(rooms_user_is_in) == 0:
-            await self.store.remove_from_user_dir(user_id)
+            if len(rooms_user_is_in) == 0:
+                logger.debug("Removing user %r from directory", user_id)
+                await self.store.remove_from_user_dir(user_id)
 
-    async def _handle_profile_change(
+    async def _handle_possible_remote_profile_change(
         self,
         user_id: str,
         room_id: str,
@@ -411,7 +437,8 @@ class UserDirectoryHandler(StateDeltasHandler):
         event_id: Optional[str],
     ) -> None:
         """Check member event changes for any profile changes and update the
-        database if there are.
+        database if there are. This is intended for remote users only. The caller
+        is responsible for checking that the given user is remote.
         """
         if not prev_event_id or not event_id:
             return
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 8ae21bc43c..ab7ef8f950 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -45,6 +45,7 @@ from synapse.http.servlet import parse_json_object_from_request
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.rest.client.login import LoginResponse
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.roommember import ProfileInfo
 from synapse.storage.state import StateFilter
@@ -83,6 +84,8 @@ __all__ = [
     "DirectServeJsonResource",
     "ModuleApi",
     "PRESENCE_ALL_USERS",
+    "LoginResponse",
+    "JsonDict",
 ]
 
 logger = logging.getLogger(__name__)
@@ -139,6 +142,7 @@ class ModuleApi:
         self._spam_checker = hs.get_spam_checker()
         self._account_validity_handler = hs.get_account_validity_handler()
         self._third_party_event_rules = hs.get_third_party_event_rules()
+        self._password_auth_provider = hs.get_password_auth_provider()
         self._presence_router = hs.get_presence_router()
 
     #################################################################################
@@ -164,6 +168,11 @@ class ModuleApi:
         """Registers callbacks for presence router capabilities."""
         return self._presence_router.register_presence_router_callbacks
 
+    @property
+    def register_password_auth_provider_callbacks(self):
+        """Registers callbacks for password auth provider capabilities."""
+        return self._password_auth_provider.register_password_auth_provider_callbacks
+
     def register_web_resource(self, path: str, resource: IResource):
         """Registers a web resource to be served at the given path.
 
@@ -773,9 +782,9 @@ class ModuleApi:
             # Sanitize some of the data. We don't want to return tokens.
             return [
                 UserIpAndAgent(
-                    ip=str(data["ip"]),
-                    user_agent=str(data["user_agent"]),
-                    last_seen=int(data["last_seen"]),
+                    ip=data["ip"],
+                    user_agent=data["user_agent"],
+                    last_seen=data["last_seen"],
                 )
                 for data in raw_data
             ]
diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py
index 98ea911a81..1db900e41f 100644
--- a/synapse/module_api/errors.py
+++ b/synapse/module_api/errors.py
@@ -14,9 +14,16 @@
 
 """Exception types which are exposed as part of the stable module API"""
 
-from synapse.api.errors import (  # noqa: F401
+from synapse.api.errors import (
     InvalidClientCredentialsError,
     RedirectException,
     SynapseError,
 )
-from synapse.config._base import ConfigError  # noqa: F401
+from synapse.config._base import ConfigError
+
+__all__ = [
+    "InvalidClientCredentialsError",
+    "RedirectException",
+    "SynapseError",
+    "ConfigError",
+]
diff --git a/synapse/py.typed b/synapse/py.typed
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/py.typed
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 8c80153ab6..7bae36db16 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -182,9 +182,13 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         # a logcontext which we use for processing incoming commands. We declare it as a
         # background process so that the CPU stats get reported to prometheus.
-        self._logging_context = BackgroundProcessLoggingContext(
-            "replication-conn", self.conn_id
-        )
+        with PreserveLoggingContext():
+            # thanks to `PreserveLoggingContext()`, the new logcontext is guaranteed to
+            # capture the sentinel context as its containing context and won't prevent
+            # GC of / unintentionally reactivate what would be the current context.
+            self._logging_context = BackgroundProcessLoggingContext(
+                "replication-conn", self.conn_id
+            )
 
     def connectionMade(self):
         logger.info("[%s] Connection established", self.id())
@@ -434,8 +438,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         if self.transport:
             self.transport.unregisterProducer()
 
-        # mark the logging context as finished
-        self._logging_context.__exit__(None, None, None)
+        # mark the logging context as finished by triggering `__exit__()`
+        with PreserveLoggingContext():
+            with self._logging_context:
+                pass
+            # the sentinel context is now active, which may not be correct.
+            # PreserveLoggingContext() will restore the correct logging context.
 
     def __str__(self):
         addr = None
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 062fe2f33e..8d28bd3f3f 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -100,9 +100,13 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
 
         # a logcontext which we use for processing incoming commands. We declare it as a
         # background process so that the CPU stats get reported to prometheus.
-        self._logging_context = BackgroundProcessLoggingContext(
-            "replication_command_handler"
-        )
+        with PreserveLoggingContext():
+            # thanks to `PreserveLoggingContext()`, the new logcontext is guaranteed to
+            # capture the sentinel context as its containing context and won't prevent
+            # GC of / unintentionally reactivate what would be the current context.
+            self._logging_context = BackgroundProcessLoggingContext(
+                "replication_command_handler"
+            )
 
     def connectionMade(self):
         logger.info("Connected to redis")
@@ -182,8 +186,12 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
         super().connectionLost(reason)
         self.synapse_handler.lost_connection(self)
 
-        # mark the logging context as finished
-        self._logging_context.__exit__(None, None, None)
+        # mark the logging context as finished by triggering `__exit__()`
+        with PreserveLoggingContext():
+            with self._logging_context:
+                pass
+            # the sentinel context is now active, which may not be correct.
+            # PreserveLoggingContext() will restore the correct logging context.
 
     def send_command(self, cmd: Command):
         """Send a command if connection has been established.
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 0b0711c03c..d695c18be2 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -232,12 +232,12 @@ class RelationPaginationServlet(RestServlet):
         # Similarly, we don't allow relations to be applied to relations, so we
         # return the original relations without any aggregations on top of them
         # here.
-        events = await self._event_serializer.serialize_events(
+        serialized_events = await self._event_serializer.serialize_events(
             events, now, bundle_aggregations=False
         )
 
         return_value = pagination_chunk.to_dict()
-        return_value["chunk"] = events
+        return_value["chunk"] = serialized_events
         return_value["original_event"] = original_event
 
         return 200, return_value
@@ -416,10 +416,10 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
         )
 
         now = self.clock.time_msec()
-        events = await self._event_serializer.serialize_events(events, now)
+        serialized_events = await self._event_serializer.serialize_events(events, now)
 
         return_value = result.to_dict()
-        return_value["chunk"] = events
+        return_value["chunk"] = serialized_events
 
         return 200, return_value
 
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index 38ad4c2447..99f8156ad0 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -32,7 +32,6 @@ from synapse.http.servlet import (
 from synapse.http.site import SynapseRequest
 from synapse.rest.client.transactions import HttpTransactionCache
 from synapse.types import JsonDict
-from synapse.util.stringutils import random_string
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -160,11 +159,6 @@ class RoomBatchSendEventRestServlet(RestServlet):
         base_insertion_event = None
         if batch_id_from_query:
             batch_id_to_connect_to = batch_id_from_query
-            #  All but the first base insertion event should point at a fake
-            #  event, which causes the HS to ask for the state at the start of
-            #  the batch later.
-            fake_prev_event_id = "$" + random_string(43)
-            prev_event_ids = [fake_prev_event_id]
         # Otherwise, create an insertion event to act as a starting point.
         #
         # We don't always have an insertion event to start hanging more history
@@ -173,8 +167,6 @@ class RoomBatchSendEventRestServlet(RestServlet):
         # an insertion event), in which case we just create a new insertion event
         # that can then get pointed to by a "marker" event later.
         else:
-            prev_event_ids = prev_event_ids_from_query
-
             base_insertion_event_dict = (
                 self.room_batch_handler.create_insertion_event_dict(
                     sender=requester.user.to_string(),
@@ -182,7 +174,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
                     origin_server_ts=last_event_in_batch["origin_server_ts"],
                 )
             )
-            base_insertion_event_dict["prev_events"] = prev_event_ids.copy()
+            base_insertion_event_dict["prev_events"] = prev_event_ids_from_query.copy()
 
             (
                 base_insertion_event,
@@ -203,6 +195,11 @@ class RoomBatchSendEventRestServlet(RestServlet):
                 EventContentFields.MSC2716_NEXT_BATCH_ID
             ]
 
+        # Also connect the historical event chain to the end of the floating
+        # state chain, which causes the HS to ask for the state at the start of
+        # the batch later.
+        prev_event_ids = [state_event_ids_at_start[-1]]
+
         # Create and persist all of the historical events as well as insertion
         # and batch meta events to make the batch navigable in the DAG.
         event_ids, next_batch_id = await self.room_batch_handler.handle_batch_of_events(
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 08bd85f664..bec77088ee 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -16,12 +16,15 @@
 import functools
 import os
 import re
-from typing import Any, Callable, List
+from typing import Any, Callable, List, TypeVar, cast
 
 NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
 
 
-def _wrap_in_base_path(func: Callable[..., str]) -> Callable[..., str]:
+F = TypeVar("F", bound=Callable[..., str])
+
+
+def _wrap_in_base_path(func: F) -> F:
     """Takes a function that returns a relative path and turns it into an
     absolute path based on the location of the primary media store
     """
@@ -31,7 +34,7 @@ def _wrap_in_base_path(func: Callable[..., str]) -> Callable[..., str]:
         path = func(self, *args, **kwargs)
         return os.path.join(self.base_path, path)
 
-    return _wrapped
+    return cast(F, _wrapped)
 
 
 class MediaFilePaths:
@@ -45,23 +48,6 @@ class MediaFilePaths:
     def __init__(self, primary_base_path: str):
         self.base_path = primary_base_path
 
-    def default_thumbnail_rel(
-        self,
-        default_top_level: str,
-        default_sub_type: str,
-        width: int,
-        height: int,
-        content_type: str,
-        method: str,
-    ) -> str:
-        top_level_type, sub_type = content_type.split("/")
-        file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
-        return os.path.join(
-            "default_thumbnails", default_top_level, default_sub_type, file_name
-        )
-
-    default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
-
     def local_media_filepath_rel(self, media_id: str) -> str:
         return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
 
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
index 78b1603f19..2a59552c20 100644
--- a/synapse/rest/media/v1/oembed.py
+++ b/synapse/rest/media/v1/oembed.py
@@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, List, Optional
 
 import attr
 
-from synapse.http.client import SimpleHttpClient
 from synapse.types import JsonDict
 from synapse.util import json_decoder
 
@@ -48,7 +47,7 @@ class OEmbedProvider:
     requesting/parsing oEmbed content.
     """
 
-    def __init__(self, hs: "HomeServer", client: SimpleHttpClient):
+    def __init__(self, hs: "HomeServer"):
         self._oembed_patterns = {}
         for oembed_endpoint in hs.config.oembed.oembed_patterns:
             api_endpoint = oembed_endpoint.api_endpoint
@@ -69,7 +68,6 @@ class OEmbedProvider:
             # Iterate through each URL pattern and point it to the endpoint.
             for pattern in oembed_endpoint.url_patterns:
                 self._oembed_patterns[pattern] = api_endpoint
-        self._client = client
 
     def get_oembed_url(self, url: str) -> Optional[str]:
         """
@@ -139,10 +137,11 @@ class OEmbedProvider:
             # oEmbed responses *must* be UTF-8 according to the spec.
             oembed = json_decoder.decode(raw_body.decode("utf-8"))
 
-            # Ensure there's a version of 1.0.
-            oembed_version = oembed["version"]
-            if oembed_version != "1.0":
-                raise RuntimeError(f"Invalid version: {oembed_version}")
+            # The version is a required string field, but not always provided,
+            # or sometimes provided as a float. Be lenient.
+            oembed_version = oembed.get("version", "1.0")
+            if oembed_version != "1.0" and oembed_version != 1:
+                raise RuntimeError(f"Invalid oEmbed version: {oembed_version}")
 
             # Ensure the cache age is None or an int.
             cache_age = oembed.get("cache_age")
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 1fe0fc8aa9..278fd901e2 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import codecs
 import datetime
 import errno
 import fnmatch
@@ -22,7 +23,7 @@ import re
 import shutil
 import sys
 import traceback
-from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Tuple, Union
 from urllib import parse as urlparse
 
 import attr
@@ -140,7 +141,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         self.primary_base_path = media_repo.primary_base_path
         self.media_storage = media_storage
 
-        self._oembed = OEmbedProvider(hs, self.client)
+        self._oembed = OEmbedProvider(hs)
 
         # We run the background jobs if we're the instance specified (or no
         # instance is specified, where we assume there is only one instance
@@ -295,8 +296,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             with open(media_info.filename, "rb") as file:
                 body = file.read()
 
-            encoding = get_html_media_encoding(body, media_info.media_type)
-            tree = decode_body(body, encoding)
+            tree = decode_body(body, media_info.uri, media_info.media_type)
             if tree is not None:
                 # Check if this HTML document points to oEmbed information and
                 # defer to that.
@@ -632,16 +632,27 @@ class PreviewUrlResource(DirectServeJsonResource):
             logger.debug("No media removed from url cache")
 
 
-def get_html_media_encoding(body: bytes, content_type: str) -> str:
+def _normalise_encoding(encoding: str) -> Optional[str]:
+    """Use the Python codec's name as the normalised entry."""
+    try:
+        return codecs.lookup(encoding).name
+    except LookupError:
+        return None
+
+
+def get_html_media_encodings(body: bytes, content_type: Optional[str]) -> Iterable[str]:
     """
-    Get the encoding of the body based on the (presumably) HTML body or media_type.
+    Get potential encoding of the body based on the (presumably) HTML body or the content-type header.
 
     The precedence used for finding a character encoding is:
 
-    1. meta tag with a charset declared.
+    1. <meta> tag with a charset declared.
     2. The XML document's character encoding attribute.
     3. The Content-Type header.
-    4. Fallback to UTF-8.
+    4. Fallback to utf-8.
+    5. Fallback to windows-1252.
+
+    This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector.
 
     Args:
         body: The HTML document, as bytes.
@@ -650,39 +661,55 @@ def get_html_media_encoding(body: bytes, content_type: str) -> str:
     Returns:
         The character encoding of the body, as a string.
     """
+    # There's no point in returning an encoding more than once.
+    attempted_encodings: Set[str] = set()
+
     # Limit searches to the first 1kb, since it ought to be at the top.
     body_start = body[:1024]
 
-    # Let's try and figure out if it has an encoding set in a meta tag.
+    # Check if it has an encoding set in a meta tag.
     match = _charset_match.search(body_start)
     if match:
-        return match.group(1).decode("ascii")
+        encoding = _normalise_encoding(match.group(1).decode("ascii"))
+        if encoding:
+            attempted_encodings.add(encoding)
+            yield encoding
 
     # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
 
-    # If we didn't find a match, see if it an XML document with an encoding.
+    # Check if it has an XML document with an encoding.
     match = _xml_encoding_match.match(body_start)
     if match:
-        return match.group(1).decode("ascii")
-
-    # If we don't find a match, we'll look at the HTTP Content-Type, and
-    # if that doesn't exist, we'll fall back to UTF-8.
-    content_match = _content_type_match.match(content_type)
-    if content_match:
-        return content_match.group(1)
-
-    return "utf-8"
+        encoding = _normalise_encoding(match.group(1).decode("ascii"))
+        if encoding and encoding not in attempted_encodings:
+            attempted_encodings.add(encoding)
+            yield encoding
+
+    # Check the HTTP Content-Type header for a character set.
+    if content_type:
+        content_match = _content_type_match.match(content_type)
+        if content_match:
+            encoding = _normalise_encoding(content_match.group(1))
+            if encoding and encoding not in attempted_encodings:
+                attempted_encodings.add(encoding)
+                yield encoding
+
+    # Finally, fallback to UTF-8, then windows-1252.
+    for fallback in ("utf-8", "cp1252"):
+        if fallback not in attempted_encodings:
+            yield fallback
 
 
 def decode_body(
-    body: bytes, request_encoding: Optional[str] = None
+    body: bytes, uri: str, content_type: Optional[str] = None
 ) -> Optional["etree.Element"]:
     """
     This uses lxml to parse the HTML document.
 
     Args:
         body: The HTML document, as bytes.
-        request_encoding: The character encoding of the body, as a string.
+        uri: The URI used to download the body.
+        content_type: The Content-Type header.
 
     Returns:
         The parsed HTML body, or None if an error occurred during processed.
@@ -691,32 +718,25 @@ def decode_body(
     if not body:
         return None
 
-    from lxml import etree
-
-    # Create an HTML parser. If this fails, log and return no metadata.
-    try:
-        parser = etree.HTMLParser(recover=True, encoding=request_encoding)
-    except LookupError:
-        # blindly consider the encoding as utf-8.
-        parser = etree.HTMLParser(recover=True, encoding="utf-8")
-    except Exception as e:
-        logger.warning("Unable to create HTML parser: %s" % (e,))
+    for encoding in get_html_media_encodings(body, content_type):
+        try:
+            body_str = body.decode(encoding)
+        except Exception:
+            pass
+        else:
+            break
+    else:
+        logger.warning("Unable to decode HTML body for %s", uri)
         return None
 
-    def _attempt_decode_body(
-        body_attempt: Union[bytes, str]
-    ) -> Optional["etree.Element"]:
-        # Attempt to parse the body. Returns None if the body was successfully
-        # parsed, but no tree was found.
-        return etree.fromstring(body_attempt, parser)
+    from lxml import etree
 
-    # Attempt to parse the body. If this fails, log and return no metadata.
-    try:
-        return _attempt_decode_body(body)
-    except UnicodeDecodeError:
-        # blindly try decoding the body as utf-8, which seems to fix
-        # the charset mismatches on https://google.com
-        return _attempt_decode_body(body.decode("utf-8", "ignore"))
+    # Create an HTML parser.
+    parser = etree.HTMLParser(recover=True, encoding="utf-8")
+
+    # Attempt to parse the body. Returns None if the body was successfully
+    # parsed, but no tree was found.
+    return etree.fromstring(body_str, parser)
 
 
 def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
diff --git a/synapse/server.py b/synapse/server.py
index 5bc045d615..a64c846d1c 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -65,7 +65,7 @@ from synapse.handlers.account_data import AccountDataHandler
 from synapse.handlers.account_validity import AccountValidityHandler
 from synapse.handlers.admin import AdminHandler
 from synapse.handlers.appservice import ApplicationServicesHandler
-from synapse.handlers.auth import AuthHandler, MacaroonGenerator
+from synapse.handlers.auth import AuthHandler, MacaroonGenerator, PasswordAuthProvider
 from synapse.handlers.cas import CasHandler
 from synapse.handlers.deactivate_account import DeactivateAccountHandler
 from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
@@ -688,6 +688,10 @@ class HomeServer(metaclass=abc.ABCMeta):
         return ThirdPartyEventRules(self)
 
     @cache_in_self
+    def get_password_auth_provider(self) -> PasswordAuthProvider:
+        return PasswordAuthProvider()
+
+    @cache_in_self
     def get_room_member_handler(self) -> RoomMemberHandler:
         if self.config.worker.worker_app:
             return RoomMemberWorkerHandler(self)
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index ffe6207a3c..6edadea550 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -332,7 +332,7 @@ def _resolve_auth_events(
             event_auth.check_auth_rules_for_event(
                 RoomVersions.V1,
                 event,
-                auth_events,
+                auth_events.values(),
             )
             prev_event = event
         except AuthError:
@@ -350,7 +350,7 @@ def _resolve_normal_events(
             event_auth.check_auth_rules_for_event(
                 RoomVersions.V1,
                 event,
-                auth_events,
+                auth_events.values(),
             )
             return event
         except AuthError:
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index bd18eefd58..c618df2fde 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -549,7 +549,7 @@ async def _iterative_auth_checks(
             event_auth.check_auth_rules_for_event(
                 room_version,
                 event,
-                auth_events,
+                auth_events.values(),
             )
 
             resolved_state[(event.type, event.state_key)] = event_id
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 6c1ef09049..b81d9218ce 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -13,14 +13,26 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast
+
+from typing_extensions import TypedDict
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
-from synapse.types import UserID
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+    make_tuple_comparison_clause,
+)
+from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
+from synapse.storage.types import Connection
+from synapse.types import JsonDict, UserID
 from synapse.util.caches.lrucache import LruCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 # Number of msec of granularity to store the user IP 'last seen' time. Smaller
@@ -29,8 +41,31 @@ logger = logging.getLogger(__name__)
 LAST_SEEN_GRANULARITY = 120 * 1000
 
 
+class DeviceLastConnectionInfo(TypedDict):
+    """Metadata for the last connection seen for a user and device combination"""
+
+    # These types must match the columns in the `devices` table
+    user_id: str
+    device_id: str
+
+    ip: Optional[str]
+    user_agent: Optional[str]
+    last_seen: Optional[int]
+
+
+class LastConnectionInfo(TypedDict):
+    """Metadata for the last connection seen for an access token and IP combination"""
+
+    # These types must match the columns in the `user_ips` table
+    access_token: str
+    ip: str
+
+    user_agent: str
+    last_seen: int
+
+
 class ClientIpBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -81,8 +116,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
             "devices_last_seen", self._devices_last_seen_update
         )
 
-    async def _remove_user_ip_nonunique(self, progress, batch_size):
-        def f(conn):
+    async def _remove_user_ip_nonunique(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        def f(conn: LoggingDatabaseConnection) -> None:
             txn = conn.cursor()
             txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
             txn.close()
@@ -93,14 +130,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
         )
         return 1
 
-    async def _analyze_user_ip(self, progress, batch_size):
+    async def _analyze_user_ip(self, progress: JsonDict, batch_size: int) -> int:
         # Background update to analyze user_ips table before we run the
         # deduplication background update. The table may not have been analyzed
         # for ages due to the table locks.
         #
         # This will lock out the naive upserts to user_ips while it happens, but
         # the analyze should be quick (28GB table takes ~10s)
-        def user_ips_analyze(txn):
+        def user_ips_analyze(txn: LoggingTransaction) -> None:
             txn.execute("ANALYZE user_ips")
 
         await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
@@ -109,16 +146,16 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
         return 1
 
-    async def _remove_user_ip_dupes(self, progress, batch_size):
+    async def _remove_user_ip_dupes(self, progress: JsonDict, batch_size: int) -> int:
         # This works function works by scanning the user_ips table in batches
         # based on `last_seen`. For each row in a batch it searches the rest of
         # the table to see if there are any duplicates, if there are then they
         # are removed and replaced with a suitable row.
 
         # Fetch the start of the batch
-        begin_last_seen = progress.get("last_seen", 0)
+        begin_last_seen: int = progress.get("last_seen", 0)
 
-        def get_last_seen(txn):
+        def get_last_seen(txn: LoggingTransaction) -> Optional[int]:
             txn.execute(
                 """
                 SELECT last_seen FROM user_ips
@@ -129,7 +166,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
                 """,
                 (begin_last_seen, batch_size),
             )
-            row = txn.fetchone()
+            row = cast(Optional[Tuple[int]], txn.fetchone())
             if row:
                 return row[0]
             else:
@@ -149,7 +186,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
             end_last_seen,
         )
 
-        def remove(txn):
+        def remove(txn: LoggingTransaction) -> None:
             # This works by looking at all entries in the given time span, and
             # then for each (user_id, access_token, ip) tuple in that range
             # checking for any duplicates in the rest of the table (via a join).
@@ -161,10 +198,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
             # Define the search space, which requires handling the last batch in
             # a different way
+            args: Tuple[int, ...]
             if last:
                 clause = "? <= last_seen"
                 args = (begin_last_seen,)
             else:
+                assert end_last_seen is not None
                 clause = "? <= last_seen AND last_seen < ?"
                 args = (begin_last_seen, end_last_seen)
 
@@ -189,7 +228,9 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
                 ),
                 args,
             )
-            res = txn.fetchall()
+            res = cast(
+                List[Tuple[str, str, str, Optional[str], str, int, int]], txn.fetchall()
+            )
 
             # We've got some duplicates
             for i in res:
@@ -278,13 +319,15 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
         return batch_size
 
-    async def _devices_last_seen_update(self, progress, batch_size):
+    async def _devices_last_seen_update(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Background update to insert last seen info into devices table"""
 
-        last_user_id = progress.get("last_user_id", "")
-        last_device_id = progress.get("last_device_id", "")
+        last_user_id: str = progress.get("last_user_id", "")
+        last_device_id: str = progress.get("last_device_id", "")
 
-        def _devices_last_seen_update_txn(txn):
+        def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int:
             # This consists of two queries:
             #
             #   1. The sub-query searches for the next N devices and joins
@@ -296,6 +339,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
             #      we'll just end up updating the same device row multiple
             #      times, which is fine.
 
+            where_args: List[Union[str, int]]
             where_clause, where_args = make_tuple_comparison_clause(
                 [("user_id", last_user_id), ("device_id", last_device_id)],
             )
@@ -319,7 +363,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
             }
             txn.execute(sql, where_args + [batch_size])
 
-            rows = txn.fetchall()
+            rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
             if not rows:
                 return 0
 
@@ -350,7 +394,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
 
 class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.user_ips_max_age = hs.config.server.user_ips_max_age
@@ -359,7 +403,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
             self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
 
     @wrap_as_background_process("prune_old_user_ips")
-    async def _prune_old_user_ips(self):
+    async def _prune_old_user_ips(self) -> None:
         """Removes entries in user IPs older than the configured period."""
 
         if self.user_ips_max_age is None:
@@ -394,9 +438,9 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
             )
         """
 
-        timestamp = self.clock.time_msec() - self.user_ips_max_age
+        timestamp = self._clock.time_msec() - self.user_ips_max_age
 
-        def _prune_old_user_ips_txn(txn):
+        def _prune_old_user_ips_txn(txn: LoggingTransaction) -> None:
             txn.execute(sql, (timestamp,))
 
         await self.db_pool.runInteraction(
@@ -405,7 +449,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
 
     async def get_last_client_ip_by_device(
         self, user_id: str, device_id: Optional[str]
-    ) -> Dict[Tuple[str, str], dict]:
+    ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
         """For each device_id listed, give the user_ip it was last seen on.
 
         The result might be slightly out of date as client IPs are inserted in batches.
@@ -423,26 +467,32 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
         if device_id is not None:
             keyvalues["device_id"] = device_id
 
-        res = await self.db_pool.simple_select_list(
-            table="devices",
-            keyvalues=keyvalues,
-            retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+        res = cast(
+            List[DeviceLastConnectionInfo],
+            await self.db_pool.simple_select_list(
+                table="devices",
+                keyvalues=keyvalues,
+                retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+            ),
         )
 
         return {(d["user_id"], d["device_id"]): d for d in res}
 
 
-class ClientIpStore(ClientIpWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
 
-        self.client_ip_last_seen = LruCache(
+        # (user_id, access_token, ip,) -> last_seen
+        self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
             cache_name="client_ip_last_seen", max_size=50000
         )
 
         super().__init__(database, db_conn, hs)
 
         # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
-        self._batch_row_update = {}
+        self._batch_row_update: Dict[
+            Tuple[str, str, str], Tuple[str, Optional[str], int]
+        ] = {}
 
         self._client_ip_looper = self._clock.looping_call(
             self._update_client_ips_batch, 5 * 1000
@@ -452,8 +502,14 @@ class ClientIpStore(ClientIpWorkerStore):
         )
 
     async def insert_client_ip(
-        self, user_id, access_token, ip, user_agent, device_id, now=None
-    ):
+        self,
+        user_id: str,
+        access_token: str,
+        ip: str,
+        user_agent: str,
+        device_id: Optional[str],
+        now: Optional[int] = None,
+    ) -> None:
         if not now:
             now = int(self._clock.time_msec())
         key = (user_id, access_token, ip)
@@ -485,7 +541,11 @@ class ClientIpStore(ClientIpWorkerStore):
             "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
         )
 
-    def _update_client_ips_batch_txn(self, txn, to_update):
+    def _update_client_ips_batch_txn(
+        self,
+        txn: LoggingTransaction,
+        to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]],
+    ) -> None:
         if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
             not self.database_engine.can_native_upsert
         ):
@@ -525,7 +585,7 @@ class ClientIpStore(ClientIpWorkerStore):
 
     async def get_last_client_ip_by_device(
         self, user_id: str, device_id: Optional[str]
-    ) -> Dict[Tuple[str, str], dict]:
+    ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
         """For each device_id listed, give the user_ip it was last seen on
 
         Args:
@@ -561,12 +621,12 @@ class ClientIpStore(ClientIpWorkerStore):
 
     async def get_user_ip_and_agents(
         self, user: UserID, since_ts: int = 0
-    ) -> List[Dict[str, Union[str, int]]]:
+    ) -> List[LastConnectionInfo]:
         """
         Fetch IP/User Agent connection since a given timestamp.
         """
         user_id = user.to_string()
-        results = {}
+        results: Dict[Tuple[str, str], Tuple[str, int]] = {}
 
         for key in self._batch_row_update:
             (
@@ -579,7 +639,7 @@ class ClientIpStore(ClientIpWorkerStore):
                 if last_seen >= since_ts:
                     results[(access_token, ip)] = (user_agent, last_seen)
 
-        def get_recent(txn):
+        def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
             txn.execute(
                 """
                 SELECT access_token, ip, user_agent, last_seen FROM user_ips
@@ -589,7 +649,7 @@ class ClientIpStore(ClientIpWorkerStore):
                 """,
                 (since_ts, user_id),
             )
-            return txn.fetchall()
+            return cast(List[Tuple[str, str, str, int]], txn.fetchall())
 
         rows = await self.db_pool.runInteraction(
             desc="get_user_ip_and_agents", func=get_recent
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 10184d6ae7..ba9f71a230 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -906,7 +906,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             desc="get_latest_event_ids_in_room",
         )
 
-    async def get_min_depth(self, room_id: str) -> int:
+    async def get_min_depth(self, room_id: str) -> Optional[int]:
         """For the given room, get the minimum depth we have seen for it."""
         return await self.db_pool.runInteraction(
             "get_min_depth", self._get_min_depth_interaction, room_id
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 19f55c19c5..37439f8562 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -2069,12 +2069,14 @@ class PersistEventsStore:
 
             state_groups[event.event_id] = context.state_group
 
-        self.db_pool.simple_insert_many_txn(
+        self.db_pool.simple_upsert_many_txn(
             txn,
             table="event_to_state_groups",
-            values=[
-                {"state_group": state_group_id, "event_id": event_id}
-                for event_id, state_group_id in state_groups.items()
+            key_names=["event_id"],
+            key_values=[[event_id] for event_id, _ in state_groups.items()],
+            value_names=["state_group"],
+            value_values=[
+                [state_group_id] for _, state_group_id in state_groups.items()
             ],
         )
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4a1a2f4a6a..ae37901be9 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -55,8 +55,9 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import BackfillStream
 from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.engines import PostgresEngine
+from synapse.storage.types import Connection
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
@@ -86,6 +87,47 @@ class _EventCacheEntry:
     redacted_event: Optional[EventBase]
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventRow:
+    """
+    An event, as pulled from the database.
+
+    Properties:
+        event_id: The event ID of the event.
+
+        stream_ordering: stream ordering for this event
+
+        json: json-encoded event structure
+
+        internal_metadata: json-encoded internal metadata dict
+
+        format_version: The format of the event. Hopefully one of EventFormatVersions.
+            'None' means the event predates EventFormatVersions (so the event is format V1).
+
+        room_version_id: The version of the room which contains the event. Hopefully
+            one of RoomVersions.
+
+           Due to historical reasons, there may be a few events in the database which
+           do not have an associated room; in this case None will be returned here.
+
+        rejected_reason: if the event was rejected, the reason why.
+
+        redactions: a list of event-ids which (claim to) redact this event.
+
+        outlier: True if this event is an outlier.
+    """
+
+    event_id: str
+    stream_ordering: int
+    json: str
+    internal_metadata: str
+    format_version: Optional[int]
+    room_version_id: Optional[int]
+    rejected_reason: Optional[str]
+    redactions: List[str]
+    outlier: bool
+
+
 class EventRedactBehaviour(Names):
     """
     What to do when retrieving a redacted event from the database.
@@ -686,7 +728,7 @@ class EventsWorkerStore(SQLBaseStore):
             for e in state_to_include.values()
         ]
 
-    def _do_fetch(self, conn):
+    def _do_fetch(self, conn: Connection) -> None:
         """Takes a database connection and waits for requests for events from
         the _event_fetch_list queue.
         """
@@ -713,13 +755,15 @@ class EventsWorkerStore(SQLBaseStore):
 
             self._fetch_event_list(conn, event_list)
 
-    def _fetch_event_list(self, conn, event_list):
+    def _fetch_event_list(
+        self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
+    ) -> None:
         """Handle a load of requests from the _event_fetch_list queue
 
         Args:
-            conn (twisted.enterprise.adbapi.Connection): database connection
+            conn: database connection
 
-            event_list (list[Tuple[list[str], Deferred]]):
+            event_list:
                 The fetch requests. Each entry consists of a list of event
                 ids to be fetched, and a deferred to be completed once the
                 events have been fetched.
@@ -788,7 +832,7 @@ class EventsWorkerStore(SQLBaseStore):
                 row = row_map.get(event_id)
                 fetched_events[event_id] = row
                 if row:
-                    redaction_ids.update(row["redactions"])
+                    redaction_ids.update(row.redactions)
 
             events_to_fetch = redaction_ids.difference(fetched_events.keys())
             if events_to_fetch:
@@ -799,32 +843,32 @@ class EventsWorkerStore(SQLBaseStore):
         for event_id, row in fetched_events.items():
             if not row:
                 continue
-            assert row["event_id"] == event_id
+            assert row.event_id == event_id
 
-            rejected_reason = row["rejected_reason"]
+            rejected_reason = row.rejected_reason
 
             # If the event or metadata cannot be parsed, log the error and act
             # as if the event is unknown.
             try:
-                d = db_to_json(row["json"])
+                d = db_to_json(row.json)
             except ValueError:
                 logger.error("Unable to parse json from event: %s", event_id)
                 continue
             try:
-                internal_metadata = db_to_json(row["internal_metadata"])
+                internal_metadata = db_to_json(row.internal_metadata)
             except ValueError:
                 logger.error(
                     "Unable to parse internal_metadata from event: %s", event_id
                 )
                 continue
 
-            format_version = row["format_version"]
+            format_version = row.format_version
             if format_version is None:
                 # This means that we stored the event before we had the concept
                 # of a event format version, so it must be a V1 event.
                 format_version = EventFormatVersions.V1
 
-            room_version_id = row["room_version_id"]
+            room_version_id = row.room_version_id
 
             if not room_version_id:
                 # this should only happen for out-of-band membership events which
@@ -889,8 +933,8 @@ class EventsWorkerStore(SQLBaseStore):
                 internal_metadata_dict=internal_metadata,
                 rejected_reason=rejected_reason,
             )
-            original_ev.internal_metadata.stream_ordering = row["stream_ordering"]
-            original_ev.internal_metadata.outlier = row["outlier"]
+            original_ev.internal_metadata.stream_ordering = row.stream_ordering
+            original_ev.internal_metadata.outlier = row.outlier
 
             event_map[event_id] = original_ev
 
@@ -898,7 +942,7 @@ class EventsWorkerStore(SQLBaseStore):
         # the cache entries.
         result_map = {}
         for event_id, original_ev in event_map.items():
-            redactions = fetched_events[event_id]["redactions"]
+            redactions = fetched_events[event_id].redactions
             redacted_event = self._maybe_redact_event_row(
                 original_ev, redactions, event_map
             )
@@ -912,17 +956,17 @@ class EventsWorkerStore(SQLBaseStore):
 
         return result_map
 
-    async def _enqueue_events(self, events):
+    async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
         """Fetches events from the database using the _event_fetch_list. This
         allows batch and bulk fetching of events - it allows us to fetch events
         without having to create a new transaction for each request for events.
 
         Args:
-            events (Iterable[str]): events to be fetched.
+            events: events to be fetched.
 
         Returns:
-            Dict[str, Dict]: map from event id to row data from the database.
-                May contain events that weren't requested.
+            A map from event id to row data from the database. May contain events
+            that weren't requested.
         """
 
         events_d = defer.Deferred()
@@ -949,43 +993,19 @@ class EventsWorkerStore(SQLBaseStore):
 
         return row_map
 
-    def _fetch_event_rows(self, txn, event_ids):
+    def _fetch_event_rows(
+        self, txn: LoggingTransaction, event_ids: Iterable[str]
+    ) -> Dict[str, _EventRow]:
         """Fetch event rows from the database
 
         Events which are not found are omitted from the result.
 
-        The returned per-event dicts contain the following keys:
-
-         * event_id (str)
-
-         * stream_ordering (int): stream ordering for this event
-
-         * json (str): json-encoded event structure
-
-         * internal_metadata (str): json-encoded internal metadata dict
-
-         * format_version (int|None): The format of the event. Hopefully one
-           of EventFormatVersions. 'None' means the event predates
-           EventFormatVersions (so the event is format V1).
-
-         * room_version_id (str|None): The version of the room which contains the event.
-           Hopefully one of RoomVersions.
-
-           Due to historical reasons, there may be a few events in the database which
-           do not have an associated room; in this case None will be returned here.
-
-         * rejected_reason (str|None): if the event was rejected, the reason
-           why.
-
-         * redactions (List[str]): a list of event-ids which (claim to) redact
-           this event.
-
         Args:
-            txn (twisted.enterprise.adbapi.Connection):
-            event_ids (Iterable[str]): event IDs to fetch
+            txn: The database transaction.
+            event_ids: event IDs to fetch
 
         Returns:
-            Dict[str, Dict]: a map from event id to event info.
+            A map from event id to event info.
         """
         event_dict = {}
         for evs in batch_iter(event_ids, 200):
@@ -1013,17 +1033,17 @@ class EventsWorkerStore(SQLBaseStore):
 
             for row in txn:
                 event_id = row[0]
-                event_dict[event_id] = {
-                    "event_id": event_id,
-                    "stream_ordering": row[1],
-                    "internal_metadata": row[2],
-                    "json": row[3],
-                    "format_version": row[4],
-                    "room_version_id": row[5],
-                    "rejected_reason": row[6],
-                    "redactions": [],
-                    "outlier": row[7],
-                }
+                event_dict[event_id] = _EventRow(
+                    event_id=event_id,
+                    stream_ordering=row[1],
+                    internal_metadata=row[2],
+                    json=row[3],
+                    format_version=row[4],
+                    room_version_id=row[5],
+                    rejected_reason=row[6],
+                    redactions=[],
+                    outlier=row[7],
+                )
 
             # check for redactions
             redactions_sql = "SELECT event_id, redacts FROM redactions WHERE "
@@ -1035,7 +1055,7 @@ class EventsWorkerStore(SQLBaseStore):
             for (redacter, redacted) in txn:
                 d = event_dict.get(redacted)
                 if d:
-                    d["redactions"].append(redacter)
+                    d.redactions.append(redacter)
 
         return event_dict
 
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 181841ee06..0ab56d8a07 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -2237,7 +2237,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                     # accident.
                     row = {"client_secret": None, "validated_at": None}
                 else:
-                    raise ThreepidValidationError(400, "Unknown session_id")
+                    raise ThreepidValidationError("Unknown session_id")
 
             retrieved_client_secret = row["client_secret"]
             validated_at = row["validated_at"]
@@ -2252,14 +2252,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
             if not row:
                 raise ThreepidValidationError(
-                    400, "Validation token not found or has expired"
+                    "Validation token not found or has expired"
                 )
             expires = row["expires"]
             next_link = row["next_link"]
 
             if retrieved_client_secret != client_secret:
                 raise ThreepidValidationError(
-                    400, "This client_secret does not match the provided session_id"
+                    "This client_secret does not match the provided session_id"
                 )
 
             # If the session is already validated, no need to revalidate
@@ -2268,7 +2268,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
             if expires <= current_ts:
                 raise ThreepidValidationError(
-                    400, "This token has expired. Please request a new one"
+                    "This token has expired. Please request a new one"
                 )
 
             # Looks good. Validate the session
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index d69eaf80ce..835d7889cb 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -679,8 +679,8 @@ class RoomWorkerStore(SQLBaseStore):
         # policy.
         if not ret:
             return {
-                "min_lifetime": self.config.server.retention_default_min_lifetime,
-                "max_lifetime": self.config.server.retention_default_max_lifetime,
+                "min_lifetime": self.config.retention.retention_default_min_lifetime,
+                "max_lifetime": self.config.retention.retention_default_max_lifetime,
             }
 
         row = ret[0]
@@ -690,10 +690,10 @@ class RoomWorkerStore(SQLBaseStore):
         # The default values will be None if no default policy has been defined, or if one
         # of the attributes is missing from the default policy.
         if row["min_lifetime"] is None:
-            row["min_lifetime"] = self.config.server.retention_default_min_lifetime
+            row["min_lifetime"] = self.config.retention.retention_default_min_lifetime
 
         if row["max_lifetime"] is None:
-            row["max_lifetime"] = self.config.server.retention_default_max_lifetime
+            row["max_lifetime"] = self.config.retention.retention_default_max_lifetime
 
         return row
 
diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py
index 300a563c9e..dcbce8fdcf 100644
--- a/synapse/storage/databases/main/room_batch.py
+++ b/synapse/storage/databases/main/room_batch.py
@@ -36,3 +36,16 @@ class RoomBatchStore(SQLBaseStore):
             retcol="event_id",
             allow_none=True,
         )
+
+    async def store_state_group_id_for_event_id(
+        self, event_id: str, state_group_id: int
+    ) -> Optional[str]:
+        {
+            await self.db_pool.simple_upsert(
+                table="event_to_state_groups",
+                keyvalues={"event_id": event_id},
+                values={"state_group": state_group_id, "event_id": event_id},
+                # Unique constraint on event_id so we don't have to lock
+                lock=False,
+            )
+        }
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 11ca47ea28..1629d2a53c 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -549,6 +549,8 @@ def _apply_module_schemas(
         database_engine:
         config: application config
     """
+    # This is the old way for password_auth_provider modules to make changes
+    # to the database. This should instead be done using the module API
     for (mod, _config) in config.authproviders.password_providers:
         if not hasattr(mod, "get_db_schema_files"):
             continue
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 1aee741a8b..a1d2332326 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-SCHEMA_VERSION = 64  # remember to update the list below when updating
+SCHEMA_VERSION = 65  # remember to update the list below when updating
 """Represents the expectations made by the codebase about the database schema
 
 This should be incremented whenever the codebase changes its requirements on the
@@ -41,6 +41,10 @@ Changes in SCHEMA_VERSION = 63:
 
 Changes in SCHEMA_VERSION = 64:
     - MSC2716: Rename related tables and columns from "chunks" to "batches".
+
+Changes in SCHEMA_VERSION = 65:
+    - MSC2716: Remove unique event_id constraint from insertion_event_edges
+      because an insertion event can have multiple edges.
 """
 
 
diff --git a/synapse/storage/schema/main/delta/65/01msc2716_insertion_event_edges.sql b/synapse/storage/schema/main/delta/65/01msc2716_insertion_event_edges.sql
new file mode 100644
index 0000000000..98b25daf45
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/01msc2716_insertion_event_edges.sql
@@ -0,0 +1,19 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Recreate the insertion_event_edges event_id index without the unique constraint
+-- because an insertion event can have multiple edges.
+DROP INDEX insertion_event_edges_event_id;
+CREATE INDEX IF NOT EXISTS insertion_event_edges_event_id ON insertion_event_edges(event_id);