summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rwxr-xr-xsynapse/_scripts/hash_password.py10
-rw-r--r--synapse/api/constants.py2
-rw-r--r--synapse/api/errors.py7
-rw-r--r--synapse/api/filtering.py5
-rw-r--r--synapse/api/room_versions.py32
-rw-r--r--synapse/app/_base.py44
-rw-r--r--synapse/app/homeserver.py36
-rw-r--r--synapse/config/_base.py81
-rw-r--r--synapse/config/_base.pyi15
-rw-r--r--synapse/config/cache.py82
-rw-r--r--synapse/config/oembed.py6
-rw-r--r--synapse/config/room.py47
-rw-r--r--synapse/config/server.py2
-rw-r--r--synapse/event_auth.py21
-rw-r--r--synapse/events/__init__.py45
-rw-r--r--synapse/events/snapshot.py195
-rw-r--r--synapse/events/spamcheck.py158
-rw-r--r--synapse/federation/federation_base.py5
-rw-r--r--synapse/federation/federation_server.py48
-rw-r--r--synapse/federation/sender/__init__.py24
-rw-r--r--synapse/federation/transport/server/_base.py21
-rw-r--r--synapse/groups/groups_server.py2
-rw-r--r--synapse/handlers/account_data.py10
-rw-r--r--synapse/handlers/appservice.py39
-rw-r--r--synapse/handlers/device.py7
-rw-r--r--synapse/handlers/devicemessage.py6
-rw-r--r--synapse/handlers/directory.py3
-rw-r--r--synapse/handlers/e2e_keys.py29
-rw-r--r--synapse/handlers/event_auth.py10
-rw-r--r--synapse/handlers/federation.py15
-rw-r--r--synapse/handlers/federation_event.py44
-rw-r--r--synapse/handlers/initial_sync.py19
-rw-r--r--synapse/handlers/message.py67
-rw-r--r--synapse/handlers/oidc.py4
-rw-r--r--synapse/handlers/pagination.py10
-rw-r--r--synapse/handlers/presence.py6
-rw-r--r--synapse/handlers/receipts.py106
-rw-r--r--synapse/handlers/relations.py20
-rw-r--r--synapse/handlers/room.py188
-rw-r--r--synapse/handlers/room_batch.py2
-rw-r--r--synapse/handlers/room_member.py9
-rw-r--r--synapse/handlers/room_summary.py9
-rw-r--r--synapse/handlers/search.py10
-rw-r--r--synapse/handlers/sync.py29
-rw-r--r--synapse/handlers/typing.py4
-rw-r--r--synapse/http/client.py18
-rw-r--r--synapse/http/connectproxyclient.py39
-rw-r--r--synapse/http/federation/matrix_federation_agent.py2
-rw-r--r--synapse/http/federation/srv_resolver.py4
-rw-r--r--synapse/http/federation/well_known_resolver.py6
-rw-r--r--synapse/http/matrixfederationclient.py33
-rw-r--r--synapse/http/proxyagent.py2
-rw-r--r--synapse/http/request_metrics.py10
-rw-r--r--synapse/http/server.py74
-rw-r--r--synapse/http/site.py25
-rw-r--r--synapse/logging/_remote.py20
-rw-r--r--synapse/logging/formatter.py14
-rw-r--r--synapse/logging/handlers.py4
-rw-r--r--synapse/logging/scopecontextmanager.py28
-rw-r--r--synapse/metrics/jemalloc.py114
-rw-r--r--synapse/module_api/__init__.py10
-rw-r--r--synapse/module_api/errors.py2
-rw-r--r--synapse/notifier.py5
-rw-r--r--synapse/push/__init__.py74
-rw-r--r--synapse/push/action_generator.py44
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py54
-rw-r--r--synapse/push/httppusher.py2
-rw-r--r--synapse/push/push_rule_evaluator.py70
-rw-r--r--synapse/replication/http/_base.py21
-rw-r--r--synapse/replication/tcp/client.py18
-rw-r--r--synapse/replication/tcp/commands.py12
-rw-r--r--synapse/replication/tcp/handler.py34
-rw-r--r--synapse/replication/tcp/redis.py39
-rw-r--r--synapse/rest/client/push_rule.py4
-rw-r--r--synapse/rest/client/receipts.py13
-rw-r--r--synapse/rest/client/room.py6
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py30
-rw-r--r--synapse/server.py8
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py40
-rw-r--r--synapse/server_notices/server_notices_manager.py74
-rw-r--r--synapse/spam_checker_api/__init__.py27
-rw-r--r--synapse/state/__init__.py48
-rw-r--r--synapse/storage/background_updates.py42
-rw-r--r--synapse/storage/database.py46
-rw-r--r--synapse/storage/databases/main/__init__.py8
-rw-r--r--synapse/storage/databases/main/appservice.py47
-rw-r--r--synapse/storage/databases/main/cache.py8
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py4
-rw-r--r--synapse/storage/databases/main/event_federation.py187
-rw-r--r--synapse/storage/databases/main/events.py236
-rw-r--r--synapse/storage/databases/main/events_worker.py35
-rw-r--r--synapse/storage/databases/main/metrics.py74
-rw-r--r--synapse/storage/databases/main/purge_events.py4
-rw-r--r--synapse/storage/databases/main/push_rule.py284
-rw-r--r--synapse/storage/databases/main/relations.py6
-rw-r--r--synapse/storage/databases/main/room.py45
-rw-r--r--synapse/storage/databases/main/roommember.py129
-rw-r--r--synapse/storage/databases/main/search.py33
-rw-r--r--synapse/storage/databases/main/stream.py34
-rw-r--r--synapse/storage/databases/state/bg_updates.py16
-rw-r--r--synapse/storage/databases/state/store.py2
-rw-r--r--synapse/storage/engines/__init__.py12
-rw-r--r--synapse/storage/engines/_base.py26
-rw-r--r--synapse/storage/engines/postgres.py92
-rw-r--r--synapse/storage/engines/sqlite.py72
-rw-r--r--synapse/storage/persist_events.py63
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/schema/__init__.py10
-rw-r--r--synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql18
-rw-r--r--synapse/storage/schema/state/delta/70/08_state_group_edges_unique.sql17
-rw-r--r--synapse/storage/state.py82
-rw-r--r--synapse/storage/types.py80
-rw-r--r--synapse/types.py74
-rw-r--r--synapse/util/caches/lrucache.py79
-rw-r--r--synapse/util/retryutils.py4
-rw-r--r--synapse/visibility.py6
116 files changed, 2935 insertions, 1504 deletions
diff --git a/synapse/_scripts/hash_password.py b/synapse/_scripts/hash_password.py
index 3aa29de5bd..3bed367be2 100755
--- a/synapse/_scripts/hash_password.py
+++ b/synapse/_scripts/hash_password.py
@@ -46,14 +46,14 @@ def main() -> None:
             "Path to server config file. "
             "Used to read in bcrypt_rounds and password_pepper."
         ),
+        required=True,
     )
 
     args = parser.parse_args()
-    if "config" in args and args.config:
-        config = yaml.safe_load(args.config)
-        bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds)
-        password_config = config.get("password_config", None) or {}
-        password_pepper = password_config.get("pepper", password_pepper)
+    config = yaml.safe_load(args.config)
+    bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds)
+    password_config = config.get("password_config", None) or {}
+    password_pepper = password_config.get("pepper", password_pepper)
     password = args.password
 
     if not password:
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 0ccd4c9558..330de21f6b 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -65,6 +65,8 @@ class JoinRules:
     PRIVATE: Final = "private"
     # As defined for MSC3083.
     RESTRICTED: Final = "restricted"
+    # As defined for MSC3787.
+    KNOCK_RESTRICTED: Final = "knock_restricted"
 
 
 class RestrictedJoinRuleTypes:
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index cb3b7323d5..6650e826d5 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -17,6 +17,7 @@
 
 import logging
 import typing
+from enum import Enum
 from http import HTTPStatus
 from typing import Any, Dict, List, Optional, Union
 
@@ -30,7 +31,11 @@ if typing.TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class Codes:
+class Codes(str, Enum):
+    """
+    All known error codes, as an enum of strings.
+    """
+
     UNRECOGNIZED = "M_UNRECOGNIZED"
     UNAUTHORIZED = "M_UNAUTHORIZED"
     FORBIDDEN = "M_FORBIDDEN"
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 4a808e33fe..b91ce06de7 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -19,6 +19,7 @@ from typing import (
     TYPE_CHECKING,
     Awaitable,
     Callable,
+    Collection,
     Dict,
     Iterable,
     List,
@@ -444,9 +445,9 @@ class Filter:
         return room_ids
 
     async def _check_event_relations(
-        self, events: Iterable[FilterEvent]
+        self, events: Collection[FilterEvent]
     ) -> List[FilterEvent]:
-        # The event IDs to check, mypy doesn't understand the ifinstance check.
+        # The event IDs to check, mypy doesn't understand the isinstance check.
         event_ids = [event.event_id for event in events if isinstance(event, EventBase)]  # type: ignore[attr-defined]
         event_ids_to_keep = set(
             await self._store.events_have_relations(
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index a747a40814..3f85d61b46 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -81,6 +81,9 @@ class RoomVersion:
     msc2716_historical: bool
     # MSC2716: Adds support for redacting "insertion", "chunk", and "marker" events
     msc2716_redactions: bool
+    # MSC3787: Adds support for a `knock_restricted` join rule, mixing concepts of
+    # knocks and restricted join rules into the same join condition.
+    msc3787_knock_restricted_join_rule: bool
 
 
 class RoomVersions:
@@ -99,6 +102,7 @@ class RoomVersions:
         msc2403_knocking=False,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc3787_knock_restricted_join_rule=False,
     )
     V2 = RoomVersion(
         "2",
@@ -115,6 +119,7 @@ class RoomVersions:
         msc2403_knocking=False,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc3787_knock_restricted_join_rule=False,
     )
     V3 = RoomVersion(
         "3",
@@ -131,6 +136,7 @@ class RoomVersions:
         msc2403_knocking=False,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc3787_knock_restricted_join_rule=False,
     )
     V4 = RoomVersion(
         "4",
@@ -147,6 +153,7 @@ class RoomVersions:
         msc2403_knocking=False,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc3787_knock_restricted_join_rule=False,
     )
     V5 = RoomVersion(
         "5",
@@ -163,6 +170,7 @@ class RoomVersions:
         msc2403_knocking=False,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc3787_knock_restricted_join_rule=False,
     )
     V6 = RoomVersion(
         "6",
@@ -179,6 +187,7 @@ class RoomVersions:
         msc2403_knocking=False,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc3787_knock_restricted_join_rule=False,
     )
     MSC2176 = RoomVersion(
         "org.matrix.msc2176",
@@ -195,6 +204,7 @@ class RoomVersions:
         msc2403_knocking=False,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc3787_knock_restricted_join_rule=False,
     )
     V7 = RoomVersion(
         "7",
@@ -211,6 +221,7 @@ class RoomVersions:
         msc2403_knocking=True,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc3787_knock_restricted_join_rule=False,
     )
     V8 = RoomVersion(
         "8",
@@ -227,6 +238,7 @@ class RoomVersions:
         msc2403_knocking=True,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc3787_knock_restricted_join_rule=False,
     )
     V9 = RoomVersion(
         "9",
@@ -243,6 +255,7 @@ class RoomVersions:
         msc2403_knocking=True,
         msc2716_historical=False,
         msc2716_redactions=False,
+        msc3787_knock_restricted_join_rule=False,
     )
     MSC2716v3 = RoomVersion(
         "org.matrix.msc2716v3",
@@ -259,6 +272,24 @@ class RoomVersions:
         msc2403_knocking=True,
         msc2716_historical=True,
         msc2716_redactions=True,
+        msc3787_knock_restricted_join_rule=False,
+    )
+    MSC3787 = RoomVersion(
+        "org.matrix.msc3787",
+        RoomDisposition.UNSTABLE,
+        EventFormatVersions.V3,
+        StateResolutionVersions.V2,
+        enforce_key_validity=True,
+        special_case_aliases_auth=False,
+        strict_canonicaljson=True,
+        limit_notifications_power_levels=True,
+        msc2176_redaction_rules=False,
+        msc3083_join_rules=True,
+        msc3375_redaction_rules=True,
+        msc2403_knocking=True,
+        msc2716_historical=False,
+        msc2716_redactions=False,
+        msc3787_knock_restricted_join_rule=True,
     )
 
 
@@ -276,6 +307,7 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
         RoomVersions.V8,
         RoomVersions.V9,
         RoomVersions.MSC2716v3,
+        RoomVersions.MSC3787,
     )
 }
 
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 3623c1724d..a3446ac6e8 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -49,9 +49,12 @@ from twisted.logger import LoggingFile, LogLevel
 from twisted.protocols.tls import TLSMemoryBIOFactory
 from twisted.python.threadpool import ThreadPool
 
+import synapse.util.caches
 from synapse.api.constants import MAX_PDU_SIZE
 from synapse.app import check_bind_error
 from synapse.app.phone_stats_home import start_phone_stats_home
+from synapse.config import ConfigError
+from synapse.config._base import format_config_error
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import ManholeConfig
 from synapse.crypto import context_factory
@@ -432,6 +435,10 @@ async def start(hs: "HomeServer") -> None:
         signal.signal(signal.SIGHUP, run_sighup)
 
         register_sighup(refresh_certificate, hs)
+        register_sighup(reload_cache_config, hs.config)
+
+    # Apply the cache config.
+    hs.config.caches.resize_all_caches()
 
     # Load the certificate from disk.
     refresh_certificate(hs)
@@ -486,6 +493,43 @@ async def start(hs: "HomeServer") -> None:
         atexit.register(gc.freeze)
 
 
+def reload_cache_config(config: HomeServerConfig) -> None:
+    """Reload cache config from disk and immediately apply it.resize caches accordingly.
+
+    If the config is invalid, a `ConfigError` is logged and no changes are made.
+
+    Otherwise, this:
+        - replaces the `caches` section on the given `config` object,
+        - resizes all caches according to the new cache factors, and
+
+    Note that the following cache config keys are read, but not applied:
+        - event_cache_size: used to set a max_size and _original_max_size on
+              EventsWorkerStore._get_event_cache when it is created. We'd have to update
+              the _original_max_size (and maybe
+        - sync_response_cache_duration: would have to update the timeout_sec attribute on
+              HomeServer ->  SyncHandler -> ResponseCache.
+        - track_memory_usage. This affects synapse.util.caches.TRACK_MEMORY_USAGE which
+              influences Synapse's self-reported metrics.
+
+    Also, the HTTPConnectionPool in SimpleHTTPClient sets its maxPersistentPerHost
+    parameter based on the global_factor. This won't be applied on a config reload.
+    """
+    try:
+        previous_cache_config = config.reload_config_section("caches")
+    except ConfigError as e:
+        logger.warning("Failed to reload cache config")
+        for f in format_config_error(e):
+            logger.warning(f)
+    else:
+        logger.debug(
+            "New cache config. Was:\n %s\nNow:\n",
+            previous_cache_config.__dict__,
+            config.caches.__dict__,
+        )
+        synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
+        config.caches.resize_all_caches()
+
+
 def setup_sentry(hs: "HomeServer") -> None:
     """Enable sentry integration, if enabled in configuration"""
 
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 0f75e7b9d4..4c6c0658ab 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -16,7 +16,7 @@
 import logging
 import os
 import sys
-from typing import Dict, Iterable, Iterator, List
+from typing import Dict, Iterable, List
 
 from matrix_common.versionstring import get_distribution_version_string
 
@@ -45,7 +45,7 @@ from synapse.app._base import (
     redirect_stdio_to_logs,
     register_start,
 )
-from synapse.config._base import ConfigError
+from synapse.config._base import ConfigError, format_config_error
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import ListenerConfig
@@ -399,38 +399,6 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
     return hs
 
 
-def format_config_error(e: ConfigError) -> Iterator[str]:
-    """
-    Formats a config error neatly
-
-    The idea is to format the immediate error, plus the "causes" of those errors,
-    hopefully in a way that makes sense to the user. For example:
-
-        Error in configuration at 'oidc_config.user_mapping_provider.config.display_name_template':
-          Failed to parse config for module 'JinjaOidcMappingProvider':
-            invalid jinja template:
-              unexpected end of template, expected 'end of print statement'.
-
-    Args:
-        e: the error to be formatted
-
-    Returns: An iterator which yields string fragments to be formatted
-    """
-    yield "Error in configuration"
-
-    if e.path:
-        yield " at '%s'" % (".".join(e.path),)
-
-    yield ":\n  %s" % (e.msg,)
-
-    parent_e = e.__cause__
-    indent = 1
-    while parent_e:
-        indent += 1
-        yield ":\n%s%s" % ("  " * indent, str(parent_e))
-        parent_e = parent_e.__cause__
-
-
 def run(hs: HomeServer) -> None:
     _base.start_reactor(
         "synapse-homeserver",
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 179aa7ff88..42364fc133 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -16,14 +16,18 @@
 
 import argparse
 import errno
+import logging
 import os
 from collections import OrderedDict
 from hashlib import sha256
 from textwrap import dedent
 from typing import (
     Any,
+    ClassVar,
+    Collection,
     Dict,
     Iterable,
+    Iterator,
     List,
     MutableMapping,
     Optional,
@@ -40,6 +44,8 @@ import yaml
 
 from synapse.util.templates import _create_mxc_to_http_filter, _format_ts_filter
 
+logger = logging.getLogger(__name__)
+
 
 class ConfigError(Exception):
     """Represents a problem parsing the configuration
@@ -55,6 +61,38 @@ class ConfigError(Exception):
         self.path = path
 
 
+def format_config_error(e: ConfigError) -> Iterator[str]:
+    """
+    Formats a config error neatly
+
+    The idea is to format the immediate error, plus the "causes" of those errors,
+    hopefully in a way that makes sense to the user. For example:
+
+        Error in configuration at 'oidc_config.user_mapping_provider.config.display_name_template':
+          Failed to parse config for module 'JinjaOidcMappingProvider':
+            invalid jinja template:
+              unexpected end of template, expected 'end of print statement'.
+
+    Args:
+        e: the error to be formatted
+
+    Returns: An iterator which yields string fragments to be formatted
+    """
+    yield "Error in configuration"
+
+    if e.path:
+        yield " at '%s'" % (".".join(e.path),)
+
+    yield ":\n  %s" % (e.msg,)
+
+    parent_e = e.__cause__
+    indent = 1
+    while parent_e:
+        indent += 1
+        yield ":\n%s%s" % ("  " * indent, str(parent_e))
+        parent_e = parent_e.__cause__
+
+
 # We split these messages out to allow packages to override with package
 # specific instructions.
 MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS = """\
@@ -119,7 +157,7 @@ class Config:
             defined in subclasses.
     """
 
-    section: str
+    section: ClassVar[str]
 
     def __init__(self, root_config: "RootConfig" = None):
         self.root = root_config
@@ -309,9 +347,12 @@ class RootConfig:
     class, lower-cased and with "Config" removed.
     """
 
-    config_classes = []
+    config_classes: List[Type[Config]] = []
+
+    def __init__(self, config_files: Collection[str] = ()):
+        # Capture absolute paths here, so we can reload config after we daemonize.
+        self.config_files = [os.path.abspath(path) for path in config_files]
 
-    def __init__(self):
         for config_class in self.config_classes:
             if config_class.section is None:
                 raise ValueError("%r requires a section name" % (config_class,))
@@ -512,12 +553,10 @@ class RootConfig:
             object from parser.parse_args(..)`
         """
 
-        obj = cls()
-
         config_args = parser.parse_args(argv)
 
         config_files = find_config_files(search_paths=config_args.config_path)
-
+        obj = cls(config_files)
         if not config_files:
             parser.error("Must supply a config file.")
 
@@ -627,7 +666,7 @@ class RootConfig:
 
         generate_missing_configs = config_args.generate_missing_configs
 
-        obj = cls()
+        obj = cls(config_files)
 
         if config_args.generate_config:
             if config_args.report_stats is None:
@@ -727,6 +766,34 @@ class RootConfig:
     ) -> None:
         self.invoke_all("generate_files", config_dict, config_dir_path)
 
+    def reload_config_section(self, section_name: str) -> Config:
+        """Reconstruct the given config section, leaving all others unchanged.
+
+        This works in three steps:
+
+        1. Create a new instance of the relevant `Config` subclass.
+        2. Call `read_config` on that instance to parse the new config.
+        3. Replace the existing config instance with the new one.
+
+        :raises ValueError: if the given `section` does not exist.
+        :raises ConfigError: for any other problems reloading config.
+
+        :returns: the previous config object, which no longer has a reference to this
+            RootConfig.
+        """
+        existing_config: Optional[Config] = getattr(self, section_name, None)
+        if existing_config is None:
+            raise ValueError(f"Unknown config section '{section_name}'")
+        logger.info("Reloading config section '%s'", section_name)
+
+        new_config_data = read_config_files(self.config_files)
+        new_config = type(existing_config)(self)
+        new_config.read_config(new_config_data)
+        setattr(self, section_name, new_config)
+
+        existing_config.root = None
+        return existing_config
+
 
 def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]:
     """Read the config files into a dict
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index bd092f956d..71d6655fda 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -1,15 +1,19 @@
 import argparse
 from typing import (
     Any,
+    Collection,
     Dict,
     Iterable,
+    Iterator,
     List,
+    Literal,
     MutableMapping,
     Optional,
     Tuple,
     Type,
     TypeVar,
     Union,
+    overload,
 )
 
 import jinja2
@@ -64,6 +68,8 @@ class ConfigError(Exception):
         self.msg = msg
         self.path = path
 
+def format_config_error(e: ConfigError) -> Iterator[str]: ...
+
 MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
 MISSING_REPORT_STATS_SPIEL: str
 MISSING_SERVER_NAME: str
@@ -117,7 +123,8 @@ class RootConfig:
     background_updates: background_updates.BackgroundUpdateConfig
 
     config_classes: List[Type["Config"]] = ...
-    def __init__(self) -> None: ...
+    config_files: List[str]
+    def __init__(self, config_files: Collection[str] = ...) -> None: ...
     def invoke_all(
         self, func_name: str, *args: Any, **kwargs: Any
     ) -> MutableMapping[str, Any]: ...
@@ -157,6 +164,12 @@ class RootConfig:
     def generate_missing_files(
         self, config_dict: dict, config_dir_path: str
     ) -> None: ...
+    @overload
+    def reload_config_section(
+        self, section_name: Literal["caches"]
+    ) -> cache.CacheConfig: ...
+    @overload
+    def reload_config_section(self, section_name: str) -> Config: ...
 
 class Config:
     root: RootConfig
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 94d852f413..d2f55534d7 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -69,11 +69,11 @@ def _canonicalise_cache_name(cache_name: str) -> str:
 def add_resizable_cache(
     cache_name: str, cache_resize_callback: Callable[[float], None]
 ) -> None:
-    """Register a cache that's size can dynamically change
+    """Register a cache whose size can dynamically change
 
     Args:
         cache_name: A reference to the cache
-        cache_resize_callback: A callback function that will be ran whenever
+        cache_resize_callback: A callback function that will run whenever
             the cache needs to be resized
     """
     # Some caches have '*' in them which we strip out.
@@ -96,6 +96,13 @@ class CacheConfig(Config):
     section = "caches"
     _environ = os.environ
 
+    event_cache_size: int
+    cache_factors: Dict[str, float]
+    global_factor: float
+    track_memory_usage: bool
+    expiry_time_msec: Optional[int]
+    sync_response_cache_duration: int
+
     @staticmethod
     def reset() -> None:
         """Resets the caches to their defaults. Used for tests."""
@@ -115,6 +122,12 @@ class CacheConfig(Config):
         # A cache 'factor' is a multiplier that can be applied to each of
         # Synapse's caches in order to increase or decrease the maximum
         # number of entries that can be stored.
+        #
+        # The configuration for cache factors (caches.global_factor and
+        # caches.per_cache_factors) can be reloaded while the application is running,
+        # by sending a SIGHUP signal to the Synapse process. Changes to other parts of
+        # the caching config will NOT be applied after a SIGHUP is received; a restart
+        # is necessary.
 
         # The number of events to cache in memory. Not affected by
         # caches.global_factor.
@@ -163,6 +176,24 @@ class CacheConfig(Config):
           #
           #cache_entry_ttl: 30m
 
+          # This flag enables cache autotuning, and is further specified by the sub-options `max_cache_memory_usage`,
+          # `target_cache_memory_usage`, `min_cache_ttl`. These flags work in conjunction with each other to maintain
+          # a balance between cache memory usage and cache entry availability. You must be using jemalloc to utilize
+          # this option, and all three of the options must be specified for this feature to work.
+          #cache_autotuning:
+            # This flag sets a ceiling on much memory the cache can use before caches begin to be continuously evicted.
+            # They will continue to be evicted until the memory usage drops below the `target_memory_usage`, set in
+            # the flag below, or until the `min_cache_ttl` is hit.
+            #max_cache_memory_usage: 1024M
+
+            # This flag sets a rough target for the desired memory usage of the caches.
+            #target_cache_memory_usage: 758M
+
+            # 'min_cache_ttl` sets a limit under which newer cache entries are not evicted and is only applied when
+            # caches are actively being evicted/`max_cache_memory_usage` has been exceeded. This is to protect hot caches
+            # from being emptied while Synapse is evicting due to memory.
+            #min_cache_ttl: 5m
+
           # Controls how long the results of a /sync request are cached for after
           # a successful response is returned. A higher duration can help clients with
           # intermittent connections, at the cost of higher memory usage.
@@ -174,21 +205,21 @@ class CacheConfig(Config):
         """
 
     def read_config(self, config: JsonDict, **kwargs: Any) -> None:
+        """Populate this config object with values from `config`.
+
+        This method does NOT resize existing or future caches: use `resize_all_caches`.
+        We use two separate methods so that we can reject bad config before applying it.
+        """
         self.event_cache_size = self.parse_size(
             config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
         )
-        self.cache_factors: Dict[str, float] = {}
+        self.cache_factors = {}
 
         cache_config = config.get("caches") or {}
-        self.global_factor = cache_config.get(
-            "global_factor", properties.default_factor_size
-        )
+        self.global_factor = cache_config.get("global_factor", _DEFAULT_FACTOR_SIZE)
         if not isinstance(self.global_factor, (int, float)):
             raise ConfigError("caches.global_factor must be a number.")
 
-        # Set the global one so that it's reflected in new caches
-        properties.default_factor_size = self.global_factor
-
         # Load cache factors from the config
         individual_factors = cache_config.get("per_cache_factors") or {}
         if not isinstance(individual_factors, dict):
@@ -230,7 +261,7 @@ class CacheConfig(Config):
         cache_entry_ttl = cache_config.get("cache_entry_ttl", "30m")
 
         if expire_caches:
-            self.expiry_time_msec: Optional[int] = self.parse_duration(cache_entry_ttl)
+            self.expiry_time_msec = self.parse_duration(cache_entry_ttl)
         else:
             self.expiry_time_msec = None
 
@@ -250,23 +281,38 @@ class CacheConfig(Config):
             )
             self.expiry_time_msec = self.parse_duration(expiry_time)
 
+        self.cache_autotuning = cache_config.get("cache_autotuning")
+        if self.cache_autotuning:
+            max_memory_usage = self.cache_autotuning.get("max_cache_memory_usage")
+            self.cache_autotuning["max_cache_memory_usage"] = self.parse_size(
+                max_memory_usage
+            )
+
+            target_mem_size = self.cache_autotuning.get("target_cache_memory_usage")
+            self.cache_autotuning["target_cache_memory_usage"] = self.parse_size(
+                target_mem_size
+            )
+
+            min_cache_ttl = self.cache_autotuning.get("min_cache_ttl")
+            self.cache_autotuning["min_cache_ttl"] = self.parse_duration(min_cache_ttl)
+
         self.sync_response_cache_duration = self.parse_duration(
             cache_config.get("sync_response_cache_duration", 0)
         )
 
-        # Resize all caches (if necessary) with the new factors we've loaded
-        self.resize_all_caches()
-
-        # Store this function so that it can be called from other classes without
-        # needing an instance of Config
-        properties.resize_all_caches_func = self.resize_all_caches
-
     def resize_all_caches(self) -> None:
-        """Ensure all cache sizes are up to date
+        """Ensure all cache sizes are up-to-date.
 
         For each cache, run the mapped callback function with either
         a specific cache factor or the default, global one.
         """
+        # Set the global factor size, so that new caches are appropriately sized.
+        properties.default_factor_size = self.global_factor
+
+        # Store this function so that it can be called from other classes without
+        # needing an instance of CacheConfig
+        properties.resize_all_caches_func = self.resize_all_caches
+
         # block other threads from modifying _CACHES while we iterate it.
         with _CACHES_LOCK:
             for cache_name, callback in _CACHES.items():
diff --git a/synapse/config/oembed.py b/synapse/config/oembed.py
index 690ffb5296..e9edea0731 100644
--- a/synapse/config/oembed.py
+++ b/synapse/config/oembed.py
@@ -57,9 +57,9 @@ class OembedConfig(Config):
         """
         # Whether to use the packaged providers.json file.
         if not oembed_config.get("disable_default_providers") or False:
-            providers = json.load(
-                pkg_resources.resource_stream("synapse", "res/providers.json")
-            )
+            with pkg_resources.resource_stream("synapse", "res/providers.json") as s:
+                providers = json.load(s)
+
             yield from self._parse_and_validate_provider(
                 providers, config_path=("oembed",)
             )
diff --git a/synapse/config/room.py b/synapse/config/room.py
index e18a87ea37..462d85ac1d 100644
--- a/synapse/config/room.py
+++ b/synapse/config/room.py
@@ -63,6 +63,19 @@ class RoomConfig(Config):
                 "Invalid value for encryption_enabled_by_default_for_room_type"
             )
 
+        self.default_power_level_content_override = config.get(
+            "default_power_level_content_override",
+            None,
+        )
+        if self.default_power_level_content_override is not None:
+            for preset in self.default_power_level_content_override:
+                if preset not in vars(RoomCreationPreset).values():
+                    raise ConfigError(
+                        "Unrecognised room preset %s in default_power_level_content_override"
+                        % preset
+                    )
+                # We validate the actual overrides when we try to apply them.
+
     def generate_config_section(self, **kwargs: Any) -> str:
         return """\
         ## Rooms ##
@@ -83,4 +96,38 @@ class RoomConfig(Config):
         # will also not affect rooms created by other servers.
         #
         #encryption_enabled_by_default_for_room_type: invite
+
+        # Override the default power levels for rooms created on this server, per
+        # room creation preset.
+        #
+        # The appropriate dictionary for the room preset will be applied on top
+        # of the existing power levels content.
+        #
+        # Useful if you know that your users need special permissions in rooms
+        # that they create (e.g. to send particular types of state events without
+        # needing an elevated power level).  This takes the same shape as the
+        # `power_level_content_override` parameter in the /createRoom API, but
+        # is applied before that parameter.
+        #
+        # Valid keys are some or all of `private_chat`, `trusted_private_chat`
+        # and `public_chat`. Inside each of those should be any of the
+        # properties allowed in `power_level_content_override` in the
+        # /createRoom API. If any property is missing, its default value will
+        # continue to be used. If any property is present, it will overwrite
+        # the existing default completely (so if the `events` property exists,
+        # the default event power levels will be ignored).
+        #
+        #default_power_level_content_override:
+        #    private_chat:
+        #        "events":
+        #            "com.example.myeventtype" : 0
+        #            "m.room.avatar": 50
+        #            "m.room.canonical_alias": 50
+        #            "m.room.encryption": 100
+        #            "m.room.history_visibility": 100
+        #            "m.room.name": 50
+        #            "m.room.power_levels": 100
+        #            "m.room.server_acl": 100
+        #            "m.room.tombstone": 100
+        #        "events_default": 1
         """
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 005a3ee48c..f73d5e1f66 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -996,7 +996,7 @@ class ServerConfig(Config):
         #   federation: the server-server API (/_matrix/federation). Also implies
         #       'media', 'keys', 'openid'
         #
-        #   keys: the key discovery API (/_matrix/keys).
+        #   keys: the key discovery API (/_matrix/key).
         #
         #   media: the media API (/_matrix/media).
         #
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 621a3efccc..4c0b587a76 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -414,7 +414,12 @@ def _is_membership_change_allowed(
             raise AuthError(403, "You are banned from this room")
         elif join_rule == JoinRules.PUBLIC:
             pass
-        elif room_version.msc3083_join_rules and join_rule == JoinRules.RESTRICTED:
+        elif (
+            room_version.msc3083_join_rules and join_rule == JoinRules.RESTRICTED
+        ) or (
+            room_version.msc3787_knock_restricted_join_rule
+            and join_rule == JoinRules.KNOCK_RESTRICTED
+        ):
             # This is the same as public, but the event must contain a reference
             # to the server who authorised the join. If the event does not contain
             # the proper content it is rejected.
@@ -440,8 +445,13 @@ def _is_membership_change_allowed(
                 if authorising_user_level < invite_level:
                     raise AuthError(403, "Join event authorised by invalid server.")
 
-        elif join_rule == JoinRules.INVITE or (
-            room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
+        elif (
+            join_rule == JoinRules.INVITE
+            or (room_version.msc2403_knocking and join_rule == JoinRules.KNOCK)
+            or (
+                room_version.msc3787_knock_restricted_join_rule
+                and join_rule == JoinRules.KNOCK_RESTRICTED
+            )
         ):
             if not caller_in_room and not caller_invited:
                 raise AuthError(403, "You are not invited to this room.")
@@ -462,7 +472,10 @@ def _is_membership_change_allowed(
         if user_level < ban_level or user_level <= target_level:
             raise AuthError(403, "You don't have permission to ban")
     elif room_version.msc2403_knocking and Membership.KNOCK == membership:
-        if join_rule != JoinRules.KNOCK:
+        if join_rule != JoinRules.KNOCK and (
+            not room_version.msc3787_knock_restricted_join_rule
+            or join_rule != JoinRules.KNOCK_RESTRICTED
+        ):
             raise AuthError(403, "You don't have permission to knock")
         elif target_user_id != event.user_id:
             raise AuthError(403, "You cannot knock for other users")
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index c238376caf..39ad2793d9 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import abc
+import collections.abc
 import os
 from typing import (
     TYPE_CHECKING,
@@ -32,9 +33,11 @@ from typing import (
     overload,
 )
 
+import attr
 from typing_extensions import Literal
 from unpaddedbase64 import encode_base64
 
+from synapse.api.constants import RelationTypes
 from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
 from synapse.types import JsonDict, RoomStreamToken
 from synapse.util.caches import intern_dict
@@ -615,3 +618,45 @@ def make_event_from_dict(
     return event_type(
         event_dict, room_version, internal_metadata_dict or {}, rejected_reason
     )
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventRelation:
+    # The target event of the relation.
+    parent_id: str
+    # The relation type.
+    rel_type: str
+    # The aggregation key. Will be None if the rel_type is not m.annotation or is
+    # not a string.
+    aggregation_key: Optional[str]
+
+
+def relation_from_event(event: EventBase) -> Optional[_EventRelation]:
+    """
+    Attempt to parse relation information an event.
+
+    Returns:
+        The event relation information, if it is valid. None, otherwise.
+    """
+    relation = event.content.get("m.relates_to")
+    if not relation or not isinstance(relation, collections.abc.Mapping):
+        # No relation information.
+        return None
+
+    # Relations must have a type and parent event ID.
+    rel_type = relation.get("rel_type")
+    if not isinstance(rel_type, str):
+        return None
+
+    parent_id = relation.get("event_id")
+    if not isinstance(parent_id, str):
+        return None
+
+    # Annotations have a key field.
+    aggregation_key = None
+    if rel_type == RelationTypes.ANNOTATION:
+        aggregation_key = relation.get("key")
+        if not isinstance(aggregation_key, str):
+            aggregation_key = None
+
+    return _EventRelation(parent_id, rel_type, aggregation_key)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 46042b2bf7..7a91544119 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -15,17 +15,16 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
 
 import attr
 from frozendict import frozendict
-
-from twisted.internet.defer import Deferred
+from typing_extensions import Literal
 
 from synapse.appservice import ApplicationService
 from synapse.events import EventBase
-from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.types import JsonDict, StateMap
 
 if TYPE_CHECKING:
     from synapse.storage import Storage
     from synapse.storage.databases.main import DataStore
+    from synapse.storage.state import StateFilter
 
 
 @attr.s(slots=True, auto_attribs=True)
@@ -60,6 +59,9 @@ class EventContext:
             If ``state_group`` is None (ie, the event is an outlier),
             ``state_group_before_event`` will always also be ``None``.
 
+        state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None
+            then this is the delta of the state between the two groups.
+
         prev_group: If it is known, ``state_group``'s prev_group. Note that this being
             None does not necessarily mean that ``state_group`` does not have
             a prev_group!
@@ -78,73 +80,47 @@ class EventContext:
         app_service: If this event is being sent by a (local) application service, that
             app service.
 
-        _current_state_ids: The room state map, including this event - ie, the state
-            in ``state_group``.
-
-            (type, state_key) -> event_id
-
-            For an outlier, this is {}
-
-            Note that this is a private attribute: it should be accessed via
-            ``get_current_state_ids``. _AsyncEventContext impl calculates this
-            on-demand: it will be None until that happens.
-
-        _prev_state_ids: The room state map, excluding this event - ie, the state
-            in ``state_group_before_event``. For a non-state
-            event, this will be the same as _current_state_events.
-
-            Note that it is a completely different thing to prev_group!
-
-            (type, state_key) -> event_id
-
-            For an outlier, this is {}
-
-            As with _current_state_ids, this is a private attribute. It should be
-            accessed via get_prev_state_ids.
-
         partial_state: if True, we may be storing this event with a temporary,
             incomplete state.
     """
 
-    rejected: Union[bool, str] = False
+    _storage: "Storage"
+    rejected: Union[Literal[False], str] = False
     _state_group: Optional[int] = None
     state_group_before_event: Optional[int] = None
+    _state_delta_due_to_event: Optional[StateMap[str]] = None
     prev_group: Optional[int] = None
     delta_ids: Optional[StateMap[str]] = None
     app_service: Optional[ApplicationService] = None
 
-    _current_state_ids: Optional[StateMap[str]] = None
-    _prev_state_ids: Optional[StateMap[str]] = None
-
     partial_state: bool = False
 
     @staticmethod
     def with_state(
+        storage: "Storage",
         state_group: Optional[int],
         state_group_before_event: Optional[int],
-        current_state_ids: Optional[StateMap[str]],
-        prev_state_ids: Optional[StateMap[str]],
+        state_delta_due_to_event: Optional[StateMap[str]],
         partial_state: bool,
         prev_group: Optional[int] = None,
         delta_ids: Optional[StateMap[str]] = None,
     ) -> "EventContext":
         return EventContext(
-            current_state_ids=current_state_ids,
-            prev_state_ids=prev_state_ids,
+            storage=storage,
             state_group=state_group,
             state_group_before_event=state_group_before_event,
+            state_delta_due_to_event=state_delta_due_to_event,
             prev_group=prev_group,
             delta_ids=delta_ids,
             partial_state=partial_state,
         )
 
     @staticmethod
-    def for_outlier() -> "EventContext":
+    def for_outlier(
+        storage: "Storage",
+    ) -> "EventContext":
         """Return an EventContext instance suitable for persisting an outlier event"""
-        return EventContext(
-            current_state_ids={},
-            prev_state_ids={},
-        )
+        return EventContext(storage=storage)
 
     async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
         """Converts self to a type that can be serialized as JSON, and then
@@ -157,24 +133,14 @@ class EventContext:
             The serialized event.
         """
 
-        # We don't serialize the full state dicts, instead they get pulled out
-        # of the DB on the other side. However, the other side can't figure out
-        # the prev_state_ids, so if we're a state event we include the event
-        # id that we replaced in the state.
-        if event.is_state():
-            prev_state_ids = await self.get_prev_state_ids()
-            prev_state_id = prev_state_ids.get((event.type, event.state_key))
-        else:
-            prev_state_id = None
-
         return {
-            "prev_state_id": prev_state_id,
-            "event_type": event.type,
-            "event_state_key": event.get_state_key(),
             "state_group": self._state_group,
             "state_group_before_event": self.state_group_before_event,
             "rejected": self.rejected,
             "prev_group": self.prev_group,
+            "state_delta_due_to_event": _encode_state_dict(
+                self._state_delta_due_to_event
+            ),
             "delta_ids": _encode_state_dict(self.delta_ids),
             "app_service_id": self.app_service.id if self.app_service else None,
             "partial_state": self.partial_state,
@@ -192,16 +158,16 @@ class EventContext:
         Returns:
             The event context.
         """
-        context = _AsyncEventContextImpl(
+        context = EventContext(
             # We use the state_group and prev_state_id stuff to pull the
             # current_state_ids out of the DB and construct prev_state_ids.
             storage=storage,
-            prev_state_id=input["prev_state_id"],
-            event_type=input["event_type"],
-            event_state_key=input["event_state_key"],
             state_group=input["state_group"],
             state_group_before_event=input["state_group_before_event"],
             prev_group=input["prev_group"],
+            state_delta_due_to_event=_decode_state_dict(
+                input["state_delta_due_to_event"]
+            ),
             delta_ids=_decode_state_dict(input["delta_ids"]),
             rejected=input["rejected"],
             partial_state=input.get("partial_state", False),
@@ -231,7 +197,9 @@ class EventContext:
 
         return self._state_group
 
-    async def get_current_state_ids(self) -> Optional[StateMap[str]]:
+    async def get_current_state_ids(
+        self, state_filter: Optional["StateFilter"] = None
+    ) -> Optional[StateMap[str]]:
         """
         Gets the room state map, including this event - ie, the state in ``state_group``
 
@@ -239,6 +207,9 @@ class EventContext:
         not make it into the room state. This method will raise an exception if
         ``rejected`` is set.
 
+        Arg:
+           state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
+
         Returns:
             Returns None if state_group is None, which happens when the associated
             event is an outlier.
@@ -249,15 +220,27 @@ class EventContext:
         if self.rejected:
             raise RuntimeError("Attempt to access state_ids of rejected event")
 
-        await self._ensure_fetched()
-        return self._current_state_ids
+        assert self._state_delta_due_to_event is not None
+
+        prev_state_ids = await self.get_prev_state_ids(state_filter)
+
+        if self._state_delta_due_to_event:
+            prev_state_ids = dict(prev_state_ids)
+            prev_state_ids.update(self._state_delta_due_to_event)
 
-    async def get_prev_state_ids(self) -> StateMap[str]:
+        return prev_state_ids
+
+    async def get_prev_state_ids(
+        self, state_filter: Optional["StateFilter"] = None
+    ) -> StateMap[str]:
         """
         Gets the room state map, excluding this event.
 
         For a non-state event, this will be the same as get_current_state_ids().
 
+        Args:
+            state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
+
         Returns:
             Returns {} if state_group is None, which happens when the associated
             event is an outlier.
@@ -265,94 +248,10 @@ class EventContext:
             Maps a (type, state_key) to the event ID of the state event matching
             this tuple.
         """
-        await self._ensure_fetched()
-        # There *should* be previous state IDs now.
-        assert self._prev_state_ids is not None
-        return self._prev_state_ids
-
-    def get_cached_current_state_ids(self) -> Optional[StateMap[str]]:
-        """Gets the current state IDs if we have them already cached.
-
-        It is an error to access this for a rejected event, since rejected state should
-        not make it into the room state. This method will raise an exception if
-        ``rejected`` is set.
-
-        Returns:
-            Returns None if we haven't cached the state or if state_group is None
-            (which happens when the associated event is an outlier).
-
-            Otherwise, returns the the current state IDs.
-        """
-        if self.rejected:
-            raise RuntimeError("Attempt to access state_ids of rejected event")
-
-        return self._current_state_ids
-
-    async def _ensure_fetched(self) -> None:
-        return None
-
-
-@attr.s(slots=True)
-class _AsyncEventContextImpl(EventContext):
-    """
-    An implementation of EventContext which fetches _current_state_ids and
-    _prev_state_ids from the database on demand.
-
-    Attributes:
-
-        _storage
-
-        _fetching_state_deferred: Resolves when *_state_ids have been calculated.
-            None if we haven't started calculating yet
-
-        _event_type: The type of the event the context is associated with.
-
-        _event_state_key: The state_key of the event the context is associated with.
-
-        _prev_state_id: If the event associated with the context is a state event,
-            then `_prev_state_id` is the event_id of the state that was replaced.
-    """
-
-    # This needs to have a default as we're inheriting
-    _storage: "Storage" = attr.ib(default=None)
-    _prev_state_id: Optional[str] = attr.ib(default=None)
-    _event_type: str = attr.ib(default=None)
-    _event_state_key: Optional[str] = attr.ib(default=None)
-    _fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None)
-
-    async def _ensure_fetched(self) -> None:
-        if not self._fetching_state_deferred:
-            self._fetching_state_deferred = run_in_background(self._fill_out_state)
-
-        await make_deferred_yieldable(self._fetching_state_deferred)
-
-    async def _fill_out_state(self) -> None:
-        """Called to populate the _current_state_ids and _prev_state_ids
-        attributes by loading from the database.
-        """
-        if self.state_group is None:
-            # No state group means the event is an outlier. Usually the state_ids dicts are also
-            # pre-set to empty dicts, but they get reset when the context is serialized, so set
-            # them to empty dicts again here.
-            self._current_state_ids = {}
-            self._prev_state_ids = {}
-            return
-
-        current_state_ids = await self._storage.state.get_state_ids_for_group(
-            self.state_group
+        assert self.state_group_before_event is not None
+        return await self._storage.state.get_state_ids_for_group(
+            self.state_group_before_event, state_filter
         )
-        # Set this separately so mypy knows current_state_ids is not None.
-        self._current_state_ids = current_state_ids
-        if self._event_state_key is not None:
-            self._prev_state_ids = dict(current_state_ids)
-
-            key = (self._event_type, self._event_state_key)
-            if self._prev_state_id:
-                self._prev_state_ids[key] = self._prev_state_id
-            else:
-                self._prev_state_ids.pop(key, None)
-        else:
-            self._prev_state_ids = current_state_ids
 
 
 def _encode_state_dict(
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 3b6795d40f..7984874e21 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -27,11 +27,13 @@ from typing import (
     Union,
 )
 
+from synapse.api.errors import Codes
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.media_storage import ReadableFileWrapper
-from synapse.spam_checker_api import RegistrationBehaviour
+from synapse.spam_checker_api import Allow, Decision, RegistrationBehaviour
 from synapse.types import RoomAlias, UserProfile
 from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
+from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
     import synapse.events
@@ -39,8 +41,22 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+
 CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[
     ["synapse.events.EventBase"],
+    Awaitable[
+        Union[
+            Allow,
+            Codes,
+            # Deprecated
+            bool,
+            # Deprecated
+            str,
+        ]
+    ],
+]
+SHOULD_DROP_FEDERATED_EVENT_CALLBACK = Callable[
+    ["synapse.events.EventBase"],
     Awaitable[Union[bool, str]],
 ]
 USER_MAY_JOIN_ROOM_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
@@ -162,8 +178,14 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
 
 
 class SpamChecker:
-    def __init__(self) -> None:
+    def __init__(self, hs: "synapse.server.HomeServer") -> None:
+        self.hs = hs
+        self.clock = hs.get_clock()
+
         self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
+        self._should_drop_federated_event_callbacks: List[
+            SHOULD_DROP_FEDERATED_EVENT_CALLBACK
+        ] = []
         self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
         self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
         self._user_may_send_3pid_invite_callbacks: List[
@@ -187,6 +209,9 @@ class SpamChecker:
     def register_callbacks(
         self,
         check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None,
+        should_drop_federated_event: Optional[
+            SHOULD_DROP_FEDERATED_EVENT_CALLBACK
+        ] = None,
         user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None,
         user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None,
         user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None,
@@ -205,6 +230,11 @@ class SpamChecker:
         if check_event_for_spam is not None:
             self._check_event_for_spam_callbacks.append(check_event_for_spam)
 
+        if should_drop_federated_event is not None:
+            self._should_drop_federated_event_callbacks.append(
+                should_drop_federated_event
+            )
+
         if user_may_join_room is not None:
             self._user_may_join_room_callbacks.append(user_may_join_room)
 
@@ -240,7 +270,7 @@ class SpamChecker:
 
     async def check_event_for_spam(
         self, event: "synapse.events.EventBase"
-    ) -> Union[bool, str]:
+    ) -> Union[Decision, str]:
         """Checks if a given event is considered "spammy" by this server.
 
         If the server considers an event spammy, then it will be rejected if
@@ -251,11 +281,57 @@ class SpamChecker:
             event: the event to be checked
 
         Returns:
-            True or a string if the event is spammy. If a string is returned it
-            will be used as the error message returned to the user.
+            - on `ALLOW`, the event is considered good (non-spammy) and should
+                be let through. Other spamcheck filters may still reject it.
+            - on `Code`, the event is considered spammy and is rejected with a specific
+                error message/code.
+            - on `str`, the event is considered spammy and the string is used as error
+                message. This usage is generally discouraged as it doesn't support
+                internationalization.
         """
         for callback in self._check_event_for_spam_callbacks:
-            res: Union[bool, str] = await delay_cancellation(callback(event))
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                res: Union[Decision, str, bool] = await delay_cancellation(
+                    callback(event)
+                )
+                if res is False or res is Allow.ALLOW:
+                    # This spam-checker accepts the event.
+                    # Other spam-checkers may reject it, though.
+                    continue
+                elif res is True:
+                    # This spam-checker rejects the event with deprecated
+                    # return value `True`
+                    return Codes.FORBIDDEN
+                else:
+                    # This spam-checker rejects the event either with a `str`
+                    # or with a `Codes`. In either case, we stop here.
+                    return res
+
+        # No spam-checker has rejected the event, let it pass.
+        return Allow.ALLOW
+
+    async def should_drop_federated_event(
+        self, event: "synapse.events.EventBase"
+    ) -> Union[bool, str]:
+        """Checks if a given federated event is considered "spammy" by this
+        server.
+
+        If the server considers an event spammy, it will be silently dropped,
+        and in doing so will split-brain our view of the room's DAG.
+
+        Args:
+            event: the event to be checked
+
+        Returns:
+            True if the event should be silently dropped
+        """
+        for callback in self._should_drop_federated_event_callbacks:
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                res: Union[bool, str] = await delay_cancellation(callback(event))
             if res:
                 return res
 
@@ -276,9 +352,12 @@ class SpamChecker:
             Whether the user may join the room
         """
         for callback in self._user_may_join_room_callbacks:
-            may_join_room = await delay_cancellation(
-                callback(user_id, room_id, is_invited)
-            )
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                may_join_room = await delay_cancellation(
+                    callback(user_id, room_id, is_invited)
+                )
             if may_join_room is False:
                 return False
 
@@ -300,9 +379,12 @@ class SpamChecker:
             True if the user may send an invite, otherwise False
         """
         for callback in self._user_may_invite_callbacks:
-            may_invite = await delay_cancellation(
-                callback(inviter_userid, invitee_userid, room_id)
-            )
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                may_invite = await delay_cancellation(
+                    callback(inviter_userid, invitee_userid, room_id)
+                )
             if may_invite is False:
                 return False
 
@@ -328,9 +410,12 @@ class SpamChecker:
             True if the user may send the invite, otherwise False
         """
         for callback in self._user_may_send_3pid_invite_callbacks:
-            may_send_3pid_invite = await delay_cancellation(
-                callback(inviter_userid, medium, address, room_id)
-            )
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                may_send_3pid_invite = await delay_cancellation(
+                    callback(inviter_userid, medium, address, room_id)
+                )
             if may_send_3pid_invite is False:
                 return False
 
@@ -348,7 +433,10 @@ class SpamChecker:
             True if the user may create a room, otherwise False
         """
         for callback in self._user_may_create_room_callbacks:
-            may_create_room = await delay_cancellation(callback(userid))
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                may_create_room = await delay_cancellation(callback(userid))
             if may_create_room is False:
                 return False
 
@@ -369,9 +457,12 @@ class SpamChecker:
             True if the user may create a room alias, otherwise False
         """
         for callback in self._user_may_create_room_alias_callbacks:
-            may_create_room_alias = await delay_cancellation(
-                callback(userid, room_alias)
-            )
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                may_create_room_alias = await delay_cancellation(
+                    callback(userid, room_alias)
+                )
             if may_create_room_alias is False:
                 return False
 
@@ -390,7 +481,10 @@ class SpamChecker:
             True if the user may publish the room, otherwise False
         """
         for callback in self._user_may_publish_room_callbacks:
-            may_publish_room = await delay_cancellation(callback(userid, room_id))
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                may_publish_room = await delay_cancellation(callback(userid, room_id))
             if may_publish_room is False:
                 return False
 
@@ -412,9 +506,13 @@ class SpamChecker:
             True if the user is spammy.
         """
         for callback in self._check_username_for_spam_callbacks:
-            # Make a copy of the user profile object to ensure the spam checker cannot
-            # modify it.
-            if await delay_cancellation(callback(user_profile.copy())):
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                # Make a copy of the user profile object to ensure the spam checker cannot
+                # modify it.
+                res = await delay_cancellation(callback(user_profile.copy()))
+            if res:
                 return True
 
         return False
@@ -442,9 +540,12 @@ class SpamChecker:
         """
 
         for callback in self._check_registration_for_spam_callbacks:
-            behaviour = await delay_cancellation(
-                callback(email_threepid, username, request_info, auth_provider_id)
-            )
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                behaviour = await delay_cancellation(
+                    callback(email_threepid, username, request_info, auth_provider_id)
+                )
             assert isinstance(behaviour, RegistrationBehaviour)
             if behaviour != RegistrationBehaviour.ALLOW:
                 return behaviour
@@ -486,7 +587,10 @@ class SpamChecker:
         """
 
         for callback in self._check_media_file_for_spam_callbacks:
-            spam = await delay_cancellation(callback(file_wrapper, file_info))
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                spam = await delay_cancellation(callback(file_wrapper, file_info))
             if spam:
                 return True
 
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 41ac49fdc8..1e866b19d8 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -15,6 +15,7 @@
 import logging
 from typing import TYPE_CHECKING
 
+import synapse
 from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
 from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import EventFormatVersions, RoomVersion
@@ -98,9 +99,9 @@ class FederationBase:
                 )
             return redacted_event
 
-        result = await self.spam_checker.check_event_for_spam(pdu)
+        spam_check = await self.spam_checker.check_event_for_spam(pdu)
 
-        if result:
+        if spam_check is not synapse.spam_checker_api.Allow.ALLOW:
             logger.warning("Event contains spam, soft-failing %s", pdu.event_id)
             # we redact (to save disk space) as well as soft-failing (to stop
             # using the event in prev_events).
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 884b5d60b4..b8232e5257 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -110,6 +110,7 @@ class FederationServer(FederationBase):
 
         self.handler = hs.get_federation_handler()
         self.storage = hs.get_storage()
+        self._spam_checker = hs.get_spam_checker()
         self._federation_event_handler = hs.get_federation_event_handler()
         self.state = hs.get_state_handler()
         self._event_auth_handler = hs.get_event_auth_handler()
@@ -1019,6 +1020,12 @@ class FederationServer(FederationBase):
         except SynapseError as e:
             raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id)
 
+        if await self._spam_checker.should_drop_federated_event(pdu):
+            logger.warning(
+                "Unstaged federated event contains spam, dropping %s", pdu.event_id
+            )
+            return
+
         # Add the event to our staging area
         await self.store.insert_received_event_to_staging(origin, pdu)
 
@@ -1032,6 +1039,41 @@ class FederationServer(FederationBase):
                 pdu.room_id, room_version, lock, origin, pdu
             )
 
+    async def _get_next_nonspam_staged_event_for_room(
+        self, room_id: str, room_version: RoomVersion
+    ) -> Optional[Tuple[str, EventBase]]:
+        """Fetch the first non-spam event from staging queue.
+
+        Args:
+            room_id: the room to fetch the first non-spam event in.
+            room_version: the version of the room.
+
+        Returns:
+            The first non-spam event in that room.
+        """
+
+        while True:
+            # We need to do this check outside the lock to avoid a race between
+            # a new event being inserted by another instance and it attempting
+            # to acquire the lock.
+            next = await self.store.get_next_staged_event_for_room(
+                room_id, room_version
+            )
+
+            if next is None:
+                return None
+
+            origin, event = next
+
+            if await self._spam_checker.should_drop_federated_event(event):
+                logger.warning(
+                    "Staged federated event contains spam, dropping %s",
+                    event.event_id,
+                )
+                continue
+
+            return next
+
     @wrap_as_background_process("_process_incoming_pdus_in_room_inner")
     async def _process_incoming_pdus_in_room_inner(
         self,
@@ -1109,12 +1151,10 @@ class FederationServer(FederationBase):
                         (self._clock.time_msec() - received_ts) / 1000
                     )
 
-            # We need to do this check outside the lock to avoid a race between
-            # a new event being inserted by another instance and it attempting
-            # to acquire the lock.
-            next = await self.store.get_next_staged_event_for_room(
+            next = await self._get_next_nonspam_staged_event_for_room(
                 room_id, room_version
             )
+
             if not next:
                 break
 
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 6d2f46318b..dbe303ed9b 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -15,7 +15,17 @@
 import abc
 import logging
 from collections import OrderedDict
-from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Hashable,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+)
 
 import attr
 from prometheus_client import Counter
@@ -409,7 +419,7 @@ class FederationSender(AbstractFederationSender):
                         )
                         return
 
-                    destinations: Optional[Set[str]] = None
+                    destinations: Optional[Collection[str]] = None
                     if not event.prev_event_ids():
                         # If there are no prev event IDs then the state is empty
                         # and so no remote servers in the room
@@ -444,7 +454,7 @@ class FederationSender(AbstractFederationSender):
                             )
                             return
 
-                    destinations = {
+                    sharded_destinations = {
                         d
                         for d in destinations
                         if self._federation_shard_config.should_handle(
@@ -456,12 +466,12 @@ class FederationSender(AbstractFederationSender):
                         # If we are sending the event on behalf of another server
                         # then it already has the event and there is no reason to
                         # send the event to it.
-                        destinations.discard(send_on_behalf_of)
+                        sharded_destinations.discard(send_on_behalf_of)
 
-                    logger.debug("Sending %s to %r", event, destinations)
+                    logger.debug("Sending %s to %r", event, sharded_destinations)
 
-                    if destinations:
-                        await self._send_pdu(event, destinations)
+                    if sharded_destinations:
+                        await self._send_pdu(event, sharded_destinations)
 
                         now = self.clock.time_msec()
                         ts = await self.store.get_received_ts(event.event_id)
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index d629a3ecb5..84100a5a52 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tupl
 
 from synapse.api.errors import Codes, FederationDeniedError, SynapseError
 from synapse.api.urls import FEDERATION_V1_PREFIX
-from synapse.http.server import HttpServer, ServletCallback
+from synapse.http.server import HttpServer, ServletCallback, is_method_cancellable
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import run_in_background
@@ -169,14 +169,16 @@ def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str, Optional[str
     """
     try:
         header_str = header_bytes.decode("utf-8")
-        params = header_str.split(" ")[1].split(",")
+        params = re.split(" +", header_str)[1].split(",")
         param_dict: Dict[str, str] = {
-            k: v for k, v in [param.split("=", maxsplit=1) for param in params]
+            k.lower(): v for k, v in [param.split("=", maxsplit=1) for param in params]
         }
 
         def strip_quotes(value: str) -> str:
             if value.startswith('"'):
-                return value[1:-1]
+                return re.sub(
+                    "\\\\(.)", lambda matchobj: matchobj.group(1), value[1:-1]
+                )
             else:
                 return value
 
@@ -373,6 +375,17 @@ class BaseFederationServlet:
             if code is None:
                 continue
 
+            if is_method_cancellable(code):
+                # The wrapper added by `self._wrap` will inherit the cancellable flag,
+                # but the wrapper itself does not support cancellation yet.
+                # Once resolved, the cancellation tests in
+                # `tests/federation/transport/server/test__base.py` can be re-enabled.
+                raise Exception(
+                    f"{self.__class__.__name__}.on_{method} has been marked as "
+                    "cancellable, but federation servlets do not support cancellation "
+                    "yet."
+                )
+
             server.register_paths(
                 method,
                 (pattern,),
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 4c3a5a6e24..dfd24af695 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -934,7 +934,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         # Before deleting the group lets kick everyone out of it
         users = await self.store.get_users_in_group(group_id, include_private=True)
 
-        async def _kick_user_from_group(user_id):
+        async def _kick_user_from_group(user_id: str) -> None:
             if self.hs.is_mine_id(user_id):
                 groups_local = self.hs.get_groups_local_handler()
                 assert isinstance(
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 4af9fbc5d1..0478448b47 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -23,7 +23,7 @@ from synapse.replication.http.account_data import (
     ReplicationUserAccountDataRestServlet,
 )
 from synapse.streams import EventSource
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, StreamKeyType, UserID
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -105,7 +105,7 @@ class AccountDataHandler:
             )
 
             self._notifier.on_new_event(
-                "account_data_key", max_stream_id, users=[user_id]
+                StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
             )
 
             await self._notify_modules(user_id, room_id, account_data_type, content)
@@ -141,7 +141,7 @@ class AccountDataHandler:
             )
 
             self._notifier.on_new_event(
-                "account_data_key", max_stream_id, users=[user_id]
+                StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
             )
 
             await self._notify_modules(user_id, None, account_data_type, content)
@@ -176,7 +176,7 @@ class AccountDataHandler:
             )
 
             self._notifier.on_new_event(
-                "account_data_key", max_stream_id, users=[user_id]
+                StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
             )
             return max_stream_id
         else:
@@ -201,7 +201,7 @@ class AccountDataHandler:
             )
 
             self._notifier.on_new_event(
-                "account_data_key", max_stream_id, users=[user_id]
+                StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
             )
             return max_stream_id
         else:
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 85bd5e4768..1da7bcc85b 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -38,6 +38,7 @@ from synapse.types import (
     JsonDict,
     RoomAlias,
     RoomStreamToken,
+    StreamKeyType,
     UserID,
 )
 from synapse.util.async_helpers import Linearizer
@@ -213,8 +214,8 @@ class ApplicationServicesHandler:
         Args:
             stream_key: The stream the event came from.
 
-                `stream_key` can be "typing_key", "receipt_key", "presence_key",
-                "to_device_key" or "device_list_key". Any other value for `stream_key`
+                `stream_key` can be StreamKeyType.TYPING, StreamKeyType.RECEIPT, StreamKeyType.PRESENCE,
+                StreamKeyType.TO_DEVICE or StreamKeyType.DEVICE_LIST. Any other value for `stream_key`
                 will cause this function to return early.
 
                 Ephemeral events will only be pushed to appservices that have opted into
@@ -235,11 +236,11 @@ class ApplicationServicesHandler:
         # Only the following streams are currently supported.
         # FIXME: We should use constants for these values.
         if stream_key not in (
-            "typing_key",
-            "receipt_key",
-            "presence_key",
-            "to_device_key",
-            "device_list_key",
+            StreamKeyType.TYPING,
+            StreamKeyType.RECEIPT,
+            StreamKeyType.PRESENCE,
+            StreamKeyType.TO_DEVICE,
+            StreamKeyType.DEVICE_LIST,
         ):
             return
 
@@ -258,14 +259,14 @@ class ApplicationServicesHandler:
 
         # Ignore to-device messages if the feature flag is not enabled
         if (
-            stream_key == "to_device_key"
+            stream_key == StreamKeyType.TO_DEVICE
             and not self._msc2409_to_device_messages_enabled
         ):
             return
 
         # Ignore device lists if the feature flag is not enabled
         if (
-            stream_key == "device_list_key"
+            stream_key == StreamKeyType.DEVICE_LIST
             and not self._msc3202_transaction_extensions_enabled
         ):
             return
@@ -283,15 +284,15 @@ class ApplicationServicesHandler:
             if (
                 stream_key
                 in (
-                    "typing_key",
-                    "receipt_key",
-                    "presence_key",
-                    "to_device_key",
+                    StreamKeyType.TYPING,
+                    StreamKeyType.RECEIPT,
+                    StreamKeyType.PRESENCE,
+                    StreamKeyType.TO_DEVICE,
                 )
                 and service.supports_ephemeral
             )
             or (
-                stream_key == "device_list_key"
+                stream_key == StreamKeyType.DEVICE_LIST
                 and service.msc3202_transaction_extensions
             )
         ]
@@ -317,7 +318,7 @@ class ApplicationServicesHandler:
         logger.debug("Checking interested services for %s", stream_key)
         with Measure(self.clock, "notify_interested_services_ephemeral"):
             for service in services:
-                if stream_key == "typing_key":
+                if stream_key == StreamKeyType.TYPING:
                     # Note that we don't persist the token (via set_appservice_stream_type_pos)
                     # for typing_key due to performance reasons and due to their highly
                     # ephemeral nature.
@@ -333,7 +334,7 @@ class ApplicationServicesHandler:
                 async with self._ephemeral_events_linearizer.queue(
                     (service.id, stream_key)
                 ):
-                    if stream_key == "receipt_key":
+                    if stream_key == StreamKeyType.RECEIPT:
                         events = await self._handle_receipts(service, new_token)
                         self.scheduler.enqueue_for_appservice(service, ephemeral=events)
 
@@ -342,7 +343,7 @@ class ApplicationServicesHandler:
                             service, "read_receipt", new_token
                         )
 
-                    elif stream_key == "presence_key":
+                    elif stream_key == StreamKeyType.PRESENCE:
                         events = await self._handle_presence(service, users, new_token)
                         self.scheduler.enqueue_for_appservice(service, ephemeral=events)
 
@@ -351,7 +352,7 @@ class ApplicationServicesHandler:
                             service, "presence", new_token
                         )
 
-                    elif stream_key == "to_device_key":
+                    elif stream_key == StreamKeyType.TO_DEVICE:
                         # Retrieve a list of to-device message events, as well as the
                         # maximum stream token of the messages we were able to retrieve.
                         to_device_messages = await self._get_to_device_messages(
@@ -366,7 +367,7 @@ class ApplicationServicesHandler:
                             service, "to_device", new_token
                         )
 
-                    elif stream_key == "device_list_key":
+                    elif stream_key == StreamKeyType.DEVICE_LIST:
                         device_list_summary = await self._get_device_list_summary(
                             service, new_token
                         )
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index a91b1ee4d5..1d6d1f8a92 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -43,6 +43,7 @@ from synapse.metrics.background_process_metrics import (
 )
 from synapse.types import (
     JsonDict,
+    StreamKeyType,
     StreamToken,
     UserID,
     get_domain_from_id,
@@ -502,7 +503,7 @@ class DeviceHandler(DeviceWorkerHandler):
         # specify the user ID too since the user should always get their own device list
         # updates, even if they aren't in any rooms.
         self.notifier.on_new_event(
-            "device_list_key", position, users={user_id}, rooms=room_ids
+            StreamKeyType.DEVICE_LIST, position, users={user_id}, rooms=room_ids
         )
 
         # We may need to do some processing asynchronously for local user IDs.
@@ -523,7 +524,9 @@ class DeviceHandler(DeviceWorkerHandler):
             from_user_id, user_ids
         )
 
-        self.notifier.on_new_event("device_list_key", position, users=[from_user_id])
+        self.notifier.on_new_event(
+            StreamKeyType.DEVICE_LIST, position, users=[from_user_id]
+        )
 
     async def user_left_room(self, user: UserID, room_id: str) -> None:
         user_id = user.to_string()
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 4cb725d027..53668cce3b 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -26,7 +26,7 @@ from synapse.logging.opentracing import (
     set_tag,
 )
 from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
-from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.stringutils import random_string
 
@@ -151,7 +151,7 @@ class DeviceMessageHandler:
         # Notify listeners that there are new to-device messages to process,
         # handing them the latest stream id.
         self.notifier.on_new_event(
-            "to_device_key", last_stream_id, users=local_messages.keys()
+            StreamKeyType.TO_DEVICE, last_stream_id, users=local_messages.keys()
         )
 
     async def _check_for_unknown_devices(
@@ -285,7 +285,7 @@ class DeviceMessageHandler:
         # Notify listeners that there are new to-device messages to process,
         # handing them the latest stream id.
         self.notifier.on_new_event(
-            "to_device_key", last_stream_id, users=local_messages.keys()
+            StreamKeyType.TO_DEVICE, last_stream_id, users=local_messages.keys()
         )
 
         if self.federation_sender:
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 33d827a45b..4aa33df884 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -71,6 +71,9 @@ class DirectoryHandler:
             if wchar in room_alias.localpart:
                 raise SynapseError(400, "Invalid characters in room alias")
 
+        if ":" in room_alias.localpart:
+            raise SynapseError(400, "Invalid character in room alias localpart: ':'.")
+
         if not self.hs.is_mine(room_alias):
             raise SynapseError(400, "Room alias must be local")
             # TODO(erikj): Change this.
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d6714228ef..e6c2cfb8c8 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -1105,22 +1105,19 @@ class E2eKeysHandler:
             # can request over federation
             raise NotFoundError("No %s key found for %s" % (key_type, user_id))
 
-        (
-            key,
-            key_id,
-            verify_key,
-        ) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
-
-        if key is None:
+        cross_signing_keys = await self._retrieve_cross_signing_keys_for_remote_user(
+            user, key_type
+        )
+        if cross_signing_keys is None:
             raise NotFoundError("No %s key found for %s" % (key_type, user_id))
 
-        return key, key_id, verify_key
+        return cross_signing_keys
 
     async def _retrieve_cross_signing_keys_for_remote_user(
         self,
         user: UserID,
         desired_key_type: str,
-    ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
+    ) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]:
         """Queries cross-signing keys for a remote user and saves them to the database
 
         Only the key specified by `key_type` will be returned, while all retrieved keys
@@ -1146,12 +1143,10 @@ class E2eKeysHandler:
                 type(e),
                 e,
             )
-            return None, None, None
+            return None
 
         # Process each of the retrieved cross-signing keys
-        desired_key = None
-        desired_key_id = None
-        desired_verify_key = None
+        desired_key_data = None
         retrieved_device_ids = []
         for key_type in ["master", "self_signing"]:
             key_content = remote_result.get(key_type + "_key")
@@ -1196,9 +1191,7 @@ class E2eKeysHandler:
 
             # If this is the desired key type, save it and its ID/VerifyKey
             if key_type == desired_key_type:
-                desired_key = key_content
-                desired_verify_key = verify_key
-                desired_key_id = key_id
+                desired_key_data = key_content, key_id, verify_key
 
             # At the same time, store this key in the db for subsequent queries
             await self.store.set_e2e_cross_signing_key(
@@ -1212,7 +1205,7 @@ class E2eKeysHandler:
                 user.to_string(), retrieved_device_ids
             )
 
-        return desired_key, desired_key_id, desired_verify_key
+        return desired_key_data
 
 
 def _check_cross_signing_key(
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index d441ebb0ab..6bed464351 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -241,7 +241,15 @@ class EventAuthHandler:
 
         # If the join rule is not restricted, this doesn't apply.
         join_rules_event = await self._store.get_event(join_rules_event_id)
-        return join_rules_event.content.get("join_rule") == JoinRules.RESTRICTED
+        content_join_rule = join_rules_event.content.get("join_rule")
+        if content_join_rule == JoinRules.RESTRICTED:
+            return True
+
+        # also check for MSC3787 behaviour
+        if room_version.msc3787_knock_restricted_join_rule:
+            return content_join_rule == JoinRules.KNOCK_RESTRICTED
+
+        return False
 
     async def get_rooms_that_allow_join(
         self, state_ids: StateMap[str]
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 38dc5b1f6e..0386d0a07b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -54,6 +54,7 @@ from synapse.replication.http.federation import (
     ReplicationStoreRoomOnOutlierMembershipRestServlet,
 )
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.storage.state import StateFilter
 from synapse.types import JsonDict, StateMap, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.retryutils import NotRetryingDestination
@@ -659,7 +660,7 @@ class FederationHandler:
         # in the invitee's sync stream. It is stripped out for all other local users.
         event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
 
-        context = EventContext.for_outlier()
+        context = EventContext.for_outlier(self.storage)
         stream_id = await self._federation_event_handler.persist_events_and_notify(
             event.room_id, [(event, context)]
         )
@@ -848,7 +849,7 @@ class FederationHandler:
             )
         )
 
-        context = EventContext.for_outlier()
+        context = EventContext.for_outlier(self.storage)
         await self._federation_event_handler.persist_events_and_notify(
             event.room_id, [(event, context)]
         )
@@ -877,7 +878,7 @@ class FederationHandler:
 
         await self.federation_client.send_leave(host_list, event)
 
-        context = EventContext.for_outlier()
+        context = EventContext.for_outlier(self.storage)
         stream_id = await self._federation_event_handler.persist_events_and_notify(
             event.room_id, [(event, context)]
         )
@@ -1259,7 +1260,9 @@ class FederationHandler:
             event.content["third_party_invite"]["signed"]["token"],
         )
         original_invite = None
-        prev_state_ids = await context.get_prev_state_ids()
+        prev_state_ids = await context.get_prev_state_ids(
+            StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)])
+        )
         original_invite_id = prev_state_ids.get(key)
         if original_invite_id:
             original_invite = await self.store.get_event(
@@ -1308,7 +1311,9 @@ class FederationHandler:
         signed = event.content["third_party_invite"]["signed"]
         token = signed["token"]
 
-        prev_state_ids = await context.get_prev_state_ids()
+        prev_state_ids = await context.get_prev_state_ids(
+            StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)])
+        )
         invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
 
         invite_event = None
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 6cf927e4ff..ca82df8a6d 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -30,6 +30,7 @@ from typing import (
 
 from prometheus_client import Counter
 
+from synapse import event_auth
 from synapse.api.constants import (
     EventContentFields,
     EventTypes,
@@ -63,6 +64,7 @@ from synapse.replication.http.federation import (
 )
 from synapse.state import StateResolutionStore
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.storage.state import StateFilter
 from synapse.types import (
     PersistedEventPosition,
     RoomStreamToken,
@@ -103,7 +105,7 @@ class FederationEventHandler:
         self._event_creation_handler = hs.get_event_creation_handler()
         self._event_auth_handler = hs.get_event_auth_handler()
         self._message_handler = hs.get_message_handler()
-        self._action_generator = hs.get_action_generator()
+        self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
         self._state_resolution_handler = hs.get_state_resolution_handler()
         # avoid a circular dependency by deferring execution here
         self._get_room_member_handler = hs.get_room_member_handler
@@ -475,7 +477,23 @@ class FederationEventHandler:
             # and discover that we do not have it.
             event.internal_metadata.proactively_send = False
 
-            return await self.persist_events_and_notify(room_id, [(event, context)])
+            stream_id_after_persist = await self.persist_events_and_notify(
+                room_id, [(event, context)]
+            )
+
+            # If we're joining the room again, check if there is new marker
+            # state indicating that there is new history imported somewhere in
+            # the DAG. Multiple markers can exist in the current state with
+            # unique state_keys.
+            #
+            # Do this after the state from the remote join was persisted (via
+            # `persist_events_and_notify`). Otherwise we can run into a
+            # situation where the create event doesn't exist yet in the
+            # `current_state_events`
+            for e in state:
+                await self._handle_marker_event(origin, e)
+
+            return stream_id_after_persist
 
     async def update_state_for_partial_state_event(
         self, destination: str, event: EventBase
@@ -1228,6 +1246,14 @@ class FederationEventHandler:
             # Nothing to retrieve then (invalid marker)
             return
 
+        already_seen_insertion_event = await self._store.have_seen_event(
+            marker_event.room_id, insertion_event_id
+        )
+        if already_seen_insertion_event:
+            # No need to process a marker again if we have already seen the
+            # insertion event that it was pointing to
+            return
+
         logger.debug(
             "_handle_marker_event: backfilling insertion event %s", insertion_event_id
         )
@@ -1423,7 +1449,7 @@ class FederationEventHandler:
                 # we're not bothering about room state, so flag the event as an outlier.
                 event.internal_metadata.outlier = True
 
-                context = EventContext.for_outlier()
+                context = EventContext.for_outlier(self._storage)
                 try:
                     validate_event_for_room_version(room_version_obj, event)
                     check_auth_rules_for_event(room_version_obj, event, auth)
@@ -1500,7 +1526,11 @@ class FederationEventHandler:
             return context
 
         # now check auth against what we think the auth events *should* be.
-        prev_state_ids = await context.get_prev_state_ids()
+        event_types = event_auth.auth_types_for_event(event.room_version, event)
+        prev_state_ids = await context.get_prev_state_ids(
+            StateFilter.from_types(event_types)
+        )
+
         auth_events_ids = self._event_auth_handler.compute_auth_events(
             event, prev_state_ids, for_verification=True
         )
@@ -1874,10 +1904,10 @@ class FederationEventHandler:
         )
 
         return EventContext.with_state(
+            storage=self._storage,
             state_group=state_group,
             state_group_before_event=context.state_group_before_event,
-            current_state_ids=current_state_ids,
-            prev_state_ids=prev_state_ids,
+            state_delta_due_to_event=state_updates,
             prev_group=prev_group,
             delta_ids=state_updates,
             partial_state=context.partial_state,
@@ -1913,7 +1943,7 @@ class FederationEventHandler:
                     min_depth,
                 )
             else:
-                await self._action_generator.handle_push_actions_for_event(
+                await self._bulk_push_rule_evaluator.action_for_event_by_user(
                     event, context
                 )
 
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 7b94770f97..d79248ad90 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -30,6 +30,7 @@ from synapse.types import (
     Requester,
     RoomStreamToken,
     StateMap,
+    StreamKeyType,
     StreamToken,
     UserID,
 )
@@ -143,7 +144,7 @@ class InitialSyncHandler:
             to_key=int(now_token.receipt_key),
         )
         if self.hs.config.experimental.msc2285_enabled:
-            receipt = ReceiptEventSource.filter_out_private(receipt, user_id)
+            receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
 
         tags_by_room = await self.store.get_tags_for_user(user_id)
 
@@ -220,8 +221,10 @@ class InitialSyncHandler:
                     self.storage, user_id, messages
                 )
 
-                start_token = now_token.copy_and_replace("room_key", token)
-                end_token = now_token.copy_and_replace("room_key", room_end_token)
+                start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
+                end_token = now_token.copy_and_replace(
+                    StreamKeyType.ROOM, room_end_token
+                )
                 time_now = self.clock.time_msec()
 
                 d["messages"] = {
@@ -369,8 +372,8 @@ class InitialSyncHandler:
             self.storage, user_id, messages, is_peeking=is_peeking
         )
 
-        start_token = StreamToken.START.copy_and_replace("room_key", token)
-        end_token = StreamToken.START.copy_and_replace("room_key", stream_token)
+        start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token)
+        end_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, stream_token)
 
         time_now = self.clock.time_msec()
 
@@ -449,7 +452,9 @@ class InitialSyncHandler:
             if not receipts:
                 return []
             if self.hs.config.experimental.msc2285_enabled:
-                receipts = ReceiptEventSource.filter_out_private(receipts, user_id)
+                receipts = ReceiptEventSource.filter_out_private_receipts(
+                    receipts, user_id
+                )
             return receipts
 
         presence, receipts, (messages, token) = await make_deferred_yieldable(
@@ -472,7 +477,7 @@ class InitialSyncHandler:
             self.storage, user_id, messages, is_peeking=is_peeking
         )
 
-        start_token = now_token.copy_and_replace("room_key", token)
+        start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
         end_token = now_token
 
         time_now = self.clock.time_msec()
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index e1082954cc..25eab0c5d4 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -23,6 +23,7 @@ from canonicaljson import encode_canonical_json
 
 from twisted.internet.interfaces import IDelayedCall
 
+import synapse
 from synapse import event_auth
 from synapse.api.constants import (
     EventContentFields,
@@ -44,7 +45,7 @@ from synapse.api.errors import (
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
 from synapse.api.urls import ConsentURIBuilder
 from synapse.event_auth import validate_event_for_room_version
-from synapse.events import EventBase
+from synapse.events import EventBase, relation_from_event
 from synapse.events.builder import EventBuilder
 from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
@@ -426,7 +427,7 @@ class EventCreationHandler:
         # This is to stop us from diverging history *too* much.
         self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
 
-        self.action_generator = hs.get_action_generator()
+        self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
 
         self.spam_checker = hs.get_spam_checker()
         self.third_party_event_rules: "ThirdPartyEventRules" = (
@@ -634,7 +635,9 @@ class EventCreationHandler:
             # federation as well as those created locally. As of room v3, aliases events
             # can be created by users that are not in the room, therefore we have to
             # tolerate them in event_auth.check().
-            prev_state_ids = await context.get_prev_state_ids()
+            prev_state_ids = await context.get_prev_state_ids(
+                StateFilter.from_types([(EventTypes.Member, None)])
+            )
             prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
             prev_event = (
                 await self.store.get_event(prev_event_id, allow_none=True)
@@ -757,7 +760,13 @@ class EventCreationHandler:
             The previous version of the event is returned, if it is found in the
             event context. Otherwise, None is returned.
         """
-        prev_state_ids = await context.get_prev_state_ids()
+        if event.internal_metadata.is_outlier():
+            # This can happen due to out of band memberships
+            return None
+
+        prev_state_ids = await context.get_prev_state_ids(
+            StateFilter.from_types([(event.type, None)])
+        )
         prev_event_id = prev_state_ids.get((event.type, event.state_key))
         if not prev_event_id:
             return None
@@ -877,11 +886,11 @@ class EventCreationHandler:
                 event.sender,
             )
 
-            spam_error = await self.spam_checker.check_event_for_spam(event)
-            if spam_error:
-                if not isinstance(spam_error, str):
-                    spam_error = "Spam is not permitted here"
-                raise SynapseError(403, spam_error, Codes.FORBIDDEN)
+            spam_check = await self.spam_checker.check_event_for_spam(event)
+            if spam_check is not synapse.spam_checker_api.Allow.ALLOW:
+                raise SynapseError(
+                    403, "This message had been rejected as probable spam", spam_check
+                )
 
             ev = await self.handle_new_client_event(
                 requester=requester,
@@ -1001,7 +1010,7 @@ class EventCreationHandler:
         # after it is created
         if builder.internal_metadata.outlier:
             event.internal_metadata.outlier = True
-            context = EventContext.for_outlier()
+            context = EventContext.for_outlier(self.storage)
         elif (
             event.type == EventTypes.MSC2716_INSERTION
             and state_event_ids
@@ -1056,20 +1065,11 @@ class EventCreationHandler:
             SynapseError if the event is invalid.
         """
 
-        relation = event.content.get("m.relates_to")
+        relation = relation_from_event(event)
         if not relation:
             return
 
-        relation_type = relation.get("rel_type")
-        if not relation_type:
-            return
-
-        # Ensure the parent is real.
-        relates_to = relation.get("event_id")
-        if not relates_to:
-            return
-
-        parent_event = await self.store.get_event(relates_to, allow_none=True)
+        parent_event = await self.store.get_event(relation.parent_id, allow_none=True)
         if parent_event:
             # And in the same room.
             if parent_event.room_id != event.room_id:
@@ -1078,28 +1078,31 @@ class EventCreationHandler:
         else:
             # There must be some reason that the client knows the event exists,
             # see if there are existing relations. If so, assume everything is fine.
-            if not await self.store.event_is_target_of_relation(relates_to):
+            if not await self.store.event_is_target_of_relation(relation.parent_id):
                 # Otherwise, the client can't know about the parent event!
                 raise SynapseError(400, "Can't send relation to unknown event")
 
         # If this event is an annotation then we check that that the sender
         # can't annotate the same way twice (e.g. stops users from liking an
         # event multiple times).
-        if relation_type == RelationTypes.ANNOTATION:
-            aggregation_key = relation["key"]
+        if relation.rel_type == RelationTypes.ANNOTATION:
+            aggregation_key = relation.aggregation_key
+
+            if aggregation_key is None:
+                raise SynapseError(400, "Missing aggregation key")
 
             if len(aggregation_key) > 500:
                 raise SynapseError(400, "Aggregation key is too long")
 
             already_exists = await self.store.has_user_annotated_event(
-                relates_to, event.type, aggregation_key, event.sender
+                relation.parent_id, event.type, aggregation_key, event.sender
             )
             if already_exists:
                 raise SynapseError(400, "Can't send same reaction twice")
 
         # Don't attempt to start a thread if the parent event is a relation.
-        elif relation_type == RelationTypes.THREAD:
-            if await self.store.event_includes_relation(relates_to):
+        elif relation.rel_type == RelationTypes.THREAD:
+            if await self.store.event_includes_relation(relation.parent_id):
                 raise SynapseError(
                     400, "Cannot start threads from an event with a relation"
                 )
@@ -1245,7 +1248,9 @@ class EventCreationHandler:
         # and `state_groups` because they have `prev_events` that aren't persisted yet
         # (historical messages persisted in reverse-chronological order).
         if not event.internal_metadata.is_historical():
-            await self.action_generator.handle_push_actions_for_event(event, context)
+            await self._bulk_push_rule_evaluator.action_for_event_by_user(
+                event, context
+            )
 
         try:
             # If we're a worker we need to hit out to the master.
@@ -1547,7 +1552,11 @@ class EventCreationHandler:
                         "Redacting MSC2716 events is not supported in this room version",
                     )
 
-            prev_state_ids = await context.get_prev_state_ids()
+            event_types = event_auth.auth_types_for_event(event.room_version, event)
+            prev_state_ids = await context.get_prev_state_ids(
+                StateFilter.from_types(event_types)
+            )
+
             auth_events_ids = self._event_auth_handler.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index f6ffb7d18d..9de61d554f 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -224,7 +224,7 @@ class OidcHandler:
             self._sso_handler.render_error(request, "invalid_session", str(e))
             return
         except MacaroonInvalidSignatureException as e:
-            logger.exception("Could not verify session for OIDC callback")
+            logger.warning("Could not verify session for OIDC callback: %s", e)
             self._sso_handler.render_error(request, "mismatching_session", str(e))
             return
 
@@ -827,7 +827,7 @@ class OidcProvider:
             logger.debug("Exchanging OAuth2 code for a token")
             token = await self._exchange_code(code)
         except OidcError as e:
-            logger.exception("Could not exchange OAuth2 code")
+            logger.warning("Could not exchange OAuth2 code: %s", e)
             self._sso_handler.render_error(request, e.error, e.error_description)
             return
 
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 7ee3340373..19a4407050 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -27,7 +27,7 @@ from synapse.handlers.room import ShutdownRoomResponse
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, Requester
+from synapse.types import JsonDict, Requester, StreamKeyType
 from synapse.util.async_helpers import ReadWriteLock
 from synapse.util.stringutils import random_string
 from synapse.visibility import filter_events_for_client
@@ -239,7 +239,7 @@ class PaginationHandler:
             # defined in the server's configuration, we can safely assume that's the
             # case and use it for this room.
             max_lifetime = (
-                retention_policy["max_lifetime"] or self._retention_default_max_lifetime
+                retention_policy.max_lifetime or self._retention_default_max_lifetime
             )
 
             # Cap the effective max_lifetime to be within the range allowed in the
@@ -448,7 +448,7 @@ class PaginationHandler:
             )
             # We expect `/messages` to use historic pagination tokens by default but
             # `/messages` should still works with live tokens when manually provided.
-            assert from_token.room_key.topological
+            assert from_token.room_key.topological is not None
 
         if pagin_config.limit is None:
             # This shouldn't happen as we've set a default limit before this
@@ -491,7 +491,7 @@ class PaginationHandler:
 
                     if leave_token.topological < curr_topo:
                         from_token = from_token.copy_and_replace(
-                            "room_key", leave_token
+                            StreamKeyType.ROOM, leave_token
                         )
 
                 await self.hs.get_federation_handler().maybe_backfill(
@@ -513,7 +513,7 @@ class PaginationHandler:
                 event_filter=event_filter,
             )
 
-            next_token = from_token.copy_and_replace("room_key", next_key)
+            next_token = from_token.copy_and_replace(StreamKeyType.ROOM, next_key)
 
         if events:
             if event_filter:
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 268481ec19..dd84e6c88b 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -66,7 +66,7 @@ from synapse.replication.tcp.commands import ClearUserSyncsCommand
 from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
 from synapse.storage.databases.main import DataStore
 from synapse.streams import EventSource
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.descriptors import _CacheContext, cached
 from synapse.util.metrics import Measure
@@ -522,7 +522,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
         room_ids_to_states, users_to_states = parties
 
         self.notifier.on_new_event(
-            "presence_key",
+            StreamKeyType.PRESENCE,
             stream_id,
             rooms=room_ids_to_states.keys(),
             users=users_to_states.keys(),
@@ -1145,7 +1145,7 @@ class PresenceHandler(BasePresenceHandler):
         room_ids_to_states, users_to_states = parties
 
         self.notifier.on_new_event(
-            "presence_key",
+            StreamKeyType.PRESENCE,
             stream_id,
             rooms=room_ids_to_states.keys(),
             users=[UserID.from_string(u) for u in users_to_states],
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 43d615357b..e6a35f1d09 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -17,7 +17,13 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 from synapse.api.constants import ReceiptTypes
 from synapse.appservice import ApplicationService
 from synapse.streams import EventSource
-from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
+from synapse.types import (
+    JsonDict,
+    ReadReceipt,
+    StreamKeyType,
+    UserID,
+    get_domain_from_id,
+)
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -129,7 +135,9 @@ class ReceiptsHandler:
 
         affected_room_ids = list({r.room_id for r in receipts})
 
-        self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
+        self.notifier.on_new_event(
+            StreamKeyType.RECEIPT, max_batch_id, rooms=affected_room_ids
+        )
         # Note that the min here shouldn't be relied upon to be accurate.
         await self.hs.get_pusherpool().on_new_receipts(
             min_batch_id, max_batch_id, affected_room_ids
@@ -165,43 +173,69 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
         self.config = hs.config
 
     @staticmethod
-    def filter_out_private(events: List[JsonDict], user_id: str) -> List[JsonDict]:
-        """
-        This method takes in what is returned by
-        get_linearized_receipts_for_rooms() and goes through read receipts
-        filtering out m.read.private receipts if they were not sent by the
-        current user.
+    def filter_out_private_receipts(
+        rooms: List[JsonDict], user_id: str
+    ) -> List[JsonDict]:
         """
+        Filters a list of serialized receipts (as returned by /sync and /initialSync)
+        and removes private read receipts of other users.
 
-        visible_events = []
-
-        # filter out private receipts the user shouldn't see
-        for event in events:
-            content = event.get("content", {})
-            new_event = event.copy()
-            new_event["content"] = {}
-
-            for event_id, event_content in content.items():
-                receipt_event = {}
-                for receipt_type, receipt_content in event_content.items():
-                    if receipt_type == ReceiptTypes.READ_PRIVATE:
-                        user_rr = receipt_content.get(user_id, None)
-                        if user_rr:
-                            receipt_event[ReceiptTypes.READ_PRIVATE] = {
-                                user_id: user_rr.copy()
-                            }
-                    else:
-                        receipt_event[receipt_type] = receipt_content.copy()
+        This operates on the return value of get_linearized_receipts_for_rooms(),
+        which is wrapped in a cache. Care must be taken to ensure that the input
+        values are not modified.
 
-                # Only include the receipt event if it is non-empty.
-                if receipt_event:
-                    new_event["content"][event_id] = receipt_event
+        Args:
+            rooms: A list of mappings, each mapping has a `content` field, which
+                is a map of event ID -> receipt type -> user ID -> receipt information.
 
-            # Append new_event to visible_events unless empty
-            if len(new_event["content"].keys()) > 0:
-                visible_events.append(new_event)
+        Returns:
+            The same as rooms, but filtered.
+        """
 
-        return visible_events
+        result = []
+
+        # Iterate through each room's receipt content.
+        for room in rooms:
+            # The receipt content with other user's private read receipts removed.
+            content = {}
+
+            # Iterate over each event ID / receipts for that event.
+            for event_id, orig_event_content in room.get("content", {}).items():
+                event_content = orig_event_content
+                # If there are private read receipts, additional logic is necessary.
+                if ReceiptTypes.READ_PRIVATE in event_content:
+                    # Make a copy without private read receipts to avoid leaking
+                    # other user's private read receipts..
+                    event_content = {
+                        receipt_type: receipt_value
+                        for receipt_type, receipt_value in event_content.items()
+                        if receipt_type != ReceiptTypes.READ_PRIVATE
+                    }
+
+                    # Copy the current user's private read receipt from the
+                    # original content, if it exists.
+                    user_private_read_receipt = orig_event_content[
+                        ReceiptTypes.READ_PRIVATE
+                    ].get(user_id, None)
+                    if user_private_read_receipt:
+                        event_content[ReceiptTypes.READ_PRIVATE] = {
+                            user_id: user_private_read_receipt
+                        }
+
+                # Include the event if there is at least one non-private read
+                # receipt or the current user has a private read receipt.
+                if event_content:
+                    content[event_id] = event_content
+
+            # Include the event if there is at least one non-private read receipt
+            # or the current user has a private read receipt.
+            if content:
+                # Build a new event to avoid mutating the cache.
+                new_room = {k: v for k, v in room.items() if k != "content"}
+                new_room["content"] = content
+                result.append(new_room)
+
+        return result
 
     async def get_new_events(
         self,
@@ -223,7 +257,9 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
         )
 
         if self.config.experimental.msc2285_enabled:
-            events = ReceiptEventSource.filter_out_private(events, user.to_string())
+            events = ReceiptEventSource.filter_out_private_receipts(
+                events, user.to_string()
+            )
 
         return events, to_key
 
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index c2754ec918..ab7e54857d 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -11,7 +11,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import collections.abc
 import logging
 from typing import (
     TYPE_CHECKING,
@@ -28,7 +27,7 @@ import attr
 
 from synapse.api.constants import RelationTypes
 from synapse.api.errors import SynapseError
-from synapse.events import EventBase
+from synapse.events import EventBase, relation_from_event
 from synapse.storage.databases.main.relations import _RelatedEvent
 from synapse.types import JsonDict, Requester, StreamToken, UserID
 from synapse.visibility import filter_events_for_client
@@ -373,20 +372,21 @@ class RelationsHandler:
             if event.is_state():
                 continue
 
-            relates_to = event.content.get("m.relates_to")
-            relation_type = None
-            if isinstance(relates_to, collections.abc.Mapping):
-                relation_type = relates_to.get("rel_type")
+            relates_to = relation_from_event(event)
+            if relates_to:
                 # An event which is a replacement (ie edit) or annotation (ie,
                 # reaction) may not have any other event related to it.
-                if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+                if relates_to.rel_type in (
+                    RelationTypes.ANNOTATION,
+                    RelationTypes.REPLACE,
+                ):
                     continue
 
+                # Track the event's relation information for later.
+                relations_by_id[event.event_id] = relates_to.rel_type
+
             # The event should get bundled aggregations.
             events_by_id[event.event_id] = event
-            # Track the event's relation information for later.
-            if isinstance(relation_type, str):
-                relations_by_id[event.event_id] = relation_type
 
         # event ID -> bundled aggregation in non-serialized form.
         results: Dict[str, BundledAggregations] = {}
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 604eb6ec15..92e1de0500 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -33,6 +33,7 @@ from typing import (
 import attr
 from typing_extensions import TypedDict
 
+import synapse.events.snapshot
 from synapse.api.constants import (
     EventContentFields,
     EventTypes,
@@ -72,12 +73,12 @@ from synapse.types import (
     RoomID,
     RoomStreamToken,
     StateMap,
+    StreamKeyType,
     StreamToken,
     UserID,
     create_requester,
 )
 from synapse.util import stringutils
-from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.stringutils import parse_and_validate_server_name
 from synapse.visibility import filter_events_for_client
@@ -149,10 +150,11 @@ class RoomCreationHandler:
             )
             preset_config["encrypted"] = encrypted
 
-        self._replication = hs.get_replication_data_handler()
+        self._default_power_level_content_override = (
+            self.config.room.default_power_level_content_override
+        )
 
-        # linearizer to stop two upgrades happening at once
-        self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")
+        self._replication = hs.get_replication_data_handler()
 
         # If a user tries to update the same room multiple times in quick
         # succession, only process the first attempt and return its result to
@@ -196,6 +198,39 @@ class RoomCreationHandler:
                     400, "An upgrade for this room is currently in progress"
                 )
 
+        # Check whether the room exists and 404 if it doesn't.
+        # We could go straight for the auth check, but that will raise a 403 instead.
+        old_room = await self.store.get_room(old_room_id)
+        if old_room is None:
+            raise NotFoundError("Unknown room id %s" % (old_room_id,))
+
+        new_room_id = self._generate_room_id()
+
+        # Check whether the user has the power level to carry out the upgrade.
+        # `check_auth_rules_from_context` will check that they are in the room and have
+        # the required power level to send the tombstone event.
+        (
+            tombstone_event,
+            tombstone_context,
+        ) = await self.event_creation_handler.create_event(
+            requester,
+            {
+                "type": EventTypes.Tombstone,
+                "state_key": "",
+                "room_id": old_room_id,
+                "sender": user_id,
+                "content": {
+                    "body": "This room has been replaced",
+                    "replacement_room": new_room_id,
+                },
+            },
+        )
+        old_room_version = await self.store.get_room_version(old_room_id)
+        validate_event_for_room_version(old_room_version, tombstone_event)
+        await self._event_auth_handler.check_auth_rules_from_context(
+            old_room_version, tombstone_event, tombstone_context
+        )
+
         # Upgrade the room
         #
         # If this user has sent multiple upgrade requests for the same room
@@ -206,19 +241,35 @@ class RoomCreationHandler:
             self._upgrade_room,
             requester,
             old_room_id,
-            new_version,  # args for _upgrade_room
+            old_room,  # args for _upgrade_room
+            new_room_id,
+            new_version,
+            tombstone_event,
+            tombstone_context,
         )
 
         return ret
 
     async def _upgrade_room(
-        self, requester: Requester, old_room_id: str, new_version: RoomVersion
+        self,
+        requester: Requester,
+        old_room_id: str,
+        old_room: Dict[str, Any],
+        new_room_id: str,
+        new_version: RoomVersion,
+        tombstone_event: EventBase,
+        tombstone_context: synapse.events.snapshot.EventContext,
     ) -> str:
         """
         Args:
             requester: the user requesting the upgrade
             old_room_id: the id of the room to be replaced
-            new_versions: the version to upgrade the room to
+            old_room: a dict containing room information for the room to be replaced,
+                as returned by `RoomWorkerStore.get_room`.
+            new_room_id: the id of the replacement room
+            new_version: the version to upgrade the room to
+            tombstone_event: the tombstone event to send to the old room
+            tombstone_context: the context for the tombstone event
 
         Raises:
             ShadowBanError if the requester is shadow-banned.
@@ -226,40 +277,15 @@ class RoomCreationHandler:
         user_id = requester.user.to_string()
         assert self.hs.is_mine_id(user_id), "User must be our own: %s" % (user_id,)
 
-        # start by allocating a new room id
-        r = await self.store.get_room(old_room_id)
-        if r is None:
-            raise NotFoundError("Unknown room id %s" % (old_room_id,))
-        new_room_id = await self._generate_room_id(
-            creator_id=user_id,
-            is_public=r["is_public"],
-            room_version=new_version,
-        )
-
         logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
 
-        # we create and auth the tombstone event before properly creating the new
-        # room, to check our user has perms in the old room.
-        (
-            tombstone_event,
-            tombstone_context,
-        ) = await self.event_creation_handler.create_event(
-            requester,
-            {
-                "type": EventTypes.Tombstone,
-                "state_key": "",
-                "room_id": old_room_id,
-                "sender": user_id,
-                "content": {
-                    "body": "This room has been replaced",
-                    "replacement_room": new_room_id,
-                },
-            },
-        )
-        old_room_version = await self.store.get_room_version(old_room_id)
-        validate_event_for_room_version(old_room_version, tombstone_event)
-        await self._event_auth_handler.check_auth_rules_from_context(
-            old_room_version, tombstone_event, tombstone_context
+        # create the new room. may raise a `StoreError` in the exceedingly unlikely
+        # event of a room ID collision.
+        await self.store.store_room(
+            room_id=new_room_id,
+            room_creator_user_id=user_id,
+            is_public=old_room["is_public"],
+            room_version=new_version,
         )
 
         await self.clone_existing_room(
@@ -277,7 +303,10 @@ class RoomCreationHandler:
             context=tombstone_context,
         )
 
-        old_room_state = await tombstone_context.get_current_state_ids()
+        state_filter = StateFilter.from_types(
+            [(EventTypes.CanonicalAlias, ""), (EventTypes.PowerLevels, "")]
+        )
+        old_room_state = await tombstone_context.get_current_state_ids(state_filter)
 
         # We know the tombstone event isn't an outlier so it has current state.
         assert old_room_state is not None
@@ -401,7 +430,7 @@ class RoomCreationHandler:
             requester: the user requesting the upgrade
             old_room_id : the id of the room to be replaced
             new_room_id: the id to give the new room (should already have been
-                created with _gemerate_room_id())
+                created with _generate_room_id())
             new_room_version: the new room version to use
             tombstone_event_id: the ID of the tombstone event in the old room.
         """
@@ -443,14 +472,14 @@ class RoomCreationHandler:
             (EventTypes.PowerLevels, ""),
         ]
 
-        # If the old room was a space, copy over the room type and the rooms in
-        # the space.
-        if (
-            old_room_create_event.content.get(EventContentFields.ROOM_TYPE)
-            == RoomTypes.SPACE
-        ):
-            creation_content[EventContentFields.ROOM_TYPE] = RoomTypes.SPACE
-            types_to_copy.append((EventTypes.SpaceChild, None))
+        # Copy the room type as per MSC3818.
+        room_type = old_room_create_event.content.get(EventContentFields.ROOM_TYPE)
+        if room_type is not None:
+            creation_content[EventContentFields.ROOM_TYPE] = room_type
+
+            # If the old room was a space, copy over the rooms in the space.
+            if room_type == RoomTypes.SPACE:
+                types_to_copy.append((EventTypes.SpaceChild, None))
 
         old_room_state_ids = await self.store.get_filtered_current_state_ids(
             old_room_id, StateFilter.from_types(types_to_copy)
@@ -725,6 +754,21 @@ class RoomCreationHandler:
                 if wchar in config["room_alias_name"]:
                     raise SynapseError(400, "Invalid characters in room alias")
 
+            if ":" in config["room_alias_name"]:
+                # Prevent someone from trying to pass in a full alias here.
+                # Note that it's permissible for a room alias to have multiple
+                # hash symbols at the start (notably bridged over from IRC, too),
+                # but the first colon in the alias is defined to separate the local
+                # part from the server name.
+                # (remember server names can contain port numbers, also separated
+                # by a colon. But under no circumstances should the local part be
+                # allowed to contain a colon!)
+                raise SynapseError(
+                    400,
+                    "':' is not permitted in the room alias name. "
+                    "Please note this expects a local part — 'wombat', not '#wombat:example.com'.",
+                )
+
             room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
             mapping = await self.store.get_association_from_room_alias(room_alias)
 
@@ -778,7 +822,7 @@ class RoomCreationHandler:
         visibility = config.get("visibility", "private")
         is_public = visibility == "public"
 
-        room_id = await self._generate_room_id(
+        room_id = await self._generate_and_create_room_id(
             creator_id=user_id,
             is_public=is_public,
             room_version=room_version,
@@ -1042,9 +1086,19 @@ class RoomCreationHandler:
                 for invitee in invite_list:
                     power_level_content["users"][invitee] = 100
 
-            # Power levels overrides are defined per chat preset
+            # If the user supplied a preset name e.g. "private_chat",
+            # we apply that preset
             power_level_content.update(config["power_level_content_override"])
 
+            # If the server config contains default_power_level_content_override,
+            # and that contains information for this room preset, apply it.
+            if self._default_power_level_content_override:
+                override = self._default_power_level_content_override.get(preset_config)
+                if override is not None:
+                    power_level_content.update(override)
+
+            # Finally, if the user supplied specific permissions for this room,
+            # apply those.
             if power_level_content_override:
                 power_level_content.update(power_level_content_override)
 
@@ -1090,7 +1144,26 @@ class RoomCreationHandler:
 
         return last_sent_stream_id
 
-    async def _generate_room_id(
+    def _generate_room_id(self) -> str:
+        """Generates a random room ID.
+
+        Room IDs look like "!opaque_id:domain" and are case-sensitive as per the spec
+        at https://spec.matrix.org/v1.2/appendices/#room-ids-and-event-ids.
+
+        Does not check for collisions with existing rooms or prevent future calls from
+        returning the same room ID. To ensure the uniqueness of a new room ID, use
+        `_generate_and_create_room_id` instead.
+
+        Synapse's room IDs are 18 [a-zA-Z] characters long, which comes out to around
+        102 bits.
+
+        Returns:
+            A random room ID of the form "!opaque_id:domain".
+        """
+        random_string = stringutils.random_string(18)
+        return RoomID(random_string, self.hs.hostname).to_string()
+
+    async def _generate_and_create_room_id(
         self,
         creator_id: str,
         is_public: bool,
@@ -1101,8 +1174,7 @@ class RoomCreationHandler:
         attempts = 0
         while attempts < 5:
             try:
-                random_string = stringutils.random_string(18)
-                gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
+                gen_room_id = self._generate_room_id()
                 await self.store.store_room(
                     room_id=gen_room_id,
                     room_creator_user_id=creator_id,
@@ -1239,10 +1311,10 @@ class RoomContextHandler:
             events_after=events_after,
             state=await filter_evts(state_events),
             aggregations=aggregations,
-            start=await token.copy_and_replace("room_key", results.start).to_string(
-                self.store
-            ),
-            end=await token.copy_and_replace("room_key", results.end).to_string(
+            start=await token.copy_and_replace(
+                StreamKeyType.ROOM, results.start
+            ).to_string(self.store),
+            end=await token.copy_and_replace(StreamKeyType.ROOM, results.end).to_string(
                 self.store
             ),
         )
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 29de7e5bed..fbfd748406 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -53,6 +53,7 @@ class RoomBatchHandler:
         # We want to use the successor event depth so they appear after `prev_event` because
         # it has a larger `depth` but before the successor event because the `stream_ordering`
         # is negative before the successor event.
+        assert most_recent_prev_event_id is not None
         successor_event_ids = await self.store.get_successor_events(
             most_recent_prev_event_id
         )
@@ -139,6 +140,7 @@ class RoomBatchHandler:
             _,
         ) = await self.store.get_max_depth_of(event_ids)
         # mapping from (type, state_key) -> state_event_id
+        assert most_recent_event_id is not None
         prev_state_map = await self.state_store.get_state_ids_for_event(
             most_recent_event_id
         )
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 0cd8833326..056f0cdade 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -38,6 +38,7 @@ from synapse.event_auth import get_named_level, get_power_level_event
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
+from synapse.storage.state import StateFilter
 from synapse.types import (
     JsonDict,
     Requester,
@@ -362,7 +363,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             historical=historical,
         )
 
-        prev_state_ids = await context.get_prev_state_ids()
+        prev_state_ids = await context.get_prev_state_ids(
+            StateFilter.from_types([(EventTypes.Member, None)])
+        )
 
         prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
 
@@ -1174,7 +1177,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         else:
             requester = types.create_requester(target_user)
 
-        prev_state_ids = await context.get_prev_state_ids()
+        prev_state_ids = await context.get_prev_state_ids(
+            StateFilter.from_types([(EventTypes.GuestAccess, None)])
+        )
         if event.membership == Membership.JOIN:
             if requester.is_guest:
                 guest_can_join = await self._can_guest_join(prev_state_ids)
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index ff24ec8063..af83de3193 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -562,8 +562,13 @@ class RoomSummaryHandler:
         if join_rules_event_id:
             join_rules_event = await self._store.get_event(join_rules_event_id)
             join_rule = join_rules_event.content.get("join_rule")
-            if join_rule == JoinRules.PUBLIC or (
-                room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
+            if (
+                join_rule == JoinRules.PUBLIC
+                or (room_version.msc2403_knocking and join_rule == JoinRules.KNOCK)
+                or (
+                    room_version.msc3787_knock_restricted_join_rule
+                    and join_rule == JoinRules.KNOCK_RESTRICTED
+                )
             ):
                 return True
 
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 5619f8f50e..cd1c47dae8 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -24,7 +24,7 @@ from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.filtering import Filter
 from synapse.events import EventBase
 from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, StreamKeyType, UserID
 from synapse.visibility import filter_events_for_client
 
 if TYPE_CHECKING:
@@ -655,11 +655,11 @@ class SearchHandler:
                 "events_before": events_before,
                 "events_after": events_after,
                 "start": await now_token.copy_and_replace(
-                    "room_key", res.start
+                    StreamKeyType.ROOM, res.start
+                ).to_string(self.store),
+                "end": await now_token.copy_and_replace(
+                    StreamKeyType.ROOM, res.end
                 ).to_string(self.store),
-                "end": await now_token.copy_and_replace("room_key", res.end).to_string(
-                    self.store
-                ),
             }
 
             if include_profile:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 2c555a66d0..59b5d497be 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -37,6 +37,7 @@ from synapse.types import (
     Requester,
     RoomStreamToken,
     StateMap,
+    StreamKeyType,
     StreamToken,
     UserID,
 )
@@ -410,10 +411,10 @@ class SyncHandler:
             set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
             return sync_result
 
-    async def push_rules_for_user(self, user: UserID) -> JsonDict:
+    async def push_rules_for_user(self, user: UserID) -> Dict[str, Dict[str, list]]:
         user_id = user.to_string()
-        rules = await self.store.get_push_rules_for_user(user_id)
-        rules = format_push_rules_for_user(user, rules)
+        rules_raw = await self.store.get_push_rules_for_user(user_id)
+        rules = format_push_rules_for_user(user, rules_raw)
         return rules
 
     async def ephemeral_by_room(
@@ -449,7 +450,7 @@ class SyncHandler:
                 room_ids=room_ids,
                 is_guest=sync_config.is_guest,
             )
-            now_token = now_token.copy_and_replace("typing_key", typing_key)
+            now_token = now_token.copy_and_replace(StreamKeyType.TYPING, typing_key)
 
             ephemeral_by_room: JsonDict = {}
 
@@ -471,7 +472,7 @@ class SyncHandler:
                 room_ids=room_ids,
                 is_guest=sync_config.is_guest,
             )
-            now_token = now_token.copy_and_replace("receipt_key", receipt_key)
+            now_token = now_token.copy_and_replace(StreamKeyType.RECEIPT, receipt_key)
 
             for event in receipts:
                 room_id = event["room_id"]
@@ -537,7 +538,9 @@ class SyncHandler:
                 prev_batch_token = now_token
                 if recents:
                     room_key = recents[0].internal_metadata.before
-                    prev_batch_token = now_token.copy_and_replace("room_key", room_key)
+                    prev_batch_token = now_token.copy_and_replace(
+                        StreamKeyType.ROOM, room_key
+                    )
 
                 return TimelineBatch(
                     events=recents, prev_batch=prev_batch_token, limited=False
@@ -611,7 +614,7 @@ class SyncHandler:
                 recents = recents[-timeline_limit:]
                 room_key = recents[0].internal_metadata.before
 
-            prev_batch_token = now_token.copy_and_replace("room_key", room_key)
+            prev_batch_token = now_token.copy_and_replace(StreamKeyType.ROOM, room_key)
 
         # Don't bother to bundle aggregations if the timeline is unlimited,
         # as clients will have all the necessary information.
@@ -1398,7 +1401,7 @@ class SyncHandler:
                 now_token.to_device_key,
             )
             sync_result_builder.now_token = now_token.copy_and_replace(
-                "to_device_key", stream_id
+                StreamKeyType.TO_DEVICE, stream_id
             )
             sync_result_builder.to_device = messages
         else:
@@ -1503,7 +1506,7 @@ class SyncHandler:
         )
         assert presence_key
         sync_result_builder.now_token = now_token.copy_and_replace(
-            "presence_key", presence_key
+            StreamKeyType.PRESENCE, presence_key
         )
 
         extra_users_ids = set(newly_joined_or_invited_users)
@@ -1826,7 +1829,7 @@ class SyncHandler:
                 # stream token as it'll only be used in the context of this
                 # room. (c.f. the docstring of `to_room_stream_token`).
                 leave_token = since_token.copy_and_replace(
-                    "room_key", leave_position.to_room_stream_token()
+                    StreamKeyType.ROOM, leave_position.to_room_stream_token()
                 )
 
                 # If this is an out of band message, like a remote invite
@@ -1875,7 +1878,9 @@ class SyncHandler:
             if room_entry:
                 events, start_key = room_entry
 
-                prev_batch_token = now_token.copy_and_replace("room_key", start_key)
+                prev_batch_token = now_token.copy_and_replace(
+                    StreamKeyType.ROOM, start_key
+                )
 
                 entry = RoomSyncResultBuilder(
                     room_id=room_id,
@@ -1972,7 +1977,7 @@ class SyncHandler:
                             continue
 
                 leave_token = now_token.copy_and_replace(
-                    "room_key", RoomStreamToken(None, event.stream_ordering)
+                    StreamKeyType.ROOM, RoomStreamToken(None, event.stream_ordering)
                 )
                 room_entries.append(
                     RoomSyncResultBuilder(
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 6854428b7c..bb00750bfd 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -25,7 +25,7 @@ from synapse.metrics.background_process_metrics import (
 )
 from synapse.replication.tcp.streams import TypingStream
 from synapse.streams import EventSource
-from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
@@ -382,7 +382,7 @@ class TypingWriterHandler(FollowerTypingHandler):
         )
 
         self.notifier.on_new_event(
-            "typing_key", self._latest_room_serial, rooms=[member.room_id]
+            StreamKeyType.TYPING, self._latest_room_serial, rooms=[member.room_id]
         )
 
     async def get_all_typing_updates(
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 8310fb466a..084d0a5b84 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -43,8 +43,10 @@ from twisted.internet import defer, error as twisted_error, protocol, ssl
 from twisted.internet.address import IPv4Address, IPv6Address
 from twisted.internet.interfaces import (
     IAddress,
+    IDelayedCall,
     IHostResolution,
     IReactorPluggableNameResolver,
+    IReactorTime,
     IResolutionReceiver,
     ITCPTransport,
 )
@@ -121,13 +123,15 @@ def check_against_blacklist(
 _EPSILON = 0.00000001
 
 
-def _make_scheduler(reactor):
+def _make_scheduler(
+    reactor: IReactorTime,
+) -> Callable[[Callable[[], object]], IDelayedCall]:
     """Makes a schedular suitable for a Cooperator using the given reactor.
 
     (This is effectively just a copy from `twisted.internet.task`)
     """
 
-    def _scheduler(x):
+    def _scheduler(x: Callable[[], object]) -> IDelayedCall:
         return reactor.callLater(_EPSILON, x)
 
     return _scheduler
@@ -348,7 +352,7 @@ class SimpleHttpClient:
         # XXX: The justification for using the cache factor here is that larger instances
         # will need both more cache and more connections.
         # Still, this should probably be a separate dial
-        pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5))
+        pool.maxPersistentPerHost = max(int(100 * hs.config.caches.global_factor), 5)
         pool.cachedConnectionTimeout = 2 * 60
 
         self.agent: IAgent = ProxyAgent(
@@ -775,7 +779,7 @@ class SimpleHttpClient:
         )
 
 
-def _timeout_to_request_timed_out_error(f: Failure):
+def _timeout_to_request_timed_out_error(f: Failure) -> Failure:
     if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError):
         # The TCP connection has its own timeout (set by the 'connectTimeout' param
         # on the Agent), which raises twisted_error.TimeoutError exception.
@@ -809,7 +813,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
     def __init__(self, deferred: defer.Deferred):
         self.deferred = deferred
 
-    def _maybe_fail(self):
+    def _maybe_fail(self) -> None:
         """
         Report a max size exceed error and disconnect the first time this is called.
         """
@@ -933,12 +937,12 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
     Do not use this since it allows an attacker to intercept your communications.
     """
 
-    def __init__(self):
+    def __init__(self) -> None:
         self._context = SSL.Context(SSL.SSLv23_METHOD)
         self._context.set_verify(VERIFY_NONE, lambda *_: False)
 
     def getContext(self, hostname=None, port=None):
         return self._context
 
-    def creatorForNetloc(self, hostname, port):
+    def creatorForNetloc(self, hostname: bytes, port: int):
         return self
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index 203e995bb7..23a60af171 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -14,15 +14,22 @@
 
 import base64
 import logging
-from typing import Optional
+from typing import Optional, Union
 
 import attr
 from zope.interface import implementer
 
 from twisted.internet import defer, protocol
 from twisted.internet.error import ConnectError
-from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
+from twisted.internet.interfaces import (
+    IAddress,
+    IConnector,
+    IProtocol,
+    IReactorCore,
+    IStreamClientEndpoint,
+)
 from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
+from twisted.python.failure import Failure
 from twisted.web import http
 
 logger = logging.getLogger(__name__)
@@ -81,14 +88,14 @@ class HTTPConnectProxyEndpoint:
         self._port = port
         self._proxy_creds = proxy_creds
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
 
     # Mypy encounters a false positive here: it complains that ClientFactory
     # is incompatible with IProtocolFactory. But ClientFactory inherits from
     # Factory, which implements IProtocolFactory. So I think this is a bug
     # in mypy-zope.
-    def connect(self, protocolFactory: ClientFactory):  # type: ignore[override]
+    def connect(self, protocolFactory: ClientFactory) -> "defer.Deferred[IProtocol]":  # type: ignore[override]
         f = HTTPProxiedClientFactory(
             self._host, self._port, protocolFactory, self._proxy_creds
         )
@@ -125,10 +132,10 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
         self.proxy_creds = proxy_creds
         self.on_connection: "defer.Deferred[None]" = defer.Deferred()
 
-    def startedConnecting(self, connector):
+    def startedConnecting(self, connector: IConnector) -> None:
         return self.wrapped_factory.startedConnecting(connector)
 
-    def buildProtocol(self, addr):
+    def buildProtocol(self, addr: IAddress) -> "HTTPConnectProtocol":
         wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
         if wrapped_protocol is None:
             raise TypeError("buildProtocol produced None instead of a Protocol")
@@ -141,13 +148,13 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
             self.proxy_creds,
         )
 
-    def clientConnectionFailed(self, connector, reason):
+    def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
         logger.debug("Connection to proxy failed: %s", reason)
         if not self.on_connection.called:
             self.on_connection.errback(reason)
         return self.wrapped_factory.clientConnectionFailed(connector, reason)
 
-    def clientConnectionLost(self, connector, reason):
+    def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
         logger.debug("Connection to proxy lost: %s", reason)
         if not self.on_connection.called:
             self.on_connection.errback(reason)
@@ -191,10 +198,10 @@ class HTTPConnectProtocol(protocol.Protocol):
         )
         self.http_setup_client.on_connected.addCallback(self.proxyConnected)
 
-    def connectionMade(self):
+    def connectionMade(self) -> None:
         self.http_setup_client.makeConnection(self.transport)
 
-    def connectionLost(self, reason=connectionDone):
+    def connectionLost(self, reason: Failure = connectionDone) -> None:
         if self.wrapped_protocol.connected:
             self.wrapped_protocol.connectionLost(reason)
 
@@ -203,7 +210,7 @@ class HTTPConnectProtocol(protocol.Protocol):
         if not self.connected_deferred.called:
             self.connected_deferred.errback(reason)
 
-    def proxyConnected(self, _):
+    def proxyConnected(self, _: Union[None, "defer.Deferred[None]"]) -> None:
         self.wrapped_protocol.makeConnection(self.transport)
 
         self.connected_deferred.callback(self.wrapped_protocol)
@@ -213,7 +220,7 @@ class HTTPConnectProtocol(protocol.Protocol):
         if buf:
             self.wrapped_protocol.dataReceived(buf)
 
-    def dataReceived(self, data: bytes):
+    def dataReceived(self, data: bytes) -> None:
         # if we've set up the HTTP protocol, we can send the data there
         if self.wrapped_protocol.connected:
             return self.wrapped_protocol.dataReceived(data)
@@ -243,7 +250,7 @@ class HTTPConnectSetupClient(http.HTTPClient):
         self.proxy_creds = proxy_creds
         self.on_connected: "defer.Deferred[None]" = defer.Deferred()
 
-    def connectionMade(self):
+    def connectionMade(self) -> None:
         logger.debug("Connected to proxy, sending CONNECT")
         self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
 
@@ -257,14 +264,14 @@ class HTTPConnectSetupClient(http.HTTPClient):
 
         self.endHeaders()
 
-    def handleStatus(self, version: bytes, status: bytes, message: bytes):
+    def handleStatus(self, version: bytes, status: bytes, message: bytes) -> None:
         logger.debug("Got Status: %s %s %s", status, message, version)
         if status != b"200":
             raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}")
 
-    def handleEndHeaders(self):
+    def handleEndHeaders(self) -> None:
         logger.debug("End Headers")
         self.on_connected.callback(None)
 
-    def handleResponse(self, body):
+    def handleResponse(self, body: bytes) -> None:
         pass
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index a8a520f809..2f0177f1e2 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -239,7 +239,7 @@ class MatrixHostnameEndpointFactory:
 
         self._srv_resolver = srv_resolver
 
-    def endpointForURI(self, parsed_uri: URI):
+    def endpointForURI(self, parsed_uri: URI) -> "MatrixHostnameEndpoint":
         return MatrixHostnameEndpoint(
             self._reactor,
             self._proxy_reactor,
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index f68646fd0d..de0e882b33 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -16,7 +16,7 @@
 import logging
 import random
 import time
-from typing import Callable, Dict, List
+from typing import Any, Callable, Dict, List
 
 import attr
 
@@ -109,7 +109,7 @@ class SrvResolver:
 
     def __init__(
         self,
-        dns_client=client,
+        dns_client: Any = client,
         cache: Dict[bytes, List[Server]] = SERVER_CACHE,
         get_time: Callable[[], float] = time.time,
     ):
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 43f2140429..71b685fade 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -74,9 +74,9 @@ _well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
 _had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class WellKnownLookupResult:
-    delegated_server = attr.ib()
+    delegated_server: Optional[bytes]
 
 
 class WellKnownResolver:
@@ -336,4 +336,4 @@ def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
 class _FetchWellKnownFailure(Exception):
     # True if we didn't get a non-5xx HTTP response, i.e. this may or may not be
     # a temporary failure.
-    temporary = attr.ib()
+    temporary: bool = attr.ib()
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c2ec3caa0e..0b9475debd 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -23,6 +23,8 @@ from http import HTTPStatus
 from io import BytesIO, StringIO
 from typing import (
     TYPE_CHECKING,
+    Any,
+    BinaryIO,
     Callable,
     Dict,
     Generic,
@@ -44,7 +46,7 @@ from typing_extensions import Literal
 from twisted.internet import defer
 from twisted.internet.error import DNSLookupError
 from twisted.internet.interfaces import IReactorTime
-from twisted.internet.task import _EPSILON, Cooperator
+from twisted.internet.task import Cooperator
 from twisted.web.client import ResponseFailed
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IBodyProducer, IResponse
@@ -58,11 +60,13 @@ from synapse.api.errors import (
     RequestSendFailed,
     SynapseError,
 )
+from synapse.crypto.context_factory import FederationPolicyForHTTPS
 from synapse.http import QuieterFileBodyProducer
 from synapse.http.client import (
     BlacklistingAgentWrapper,
     BodyExceededMaxSize,
     ByteWriteable,
+    _make_scheduler,
     encode_query_args,
     read_body_with_max_size,
 )
@@ -181,7 +185,7 @@ class JsonParser(ByteParser[Union[JsonDict, list]]):
 
     CONTENT_TYPE = "application/json"
 
-    def __init__(self):
+    def __init__(self) -> None:
         self._buffer = StringIO()
         self._binary_wrapper = BinaryIOWrapper(self._buffer)
 
@@ -299,7 +303,9 @@ async def _handle_response(
 class BinaryIOWrapper:
     """A wrapper for a TextIO which converts from bytes on the fly."""
 
-    def __init__(self, file: typing.TextIO, encoding="utf-8", errors="strict"):
+    def __init__(
+        self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
+    ):
         self.decoder = codecs.getincrementaldecoder(encoding)(errors)
         self.file = file
 
@@ -317,7 +323,11 @@ class MatrixFederationHttpClient:
             requests.
     """
 
-    def __init__(self, hs: "HomeServer", tls_client_options_factory):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+    ):
         self.hs = hs
         self.signing_key = hs.signing_key
         self.server_name = hs.hostname
@@ -348,10 +358,7 @@ class MatrixFederationHttpClient:
         self.version_string_bytes = hs.version_string.encode("ascii")
         self.default_timeout = 60
 
-        def schedule(x):
-            self.reactor.callLater(_EPSILON, x)
-
-        self._cooperator = Cooperator(scheduler=schedule)
+        self._cooperator = Cooperator(scheduler=_make_scheduler(self.reactor))
 
         self._sleeper = AwakenableSleeper(self.reactor)
 
@@ -364,7 +371,7 @@ class MatrixFederationHttpClient:
         self,
         request: MatrixFederationRequest,
         try_trailing_slash_on_400: bool = False,
-        **send_request_args,
+        **send_request_args: Any,
     ) -> IResponse:
         """Wrapper for _send_request which can optionally retry the request
         upon receiving a combination of a 400 HTTP response code and a
@@ -740,7 +747,7 @@ class MatrixFederationHttpClient:
         for key, sig in request["signatures"][self.server_name].items():
             auth_headers.append(
                 (
-                    'X-Matrix origin=%s,key="%s",sig="%s",destination="%s"'
+                    'X-Matrix origin="%s",key="%s",sig="%s",destination="%s"'
                     % (
                         self.server_name,
                         key,
@@ -1159,7 +1166,7 @@ class MatrixFederationHttpClient:
         self,
         destination: str,
         path: str,
-        output_stream,
+        output_stream: BinaryIO,
         args: Optional[QueryParams] = None,
         retry_on_dns_fail: bool = True,
         max_size: Optional[int] = None,
@@ -1250,10 +1257,10 @@ class MatrixFederationHttpClient:
         return length, headers
 
 
-def _flatten_response_never_received(e):
+def _flatten_response_never_received(e: BaseException) -> str:
     if hasattr(e, "reasons"):
         reasons = ", ".join(
-            _flatten_response_never_received(f.value) for f in e.reasons
+            _flatten_response_never_received(f.value) for f in e.reasons  # type: ignore[attr-defined]
         )
 
         return "%s:[%s]" % (type(e).__name__, reasons)
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index a16dde2380..b2a50c9105 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -245,7 +245,7 @@ def http_proxy_endpoint(
     proxy: Optional[bytes],
     reactor: IReactorCore,
     tls_options_factory: Optional[IPolicyForHTTPS],
-    **kwargs,
+    **kwargs: object,
 ) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]:
     """Parses an http proxy setting and returns an endpoint for the proxy
 
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 4886626d50..2b6d113544 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -162,7 +162,7 @@ class RequestMetrics:
         with _in_flight_requests_lock:
             _in_flight_requests.add(self)
 
-    def stop(self, time_sec, response_code, sent_bytes):
+    def stop(self, time_sec: float, response_code: int, sent_bytes: int) -> None:
         with _in_flight_requests_lock:
             _in_flight_requests.discard(self)
 
@@ -186,13 +186,13 @@ class RequestMetrics:
             )
             return
 
-        response_code = str(response_code)
+        response_code_str = str(response_code)
 
-        outgoing_responses_counter.labels(self.method, response_code).inc()
+        outgoing_responses_counter.labels(self.method, response_code_str).inc()
 
         response_count.labels(self.method, self.name, tag).inc()
 
-        response_timer.labels(self.method, self.name, tag, response_code).observe(
+        response_timer.labels(self.method, self.name, tag, response_code_str).observe(
             time_sec - self.start_ts
         )
 
@@ -221,7 +221,7 @@ class RequestMetrics:
         # flight.
         self.update_metrics()
 
-    def update_metrics(self):
+    def update_metrics(self) -> None:
         """Updates the in flight metrics with values from this request."""
         if not self.start_context:
             logger.error(
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 657bffcddd..e3dcc3f3dd 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -33,6 +33,7 @@ from typing import (
     Optional,
     Pattern,
     Tuple,
+    TypeVar,
     Union,
 )
 
@@ -92,6 +93,68 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
 HTTP_STATUS_REQUEST_CANCELLED = 499
 
 
+F = TypeVar("F", bound=Callable[..., Any])
+
+
+_cancellable_method_names = frozenset(
+    {
+        # `RestServlet`, `BaseFederationServlet` and `BaseFederationServerServlet`
+        # methods
+        "on_GET",
+        "on_PUT",
+        "on_POST",
+        "on_DELETE",
+        # `_AsyncResource`, `DirectServeHtmlResource` and `DirectServeJsonResource`
+        # methods
+        "_async_render_GET",
+        "_async_render_PUT",
+        "_async_render_POST",
+        "_async_render_DELETE",
+        "_async_render_OPTIONS",
+        # `ReplicationEndpoint` methods
+        "_handle_request",
+    }
+)
+
+
+def cancellable(method: F) -> F:
+    """Marks a servlet method as cancellable.
+
+    Methods with this decorator will be cancelled if the client disconnects before we
+    finish processing the request.
+
+    During cancellation, `Deferred.cancel()` will be invoked on the `Deferred` wrapping
+    the method. The `cancel()` call will propagate down to the `Deferred` that is
+    currently being waited on. That `Deferred` will raise a `CancelledError`, which will
+    propagate up, as per normal exception handling.
+
+    Before applying this decorator to a new endpoint, you MUST recursively check
+    that all `await`s in the function are on `async` functions or `Deferred`s that
+    handle cancellation cleanly, otherwise a variety of bugs may occur, ranging from
+    premature logging context closure, to stuck requests, to database corruption.
+
+    Usage:
+        class SomeServlet(RestServlet):
+            @cancellable
+            async def on_GET(self, request: SynapseRequest) -> ...:
+                ...
+    """
+    if method.__name__ not in _cancellable_method_names and not any(
+        method.__name__.startswith(prefix) for prefix in _cancellable_method_names
+    ):
+        raise ValueError(
+            "@cancellable decorator can only be applied to servlet methods."
+        )
+
+    method.cancellable = True  # type: ignore[attr-defined]
+    return method
+
+
+def is_method_cancellable(method: Callable[..., Any]) -> bool:
+    """Checks whether a servlet method has the `@cancellable` flag."""
+    return getattr(method, "cancellable", False)
+
+
 def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
     """Sends a JSON error response to clients."""
 
@@ -253,6 +316,9 @@ class HttpServer(Protocol):
         If the regex contains groups these gets passed to the callback via
         an unpacked tuple.
 
+        The callback may be marked with the `@cancellable` decorator, which will
+        cause request processing to be cancelled when clients disconnect early.
+
         Args:
             method: The HTTP method to listen to.
             path_patterns: The regex used to match requests.
@@ -283,7 +349,9 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
 
     def render(self, request: SynapseRequest) -> int:
         """This gets called by twisted every time someone sends us a request."""
-        defer.ensureDeferred(self._async_render_wrapper(request))
+        request.render_deferred = defer.ensureDeferred(
+            self._async_render_wrapper(request)
+        )
         return NOT_DONE_YET
 
     @wrap_async_request_handler
@@ -319,6 +387,8 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
 
         method_handler = getattr(self, "_async_render_%s" % (request_method,), None)
         if method_handler:
+            request.is_render_cancellable = is_method_cancellable(method_handler)
+
             raw_callback_return = method_handler(request)
 
             # Is it synchronous? We'll allow this for now.
@@ -479,6 +549,8 @@ class JsonResource(DirectServeJsonResource):
     async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
         callback, servlet_classname, group_dict = self._get_handler_for_request(request)
 
+        request.is_render_cancellable = is_method_cancellable(callback)
+
         # Make sure we have an appropriate name for this handler in prometheus
         # (rather than the default of JsonResource).
         request.request_metrics.name = servlet_classname
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 0b85a57d77..eeec74b78a 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union
 import attr
 from zope.interface import implementer
 
+from twisted.internet.defer import Deferred
 from twisted.internet.interfaces import IAddress, IReactorTime
 from twisted.python.failure import Failure
 from twisted.web.http import HTTPChannel
@@ -91,6 +92,14 @@ class SynapseRequest(Request):
         # we can't yet create the logcontext, as we don't know the method.
         self.logcontext: Optional[LoggingContext] = None
 
+        # The `Deferred` to cancel if the client disconnects early and
+        # `is_render_cancellable` is set. Expected to be set by `Resource.render`.
+        self.render_deferred: Optional["Deferred[None]"] = None
+        # A boolean indicating whether `render_deferred` should be cancelled if the
+        # client disconnects early. Expected to be set by the coroutine started by
+        # `Resource.render`, if rendering is asynchronous.
+        self.is_render_cancellable = False
+
         global _next_request_seq
         self.request_seq = _next_request_seq
         _next_request_seq += 1
@@ -357,7 +366,21 @@ class SynapseRequest(Request):
                     {"event": "client connection lost", "reason": str(reason.value)}
                 )
 
-            if not self._is_processing:
+            if self._is_processing:
+                if self.is_render_cancellable:
+                    if self.render_deferred is not None:
+                        # Throw a cancellation into the request processing, in the hope
+                        # that it will finish up sooner than it normally would.
+                        # The `self.processing()` context manager will call
+                        # `_finished_processing()` when done.
+                        with PreserveLoggingContext():
+                            self.render_deferred.cancel()
+                    else:
+                        logger.error(
+                            "Connection from client lost, but have no Deferred to "
+                            "cancel even though the request is marked as cancellable."
+                        )
+            else:
                 self._finished_processing()
 
     def _started_processing(self, servlet_name: str) -> None:
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index 475756f1db..5a61b21eaf 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -31,7 +31,11 @@ from twisted.internet.endpoints import (
     TCP4ClientEndpoint,
     TCP6ClientEndpoint,
 )
-from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
+from twisted.internet.interfaces import (
+    IPushProducer,
+    IReactorTCP,
+    IStreamClientEndpoint,
+)
 from twisted.internet.protocol import Factory, Protocol
 from twisted.internet.tcp import Connection
 from twisted.python.failure import Failure
@@ -59,14 +63,14 @@ class LogProducer:
     _buffer: Deque[logging.LogRecord]
     _paused: bool = attr.ib(default=False, init=False)
 
-    def pauseProducing(self):
+    def pauseProducing(self) -> None:
         self._paused = True
 
-    def stopProducing(self):
+    def stopProducing(self) -> None:
         self._paused = True
         self._buffer = deque()
 
-    def resumeProducing(self):
+    def resumeProducing(self) -> None:
         # If we're already producing, nothing to do.
         self._paused = False
 
@@ -102,8 +106,8 @@ class RemoteHandler(logging.Handler):
         host: str,
         port: int,
         maximum_buffer: int = 1000,
-        level=logging.NOTSET,
-        _reactor=None,
+        level: int = logging.NOTSET,
+        _reactor: Optional[IReactorTCP] = None,
     ):
         super().__init__(level=level)
         self.host = host
@@ -118,7 +122,7 @@ class RemoteHandler(logging.Handler):
         if _reactor is None:
             from twisted.internet import reactor
 
-            _reactor = reactor
+            _reactor = reactor  # type: ignore[assignment]
 
         try:
             ip = ip_address(self.host)
@@ -139,7 +143,7 @@ class RemoteHandler(logging.Handler):
         self._stopping = False
         self._connect()
 
-    def close(self):
+    def close(self) -> None:
         self._stopping = True
         self._service.stopService()
 
diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py
index c0f12ecd15..c88b8ae545 100644
--- a/synapse/logging/formatter.py
+++ b/synapse/logging/formatter.py
@@ -16,6 +16,8 @@
 import logging
 import traceback
 from io import StringIO
+from types import TracebackType
+from typing import Optional, Tuple, Type
 
 
 class LogFormatter(logging.Formatter):
@@ -28,10 +30,14 @@ class LogFormatter(logging.Formatter):
     where it was caught are logged).
     """
 
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-
-    def formatException(self, ei):
+    def formatException(
+        self,
+        ei: Tuple[
+            Optional[Type[BaseException]],
+            Optional[BaseException],
+            Optional[TracebackType],
+        ],
+    ) -> str:
         sio = StringIO()
         (typ, val, tb) = ei
 
diff --git a/synapse/logging/handlers.py b/synapse/logging/handlers.py
index 478b527494..dec2a2c3dd 100644
--- a/synapse/logging/handlers.py
+++ b/synapse/logging/handlers.py
@@ -49,7 +49,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
         )
         self._flushing_thread.start()
 
-        def on_reactor_running():
+        def on_reactor_running() -> None:
             self._reactor_started = True
 
         reactor_to_use: IReactorCore
@@ -74,7 +74,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
         else:
             return True
 
-    def _flush_periodically(self):
+    def _flush_periodically(self) -> None:
         """
         Whilst this handler is active, flush the handler periodically.
         """
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index d57e7c5324..a26a1a58e7 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -13,6 +13,8 @@
 # limitations under the License.import logging
 
 import logging
+from types import TracebackType
+from typing import Optional, Type
 
 from opentracing import Scope, ScopeManager
 
@@ -107,19 +109,26 @@ class _LogContextScope(Scope):
         and - if enter_logcontext was set - the logcontext is finished too.
     """
 
-    def __init__(self, manager, span, logcontext, enter_logcontext, finish_on_close):
+    def __init__(
+        self,
+        manager: LogContextScopeManager,
+        span,
+        logcontext,
+        enter_logcontext: bool,
+        finish_on_close: bool,
+    ):
         """
         Args:
-            manager (LogContextScopeManager):
+            manager:
                 the manager that is responsible for this scope.
             span (Span):
                 the opentracing span which this scope represents the local
                 lifetime for.
             logcontext (LogContext):
                 the logcontext to which this scope is attached.
-            enter_logcontext (Boolean):
+            enter_logcontext:
                 if True the logcontext will be exited when the scope is finished
-            finish_on_close (Boolean):
+            finish_on_close:
                 if True finish the span when the scope is closed
         """
         super().__init__(manager, span)
@@ -127,16 +136,21 @@ class _LogContextScope(Scope):
         self._finish_on_close = finish_on_close
         self._enter_logcontext = enter_logcontext
 
-    def __exit__(self, exc_type, value, traceback):
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
         if exc_type == twisted.internet.defer._DefGen_Return:
             # filter out defer.returnValue() calls
             exc_type = value = traceback = None
         super().__exit__(exc_type, value, traceback)
 
-    def __str__(self):
+    def __str__(self) -> str:
         return f"Scope<{self.span}>"
 
-    def close(self):
+    def close(self) -> None:
         active_scope = self.manager.active
         if active_scope is not self:
             logger.error(
diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py
index 6bc329f04a..1fc8a0e888 100644
--- a/synapse/metrics/jemalloc.py
+++ b/synapse/metrics/jemalloc.py
@@ -18,6 +18,7 @@ import os
 import re
 from typing import Iterable, Optional, overload
 
+import attr
 from prometheus_client import REGISTRY, Metric
 from typing_extensions import Literal
 
@@ -27,52 +28,24 @@ from synapse.metrics._types import Collector
 logger = logging.getLogger(__name__)
 
 
-def _setup_jemalloc_stats() -> None:
-    """Checks to see if jemalloc is loaded, and hooks up a collector to record
-    statistics exposed by jemalloc.
-    """
-
-    # Try to find the loaded jemalloc shared library, if any. We need to
-    # introspect into what is loaded, rather than loading whatever is on the
-    # path, as if we load a *different* jemalloc version things will seg fault.
-
-    # We look in `/proc/self/maps`, which only exists on linux.
-    if not os.path.exists("/proc/self/maps"):
-        logger.debug("Not looking for jemalloc as no /proc/self/maps exist")
-        return
-
-    # We're looking for a path at the end of the line that includes
-    # "libjemalloc".
-    regex = re.compile(r"/\S+/libjemalloc.*$")
-
-    jemalloc_path = None
-    with open("/proc/self/maps") as f:
-        for line in f:
-            match = regex.search(line.strip())
-            if match:
-                jemalloc_path = match.group()
-
-    if not jemalloc_path:
-        # No loaded jemalloc was found.
-        logger.debug("jemalloc not found")
-        return
-
-    logger.debug("Found jemalloc at %s", jemalloc_path)
-
-    jemalloc = ctypes.CDLL(jemalloc_path)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class JemallocStats:
+    jemalloc: ctypes.CDLL
 
     @overload
     def _mallctl(
-        name: str, read: Literal[True] = True, write: Optional[int] = None
+        self, name: str, read: Literal[True] = True, write: Optional[int] = None
     ) -> int:
         ...
 
     @overload
-    def _mallctl(name: str, read: Literal[False], write: Optional[int] = None) -> None:
+    def _mallctl(
+        self, name: str, read: Literal[False], write: Optional[int] = None
+    ) -> None:
         ...
 
     def _mallctl(
-        name: str, read: bool = True, write: Optional[int] = None
+        self, name: str, read: bool = True, write: Optional[int] = None
     ) -> Optional[int]:
         """Wrapper around `mallctl` for reading and writing integers to
         jemalloc.
@@ -120,7 +93,7 @@ def _setup_jemalloc_stats() -> None:
         # Where oldp/oldlenp is a buffer where the old value will be written to
         # (if not null), and newp/newlen is the buffer with the new value to set
         # (if not null). Note that they're all references *except* newlen.
-        result = jemalloc.mallctl(
+        result = self.jemalloc.mallctl(
             name.encode("ascii"),
             input_var_ref,
             input_len_ref,
@@ -136,21 +109,80 @@ def _setup_jemalloc_stats() -> None:
 
         return input_var.value
 
-    def _jemalloc_refresh_stats() -> None:
+    def refresh_stats(self) -> None:
         """Request that jemalloc updates its internal statistics. This needs to
         be called before querying for stats, otherwise it will return stale
         values.
         """
         try:
-            _mallctl("epoch", read=False, write=1)
+            self._mallctl("epoch", read=False, write=1)
         except Exception as e:
             logger.warning("Failed to reload jemalloc stats: %s", e)
 
+    def get_stat(self, name: str) -> int:
+        """Request the stat of the given name at the time of the last
+        `refresh_stats` call. This may throw if we fail to read
+        the stat.
+        """
+        return self._mallctl(f"stats.{name}")
+
+
+_JEMALLOC_STATS: Optional[JemallocStats] = None
+
+
+def get_jemalloc_stats() -> Optional[JemallocStats]:
+    """Returns an interface to jemalloc, if it is being used.
+
+    Note that this will always return None until `setup_jemalloc_stats` has been
+    called.
+    """
+    return _JEMALLOC_STATS
+
+
+def _setup_jemalloc_stats() -> None:
+    """Checks to see if jemalloc is loaded, and hooks up a collector to record
+    statistics exposed by jemalloc.
+    """
+
+    global _JEMALLOC_STATS
+
+    # Try to find the loaded jemalloc shared library, if any. We need to
+    # introspect into what is loaded, rather than loading whatever is on the
+    # path, as if we load a *different* jemalloc version things will seg fault.
+
+    # We look in `/proc/self/maps`, which only exists on linux.
+    if not os.path.exists("/proc/self/maps"):
+        logger.debug("Not looking for jemalloc as no /proc/self/maps exist")
+        return
+
+    # We're looking for a path at the end of the line that includes
+    # "libjemalloc".
+    regex = re.compile(r"/\S+/libjemalloc.*$")
+
+    jemalloc_path = None
+    with open("/proc/self/maps") as f:
+        for line in f:
+            match = regex.search(line.strip())
+            if match:
+                jemalloc_path = match.group()
+
+    if not jemalloc_path:
+        # No loaded jemalloc was found.
+        logger.debug("jemalloc not found")
+        return
+
+    logger.debug("Found jemalloc at %s", jemalloc_path)
+
+    jemalloc_dll = ctypes.CDLL(jemalloc_path)
+
+    stats = JemallocStats(jemalloc_dll)
+    _JEMALLOC_STATS = stats
+
     class JemallocCollector(Collector):
         """Metrics for internal jemalloc stats."""
 
         def collect(self) -> Iterable[Metric]:
-            _jemalloc_refresh_stats()
+            stats.refresh_stats()
 
             g = GaugeMetricFamily(
                 "jemalloc_stats_app_memory_bytes",
@@ -184,7 +216,7 @@ def _setup_jemalloc_stats() -> None:
                 "metadata",
             ):
                 try:
-                    value = _mallctl(f"stats.{t}")
+                    value = stats.get_stat(t)
                 except Exception as e:
                     # There was an error fetching the value, skip.
                     logger.warning("Failed to read jemalloc stats.%s: %s", t, e)
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 73f92d2df8..95f3b27927 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -35,6 +35,7 @@ from typing_extensions import ParamSpec
 from twisted.internet import defer
 from twisted.web.resource import Resource
 
+from synapse import spam_checker_api
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.events.presence_router import (
@@ -47,6 +48,7 @@ from synapse.events.spamcheck import (
     CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK,
     CHECK_REGISTRATION_FOR_SPAM_CALLBACK,
     CHECK_USERNAME_FOR_SPAM_CALLBACK,
+    SHOULD_DROP_FEDERATED_EVENT_CALLBACK,
     USER_MAY_CREATE_ROOM_ALIAS_CALLBACK,
     USER_MAY_CREATE_ROOM_CALLBACK,
     USER_MAY_INVITE_CALLBACK,
@@ -139,6 +141,9 @@ are loaded into Synapse.
 
 PRESENCE_ALL_USERS = PresenceRouter.ALL_USERS
 
+ALLOW = spam_checker_api.Allow.ALLOW
+# Singleton value used to mark a message as permitted.
+
 __all__ = [
     "errors",
     "make_deferred_yieldable",
@@ -146,6 +151,7 @@ __all__ = [
     "respond_with_html",
     "run_in_background",
     "cached",
+    "Allow",
     "UserID",
     "DatabasePool",
     "LoggingTransaction",
@@ -234,6 +240,9 @@ class ModuleApi:
         self,
         *,
         check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None,
+        should_drop_federated_event: Optional[
+            SHOULD_DROP_FEDERATED_EVENT_CALLBACK
+        ] = None,
         user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None,
         user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None,
         user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None,
@@ -254,6 +263,7 @@ class ModuleApi:
         """
         return self._spam_checker.register_callbacks(
             check_event_for_spam=check_event_for_spam,
+            should_drop_federated_event=should_drop_federated_event,
             user_may_join_room=user_may_join_room,
             user_may_invite=user_may_invite,
             user_may_send_3pid_invite=user_may_send_3pid_invite,
diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py
index e58e0e60fe..bedd045d6f 100644
--- a/synapse/module_api/errors.py
+++ b/synapse/module_api/errors.py
@@ -15,6 +15,7 @@
 """Exception types which are exposed as part of the stable module API"""
 
 from synapse.api.errors import (
+    Codes,
     InvalidClientCredentialsError,
     RedirectException,
     SynapseError,
@@ -24,6 +25,7 @@ from synapse.handlers.push_rules import InvalidRuleException
 from synapse.storage.push_rule import RuleNotFoundException
 
 __all__ = [
+    "Codes",
     "InvalidClientCredentialsError",
     "RedirectException",
     "SynapseError",
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 01a50b9d62..ba23257f54 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -46,6 +46,7 @@ from synapse.types import (
     JsonDict,
     PersistedEventPosition,
     RoomStreamToken,
+    StreamKeyType,
     StreamToken,
     UserID,
 )
@@ -370,7 +371,7 @@ class Notifier:
 
         if users or rooms:
             self.on_new_event(
-                "room_key",
+                StreamKeyType.ROOM,
                 max_room_stream_token,
                 users=users,
                 rooms=rooms,
@@ -440,7 +441,7 @@ class Notifier:
             for room in rooms:
                 user_streams |= self.room_to_user_streams.get(room, set())
 
-            if stream_key == "to_device_key":
+            if stream_key == StreamKeyType.TO_DEVICE:
                 issue9533_logger.debug(
                     "to-device messages stream id %s, awaking streams for %s",
                     new_token,
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index a1b7711098..57c4d70466 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -12,6 +12,80 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+"""
+This module implements the push rules & notifications portion of the Matrix
+specification.
+
+There's a few related features:
+
+* Push notifications (i.e. email or outgoing requests to a Push Gateway).
+* Calculation of unread notifications (for /sync and /notifications).
+
+When Synapse receives a new event (locally, via the Client-Server API, or via
+federation), the following occurs:
+
+1. The push rules get evaluated to generate a set of per-user actions.
+2. The event is persisted into the database.
+3. (In the background) The notifier is notified about the new event.
+
+The per-user actions are initially stored in the event_push_actions_staging table,
+before getting moved into the event_push_actions table when the event is persisted.
+The event_push_actions table is periodically summarised into the event_push_summary
+and event_push_summary_stream_ordering tables.
+
+Since push actions block an event from being persisted the generation of push
+actions is performance sensitive.
+
+The general interaction of the classes are:
+
+        +---------------------------------------------+
+        | FederationEventHandler/EventCreationHandler |
+        +---------------------------------------------+
+                |
+                v
+        +-----------------------+     +---------------------------+
+        | BulkPushRuleEvaluator |---->| PushRuleEvaluatorForEvent |
+        +-----------------------+     +---------------------------+
+                |
+                v
+        +-----------------------------+
+        | EventPushActionsWorkerStore |
+        +-----------------------------+
+
+The notifier notifies the pusher pool of the new event, which checks for affected
+users. Each user-configured pusher of the affected users then performs the
+previously calculated action.
+
+The general interaction of the classes are:
+
+        +----------+
+        | Notifier |
+        +----------+
+                |
+                v
+        +------------+     +--------------+
+        | PusherPool |---->| PusherConfig |
+        +------------+     +--------------+
+                |
+                |     +---------------+
+                +<--->| PusherFactory |
+                |     +---------------+
+                v
+        +------------------------+     +-----------------------------------------------+
+        | EmailPusher/HttpPusher |---->| EventPushActionsWorkerStore/PusherWorkerStore |
+        +------------------------+     +-----------------------------------------------+
+                |
+                v
+        +-------------------------+
+        | Mailer/SimpleHttpClient |
+        +-------------------------+
+
+The Pusher instance also calls out to various utilities for generating payloads
+(or email templates), but those interactions are not detailed in this diagram
+(and are specific to the type of pusher).
+
+"""
+
 import abc
 from typing import TYPE_CHECKING, Any, Dict, Optional
 
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
deleted file mode 100644
index 60758df016..0000000000
--- a/synapse/push/action_generator.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# Copyright 2015 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from typing import TYPE_CHECKING
-
-from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
-from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
-from synapse.util.metrics import Measure
-
-if TYPE_CHECKING:
-    from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-class ActionGenerator:
-    def __init__(self, hs: "HomeServer"):
-        self.clock = hs.get_clock()
-        self.bulk_evaluator = BulkPushRuleEvaluator(hs)
-        # really we want to get all user ids and all profile tags too,
-        # since we want the actions for each profile tag for every user and
-        # also actions for a client with no profile tag for each user.
-        # Currently the event stream doesn't support profile tags on an
-        # event stream, so we just run the rules for a client with no profile
-        # tag (ie. we just need all the users).
-
-    async def handle_push_actions_for_event(
-        self, event: EventBase, context: EventContext
-    ) -> None:
-        with Measure(self.clock, "action_for_event_by_user"):
-            await self.bulk_evaluator.action_for_event_by_user(event, context)
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index b07cf2eee7..4cc8a2ecca 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -20,8 +20,8 @@ import attr
 from prometheus_client import Counter
 
 from synapse.api.constants import EventTypes, Membership, RelationTypes
-from synapse.event_auth import get_user_power_level
-from synapse.events import EventBase
+from synapse.event_auth import auth_types_for_event, get_user_power_level
+from synapse.events import EventBase, relation_from_event
 from synapse.events.snapshot import EventContext
 from synapse.state import POWER_KEY
 from synapse.storage.databases.main.roommember import EventIdMembership
@@ -29,7 +29,9 @@ from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import CacheMetric, register_cache
 from synapse.util.caches.descriptors import lru_cache
 from synapse.util.caches.lrucache import LruCache
+from synapse.util.metrics import measure_func
 
+from ..storage.state import StateFilter
 from .push_rule_evaluator import PushRuleEvaluatorForEvent
 
 if TYPE_CHECKING:
@@ -77,8 +79,8 @@ def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
         return False
 
     # Exclude edits.
-    relates_to = event.content.get("m.relates_to", {})
-    if relates_to.get("rel_type") == RelationTypes.REPLACE:
+    relates_to = relation_from_event(event)
+    if relates_to and relates_to.rel_type == RelationTypes.REPLACE:
         return False
 
     # Mark events that have a non-empty string body as unread.
@@ -105,6 +107,7 @@ class BulkPushRuleEvaluator:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastores().main
+        self.clock = hs.get_clock()
         self._event_auth_handler = hs.get_event_auth_handler()
 
         # Used by `RulesForRoom` to ensure only one thing mutates the cache at a
@@ -166,8 +169,12 @@ class BulkPushRuleEvaluator:
     async def _get_power_levels_and_sender_level(
         self, event: EventBase, context: EventContext
     ) -> Tuple[dict, int]:
-        prev_state_ids = await context.get_prev_state_ids()
+        event_types = auth_types_for_event(event.room_version, event)
+        prev_state_ids = await context.get_prev_state_ids(
+            StateFilter.from_types(event_types)
+        )
         pl_event_id = prev_state_ids.get(POWER_KEY)
+
         if pl_event_id:
             # fastpath: if there's a power level event, that's all we need, and
             # not having a power level event is an extreme edge case
@@ -185,6 +192,7 @@ class BulkPushRuleEvaluator:
 
         return pl_event.content if pl_event else {}, sender_level
 
+    @measure_func("action_for_event_by_user")
     async def action_for_event_by_user(
         self, event: EventBase, context: EventContext
     ) -> None:
@@ -192,6 +200,10 @@ class BulkPushRuleEvaluator:
         should increment the unread count, and insert the results into the
         event_push_actions_staging table.
         """
+        if event.internal_metadata.is_outlier():
+            # This can happen due to out of band memberships
+            return
+
         count_as_unread = _should_count_as_unread(event, context)
 
         rules_by_user = await self._get_rules_for_event(event, context)
@@ -208,8 +220,6 @@ class BulkPushRuleEvaluator:
             event, len(room_members), sender_power_level, power_levels
         )
 
-        condition_cache: Dict[str, bool] = {}
-
         # If the event is not a state event check if any users ignore the sender.
         if not event.is_state():
             ignorers = await self.store.ignored_by(event.sender)
@@ -247,8 +257,8 @@ class BulkPushRuleEvaluator:
                 if "enabled" in rule and not rule["enabled"]:
                     continue
 
-                matches = _condition_checker(
-                    evaluator, rule["conditions"], uid, display_name, condition_cache
+                matches = evaluator.check_conditions(
+                    rule["conditions"], uid, display_name
                 )
                 if matches:
                     actions = [x for x in rule["actions"] if x != "dont_notify"]
@@ -267,32 +277,6 @@ class BulkPushRuleEvaluator:
         )
 
 
-def _condition_checker(
-    evaluator: PushRuleEvaluatorForEvent,
-    conditions: List[dict],
-    uid: str,
-    display_name: Optional[str],
-    cache: Dict[str, bool],
-) -> bool:
-    for cond in conditions:
-        _cache_key = cond.get("_cache_key", None)
-        if _cache_key:
-            res = cache.get(_cache_key, None)
-            if res is False:
-                return False
-            elif res is True:
-                continue
-
-        res = evaluator.matches(cond, uid, display_name)
-        if _cache_key:
-            cache[_cache_key] = bool(res)
-
-        if not res:
-            return False
-
-    return True
-
-
 MemberMap = Dict[str, Optional[EventIdMembership]]
 Rule = Dict[str, dict]
 RulesByUser = Dict[str, List[Rule]]
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index f3c4419932..d4e07cb97b 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -410,7 +410,7 @@ class HttpPusher(Pusher):
         rejected = []
         if "rejected" in resp:
             rejected = resp["rejected"]
-        else:
+        if not rejected:
             self.badge_count_last_call = badge
         return rejected
 
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index f617c759e6..54db6b5612 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -129,9 +129,55 @@ class PushRuleEvaluatorForEvent:
         # Maps strings of e.g. 'content.body' -> event["content"]["body"]
         self._value_cache = _flatten_dict(event)
 
+        # Maps cache keys to final values.
+        self._condition_cache: Dict[str, bool] = {}
+
+    def check_conditions(
+        self, conditions: List[dict], uid: str, display_name: Optional[str]
+    ) -> bool:
+        """
+        Returns true if a user's conditions/user ID/display name match the event.
+
+        Args:
+            conditions: The user's conditions to match.
+            uid: The user's MXID.
+            display_name: The display name.
+
+        Returns:
+             True if all conditions match the event, False otherwise.
+        """
+        for cond in conditions:
+            _cache_key = cond.get("_cache_key", None)
+            if _cache_key:
+                res = self._condition_cache.get(_cache_key, None)
+                if res is False:
+                    return False
+                elif res is True:
+                    continue
+
+            res = self.matches(cond, uid, display_name)
+            if _cache_key:
+                self._condition_cache[_cache_key] = bool(res)
+
+            if not res:
+                return False
+
+        return True
+
     def matches(
         self, condition: Dict[str, Any], user_id: str, display_name: Optional[str]
     ) -> bool:
+        """
+        Returns true if a user's condition/user ID/display name match the event.
+
+        Args:
+            condition: The user's condition to match.
+            uid: The user's MXID.
+            display_name: The display name, or None if there is not one.
+
+        Returns:
+             True if the condition matches the event, False otherwise.
+        """
         if condition["kind"] == "event_match":
             return self._event_match(condition, user_id)
         elif condition["kind"] == "contains_display_name":
@@ -146,6 +192,16 @@ class PushRuleEvaluatorForEvent:
             return True
 
     def _event_match(self, condition: dict, user_id: str) -> bool:
+        """
+        Check an "event_match" push rule condition.
+
+        Args:
+            condition: The "event_match" push rule condition to match.
+            user_id: The user's MXID.
+
+        Returns:
+             True if the condition matches the event, False otherwise.
+        """
         pattern = condition.get("pattern", None)
 
         if not pattern:
@@ -167,13 +223,22 @@ class PushRuleEvaluatorForEvent:
 
             return _glob_matches(pattern, body, word_boundary=True)
         else:
-            haystack = self._get_value(condition["key"])
+            haystack = self._value_cache.get(condition["key"], None)
             if haystack is None:
                 return False
 
             return _glob_matches(pattern, haystack)
 
     def _contains_display_name(self, display_name: Optional[str]) -> bool:
+        """
+        Check an "event_match" push rule condition.
+
+        Args:
+            display_name: The display name, or None if there is not one.
+
+        Returns:
+             True if the display name is found in the event body, False otherwise.
+        """
         if not display_name:
             return False
 
@@ -191,9 +256,6 @@ class PushRuleEvaluatorForEvent:
 
         return bool(r.search(body))
 
-    def _get_value(self, dotted_key: str) -> Optional[str]:
-        return self._value_cache.get(dotted_key, None)
-
 
 # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
 regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache(
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 2bd244ed79..a4ae4040c3 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -26,7 +26,8 @@ from twisted.web.server import Request
 
 from synapse.api.errors import HttpResponseException, SynapseError
 from synapse.http import RequestTimedOutError
-from synapse.http.server import HttpServer
+from synapse.http.server import HttpServer, is_method_cancellable
+from synapse.http.site import SynapseRequest
 from synapse.logging import opentracing
 from synapse.logging.opentracing import trace
 from synapse.types import JsonDict
@@ -310,6 +311,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         url_args = list(self.PATH_ARGS)
         method = self.METHOD
 
+        if self.CACHE and is_method_cancellable(self._handle_request):
+            raise Exception(
+                f"{self.__class__.__name__} has been marked as cancellable, but CACHE "
+                "is set. The cancellable flag would have no effect."
+            )
+
         if self.CACHE:
             url_args.append("txn_id")
 
@@ -324,7 +331,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         )
 
     async def _check_auth_and_handle(
-        self, request: Request, **kwargs: Any
+        self, request: SynapseRequest, **kwargs: Any
     ) -> Tuple[int, JsonDict]:
         """Called on new incoming requests when caching is enabled. Checks
         if there is a cached response for the request and returns that,
@@ -340,8 +347,18 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         if self.CACHE:
             txn_id = kwargs.pop("txn_id")
 
+            # We ignore the `@cancellable` flag, since cancellation wouldn't interupt
+            # `_handle_request` and `ResponseCache` does not handle cancellation
+            # correctly yet. In particular, there may be issues to do with logging
+            # context lifetimes.
+
             return await self.response_cache.wrap(
                 txn_id, self._handle_request, request, **kwargs
             )
 
+        # The `@cancellable` decorator may be applied to `_handle_request`. But we
+        # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
+        # so we have to set up the cancellable flag ourselves.
+        request.is_render_cancellable = is_method_cancellable(self._handle_request)
+
         return await self._handle_request(request, **kwargs)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 350762f494..a52e25c1af 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -43,7 +43,7 @@ from synapse.replication.tcp.streams.events import (
     EventsStreamEventRow,
     EventsStreamRow,
 )
-from synapse.types import PersistedEventPosition, ReadReceipt, UserID
+from synapse.types import PersistedEventPosition, ReadReceipt, StreamKeyType, UserID
 from synapse.util.async_helpers import Linearizer, timeout_deferred
 from synapse.util.metrics import Measure
 
@@ -153,19 +153,19 @@ class ReplicationDataHandler:
         if stream_name == TypingStream.NAME:
             self._typing_handler.process_replication_rows(token, rows)
             self.notifier.on_new_event(
-                "typing_key", token, rooms=[row.room_id for row in rows]
+                StreamKeyType.TYPING, token, rooms=[row.room_id for row in rows]
             )
         elif stream_name == PushRulesStream.NAME:
             self.notifier.on_new_event(
-                "push_rules_key", token, users=[row.user_id for row in rows]
+                StreamKeyType.PUSH_RULES, token, users=[row.user_id for row in rows]
             )
         elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
             self.notifier.on_new_event(
-                "account_data_key", token, users=[row.user_id for row in rows]
+                StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows]
             )
         elif stream_name == ReceiptsStream.NAME:
             self.notifier.on_new_event(
-                "receipt_key", token, rooms=[row.room_id for row in rows]
+                StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows]
             )
             await self._pusher_pool.on_new_receipts(
                 token, token, {row.room_id for row in rows}
@@ -173,14 +173,18 @@ class ReplicationDataHandler:
         elif stream_name == ToDeviceStream.NAME:
             entities = [row.entity for row in rows if row.entity.startswith("@")]
             if entities:
-                self.notifier.on_new_event("to_device_key", token, users=entities)
+                self.notifier.on_new_event(
+                    StreamKeyType.TO_DEVICE, token, users=entities
+                )
         elif stream_name == DeviceListsStream.NAME:
             all_room_ids: Set[str] = set()
             for row in rows:
                 if row.entity.startswith("@"):
                     room_ids = await self.store.get_rooms_for_user(row.entity)
                     all_room_ids.update(room_ids)
-            self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
+            self.notifier.on_new_event(
+                StreamKeyType.DEVICE_LIST, token, rooms=all_room_ids
+            )
         elif stream_name == GroupServerStream.NAME:
             self.notifier.on_new_event(
                 "groups_key", token, users=[row.user_id for row in rows]
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index fe34948168..32f52e54d8 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -58,6 +58,15 @@ class Command(metaclass=abc.ABCMeta):
         # by default, we just use the command name.
         return self.NAME
 
+    def redis_channel_name(self, prefix: str) -> str:
+        """
+        Returns the Redis channel name upon which to publish this command.
+
+        Args:
+            prefix: The prefix for the channel.
+        """
+        return prefix
+
 
 SC = TypeVar("SC", bound="_SimpleCommand")
 
@@ -395,6 +404,9 @@ class UserIpCommand(Command):
             f"{self.user_agent!r}, {self.device_id!r}, {self.last_seen})"
         )
 
+    def redis_channel_name(self, prefix: str) -> str:
+        return f"{prefix}/USER_IP"
+
 
 class RemoteServerUpCommand(_SimpleCommand):
     """Sent when a worker has detected that a remote server is no longer
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 9aba1cd451..e1cbfa50eb 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -1,5 +1,5 @@
 # Copyright 2017 Vector Creations Ltd
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2020, 2022 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -101,6 +101,9 @@ class ReplicationCommandHandler:
         self._instance_id = hs.get_instance_id()
         self._instance_name = hs.get_instance_name()
 
+        # Additional Redis channel suffixes to subscribe to.
+        self._channels_to_subscribe_to: List[str] = []
+
         self._is_presence_writer = (
             hs.get_instance_name() in hs.config.worker.writers.presence
         )
@@ -243,6 +246,31 @@ class ReplicationCommandHandler:
             # If we're NOT using Redis, this must be handled by the master
             self._should_insert_client_ips = hs.get_instance_name() == "master"
 
+        if self._is_master or self._should_insert_client_ips:
+            self.subscribe_to_channel("USER_IP")
+
+    def subscribe_to_channel(self, channel_name: str) -> None:
+        """
+        Indicates that we wish to subscribe to a Redis channel by name.
+
+        (The name will later be prefixed with the server name; i.e. subscribing
+        to the 'ABC' channel actually subscribes to 'example.com/ABC' Redis-side.)
+
+        Raises:
+          - If replication has already started, then it's too late to subscribe
+            to new channels.
+        """
+
+        if self._factory is not None:
+            # We don't allow subscribing after the fact to avoid the chance
+            # of missing an important message because we didn't subscribe in time.
+            raise RuntimeError(
+                "Cannot subscribe to more channels after replication started."
+            )
+
+        if channel_name not in self._channels_to_subscribe_to:
+            self._channels_to_subscribe_to.append(channel_name)
+
     def _add_command_to_stream_queue(
         self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
     ) -> None:
@@ -321,7 +349,9 @@ class ReplicationCommandHandler:
 
             # Now create the factory/connection for the subscription stream.
             self._factory = RedisDirectTcpReplicationClientFactory(
-                hs, outbound_redis_connection
+                hs,
+                outbound_redis_connection,
+                channel_names=self._channels_to_subscribe_to,
             )
             hs.get_reactor().connectTCP(
                 hs.config.redis.redis_host,
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 989c5be032..fd1c0ec6af 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -14,7 +14,7 @@
 
 import logging
 from inspect import isawaitable
-from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, cast
+from typing import TYPE_CHECKING, Any, Generic, List, Optional, Type, TypeVar, cast
 
 import attr
 import txredisapi
@@ -85,14 +85,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
 
     Attributes:
         synapse_handler: The command handler to handle incoming commands.
-        synapse_stream_name: The *redis* stream name to subscribe to and publish
+        synapse_stream_prefix: The *redis* stream name to subscribe to and publish
             from (not anything to do with Synapse replication streams).
         synapse_outbound_redis_connection: The connection to redis to use to send
             commands.
     """
 
     synapse_handler: "ReplicationCommandHandler"
-    synapse_stream_name: str
+    synapse_stream_prefix: str
+    synapse_channel_names: List[str]
     synapse_outbound_redis_connection: txredisapi.ConnectionHandler
 
     def __init__(self, *args: Any, **kwargs: Any):
@@ -117,8 +118,13 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
         # it's important to make sure that we only send the REPLICATE command once we
         # have successfully subscribed to the stream - otherwise we might miss the
         # POSITION response sent back by the other end.
-        logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
-        await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
+        fully_qualified_stream_names = [
+            f"{self.synapse_stream_prefix}/{stream_suffix}"
+            for stream_suffix in self.synapse_channel_names
+        ] + [self.synapse_stream_prefix]
+        logger.info("Sending redis SUBSCRIBE for %r", fully_qualified_stream_names)
+        await make_deferred_yieldable(self.subscribe(fully_qualified_stream_names))
+
         logger.info(
             "Successfully subscribed to redis stream, sending REPLICATE command"
         )
@@ -215,10 +221,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
         # remote instances.
         tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
 
+        channel_name = cmd.redis_channel_name(self.synapse_stream_prefix)
+
         await make_deferred_yieldable(
-            self.synapse_outbound_redis_connection.publish(
-                self.synapse_stream_name, encoded_string
-            )
+            self.synapse_outbound_redis_connection.publish(channel_name, encoded_string)
         )
 
 
@@ -300,20 +306,27 @@ def format_address(address: IAddress) -> str:
 
 class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
     """This is a reconnecting factory that connects to redis and immediately
-    subscribes to a stream.
+    subscribes to some streams.
 
     Args:
         hs
         outbound_redis_connection: A connection to redis that will be used to
             send outbound commands (this is separate to the redis connection
             used to subscribe).
+        channel_names: A list of channel names to append to the base channel name
+            to additionally subscribe to.
+            e.g. if ['ABC', 'DEF'] is specified then we'll listen to:
+            example.com; example.com/ABC; and example.com/DEF.
     """
 
     maxDelay = 5
     protocol = RedisSubscriber
 
     def __init__(
-        self, hs: "HomeServer", outbound_redis_connection: txredisapi.ConnectionHandler
+        self,
+        hs: "HomeServer",
+        outbound_redis_connection: txredisapi.ConnectionHandler,
+        channel_names: List[str],
     ):
 
         super().__init__(
@@ -326,7 +339,8 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
         )
 
         self.synapse_handler = hs.get_replication_command_handler()
-        self.synapse_stream_name = hs.hostname
+        self.synapse_stream_prefix = hs.hostname
+        self.synapse_channel_names = channel_names
 
         self.synapse_outbound_redis_connection = outbound_redis_connection
 
@@ -340,7 +354,8 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
         # protocol.
         p.synapse_handler = self.synapse_handler
         p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
-        p.synapse_stream_name = self.synapse_stream_name
+        p.synapse_stream_prefix = self.synapse_stream_prefix
+        p.synapse_channel_names = self.synapse_channel_names
 
         return p
 
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index b98640b14a..8191b4e32c 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -148,9 +148,9 @@ class PushRuleRestServlet(RestServlet):
         # we build up the full structure and then decide which bits of it
         # to send which means doing unnecessary work sometimes but is
         # is probably not going to make a whole lot of difference
-        rules = await self.store.get_push_rules_for_user(user_id)
+        rules_raw = await self.store.get_push_rules_for_user(user_id)
 
-        rules = format_push_rules_for_user(requester.user, rules)
+        rules = format_push_rules_for_user(requester.user, rules_raw)
 
         path_parts = path.split("/")[1:]
 
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index f9caab6635..4b03eb876b 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -13,12 +13,10 @@
 # limitations under the License.
 
 import logging
-import re
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.constants import ReceiptTypes
 from synapse.api.errors import SynapseError
-from synapse.http import get_request_user_agent
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.http.site import SynapseRequest
@@ -26,8 +24,6 @@ from synapse.types import JsonDict
 
 from ._base import client_patterns
 
-pattern = re.compile(r"(?:Element|SchildiChat)/1\.[012]\.")
-
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
@@ -69,14 +65,7 @@ class ReceiptRestServlet(RestServlet):
         ):
             raise SynapseError(400, "Receipt type must be 'm.read'")
 
-        # Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body.
-        user_agent = get_request_user_agent(request)
-        allow_empty_body = False
-        if "Android" in user_agent:
-            if pattern.match(user_agent) or "Riot" in user_agent:
-                allow_empty_body = True
-        # This call makes sure possible empty body is handled correctly
-        parse_json_object_from_request(request, allow_empty_body)
+        parse_json_object_from_request(request, allow_empty_body=False)
 
         await self.presence_handler.bump_presence_active_time(requester.user)
 
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 12fed856a8..5a2361a2e6 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -34,7 +34,7 @@ from synapse.api.errors import (
 )
 from synapse.api.filtering import Filter
 from synapse.events.utils import format_event_for_client_v2
-from synapse.http.server import HttpServer
+from synapse.http.server import HttpServer, cancellable
 from synapse.http.servlet import (
     ResolveRoomIdMixin,
     RestServlet,
@@ -143,6 +143,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
             self.__class__.__name__,
         )
 
+    @cancellable
     def on_GET_no_state_key(
         self, request: SynapseRequest, room_id: str, event_type: str
     ) -> Awaitable[Tuple[int, JsonDict]]:
@@ -153,6 +154,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
     ) -> Awaitable[Tuple[int, JsonDict]]:
         return self.on_PUT(request, room_id, event_type, "")
 
+    @cancellable
     async def on_GET(
         self, request: SynapseRequest, room_id: str, event_type: str, state_key: str
     ) -> Tuple[int, JsonDict]:
@@ -481,6 +483,7 @@ class RoomMemberListRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
 
+    @cancellable
     async def on_GET(
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
@@ -602,6 +605,7 @@ class RoomStateRestServlet(RestServlet):
         self.message_handler = hs.get_message_handler()
         self.auth = hs.get_auth()
 
+    @cancellable
     async def on_GET(
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, List[JsonDict]]:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 50383bdbd1..2b2db63bf7 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -668,7 +668,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         logger.debug("Running url preview cache expiry")
 
         if not (await self.store.db_pool.updates.has_completed_background_updates()):
-            logger.info("Still running DB updates; skipping expiry")
+            logger.debug("Still running DB updates; skipping url preview cache expiry")
             return
 
         def try_remove_parent_dirs(dirs: Iterable[str]) -> None:
@@ -688,7 +688,9 @@ class PreviewUrlResource(DirectServeJsonResource):
                     # Failed, skip deleting the rest of the parent dirs
                     if e.errno != errno.ENOTEMPTY:
                         logger.warning(
-                            "Failed to remove media directory: %r: %s", dir, e
+                            "Failed to remove media directory while clearing url preview cache: %r: %s",
+                            dir,
+                            e,
                         )
                     break
 
@@ -703,7 +705,11 @@ class PreviewUrlResource(DirectServeJsonResource):
             except FileNotFoundError:
                 pass  # If the path doesn't exist, meh
             except OSError as e:
-                logger.warning("Failed to remove media: %r: %s", media_id, e)
+                logger.warning(
+                    "Failed to remove media while clearing url preview cache: %r: %s",
+                    media_id,
+                    e,
+                )
                 continue
 
             removed_media.append(media_id)
@@ -714,9 +720,11 @@ class PreviewUrlResource(DirectServeJsonResource):
         await self.store.delete_url_cache(removed_media)
 
         if removed_media:
-            logger.info("Deleted %d entries from url cache", len(removed_media))
+            logger.debug(
+                "Deleted %d entries from url preview cache", len(removed_media)
+            )
         else:
-            logger.debug("No entries removed from url cache")
+            logger.debug("No entries removed from url preview cache")
 
         # Now we delete old images associated with the url cache.
         # These may be cached for a bit on the client (i.e., they
@@ -733,7 +741,9 @@ class PreviewUrlResource(DirectServeJsonResource):
             except FileNotFoundError:
                 pass  # If the path doesn't exist, meh
             except OSError as e:
-                logger.warning("Failed to remove media: %r: %s", media_id, e)
+                logger.warning(
+                    "Failed to remove media from url preview cache: %r: %s", media_id, e
+                )
                 continue
 
             dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
@@ -745,7 +755,9 @@ class PreviewUrlResource(DirectServeJsonResource):
             except FileNotFoundError:
                 pass  # If the path doesn't exist, meh
             except OSError as e:
-                logger.warning("Failed to remove media: %r: %s", media_id, e)
+                logger.warning(
+                    "Failed to remove media from url preview cache: %r: %s", media_id, e
+                )
                 continue
 
             removed_media.append(media_id)
@@ -758,9 +770,9 @@ class PreviewUrlResource(DirectServeJsonResource):
         await self.store.delete_url_cache_media(removed_media)
 
         if removed_media:
-            logger.info("Deleted %d media from url cache", len(removed_media))
+            logger.debug("Deleted %d media from url preview cache", len(removed_media))
         else:
-            logger.debug("No media removed from url cache")
+            logger.debug("No media removed from url preview cache")
 
 
 def _is_media(content_type: str) -> bool:
diff --git a/synapse/server.py b/synapse/server.py
index d49c76518a..ee60cce8eb 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -119,7 +119,7 @@ from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpC
 from synapse.http.matrixfederationclient import MatrixFederationHttpClient
 from synapse.module_api import ModuleApi
 from synapse.notifier import Notifier
-from synapse.push.action_generator import ActionGenerator
+from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
 from synapse.push.pusherpool import PusherPool
 from synapse.replication.tcp.client import ReplicationDataHandler
 from synapse.replication.tcp.external_cache import ExternalCache
@@ -644,8 +644,8 @@ class HomeServer(metaclass=abc.ABCMeta):
         return ReplicationCommandHandler(self)
 
     @cache_in_self
-    def get_action_generator(self) -> ActionGenerator:
-        return ActionGenerator(self)
+    def get_bulk_push_rule_evaluator(self) -> BulkPushRuleEvaluator:
+        return BulkPushRuleEvaluator(self)
 
     @cache_in_self
     def get_user_directory_handler(self) -> UserDirectoryHandler:
@@ -681,7 +681,7 @@ class HomeServer(metaclass=abc.ABCMeta):
 
     @cache_in_self
     def get_spam_checker(self) -> SpamChecker:
-        return SpamChecker()
+        return SpamChecker(self)
 
     @cache_in_self
     def get_third_party_event_rules(self) -> ThirdPartyEventRules:
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index 015dd08f05..b5f3a0c74e 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -21,7 +21,6 @@ from synapse.api.constants import (
     ServerNoticeMsgType,
 )
 from synapse.api.errors import AuthError, ResourceLimitError, SynapseError
-from synapse.server_notices.server_notices_manager import SERVER_NOTICE_ROOM_TAG
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -71,18 +70,19 @@ class ResourceLimitsServerNotices:
             # In practice, not sure we can ever get here
             return
 
-        room_id = await self._server_notices_manager.get_or_create_notice_room_for_user(
+        # Check if there's a server notice room for this user.
+        room_id = await self._server_notices_manager.maybe_get_notice_room_for_user(
             user_id
         )
 
-        if not room_id:
-            logger.warning("Failed to get server notices room")
-            return
-
-        await self._check_and_set_tags(user_id, room_id)
-
-        # Determine current state of room
-        currently_blocked, ref_events = await self._is_room_currently_blocked(room_id)
+        if room_id is not None:
+            # Determine current state of room
+            currently_blocked, ref_events = await self._is_room_currently_blocked(
+                room_id
+            )
+        else:
+            currently_blocked = False
+            ref_events = []
 
         limit_msg = None
         limit_type = None
@@ -161,26 +161,6 @@ class ResourceLimitsServerNotices:
             user_id, content, EventTypes.Pinned, ""
         )
 
-    async def _check_and_set_tags(self, user_id: str, room_id: str) -> None:
-        """
-        Since server notices rooms were originally not with tags,
-        important to check that tags have been set correctly
-        Args:
-            user_id(str): the user in question
-            room_id(str): the server notices room for that user
-        """
-        tags = await self._store.get_tags_for_room(user_id, room_id)
-        need_to_set_tag = True
-        if tags:
-            if SERVER_NOTICE_ROOM_TAG in tags:
-                # tag already present, nothing to do here
-                need_to_set_tag = False
-        if need_to_set_tag:
-            max_id = await self._account_data_handler.add_tag_to_room(
-                user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
-            )
-            self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
-
     async def _is_room_currently_blocked(self, room_id: str) -> Tuple[bool, List[str]]:
         """
         Determines if the room is currently blocked
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 48eae5fa06..8ecab86ec7 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Optional
 
 from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
 from synapse.events import EventBase
-from synapse.types import Requester, UserID, create_requester
+from synapse.types import Requester, StreamKeyType, UserID, create_requester
 from synapse.util.caches.descriptors import cached
 
 if TYPE_CHECKING:
@@ -91,6 +91,35 @@ class ServerNoticesManager:
         return event
 
     @cached()
+    async def maybe_get_notice_room_for_user(self, user_id: str) -> Optional[str]:
+        """Try to look up the server notice room for this user if it exists.
+
+        Does not create one if none can be found.
+
+        Args:
+            user_id: the user we want a server notice room for.
+
+        Returns:
+            The room's ID, or None if no room could be found.
+        """
+        rooms = await self._store.get_rooms_for_local_user_where_membership_is(
+            user_id, [Membership.INVITE, Membership.JOIN]
+        )
+        for room in rooms:
+            # it's worth noting that there is an asymmetry here in that we
+            # expect the user to be invited or joined, but the system user must
+            # be joined. This is kinda deliberate, in that if somebody somehow
+            # manages to invite the system user to a room, that doesn't make it
+            # the server notices room.
+            user_ids = await self._store.get_users_in_room(room.room_id)
+            if len(user_ids) <= 2 and self.server_notices_mxid in user_ids:
+                # we found a room which our user shares with the system notice
+                # user
+                return room.room_id
+
+        return None
+
+    @cached()
     async def get_or_create_notice_room_for_user(self, user_id: str) -> str:
         """Get the room for notices for a given user
 
@@ -112,31 +141,20 @@ class ServerNoticesManager:
             self.server_notices_mxid, authenticated_entity=self._server_name
         )
 
-        rooms = await self._store.get_rooms_for_local_user_where_membership_is(
-            user_id, [Membership.INVITE, Membership.JOIN]
-        )
-        for room in rooms:
-            # it's worth noting that there is an asymmetry here in that we
-            # expect the user to be invited or joined, but the system user must
-            # be joined. This is kinda deliberate, in that if somebody somehow
-            # manages to invite the system user to a room, that doesn't make it
-            # the server notices room.
-            user_ids = await self._store.get_users_in_room(room.room_id)
-            if len(user_ids) <= 2 and self.server_notices_mxid in user_ids:
-                # we found a room which our user shares with the system notice
-                # user
-                logger.info(
-                    "Using existing server notices room %s for user %s",
-                    room.room_id,
-                    user_id,
-                )
-                await self._update_notice_user_profile_if_changed(
-                    requester,
-                    room.room_id,
-                    self._config.servernotices.server_notices_mxid_display_name,
-                    self._config.servernotices.server_notices_mxid_avatar_url,
-                )
-                return room.room_id
+        room_id = await self.maybe_get_notice_room_for_user(user_id)
+        if room_id is not None:
+            logger.info(
+                "Using existing server notices room %s for user %s",
+                room_id,
+                user_id,
+            )
+            await self._update_notice_user_profile_if_changed(
+                requester,
+                room_id,
+                self._config.servernotices.server_notices_mxid_display_name,
+                self._config.servernotices.server_notices_mxid_avatar_url,
+            )
+            return room_id
 
         # apparently no existing notice room: create a new one
         logger.info("Creating server notices room for %s", user_id)
@@ -166,10 +184,12 @@ class ServerNoticesManager:
         )
         room_id = info["room_id"]
 
+        self.maybe_get_notice_room_for_user.invalidate((user_id,))
+
         max_id = await self._account_data_handler.add_tag_to_room(
             user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
         )
-        self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        self._notifier.on_new_event(StreamKeyType.ACCOUNT_DATA, max_id, users=[user_id])
 
         logger.info("Created server notices room %s for %s", room_id, user_id)
         return room_id
diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py
index 73018f2d00..95132c80b7 100644
--- a/synapse/spam_checker_api/__init__.py
+++ b/synapse/spam_checker_api/__init__.py
@@ -12,13 +12,38 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from enum import Enum
+from typing import Union
+
+from synapse.api.errors import Codes
 
 
 class RegistrationBehaviour(Enum):
     """
-    Enum to define whether a registration request should allowed, denied, or shadow-banned.
+    Enum to define whether a registration request should be allowed, denied, or shadow-banned.
     """
 
     ALLOW = "allow"
     SHADOW_BAN = "shadow_ban"
     DENY = "deny"
+
+
+# We define the following singleton enum rather than a string to be able to
+# write `Union[Allow, ..., str]` in some of the callbacks for the spam-checker
+# API, where the `str` is required to maintain backwards compatibility with
+# previous versions of the API.
+class Allow(Enum):
+    """
+    Singleton to allow events to pass through in SpamChecker APIs.
+    """
+
+    ALLOW = "allow"
+
+
+Decision = Union[Allow, Codes]
+"""
+Union to define whether a request should be allowed or rejected.
+
+To accept a request, return `ALLOW`.
+
+To reject a request without any specific information, use `Codes.FORBIDDEN`.
+"""
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index cad3b42640..4b4ed42cff 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -130,6 +130,7 @@ class StateHandler:
         self.state_store = hs.get_storage().state
         self.hs = hs
         self._state_resolution_handler = hs.get_state_resolution_handler()
+        self._storage = hs.get_storage()
 
     @overload
     async def get_current_state(
@@ -238,13 +239,13 @@ class StateHandler:
         entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
         return await self.store.get_joined_users_from_state(room_id, entry)
 
-    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+    async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
         event_ids = await self.store.get_latest_event_ids_in_room(room_id)
         return await self.get_hosts_in_room_at_events(room_id, event_ids)
 
     async def get_hosts_in_room_at_events(
         self, room_id: str, event_ids: Collection[str]
-    ) -> Set[str]:
+    ) -> FrozenSet[str]:
         """Get the hosts that were in a room at the given event ids
 
         Args:
@@ -287,7 +288,6 @@ class StateHandler:
         #
         # first of all, figure out the state before the event
         #
-
         if old_state:
             # if we're given the state before the event, then we use that
             state_ids_before_event: StateMap[str] = {
@@ -361,10 +361,10 @@ class StateHandler:
 
         if not event.is_state():
             return EventContext.with_state(
+                storage=self._storage,
                 state_group_before_event=state_group_before_event,
                 state_group=state_group_before_event,
-                current_state_ids=state_ids_before_event,
-                prev_state_ids=state_ids_before_event,
+                state_delta_due_to_event={},
                 prev_group=state_group_before_event_prev_group,
                 delta_ids=deltas_to_state_group_before_event,
                 partial_state=partial_state,
@@ -393,10 +393,10 @@ class StateHandler:
         )
 
         return EventContext.with_state(
+            storage=self._storage,
             state_group=state_group_after_event,
             state_group_before_event=state_group_before_event,
-            current_state_ids=state_ids_after_event,
-            prev_state_ids=state_ids_before_event,
+            state_delta_due_to_event=delta_ids,
             prev_group=state_group_before_event,
             delta_ids=delta_ids,
             partial_state=partial_state,
@@ -418,33 +418,37 @@ class StateHandler:
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
-        # map from state group id to the state in that state group (where
-        # 'state' is a map from state key to event id)
-        # dict[int, dict[(str, str), str]]
-        state_groups_ids = await self.state_store.get_state_groups_ids(
-            room_id, event_ids
-        )
-
-        if len(state_groups_ids) == 0:
-            return _StateCacheEntry(state={}, state_group=None)
-        elif len(state_groups_ids) == 1:
-            name, state_list = list(state_groups_ids.items()).pop()
+        state_groups = await self.state_store.get_state_group_for_events(event_ids)
 
-            prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
+        state_group_ids = state_groups.values()
 
+        # check if each event has same state group id, if so there's no state to resolve
+        state_group_ids_set = set(state_group_ids)
+        if len(state_group_ids_set) == 1:
+            (state_group_id,) = state_group_ids_set
+            state = await self.state_store.get_state_for_groups(state_group_ids_set)
+            prev_group, delta_ids = await self.state_store.get_state_group_delta(
+                state_group_id
+            )
             return _StateCacheEntry(
-                state=state_list,
-                state_group=name,
+                state=state[state_group_id],
+                state_group=state_group_id,
                 prev_group=prev_group,
                 delta_ids=delta_ids,
             )
+        elif len(state_group_ids_set) == 0:
+            return _StateCacheEntry(state={}, state_group=None)
 
         room_version = await self.store.get_room_version_id(room_id)
 
+        state_to_resolve = await self.state_store.get_state_for_groups(
+            state_group_ids_set
+        )
+
         result = await self._state_resolution_handler.resolve_state_groups(
             room_id,
             room_version,
-            state_groups_ids,
+            state_to_resolve,
             None,
             state_res_store=StateResolutionStore(self.store),
         )
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 08c6eabc6d..b1e5208c76 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,20 +12,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from types import TracebackType
 from typing import (
     TYPE_CHECKING,
+    Any,
     AsyncContextManager,
     Awaitable,
     Callable,
     Dict,
     Iterable,
+    List,
     Optional,
+    Type,
 )
 
 import attr
 
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.types import Connection
+from synapse.storage.types import Connection, Cursor
 from synapse.types import JsonDict
 from synapse.util import Clock, json_encoder
 
@@ -74,7 +78,12 @@ class _BackgroundUpdateContextManager:
 
         return self._update_duration_ms
 
-    async def __aexit__(self, *exc) -> None:
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> None:
         pass
 
 
@@ -273,12 +282,20 @@ class BackgroundUpdater:
 
         self._running = True
 
+        back_to_back_failures = 0
+
         try:
             logger.info("Starting background schema updates")
             while self.enabled:
                 try:
                     result = await self.do_next_background_update(sleep)
+                    back_to_back_failures = 0
                 except Exception:
+                    back_to_back_failures += 1
+                    if back_to_back_failures >= 5:
+                        raise RuntimeError(
+                            "5 back-to-back background update failures; aborting."
+                        )
                     logger.exception("Error doing update")
                 else:
                     if result:
@@ -352,7 +369,7 @@ class BackgroundUpdater:
             True if we have finished running all the background updates, otherwise False
         """
 
-        def get_background_updates_txn(txn):
+        def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]:
             txn.execute(
                 """
                 SELECT update_name, depends_on FROM background_updates
@@ -469,7 +486,7 @@ class BackgroundUpdater:
         self,
         update_name: str,
         update_handler: Callable[[JsonDict, int], Awaitable[int]],
-    ):
+    ) -> None:
         """Register a handler for doing a background update.
 
         The handler should take two arguments:
@@ -518,6 +535,7 @@ class BackgroundUpdater:
         where_clause: Optional[str] = None,
         unique: bool = False,
         psql_only: bool = False,
+        replaces_index: Optional[str] = None,
     ) -> None:
         """Helper for store classes to do a background index addition
 
@@ -537,6 +555,8 @@ class BackgroundUpdater:
             unique: true to make a UNIQUE index
             psql_only: true to only create this index on psql databases (useful
                 for virtual sqlite tables)
+            replaces_index: The name of an index that this index replaces.
+                The named index will be dropped upon completion of the new index.
         """
 
         def create_index_psql(conn: Connection) -> None:
@@ -568,6 +588,12 @@ class BackgroundUpdater:
                 }
                 logger.debug("[SQL] %s", sql)
                 c.execute(sql)
+
+                if replaces_index is not None:
+                    # We drop the old index as the new index has now been created.
+                    sql = f"DROP INDEX IF EXISTS {replaces_index}"
+                    logger.debug("[SQL] %s", sql)
+                    c.execute(sql)
             finally:
                 conn.set_session(autocommit=False)  # type: ignore
 
@@ -596,6 +622,12 @@ class BackgroundUpdater:
             logger.debug("[SQL] %s", sql)
             c.execute(sql)
 
+            if replaces_index is not None:
+                # We drop the old index as the new index has now been created.
+                sql = f"DROP INDEX IF EXISTS {replaces_index}"
+                logger.debug("[SQL] %s", sql)
+                c.execute(sql)
+
         if isinstance(self.db_pool.engine, engines.PostgresEngine):
             runner: Optional[Callable[[Connection], None]] = create_index_psql
         elif psql_only:
@@ -603,7 +635,7 @@ class BackgroundUpdater:
         else:
             runner = create_index_sqlite
 
-        async def updater(progress, batch_size):
+        async def updater(progress: JsonDict, batch_size: int) -> int:
             if runner is not None:
                 logger.info("Adding index %s to %s", index_name, table)
                 await self.db_pool.runWithConnection(runner)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 41f566b648..a78d68a9d7 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -31,6 +31,7 @@ from typing import (
     List,
     Optional,
     Tuple,
+    Type,
     TypeVar,
     cast,
     overload,
@@ -41,6 +42,7 @@ from prometheus_client import Histogram
 from typing_extensions import Concatenate, Literal, ParamSpec
 
 from twisted.enterprise import adbapi
+from twisted.internet.interfaces import IReactorCore
 
 from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
@@ -88,11 +90,15 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
     "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
     "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
     "event_search": "event_search_event_id_idx",
+    "local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx",
+    "remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx",
 }
 
 
 def make_pool(
-    reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+    reactor: IReactorCore,
+    db_config: DatabaseConnectionConfig,
+    engine: BaseDatabaseEngine,
 ) -> adbapi.ConnectionPool:
     """Get the connection pool for the database."""
 
@@ -101,7 +107,7 @@ def make_pool(
     db_args = dict(db_config.config.get("args", {}))
     db_args.setdefault("cp_reconnect", True)
 
-    def _on_new_connection(conn):
+    def _on_new_connection(conn: Connection) -> None:
         # Ensure we have a logging context so we can correctly track queries,
         # etc.
         with LoggingContext("db.on_new_connection"):
@@ -157,7 +163,11 @@ class LoggingDatabaseConnection:
     default_txn_name: str
 
     def cursor(
-        self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
+        self,
+        *,
+        txn_name: Optional[str] = None,
+        after_callbacks: Optional[List["_CallbackListEntry"]] = None,
+        exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
     ) -> "LoggingTransaction":
         if not txn_name:
             txn_name = self.default_txn_name
@@ -183,11 +193,16 @@ class LoggingDatabaseConnection:
         self.conn.__enter__()
         return self
 
-    def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_value: Optional[BaseException],
+        traceback: Optional[types.TracebackType],
+    ) -> Optional[bool]:
         return self.conn.__exit__(exc_type, exc_value, traceback)
 
     # Proxy through any unknown lookups to the DB conn class.
-    def __getattr__(self, name):
+    def __getattr__(self, name: str) -> Any:
         return getattr(self.conn, name)
 
 
@@ -391,17 +406,22 @@ class LoggingTransaction:
     def __enter__(self) -> "LoggingTransaction":
         return self
 
-    def __exit__(self, exc_type, exc_value, traceback):
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_value: Optional[BaseException],
+        traceback: Optional[types.TracebackType],
+    ) -> None:
         self.close()
 
 
 class PerformanceCounters:
-    def __init__(self):
-        self.current_counters = {}
-        self.previous_counters = {}
+    def __init__(self) -> None:
+        self.current_counters: Dict[str, Tuple[int, float]] = {}
+        self.previous_counters: Dict[str, Tuple[int, float]] = {}
 
     def update(self, key: str, duration_secs: float) -> None:
-        count, cum_time = self.current_counters.get(key, (0, 0))
+        count, cum_time = self.current_counters.get(key, (0, 0.0))
         count += 1
         cum_time += duration_secs
         self.current_counters[key] = (count, cum_time)
@@ -527,7 +547,7 @@ class DatabasePool:
     def start_profiling(self) -> None:
         self._previous_loop_ts = monotonic_time()
 
-        def loop():
+        def loop() -> None:
             curr = self._current_txn_total_time
             prev = self._previous_txn_total_time
             self._previous_txn_total_time = curr
@@ -1186,7 +1206,7 @@ class DatabasePool:
         if lock:
             self.engine.lock_table(txn, table)
 
-        def _getwhere(key):
+        def _getwhere(key: str) -> str:
             # If the value we're passing in is None (aka NULL), we need to use
             # IS, not =, as NULL = NULL equals NULL (False).
             if keyvalues[key] is None:
@@ -2258,7 +2278,7 @@ class DatabasePool:
         term: Optional[str],
         col: str,
         retcols: Collection[str],
-        desc="simple_search_list",
+        desc: str = "simple_search_list",
     ) -> Optional[List[Dict[str, Any]]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 5895b89202..d545a1c002 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -26,11 +26,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.stats import UserSortOrder
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
 from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import (
-    IdGenerator,
-    MultiWriterIdGenerator,
-    StreamIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.types import JsonDict, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -155,8 +151,6 @@ class DataStore(
             ],
         )
 
-        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
-        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
         self._group_updates_id_gen = StreamIdGenerator(
             db_conn, "local_group_updates", "stream_id"
         )
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 945707b0ec..e284454b66 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -203,19 +203,29 @@ class ApplicationServiceTransactionWorkerStore(
         """Get the application service state.
 
         Args:
-            service: The service whose state to set.
+            service: The service whose state to get.
         Returns:
-            An ApplicationServiceState or none.
+            An ApplicationServiceState, or None if we have yet to attempt any
+            transactions to the AS.
         """
-        result = await self.db_pool.simple_select_one(
+        # if we have created transactions for this AS but not yet attempted to send
+        # them, we will have a row in the table with state=NULL (recording the stream
+        # positions we have processed up to).
+        #
+        # On the other hand, if we have yet to create any transactions for this AS at
+        # all, then there will be no row for the AS.
+        #
+        # In either case, we return None to indicate "we don't yet know the state of
+        # this AS".
+        result = await self.db_pool.simple_select_one_onecol(
             "application_services_state",
             {"as_id": service.id},
-            ["state"],
+            retcol="state",
             allow_none=True,
             desc="get_appservice_state",
         )
         if result:
-            return ApplicationServiceState(result.get("state"))
+            return ApplicationServiceState(result)
         return None
 
     async def set_appservice_state(
@@ -296,14 +306,6 @@ class ApplicationServiceTransactionWorkerStore(
         """
 
         def _complete_appservice_txn(txn: LoggingTransaction) -> None:
-            # Set current txn_id for AS to 'txn_id'
-            self.db_pool.simple_upsert_txn(
-                txn,
-                "application_services_state",
-                {"as_id": service.id},
-                {"last_txn": txn_id},
-            )
-
             # Delete txn
             self.db_pool.simple_delete_txn(
                 txn,
@@ -452,16 +454,15 @@ class ApplicationServiceTransactionWorkerStore(
                 % (stream_type,)
             )
 
-        def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None:
-            stream_id_type = "%s_stream_id" % stream_type
-            txn.execute(
-                "UPDATE application_services_state SET %s = ? WHERE as_id=?"
-                % stream_id_type,
-                (pos, service.id),
-            )
-
-        await self.db_pool.runInteraction(
-            "set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn
+        # this may be the first time that we're recording any state for this AS, so
+        # we don't yet know if a row for it exists; hence we have to upsert here.
+        await self.db_pool.simple_upsert(
+            table="application_services_state",
+            keyvalues={"as_id": service.id},
+            values={f"{stream_type}_stream_id": pos},
+            # no need to lock when emulating upsert: as_id is a unique key
+            lock=False,
+            desc="set_appservice_stream_type_pos",
         )
 
 
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index dd4e83a2ad..1653a6a9b6 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -57,6 +57,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         self._instance_name = hs.get_instance_name()
 
+        self.db_pool.updates.register_background_index_update(
+            update_name="cache_invalidation_index_by_instance",
+            index_name="cache_invalidation_stream_by_instance_instance_index",
+            table="cache_invalidation_stream_by_instance",
+            columns=("instance_name", "stream_id"),
+            psql_only=True,  # The table is only on postgres DBs.
+        )
+
     async def get_all_updated_caches(
         self, instance_name: str, last_id: int, current_id: int, limit: int
     ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index b789a588a5..af59be6b48 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -21,7 +21,7 @@ from synapse.api.errors import StoreError
 from synapse.logging.opentracing import log_kv, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import LoggingTransaction
-from synapse.types import JsonDict, JsonSerializable
+from synapse.types import JsonDict, JsonSerializable, StreamKeyType
 from synapse.util import json_encoder
 
 
@@ -126,7 +126,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                     "message": "Set room key",
                     "room_id": room_id,
                     "session_id": session_id,
-                    "room_key": room_key,
+                    StreamKeyType.ROOM: room_key,
                 }
             )
 
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 4710224708..dcfe8caf47 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -14,7 +14,17 @@
 import itertools
 import logging
 from queue import Empty, PriorityQueue
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    cast,
+)
 
 import attr
 from prometheus_client import Counter, Gauge
@@ -33,7 +43,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.signatures import SignatureWorkerStore
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.lrucache import LruCache
@@ -135,7 +145,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         # Check if we have indexed the room so we can use the chain cover
         # algorithm.
-        room = await self.get_room(room_id)
+        room = await self.get_room(room_id)  # type: ignore[attr-defined]
         if room["has_auth_chain_index"]:
             try:
                 return await self.db_pool.runInteraction(
@@ -158,7 +168,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     def _get_auth_chain_ids_using_cover_index_txn(
-        self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        event_ids: Collection[str],
+        include_given: bool,
     ) -> Set[str]:
         """Calculates the auth chain IDs using the chain index."""
 
@@ -215,9 +229,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         chains: Dict[int, int] = {}
 
         # Add all linked chains reachable from initial set of chains.
-        for batch in batch_iter(event_chains, 1000):
+        for batch2 in batch_iter(event_chains, 1000):
             clause, args = make_in_list_sql_clause(
-                txn.database_engine, "origin_chain_id", batch
+                txn.database_engine, "origin_chain_id", batch2
             )
             txn.execute(sql % (clause,), args)
 
@@ -297,7 +311,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         front = set(event_ids)
         while front:
-            new_front = set()
+            new_front: Set[str] = set()
             for chunk in batch_iter(front, 100):
                 # Pull the auth events either from the cache or DB.
                 to_fetch: List[str] = []  # Event IDs to fetch from DB
@@ -316,7 +330,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
                     # Note we need to batch up the results by event ID before
                     # adding to the cache.
-                    to_cache = {}
+                    to_cache: Dict[str, List[Tuple[str, int]]] = {}
                     for event_id, auth_event_id, auth_event_depth in txn:
                         to_cache.setdefault(event_id, []).append(
                             (auth_event_id, auth_event_depth)
@@ -349,7 +363,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         # Check if we have indexed the room so we can use the chain cover
         # algorithm.
-        room = await self.get_room(room_id)
+        room = await self.get_room(room_id)  # type: ignore[attr-defined]
         if room["has_auth_chain_index"]:
             try:
                 return await self.db_pool.runInteraction(
@@ -370,7 +384,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     def _get_auth_chain_difference_using_cover_index_txn(
-        self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
+        self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]
     ) -> Set[str]:
         """Calculates the auth chain difference using the chain index.
 
@@ -444,9 +458,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         # (We need to take a copy of `seen_chains` as we want to mutate it in
         # the loop)
-        for batch in batch_iter(set(seen_chains), 1000):
+        for batch2 in batch_iter(set(seen_chains), 1000):
             clause, args = make_in_list_sql_clause(
-                txn.database_engine, "origin_chain_id", batch
+                txn.database_engine, "origin_chain_id", batch2
             )
             txn.execute(sql % (clause,), args)
 
@@ -529,7 +543,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         return result
 
     def _get_auth_chain_difference_txn(
-        self, txn, state_sets: List[Set[str]]
+        self, txn: LoggingTransaction, state_sets: List[Set[str]]
     ) -> Set[str]:
         """Calculates the auth chain difference using a breadth first search.
 
@@ -602,7 +616,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
             # I think building a temporary list with fetchall is more efficient than
             # just `search.extend(txn)`, but this is unconfirmed
-            search.extend(txn.fetchall())
+            search.extend(cast(List[Tuple[int, str]], txn.fetchall()))
 
         # sort by depth
         search.sort()
@@ -645,7 +659,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 # We parse the results and add the to the `found` set and the
                 # cache (note we need to batch up the results by event ID before
                 # adding to the cache).
-                to_cache = {}
+                to_cache: Dict[str, List[Tuple[str, int]]] = {}
                 for event_id, auth_event_id, auth_event_depth in txn:
                     to_cache.setdefault(event_id, []).append(
                         (auth_event_id, auth_event_depth)
@@ -696,7 +710,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         return {eid for eid, n in event_to_missing_sets.items() if n}
 
     async def get_oldest_event_ids_with_depth_in_room(
-        self, room_id
+        self, room_id: str
     ) -> List[Tuple[str, int]]:
         """Gets the oldest events(backwards extremities) in the room along with the
         aproximate depth.
@@ -713,7 +727,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             List of (event_id, depth) tuples
         """
 
-        def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id):
+        def get_oldest_event_ids_with_depth_in_room_txn(
+            txn: LoggingTransaction, room_id: str
+        ) -> List[Tuple[str, int]]:
             # Assemble a dictionary with event_id -> depth for the oldest events
             # we know of in the room. Backwards extremeties are the oldest
             # events we know of in the room but we only know of them because
@@ -743,7 +759,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
             txn.execute(sql, (room_id, False))
 
-            return txn.fetchall()
+            return cast(List[Tuple[str, int]], txn.fetchall())
 
         return await self.db_pool.runInteraction(
             "get_oldest_event_ids_with_depth_in_room",
@@ -752,7 +768,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     async def get_insertion_event_backward_extremities_in_room(
-        self, room_id
+        self, room_id: str
     ) -> List[Tuple[str, int]]:
         """Get the insertion events we know about that we haven't backfilled yet.
 
@@ -768,7 +784,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             List of (event_id, depth) tuples
         """
 
-        def get_insertion_event_backward_extremities_in_room_txn(txn, room_id):
+        def get_insertion_event_backward_extremities_in_room_txn(
+            txn: LoggingTransaction, room_id: str
+        ) -> List[Tuple[str, int]]:
             sql = """
                 SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
                 /* We only want insertion events that are also marked as backwards extremities */
@@ -780,7 +798,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             """
 
             txn.execute(sql, (room_id,))
-            return txn.fetchall()
+            return cast(List[Tuple[str, int]], txn.fetchall())
 
         return await self.db_pool.runInteraction(
             "get_insertion_event_backward_extremities_in_room",
@@ -788,7 +806,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             room_id,
         )
 
-    async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
+    async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
         """Returns the event ID and depth for the event that has the max depth from a set of event IDs
 
         Args:
@@ -817,7 +835,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
             return max_depth_event_id, current_max_depth
 
-    async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
+    async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
         """Returns the event ID and depth for the event that has the min depth from a set of event IDs
 
         Args:
@@ -865,7 +883,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
         )
 
-    def _get_prev_events_for_room_txn(self, txn, room_id: str):
+    def _get_prev_events_for_room_txn(
+        self, txn: LoggingTransaction, room_id: str
+    ) -> List[str]:
         # we just use the 10 newest events. Older events will become
         # prev_events of future events.
 
@@ -896,7 +916,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             sorted by extremity count.
         """
 
-        def _get_rooms_with_many_extremities_txn(txn):
+        def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]:
             where_clause = "1=1"
             if room_id_filter:
                 where_clause = "room_id NOT IN (%s)" % (
@@ -937,7 +957,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             "get_min_depth", self._get_min_depth_interaction, room_id
         )
 
-    def _get_min_depth_interaction(self, txn, room_id):
+    def _get_min_depth_interaction(
+        self, txn: LoggingTransaction, room_id: str
+    ) -> Optional[int]:
         min_depth = self.db_pool.simple_select_one_onecol_txn(
             txn,
             table="room_depth",
@@ -966,22 +988,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         """
         # We want to make the cache more effective, so we clamp to the last
         # change before the given ordering.
-        last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
+        last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)  # type: ignore[attr-defined]
 
         # We don't always have a full stream_to_exterm_id table, e.g. after
         # the upgrade that introduced it, so we make sure we never ask for a
         # stream_ordering from before a restart
-        last_change = max(self._stream_order_on_start, last_change)
+        last_change = max(self._stream_order_on_start, last_change)  # type: ignore[attr-defined]
 
         # provided the last_change is recent enough, we now clamp the requested
         # stream_ordering to it.
-        if last_change > self.stream_ordering_month_ago:
+        if last_change > self.stream_ordering_month_ago:  # type: ignore[attr-defined]
             stream_ordering = min(last_change, stream_ordering)
 
         return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
 
     @cached(max_entries=5000, num_args=2)
-    async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
+    async def _get_forward_extremeties_for_room(
+        self, room_id: str, stream_ordering: int
+    ) -> List[str]:
         """For a given room_id and stream_ordering, return the forward
         extremeties of the room at that point in "time".
 
@@ -989,7 +1013,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         stream_orderings from that point.
         """
 
-        if stream_ordering <= self.stream_ordering_month_ago:
+        if stream_ordering <= self.stream_ordering_month_ago:  # type: ignore[attr-defined]
             raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
 
         sql = """
@@ -1002,7 +1026,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 WHERE room_id = ?
         """
 
-        def get_forward_extremeties_for_room_txn(txn):
+        def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]:
             txn.execute(sql, (stream_ordering, room_id))
             return [event_id for event_id, in txn]
 
@@ -1104,8 +1128,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         ]
 
     async def get_backfill_events(
-        self, room_id: str, seed_event_id_list: list, limit: int
-    ):
+        self, room_id: str, seed_event_id_list: List[str], limit: int
+    ) -> List[EventBase]:
         """Get a list of Events for a given topic that occurred before (and
         including) the events in seed_event_id_list. Return a list of max size `limit`
 
@@ -1123,10 +1147,19 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
         events = await self.get_events_as_list(event_ids)
         return sorted(
-            events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering)
+            # type-ignore: mypy doesn't like negating the Optional[int] stream_ordering.
+            # But it's never None, because these events were previously persisted to the DB.
+            events,
+            key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering),  # type: ignore[operator]
         )
 
-    def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit):
+    def _get_backfill_events(
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        seed_event_id_list: List[str],
+        limit: int,
+    ) -> Set[str]:
         """
         We want to make sure that we do a breadth-first, "depth" ordered search.
         We also handle navigating historical branches of history connected by
@@ -1139,7 +1172,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             limit,
         )
 
-        event_id_results = set()
+        event_id_results: Set[str] = set()
 
         # In a PriorityQueue, the lowest valued entries are retrieved first.
         # We're using depth as the priority in the queue and tie-break based on
@@ -1147,7 +1180,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # highest and newest-in-time message. We add events to the queue with a
         # negative depth so that we process the newest-in-time messages first
         # going backwards in time. stream_ordering follows the same pattern.
-        queue = PriorityQueue()
+        queue: "PriorityQueue[Tuple[int, int, str, str]]" = PriorityQueue()
 
         for seed_event_id in seed_event_id_list:
             event_lookup_result = self.db_pool.simple_select_one_txn(
@@ -1253,7 +1286,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         return event_id_results
 
-    async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
+    async def get_missing_events(
+        self,
+        room_id: str,
+        earliest_events: List[str],
+        latest_events: List[str],
+        limit: int,
+    ) -> List[EventBase]:
         ids = await self.db_pool.runInteraction(
             "get_missing_events",
             self._get_missing_events,
@@ -1264,11 +1303,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
         return await self.get_events_as_list(ids)
 
-    def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
+    def _get_missing_events(
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        earliest_events: List[str],
+        latest_events: List[str],
+        limit: int,
+    ) -> List[str]:
 
         seen_events = set(earliest_events)
         front = set(latest_events) - seen_events
-        event_results = []
+        event_results: List[str] = []
 
         query = (
             "SELECT prev_event_id FROM event_edges "
@@ -1311,7 +1357,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
     @wrap_as_background_process("delete_old_forward_extrem_cache")
     async def _delete_old_forward_extrem_cache(self) -> None:
-        def _delete_old_forward_extrem_cache_txn(txn):
+        def _delete_old_forward_extrem_cache_txn(txn: LoggingTransaction) -> None:
             # Delete entries older than a month, while making sure we don't delete
             # the only entries for a room.
             sql = """
@@ -1324,7 +1370,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 ) AND stream_ordering < ?
             """
             txn.execute(
-                sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
+                sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)  # type: ignore[attr-defined]
             )
 
         await self.db_pool.runInteraction(
@@ -1382,7 +1428,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         """
         if self.db_pool.engine.supports_returning:
 
-            def _remove_received_event_from_staging_txn(txn):
+            def _remove_received_event_from_staging_txn(
+                txn: LoggingTransaction,
+            ) -> Optional[int]:
                 sql = """
                     DELETE FROM federation_inbound_events_staging
                     WHERE origin = ? AND event_id = ?
@@ -1390,21 +1438,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 """
 
                 txn.execute(sql, (origin, event_id))
-                return txn.fetchone()
+                row = cast(Optional[Tuple[int]], txn.fetchone())
 
-            row = await self.db_pool.runInteraction(
+                if row is None:
+                    return None
+
+                return row[0]
+
+            return await self.db_pool.runInteraction(
                 "remove_received_event_from_staging",
                 _remove_received_event_from_staging_txn,
                 db_autocommit=True,
             )
-            if row is None:
-                return None
-
-            return row[0]
 
         else:
 
-            def _remove_received_event_from_staging_txn(txn):
+            def _remove_received_event_from_staging_txn(
+                txn: LoggingTransaction,
+            ) -> Optional[int]:
                 received_ts = self.db_pool.simple_select_one_onecol_txn(
                     txn,
                     table="federation_inbound_events_staging",
@@ -1437,7 +1488,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
     ) -> Optional[Tuple[str, str]]:
         """Get the next event ID in the staging area for the given room."""
 
-        def _get_next_staged_event_id_for_room_txn(txn):
+        def _get_next_staged_event_id_for_room_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[Tuple[str, str]]:
             sql = """
                 SELECT origin, event_id
                 FROM federation_inbound_events_staging
@@ -1448,7 +1501,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
             txn.execute(sql, (room_id,))
 
-            return txn.fetchone()
+            return cast(Optional[Tuple[str, str]], txn.fetchone())
 
         return await self.db_pool.runInteraction(
             "get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn
@@ -1461,7 +1514,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
     ) -> Optional[Tuple[str, EventBase]]:
         """Get the next event in the staging area for the given room."""
 
-        def _get_next_staged_event_for_room_txn(txn):
+        def _get_next_staged_event_for_room_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[Tuple[str, str, str]]:
             sql = """
                 SELECT event_json, internal_metadata, origin
                 FROM federation_inbound_events_staging
@@ -1471,7 +1526,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             """
             txn.execute(sql, (room_id,))
 
-            return txn.fetchone()
+            return cast(Optional[Tuple[str, str, str]], txn.fetchone())
 
         row = await self.db_pool.runInteraction(
             "get_next_staged_event_for_room", _get_next_staged_event_for_room_txn
@@ -1599,18 +1654,20 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     @wrap_as_background_process("_get_stats_for_federation_staging")
-    async def _get_stats_for_federation_staging(self):
+    async def _get_stats_for_federation_staging(self) -> None:
         """Update the prometheus metrics for the inbound federation staging area."""
 
-        def _get_stats_for_federation_staging_txn(txn):
+        def _get_stats_for_federation_staging_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[int, int]:
             txn.execute("SELECT count(*) FROM federation_inbound_events_staging")
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
 
             txn.execute(
                 "SELECT min(received_ts) FROM federation_inbound_events_staging"
             )
 
-            (received_ts,) = txn.fetchone()
+            (received_ts,) = cast(Tuple[Optional[int]], txn.fetchone())
 
             # If there is nothing in the staging area default it to 0.
             age = 0
@@ -1651,19 +1708,21 @@ class EventFederationStore(EventFederationWorkerStore):
             self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
         )
 
-    async def clean_room_for_join(self, room_id):
-        return await self.db_pool.runInteraction(
+    async def clean_room_for_join(self, room_id: str) -> None:
+        await self.db_pool.runInteraction(
             "clean_room_for_join", self._clean_room_for_join_txn, room_id
         )
 
-    def _clean_room_for_join_txn(self, txn, room_id):
+    def _clean_room_for_join_txn(self, txn: LoggingTransaction, room_id: str) -> None:
         query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
 
         txn.execute(query, (room_id,))
         txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
 
-    async def _background_delete_non_state_event_auth(self, progress, batch_size):
-        def delete_event_auth(txn):
+    async def _background_delete_non_state_event_auth(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        def delete_event_auth(txn: LoggingTransaction) -> bool:
             target_min_stream_id = progress.get("target_min_stream_id_inclusive")
             max_stream_id = progress.get("max_stream_id_exclusive")
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 2c86a870cf..0df8ff5395 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -36,9 +36,8 @@ from prometheus_client import Counter
 import synapse.metrics
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.room_versions import RoomVersions
-from synapse.crypto.event_signing import compute_event_reference_hash
-from synapse.events import EventBase  # noqa: F401
-from synapse.events.snapshot import EventContext  # noqa: F401
+from synapse.events import EventBase, relation_from_event
+from synapse.events.snapshot import EventContext
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -50,7 +49,7 @@ from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.engines.postgres import PostgresEngine
 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
-from synapse.types import StateMap, get_domain_from_id
+from synapse.types import JsonDict, StateMap, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.iterutils import batch_iter, sorted_topologically
 from synapse.util.stringutils import non_null_str_or_none
@@ -130,7 +129,6 @@ class PersistEventsStore:
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         *,
-        current_state_for_room: Dict[str, StateMap[str]],
         state_delta_for_room: Dict[str, DeltaState],
         new_forward_extremities: Dict[str, Set[str]],
         use_negative_stream_ordering: bool = False,
@@ -141,8 +139,6 @@ class PersistEventsStore:
 
         Args:
             events_and_contexts:
-            current_state_for_room: Map from room_id to the current state of
-                the room based on forward extremities
             state_delta_for_room: Map from room_id to the delta to apply to
                 room state
             new_forward_extremities: Map from room_id to set of event IDs
@@ -217,9 +213,6 @@ class PersistEventsStore:
 
                 event_counter.labels(event.type, origin_type, origin_entity).inc()
 
-            for room_id, new_state in current_state_for_room.items():
-                self.store.get_current_state_ids.prefill((room_id,), new_state)
-
             for room_id, latest_event_ids in new_forward_extremities.items():
                 self.store.get_latest_event_ids_in_room.prefill(
                     (room_id,), list(latest_event_ids)
@@ -237,7 +230,9 @@ class PersistEventsStore:
         """
         results: List[str] = []
 
-        def _get_events_which_are_prevs_txn(txn, batch):
+        def _get_events_which_are_prevs_txn(
+            txn: LoggingTransaction, batch: Collection[str]
+        ) -> None:
             sql = """
             SELECT prev_event_id, internal_metadata
             FROM event_edges
@@ -287,7 +282,9 @@ class PersistEventsStore:
         # and their prev events.
         existing_prevs = set()
 
-        def _get_prevs_before_rejected_txn(txn, batch):
+        def _get_prevs_before_rejected_txn(
+            txn: LoggingTransaction, batch: Collection[str]
+        ) -> None:
             to_recursively_check = batch
 
             while to_recursively_check:
@@ -517,7 +514,7 @@ class PersistEventsStore:
     @classmethod
     def _add_chain_cover_index(
         cls,
-        txn,
+        txn: LoggingTransaction,
         db_pool: DatabasePool,
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
@@ -811,7 +808,7 @@ class PersistEventsStore:
 
     @staticmethod
     def _allocate_chain_ids(
-        txn,
+        txn: LoggingTransaction,
         db_pool: DatabasePool,
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
@@ -945,7 +942,7 @@ class PersistEventsStore:
         self,
         txn: LoggingTransaction,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-    ):
+    ) -> None:
         """Persist the mapping from transaction IDs to event IDs (if defined)."""
 
         to_insert = []
@@ -999,7 +996,7 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         state_delta_by_room: Dict[str, DeltaState],
         stream_id: int,
-    ):
+    ) -> None:
         for room_id, delta_state in state_delta_by_room.items():
             to_delete = delta_state.to_delete
             to_insert = delta_state.to_insert
@@ -1157,7 +1154,7 @@ class PersistEventsStore:
                 txn, room_id, members_changed
             )
 
-    def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
+    def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
         """Update the room version in the database based off current state
         events.
 
@@ -1191,7 +1188,7 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         new_forward_extremities: Dict[str, Set[str]],
         max_stream_order: int,
-    ):
+    ) -> None:
         for room_id in new_forward_extremities.keys():
             self.db_pool.simple_delete_txn(
                 txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
@@ -1256,9 +1253,9 @@ class PersistEventsStore:
 
     def _update_room_depths_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-    ):
+    ) -> None:
         """Update min_depth for each room
 
         Args:
@@ -1387,7 +1384,7 @@ class PersistEventsStore:
             # nothing to do here
             return
 
-        def event_dict(event):
+        def event_dict(event: EventBase) -> JsonDict:
             d = event.get_dict()
             d.pop("redacted", None)
             d.pop("redacted_because", None)
@@ -1478,18 +1475,20 @@ class PersistEventsStore:
             ),
         )
 
-    def _store_rejected_events_txn(self, txn, events_and_contexts):
+    def _store_rejected_events_txn(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> List[Tuple[EventBase, EventContext]]:
         """Add rows to the 'rejections' table for received events which were
         rejected
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
+            txn: db connection
+            events_and_contexts: events we are persisting
 
         Returns:
-            list[(EventBase, EventContext)] new list, without the rejected
-                events.
+            new list, without the rejected events.
         """
         # Remove the rejected events from the list now that we've added them
         # to the events table and the events_json table.
@@ -1510,7 +1509,7 @@ class PersistEventsStore:
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         all_events_and_contexts: List[Tuple[EventBase, EventContext]],
         inhibit_local_membership_updates: bool = False,
-    ):
+    ) -> None:
         """Update all the miscellaneous tables for new events
 
         Args:
@@ -1601,15 +1600,14 @@ class PersistEventsStore:
             inhibit_local_membership_updates=inhibit_local_membership_updates,
         )
 
-        # Insert event_reference_hashes table.
-        self._store_event_reference_hashes_txn(
-            txn, [event for event, _ in events_and_contexts]
-        )
-
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
 
-    def _add_to_cache(self, txn, events_and_contexts):
+    def _add_to_cache(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> None:
         to_prefill = []
 
         rows = []
@@ -1640,7 +1638,7 @@ class PersistEventsStore:
             if not row["rejects"] and not row["redacts"]:
                 to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
 
-        def prefill():
+        def prefill() -> None:
             for cache_entry in to_prefill:
                 self.store._get_event_cache.set(
                     (cache_entry.event.event_id,), cache_entry
@@ -1670,19 +1668,24 @@ class PersistEventsStore:
         )
 
     def insert_labels_for_event_txn(
-        self, txn, event_id, labels, room_id, topological_ordering
-    ):
+        self,
+        txn: LoggingTransaction,
+        event_id: str,
+        labels: List[str],
+        room_id: str,
+        topological_ordering: int,
+    ) -> None:
         """Store the mapping between an event's ID and its labels, with one row per
         (event_id, label) tuple.
 
         Args:
-            txn (LoggingTransaction): The transaction to execute.
-            event_id (str): The event's ID.
-            labels (list[str]): A list of text labels.
-            room_id (str): The ID of the room the event was sent to.
-            topological_ordering (int): The position of the event in the room's topology.
+            txn: The transaction to execute.
+            event_id: The event's ID.
+            labels: A list of text labels.
+            room_id: The ID of the room the event was sent to.
+            topological_ordering: The position of the event in the room's topology.
         """
-        return self.db_pool.simple_insert_many_txn(
+        self.db_pool.simple_insert_many_txn(
             txn=txn,
             table="event_labels",
             keys=("event_id", "label", "room_id", "topological_ordering"),
@@ -1691,44 +1694,32 @@ class PersistEventsStore:
             ],
         )
 
-    def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+    def _insert_event_expiry_txn(
+        self, txn: LoggingTransaction, event_id: str, expiry_ts: int
+    ) -> None:
         """Save the expiry timestamp associated with a given event ID.
 
         Args:
-            txn (LoggingTransaction): The database transaction to use.
-            event_id (str): The event ID the expiry timestamp is associated with.
-            expiry_ts (int): The timestamp at which to expire (delete) the event.
+            txn: The database transaction to use.
+            event_id: The event ID the expiry timestamp is associated with.
+            expiry_ts: The timestamp at which to expire (delete) the event.
         """
-        return self.db_pool.simple_insert_txn(
+        self.db_pool.simple_insert_txn(
             txn=txn,
             table="event_expiry",
             values={"event_id": event_id, "expiry_ts": expiry_ts},
         )
 
-    def _store_event_reference_hashes_txn(self, txn, events):
-        """Store a hash for a PDU
-        Args:
-            txn (cursor):
-            events (list): list of Events.
-        """
-
-        vals = []
-        for event in events:
-            ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
-            vals.append((event.event_id, ref_alg, memoryview(ref_hash_bytes)))
-
-        self.db_pool.simple_insert_many_txn(
-            txn,
-            table="event_reference_hashes",
-            keys=("event_id", "algorithm", "hash"),
-            values=vals,
-        )
-
     def _store_room_members_txn(
-        self, txn, events, *, inhibit_local_membership_updates: bool = False
-    ):
+        self,
+        txn: LoggingTransaction,
+        events: List[EventBase],
+        *,
+        inhibit_local_membership_updates: bool = False,
+    ) -> None:
         """
         Store a room member in the database.
+
         Args:
             txn: The transaction to use.
             events: List of events to store.
@@ -1765,6 +1756,7 @@ class PersistEventsStore:
         )
 
         for event in events:
+            assert event.internal_metadata.stream_ordering is not None
             txn.call_after(
                 self.store._membership_stream_cache.entity_has_changed,
                 event.state_key,
@@ -1813,55 +1805,50 @@ class PersistEventsStore:
             txn: The current database transaction.
             event: The event which might have relations.
         """
-        relation = event.content.get("m.relates_to")
+        relation = relation_from_event(event)
         if not relation:
-            # No relations
-            return
-
-        # Relations must have a type and parent event ID.
-        rel_type = relation.get("rel_type")
-        if not isinstance(rel_type, str):
+            # No relation, nothing to do.
             return
 
-        parent_id = relation.get("event_id")
-        if not isinstance(parent_id, str):
-            return
-
-        # Annotations have a key field.
-        aggregation_key = None
-        if rel_type == RelationTypes.ANNOTATION:
-            aggregation_key = relation.get("key")
-
         self.db_pool.simple_insert_txn(
             txn,
             table="event_relations",
             values={
                 "event_id": event.event_id,
-                "relates_to_id": parent_id,
-                "relation_type": rel_type,
-                "aggregation_key": aggregation_key,
+                "relates_to_id": relation.parent_id,
+                "relation_type": relation.rel_type,
+                "aggregation_key": relation.aggregation_key,
             },
         )
 
-        txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,))
         txn.call_after(
-            self.store.get_aggregation_groups_for_event.invalidate, (parent_id,)
+            self.store.get_relations_for_event.invalidate, (relation.parent_id,)
+        )
+        txn.call_after(
+            self.store.get_aggregation_groups_for_event.invalidate,
+            (relation.parent_id,),
         )
 
-        if rel_type == RelationTypes.REPLACE:
-            txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+        if relation.rel_type == RelationTypes.REPLACE:
+            txn.call_after(
+                self.store.get_applicable_edit.invalidate, (relation.parent_id,)
+            )
 
-        if rel_type == RelationTypes.THREAD:
-            txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+        if relation.rel_type == RelationTypes.THREAD:
+            txn.call_after(
+                self.store.get_thread_summary.invalidate, (relation.parent_id,)
+            )
             # It should be safe to only invalidate the cache if the user has not
             # previously participated in the thread, but that's difficult (and
             # potentially error-prone) so it is always invalidated.
             txn.call_after(
                 self.store.get_thread_participated.invalidate,
-                (parent_id, event.sender),
+                (relation.parent_id, event.sender),
             )
 
-    def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
+    def _handle_insertion_event(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         """Handles keeping track of insertion events and edges/connections.
         Part of MSC2716.
 
@@ -1922,7 +1909,7 @@ class PersistEventsStore:
                 },
             )
 
-    def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase):
+    def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None:
         """Handles inserting the batch edges/connections between the batch event
         and an insertion event. Part of MSC2716.
 
@@ -2022,25 +2009,29 @@ class PersistEventsStore:
             txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
         )
 
-    def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
         if isinstance(event.content.get("topic"), str):
             self.store_event_search_txn(
                 txn, event, "content.topic", event.content["topic"]
             )
 
-    def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
         if isinstance(event.content.get("name"), str):
             self.store_event_search_txn(
                 txn, event, "content.name", event.content["name"]
             )
 
-    def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_message_txn(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         if isinstance(event.content.get("body"), str):
             self.store_event_search_txn(
                 txn, event, "content.body", event.content["body"]
             )
 
-    def _store_retention_policy_for_room_txn(self, txn, event):
+    def _store_retention_policy_for_room_txn(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         if not event.is_state():
             logger.debug("Ignoring non-state m.room.retention event")
             return
@@ -2100,8 +2091,11 @@ class PersistEventsStore:
         )
 
     def _set_push_actions_for_event_and_users_txn(
-        self, txn, events_and_contexts, all_events_and_contexts
-    ):
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        all_events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> None:
         """Handles moving push actions from staging table to main
         event_push_actions table for all events in `events_and_contexts`.
 
@@ -2109,12 +2103,10 @@ class PersistEventsStore:
         from the push action staging area.
 
         Args:
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
-            all_events_and_contexts (list[(EventBase, EventContext)]): all
-                events that we were going to persist. This includes events
-                we've already persisted, etc, that wouldn't appear in
-                events_and_context.
+            events_and_contexts: events we are persisting
+            all_events_and_contexts: all events that we were going to persist.
+                This includes events we've already persisted, etc, that wouldn't
+                appear in events_and_context.
         """
 
         # Only non outlier events will have push actions associated with them,
@@ -2183,7 +2175,9 @@ class PersistEventsStore:
             ),
         )
 
-    def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
+    def _remove_push_actions_for_event_id_txn(
+        self, txn: LoggingTransaction, room_id: str, event_id: str
+    ) -> None:
         # Sad that we have to blow away the cache for the whole room here
         txn.call_after(
             self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
@@ -2194,7 +2188,9 @@ class PersistEventsStore:
             (room_id, event_id),
         )
 
-    def _store_rejections_txn(self, txn, event_id, reason):
+    def _store_rejections_txn(
+        self, txn: LoggingTransaction, event_id: str, reason: str
+    ) -> None:
         self.db_pool.simple_insert_txn(
             txn,
             table="rejections",
@@ -2206,8 +2202,10 @@ class PersistEventsStore:
         )
 
     def _store_event_state_mappings_txn(
-        self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
-    ):
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: Collection[Tuple[EventBase, EventContext]],
+    ) -> None:
         state_groups = {}
         for event, context in events_and_contexts:
             if event.internal_metadata.is_outlier():
@@ -2264,7 +2262,9 @@ class PersistEventsStore:
                 state_group_id,
             )
 
-    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+    def _update_min_depth_for_room_txn(
+        self, txn: LoggingTransaction, room_id: str, depth: int
+    ) -> None:
         min_depth = self.store._get_min_depth_interaction(txn, room_id)
 
         if min_depth is not None and depth >= min_depth:
@@ -2277,7 +2277,9 @@ class PersistEventsStore:
             values={"min_depth": depth},
         )
 
-    def _handle_mult_prev_events(self, txn, events):
+    def _handle_mult_prev_events(
+        self, txn: LoggingTransaction, events: List[EventBase]
+    ) -> None:
         """
         For the given event, update the event edges table and forward and
         backward extremities tables.
@@ -2295,7 +2297,9 @@ class PersistEventsStore:
 
         self._update_backward_extremeties(txn, events)
 
-    def _update_backward_extremeties(self, txn, events):
+    def _update_backward_extremeties(
+        self, txn: LoggingTransaction, events: List[EventBase]
+    ) -> None:
         """Updates the event_backward_extremities tables based on the new/updated
         events being persisted.
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index a4a604a499..5b22d6b452 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -14,6 +14,7 @@
 
 import logging
 import threading
+import weakref
 from enum import Enum, auto
 from typing import (
     TYPE_CHECKING,
@@ -23,6 +24,7 @@ from typing import (
     Dict,
     Iterable,
     List,
+    MutableMapping,
     Optional,
     Set,
     Tuple,
@@ -248,6 +250,12 @@ class EventsWorkerStore(SQLBaseStore):
             str, ObservableDeferred[Dict[str, EventCacheEntry]]
         ] = {}
 
+        # We keep track of the events we have currently loaded in memory so that
+        # we can reuse them even if they've been evicted from the cache. We only
+        # track events that don't need redacting in here (as then we don't need
+        # to track redaction status).
+        self._event_ref: MutableMapping[str, EventBase] = weakref.WeakValueDictionary()
+
         self._event_fetch_lock = threading.Condition()
         self._event_fetch_list: List[
             Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
@@ -723,6 +731,8 @@ class EventsWorkerStore(SQLBaseStore):
 
     def _invalidate_get_event_cache(self, event_id: str) -> None:
         self._get_event_cache.invalidate((event_id,))
+        self._event_ref.pop(event_id, None)
+        self._current_event_fetches.pop(event_id, None)
 
     def _get_events_from_cache(
         self, events: Iterable[str], update_metrics: bool = True
@@ -738,13 +748,30 @@ class EventsWorkerStore(SQLBaseStore):
         event_map = {}
 
         for event_id in events:
+            # First check if it's in the event cache
             ret = self._get_event_cache.get(
                 (event_id,), None, update_metrics=update_metrics
             )
-            if not ret:
+            if ret:
+                event_map[event_id] = ret
                 continue
 
-            event_map[event_id] = ret
+            # Otherwise check if we still have the event in memory.
+            event = self._event_ref.get(event_id)
+            if event:
+                # Reconstruct an event cache entry
+
+                cache_entry = EventCacheEntry(
+                    event=event,
+                    # We don't cache weakrefs to redacted events, so we know
+                    # this is None.
+                    redacted_event=None,
+                )
+                event_map[event_id] = cache_entry
+
+                # We add the entry back into the cache as we want to keep
+                # recently queried events in the cache.
+                self._get_event_cache.set((event_id,), cache_entry)
 
         return event_map
 
@@ -1124,6 +1151,10 @@ class EventsWorkerStore(SQLBaseStore):
             self._get_event_cache.set((event_id,), cache_entry)
             result_map[event_id] = cache_entry
 
+            if not redacted_event:
+                # We only cache references to unredacted events.
+                self._event_ref[event_id] = original_ev
+
         return result_map
 
     async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 1480a0f048..14294a0bb8 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -14,12 +14,16 @@
 import calendar
 import logging
 import time
-from typing import TYPE_CHECKING, Dict
+from typing import TYPE_CHECKING, Dict, List, Tuple, cast
 
 from synapse.metrics import GaugeBucketCollector
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
 )
@@ -71,8 +75,8 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         self._last_user_visit_update = self._get_start_of_day()
 
     @wrap_as_background_process("read_forward_extremities")
-    async def _read_forward_extremities(self):
-        def fetch(txn):
+    async def _read_forward_extremities(self) -> None:
+        def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]:
             txn.execute(
                 """
                 SELECT t1.c, t2.c
@@ -85,7 +89,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
                 ) t2 ON t1.room_id = t2.room_id
                 """
             )
-            return txn.fetchall()
+            return cast(List[Tuple[int, int]], txn.fetchall())
 
         res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
 
@@ -95,7 +99,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             (x[0] - 1) * x[1] for x in res if x[1]
         )
 
-    async def count_daily_e2ee_messages(self):
+    async def count_daily_e2ee_messages(self) -> int:
         """
         Returns an estimate of the number of messages sent in the last day.
 
@@ -103,20 +107,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         call to this function, it will return None.
         """
 
-        def _count_messages(txn):
+        def _count_messages(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT COUNT(*) FROM events
                 WHERE type = 'm.room.encrypted'
                 AND stream_ordering > ?
             """
             txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
 
-    async def count_daily_sent_e2ee_messages(self):
-        def _count_messages(txn):
+    async def count_daily_sent_e2ee_messages(self) -> int:
+        def _count_messages(txn: LoggingTransaction) -> int:
             # This is good enough as if you have silly characters in your own
             # hostname then that's your own fault.
             like_clause = "%:" + self.hs.hostname
@@ -129,29 +133,29 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             """
 
             txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction(
             "count_daily_sent_e2ee_messages", _count_messages
         )
 
-    async def count_daily_active_e2ee_rooms(self):
-        def _count(txn):
+    async def count_daily_active_e2ee_rooms(self) -> int:
+        def _count(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT COUNT(DISTINCT room_id) FROM events
                 WHERE type = 'm.room.encrypted'
                 AND stream_ordering > ?
             """
             txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction(
             "count_daily_active_e2ee_rooms", _count
         )
 
-    async def count_daily_messages(self):
+    async def count_daily_messages(self) -> int:
         """
         Returns an estimate of the number of messages sent in the last day.
 
@@ -159,20 +163,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         call to this function, it will return None.
         """
 
-        def _count_messages(txn):
+        def _count_messages(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT COUNT(*) FROM events
                 WHERE type = 'm.room.message'
                 AND stream_ordering > ?
             """
             txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction("count_messages", _count_messages)
 
-    async def count_daily_sent_messages(self):
-        def _count_messages(txn):
+    async def count_daily_sent_messages(self) -> int:
+        def _count_messages(txn: LoggingTransaction) -> int:
             # This is good enough as if you have silly characters in your own
             # hostname then that's your own fault.
             like_clause = "%:" + self.hs.hostname
@@ -185,22 +189,22 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             """
 
             txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction(
             "count_daily_sent_messages", _count_messages
         )
 
-    async def count_daily_active_rooms(self):
-        def _count(txn):
+    async def count_daily_active_rooms(self) -> int:
+        def _count(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT COUNT(DISTINCT room_id) FROM events
                 WHERE type = 'm.room.message'
                 AND stream_ordering > ?
             """
             txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
@@ -226,7 +230,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_monthly_users", self._count_users, thirty_days_ago
         )
 
-    def _count_users(self, txn, time_from):
+    def _count_users(self, txn: LoggingTransaction, time_from: int) -> int:
         """
         Returns number of users seen in the past time_from period
         """
@@ -238,7 +242,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             ) u
         """
         txn.execute(sql, (time_from,))
-        (count,) = txn.fetchone()
+        # Mypy knows that fetchone() might return None if there are no rows.
+        # We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
+        # returns exactly one row.
+        (count,) = cast(Tuple[int], txn.fetchone())
         return count
 
     async def count_r30_users(self) -> Dict[str, int]:
@@ -252,7 +259,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
              A mapping of counts globally as well as broken out by platform.
         """
 
-        def _count_r30_users(txn):
+        def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]:
             thirty_days_in_secs = 86400 * 30
             now = int(self._clock.time())
             thirty_days_ago_in_secs = now - thirty_days_in_secs
@@ -317,7 +324,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
             txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
 
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             results["all"] = count
 
             return results
@@ -344,7 +351,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
               - "web" (any web application -- it's not possible to distinguish Element Web here)
         """
 
-        def _count_r30v2_users(txn):
+        def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]:
             thirty_days_in_secs = 86400 * 30
             now = int(self._clock.time())
             sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
@@ -441,11 +448,8 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
                     thirty_days_in_secs * 1000,
                 ),
             )
-            row = txn.fetchone()
-            if row is None:
-                results["all"] = 0
-            else:
-                results["all"] = row[0]
+            (count,) = cast(Tuple[int], txn.fetchone())
+            results["all"] = count
 
             return results
 
@@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_r30v2_users", _count_r30v2_users
         )
 
-    def _get_start_of_day(self):
+    def _get_start_of_day(self) -> int:
         """
         Returns millisecond unixtime for start of UTC day.
         """
@@ -467,7 +471,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         Generates daily visit data for use in cohort/ retention analysis
         """
 
-        def _generate_user_daily_visits(txn):
+        def _generate_user_daily_visits(txn: LoggingTransaction) -> None:
             logger.info("Calling _generate_user_daily_visits")
             today_start = self._get_start_of_day()
             a_day_in_milliseconds = 24 * 60 * 60 * 1000
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index bfc85b3add..c94d5f9f81 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -69,7 +69,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         #     event_forward_extremities
         #     event_json
         #     event_push_actions
-        #     event_reference_hashes
         #     event_relations
         #     event_search
         #     event_to_state_groups
@@ -220,7 +219,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "event_auth",
             "event_edges",
             "event_forward_extremities",
-            "event_reference_hashes",
             "event_relations",
             "event_search",
             "rejections",
@@ -369,7 +367,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "event_edges",
             "event_json",
             "event_push_actions_staging",
-            "event_reference_hashes",
             "event_relations",
             "event_to_state_groups",
             "event_auth_chains",
@@ -420,6 +417,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "room_account_data",
             "room_tags",
             "local_current_membership",
+            "federation_inbound_events_staging",
         ):
             logger.info("[purge] removing %s from %s", room_id, table)
             txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 4ed913e248..ad67901cc1 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,14 +14,18 @@
 # limitations under the License.
 import abc
 import logging
-from typing import TYPE_CHECKING, Dict, List, Tuple, Union
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
 
 from synapse.api.errors import StoreError
 from synapse.config.homeserver import ExperimentalConfig
 from synapse.push.baserules import list_with_base_rules
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.pusher import PusherWorkerStore
@@ -30,9 +34,12 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
 from synapse.storage.util.id_generators import (
+    AbstractStreamIdGenerator,
     AbstractStreamIdTracker,
+    IdGenerator,
     StreamIdGenerator,
 )
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -57,7 +64,11 @@ def _is_experimental_rule_enabled(
     return True
 
 
-def _load_rules(rawrules, enabled_map, experimental_config: ExperimentalConfig):
+def _load_rules(
+    rawrules: List[JsonDict],
+    enabled_map: Dict[str, bool],
+    experimental_config: ExperimentalConfig,
+) -> List[JsonDict]:
     ruleslist = []
     for rawrule in rawrules:
         rule = dict(rawrule)
@@ -137,7 +148,7 @@ class PushRulesWorkerStore(
         )
 
     @abc.abstractmethod
-    def get_max_push_rules_stream_id(self):
+    def get_max_push_rules_stream_id(self) -> int:
         """Get the position of the push rules stream.
 
         Returns:
@@ -146,7 +157,7 @@ class PushRulesWorkerStore(
         raise NotImplementedError()
 
     @cached(max_entries=5000)
-    async def get_push_rules_for_user(self, user_id):
+    async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
         rows = await self.db_pool.simple_select_list(
             table="push_rules",
             keyvalues={"user_name": user_id},
@@ -168,7 +179,7 @@ class PushRulesWorkerStore(
         return _load_rules(rows, enabled_map, self.hs.config.experimental)
 
     @cached(max_entries=5000)
-    async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]:
+    async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
         results = await self.db_pool.simple_select_list(
             table="push_rules_enable",
             keyvalues={"user_name": user_id},
@@ -184,13 +195,13 @@ class PushRulesWorkerStore(
             return False
         else:
 
-            def have_push_rules_changed_txn(txn):
+            def have_push_rules_changed_txn(txn: LoggingTransaction) -> bool:
                 sql = (
                     "SELECT COUNT(stream_id) FROM push_rules_stream"
                     " WHERE user_id = ? AND ? < stream_id"
                 )
                 txn.execute(sql, (user_id, last_id))
-                (count,) = txn.fetchone()
+                (count,) = cast(Tuple[int], txn.fetchone())
                 return bool(count)
 
             return await self.db_pool.runInteraction(
@@ -202,11 +213,13 @@ class PushRulesWorkerStore(
         list_name="user_ids",
         num_args=1,
     )
-    async def bulk_get_push_rules(self, user_ids):
+    async def bulk_get_push_rules(
+        self, user_ids: Collection[str]
+    ) -> Dict[str, List[JsonDict]]:
         if not user_ids:
             return {}
 
-        results = {user_id: [] for user_id in user_ids}
+        results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
 
         rows = await self.db_pool.simple_select_many_batch(
             table="push_rules",
@@ -230,67 +243,18 @@ class PushRulesWorkerStore(
 
         return results
 
-    async def copy_push_rule_from_room_to_room(
-        self, new_room_id: str, user_id: str, rule: dict
-    ) -> None:
-        """Copy a single push rule from one room to another for a specific user.
-
-        Args:
-            new_room_id: ID of the new room.
-            user_id : ID of user the push rule belongs to.
-            rule: A push rule.
-        """
-        # Create new rule id
-        rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
-        new_rule_id = rule_id_scope + "/" + new_room_id
-
-        # Change room id in each condition
-        for condition in rule.get("conditions", []):
-            if condition.get("key") == "room_id":
-                condition["pattern"] = new_room_id
-
-        # Add the rule for the new room
-        await self.add_push_rule(
-            user_id=user_id,
-            rule_id=new_rule_id,
-            priority_class=rule["priority_class"],
-            conditions=rule["conditions"],
-            actions=rule["actions"],
-        )
-
-    async def copy_push_rules_from_room_to_room_for_user(
-        self, old_room_id: str, new_room_id: str, user_id: str
-    ) -> None:
-        """Copy all of the push rules from one room to another for a specific
-        user.
-
-        Args:
-            old_room_id: ID of the old room.
-            new_room_id: ID of the new room.
-            user_id: ID of user to copy push rules for.
-        """
-        # Retrieve push rules for this user
-        user_push_rules = await self.get_push_rules_for_user(user_id)
-
-        # Get rules relating to the old room and copy them to the new room
-        for rule in user_push_rules:
-            conditions = rule.get("conditions", [])
-            if any(
-                (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
-                for c in conditions
-            ):
-                await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
-
     @cachedList(
         cached_method_name="get_push_rules_enabled_for_user",
         list_name="user_ids",
         num_args=1,
     )
-    async def bulk_get_push_rules_enabled(self, user_ids):
+    async def bulk_get_push_rules_enabled(
+        self, user_ids: Collection[str]
+    ) -> Dict[str, Dict[str, bool]]:
         if not user_ids:
             return {}
 
-        results = {user_id: {} for user_id in user_ids}
+        results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids}
 
         rows = await self.db_pool.simple_select_many_batch(
             table="push_rules_enable",
@@ -306,7 +270,7 @@ class PushRulesWorkerStore(
 
     async def get_all_push_rule_updates(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+    ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
         """Get updates for push_rules replication stream.
 
         Args:
@@ -331,7 +295,9 @@ class PushRulesWorkerStore(
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_push_rule_updates_txn(txn):
+        def get_all_push_rule_updates_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
             sql = """
                 SELECT stream_id, user_id
                 FROM push_rules_stream
@@ -340,7 +306,10 @@ class PushRulesWorkerStore(
                 LIMIT ?
             """
             txn.execute(sql, (last_id, current_id, limit))
-            updates = [(stream_id, (user_id,)) for stream_id, user_id in txn]
+            updates = cast(
+                List[Tuple[int, Tuple[str]]],
+                [(stream_id, (user_id,)) for stream_id, user_id in txn],
+            )
 
             limited = False
             upper_bound = current_id
@@ -356,15 +325,30 @@ class PushRulesWorkerStore(
 
 
 class PushRuleStore(PushRulesWorkerStore):
+    # Because we have write access, this will be a StreamIdGenerator
+    # (see PushRulesWorkerStore.__init__)
+    _push_rules_stream_id_gen: AbstractStreamIdGenerator
+
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
+        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
+
     async def add_push_rule(
         self,
-        user_id,
-        rule_id,
-        priority_class,
-        conditions,
-        actions,
-        before=None,
-        after=None,
+        user_id: str,
+        rule_id: str,
+        priority_class: int,
+        conditions: List[Dict[str, str]],
+        actions: List[Union[JsonDict, str]],
+        before: Optional[str] = None,
+        after: Optional[str] = None,
     ) -> None:
         conditions_json = json_encoder.encode(conditions)
         actions_json = json_encoder.encode(actions)
@@ -400,17 +384,17 @@ class PushRuleStore(PushRulesWorkerStore):
 
     def _add_push_rule_relative_txn(
         self,
-        txn,
-        stream_id,
-        event_stream_ordering,
-        user_id,
-        rule_id,
-        priority_class,
-        conditions_json,
-        actions_json,
-        before,
-        after,
-    ):
+        txn: LoggingTransaction,
+        stream_id: int,
+        event_stream_ordering: int,
+        user_id: str,
+        rule_id: str,
+        priority_class: int,
+        conditions_json: str,
+        actions_json: str,
+        before: str,
+        after: str,
+    ) -> None:
         # Lock the table since otherwise we'll have annoying races between the
         # SELECT here and the UPSERT below.
         self.database_engine.lock_table(txn, "push_rules")
@@ -470,15 +454,15 @@ class PushRuleStore(PushRulesWorkerStore):
 
     def _add_push_rule_highest_priority_txn(
         self,
-        txn,
-        stream_id,
-        event_stream_ordering,
-        user_id,
-        rule_id,
-        priority_class,
-        conditions_json,
-        actions_json,
-    ):
+        txn: LoggingTransaction,
+        stream_id: int,
+        event_stream_ordering: int,
+        user_id: str,
+        rule_id: str,
+        priority_class: int,
+        conditions_json: str,
+        actions_json: str,
+    ) -> None:
         # Lock the table since otherwise we'll have annoying races between the
         # SELECT here and the UPSERT below.
         self.database_engine.lock_table(txn, "push_rules")
@@ -510,17 +494,17 @@ class PushRuleStore(PushRulesWorkerStore):
 
     def _upsert_push_rule_txn(
         self,
-        txn,
-        stream_id,
-        event_stream_ordering,
-        user_id,
-        rule_id,
-        priority_class,
-        priority,
-        conditions_json,
-        actions_json,
-        update_stream=True,
-    ):
+        txn: LoggingTransaction,
+        stream_id: int,
+        event_stream_ordering: int,
+        user_id: str,
+        rule_id: str,
+        priority_class: int,
+        priority: int,
+        conditions_json: str,
+        actions_json: str,
+        update_stream: bool = True,
+    ) -> None:
         """Specialised version of simple_upsert_txn that picks a push_rule_id
         using the _push_rule_id_gen if it needs to insert the rule. It assumes
         that the "push_rules" table is locked"""
@@ -600,7 +584,11 @@ class PushRuleStore(PushRulesWorkerStore):
             rule_id: The rule_id of the rule to be deleted
         """
 
-        def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
+        def delete_push_rule_txn(
+            txn: LoggingTransaction,
+            stream_id: int,
+            event_stream_ordering: int,
+        ) -> None:
             # we don't use simple_delete_one_txn because that would fail if the
             # user did not have a push_rule_enable row.
             self.db_pool.simple_delete_txn(
@@ -661,14 +649,14 @@ class PushRuleStore(PushRulesWorkerStore):
 
     def _set_push_rule_enabled_txn(
         self,
-        txn,
-        stream_id,
-        event_stream_ordering,
-        user_id,
-        rule_id,
-        enabled,
-        is_default_rule,
-    ):
+        txn: LoggingTransaction,
+        stream_id: int,
+        event_stream_ordering: int,
+        user_id: str,
+        rule_id: str,
+        enabled: bool,
+        is_default_rule: bool,
+    ) -> None:
         new_id = self._push_rules_enable_id_gen.get_next()
 
         if not is_default_rule:
@@ -740,7 +728,11 @@ class PushRuleStore(PushRulesWorkerStore):
         """
         actions_json = json_encoder.encode(actions)
 
-        def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
+        def set_push_rule_actions_txn(
+            txn: LoggingTransaction,
+            stream_id: int,
+            event_stream_ordering: int,
+        ) -> None:
             if is_default_rule:
                 # Add a dummy rule to the rules table with the user specified
                 # actions.
@@ -794,8 +786,15 @@ class PushRuleStore(PushRulesWorkerStore):
             )
 
     def _insert_push_rules_update_txn(
-        self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
-    ):
+        self,
+        txn: LoggingTransaction,
+        stream_id: int,
+        event_stream_ordering: int,
+        user_id: str,
+        rule_id: str,
+        op: str,
+        data: Optional[JsonDict] = None,
+    ) -> None:
         values = {
             "stream_id": stream_id,
             "event_stream_ordering": event_stream_ordering,
@@ -814,5 +813,56 @@ class PushRuleStore(PushRulesWorkerStore):
             self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
         )
 
-    def get_max_push_rules_stream_id(self):
+    def get_max_push_rules_stream_id(self) -> int:
         return self._push_rules_stream_id_gen.get_current_token()
+
+    async def copy_push_rule_from_room_to_room(
+        self, new_room_id: str, user_id: str, rule: dict
+    ) -> None:
+        """Copy a single push rule from one room to another for a specific user.
+
+        Args:
+            new_room_id: ID of the new room.
+            user_id : ID of user the push rule belongs to.
+            rule: A push rule.
+        """
+        # Create new rule id
+        rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
+        new_rule_id = rule_id_scope + "/" + new_room_id
+
+        # Change room id in each condition
+        for condition in rule.get("conditions", []):
+            if condition.get("key") == "room_id":
+                condition["pattern"] = new_room_id
+
+        # Add the rule for the new room
+        await self.add_push_rule(
+            user_id=user_id,
+            rule_id=new_rule_id,
+            priority_class=rule["priority_class"],
+            conditions=rule["conditions"],
+            actions=rule["actions"],
+        )
+
+    async def copy_push_rules_from_room_to_room_for_user(
+        self, old_room_id: str, new_room_id: str, user_id: str
+    ) -> None:
+        """Copy all of the push rules from one room to another for a specific
+        user.
+
+        Args:
+            old_room_id: ID of the old room.
+            new_room_id: ID of the new room.
+            user_id: ID of user to copy push rules for.
+        """
+        # Retrieve push rules for this user
+        user_push_rules = await self.get_push_rules_for_user(user_id)
+
+        # Get rules relating to the old room and copy them to the new room
+        for rule in user_push_rules:
+            conditions = rule.get("conditions", [])
+            if any(
+                (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
+                for c in conditions
+            ):
+                await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 484976ca6b..fe8fded88b 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -34,7 +34,7 @@ from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
 logger = logging.getLogger(__name__)
@@ -161,7 +161,9 @@ class RelationsWorkerStore(SQLBaseStore):
             if len(events) > limit and last_topo_id and last_stream_id:
                 next_key = RoomStreamToken(last_topo_id, last_stream_id)
                 if from_token:
-                    next_token = from_token.copy_and_replace("room_key", next_key)
+                    next_token = from_token.copy_and_replace(
+                        StreamKeyType.ROOM, next_key
+                    )
                 else:
                     next_token = StreamToken(
                         room_key=next_key,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 87e9482c60..ded15b92ef 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -45,7 +45,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import IdGenerator
-from synapse.types import JsonDict, ThirdPartyInstanceID
+from synapse.types import JsonDict, RetentionPolicy, ThirdPartyInstanceID
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 from synapse.util.stringutils import MXC_REGEX
@@ -699,7 +699,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
 
     @cached()
-    async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]:
+    async def get_retention_policy_for_room(self, room_id: str) -> RetentionPolicy:
         """Get the retention policy for a given room.
 
         If no retention policy has been found for this room, returns a policy defined
@@ -707,12 +707,20 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         the 'max_lifetime' if no default policy has been defined in the server's
         configuration).
 
+        If support for retention policies is disabled, a policy with a 'min_lifetime' and
+        'max_lifetime' of None is returned.
+
         Args:
             room_id: The ID of the room to get the retention policy of.
 
         Returns:
             A dict containing "min_lifetime" and "max_lifetime" for this room.
         """
+        # If the room retention feature is disabled, return a policy with no minimum nor
+        # maximum. This prevents incorrectly filtering out events when sending to
+        # the client.
+        if not self.config.retention.retention_enabled:
+            return RetentionPolicy()
 
         def get_retention_policy_for_room_txn(
             txn: LoggingTransaction,
@@ -736,10 +744,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         # If we don't know this room ID, ret will be None, in this case return the default
         # policy.
         if not ret:
-            return {
-                "min_lifetime": self.config.retention.retention_default_min_lifetime,
-                "max_lifetime": self.config.retention.retention_default_max_lifetime,
-            }
+            return RetentionPolicy(
+                min_lifetime=self.config.retention.retention_default_min_lifetime,
+                max_lifetime=self.config.retention.retention_default_max_lifetime,
+            )
 
         min_lifetime = ret[0]["min_lifetime"]
         max_lifetime = ret[0]["max_lifetime"]
@@ -754,10 +762,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         if max_lifetime is None:
             max_lifetime = self.config.retention.retention_default_max_lifetime
 
-        return {
-            "min_lifetime": min_lifetime,
-            "max_lifetime": max_lifetime,
-        }
+        return RetentionPolicy(
+            min_lifetime=min_lifetime,
+            max_lifetime=max_lifetime,
+        )
 
     async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
         """Retrieves all the local and remote media MXC URIs in a given room
@@ -994,7 +1002,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
 
     async def get_rooms_for_retention_period_in_range(
         self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
-    ) -> Dict[str, Dict[str, Optional[int]]]:
+    ) -> Dict[str, RetentionPolicy]:
         """Retrieves all of the rooms within the given retention range.
 
         Optionally includes the rooms which don't have a retention policy.
@@ -1016,7 +1024,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
 
         def get_rooms_for_retention_period_in_range_txn(
             txn: LoggingTransaction,
-        ) -> Dict[str, Dict[str, Optional[int]]]:
+        ) -> Dict[str, RetentionPolicy]:
             range_conditions = []
             args = []
 
@@ -1047,10 +1055,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             rooms_dict = {}
 
             for row in rows:
-                rooms_dict[row["room_id"]] = {
-                    "min_lifetime": row["min_lifetime"],
-                    "max_lifetime": row["max_lifetime"],
-                }
+                rooms_dict[row["room_id"]] = RetentionPolicy(
+                    min_lifetime=row["min_lifetime"],
+                    max_lifetime=row["max_lifetime"],
+                )
 
             if include_null:
                 # If required, do a second query that retrieves all of the rooms we know
@@ -1065,10 +1073,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                 # policy in its state), add it with a null policy.
                 for row in rows:
                     if row["room_id"] not in rooms_dict:
-                        rooms_dict[row["room_id"]] = {
-                            "min_lifetime": None,
-                            "max_lifetime": None,
-                        }
+                        rooms_dict[row["room_id"]] = RetentionPolicy()
 
             return rooms_dict
 
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 48e83592e7..cc528fcf2d 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,6 +15,7 @@
 import logging
 from typing import (
     TYPE_CHECKING,
+    Callable,
     Collection,
     Dict,
     FrozenSet,
@@ -37,7 +38,12 @@ from synapse.metrics.background_process_metrics import (
     wrap_as_background_process,
 )
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import Sqlite3Engine
 from synapse.storage.roommember import (
@@ -46,7 +52,7 @@ from synapse.storage.roommember import (
     ProfileInfo,
     RoomsForUser,
 )
-from synapse.types import PersistedEventPosition, get_domain_from_id
+from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -115,7 +121,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             )
 
     @wrap_as_background_process("_count_known_servers")
-    async def _count_known_servers(self):
+    async def _count_known_servers(self) -> int:
         """
         Count the servers that this server knows about.
 
@@ -123,7 +129,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         `synapse_federation_known_servers` LaterGauge to collect.
         """
 
-        def _transact(txn):
+        def _transact(txn: LoggingTransaction) -> int:
             if isinstance(self.database_engine, Sqlite3Engine):
                 query = """
                     SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
@@ -150,7 +156,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         self._known_servers_count = max([count, 1])
         return self._known_servers_count
 
-    def _check_safe_current_state_events_membership_updated_txn(self, txn):
+    def _check_safe_current_state_events_membership_updated_txn(
+        self, txn: LoggingTransaction
+    ) -> None:
         """Checks if it is safe to assume the new current_state_events
         membership column is up to date
         """
@@ -182,7 +190,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             "get_users_in_room", self.get_users_in_room_txn, room_id
         )
 
-    def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
+    def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
         # If we can assume current_state_events.membership is up to date
         # then we can avoid a join, which is a Very Good Thing given how
         # frequently this function gets called.
@@ -222,7 +230,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             A mapping from user ID to ProfileInfo.
         """
 
-        def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]:
+        def _get_users_in_room_with_profiles(
+            txn: LoggingTransaction,
+        ) -> Dict[str, ProfileInfo]:
             sql = """
                 SELECT state_key, display_name, avatar_url FROM room_memberships as m
                 INNER JOIN current_state_events as c
@@ -250,7 +260,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             dict of membership states, pointing to a MemberSummary named tuple.
         """
 
-        def _get_room_summary_txn(txn):
+        def _get_room_summary_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, MemberSummary]:
             # first get counts.
             # We do this all in one transaction to keep the cache small.
             # FIXME: get rid of this when we have room_stats
@@ -279,7 +291,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 """
 
             txn.execute(sql, (room_id,))
-            res = {}
+            res: Dict[str, MemberSummary] = {}
             for count, membership in txn:
                 res.setdefault(membership, MemberSummary([], count))
 
@@ -400,7 +412,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     def _get_rooms_for_local_user_where_membership_is_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         user_id: str,
         membership_list: List[str],
     ) -> List[RoomsForUser]:
@@ -488,7 +500,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     def _get_rooms_for_user_with_stream_ordering_txn(
-        self, txn, user_id: str
+        self, txn: LoggingTransaction, user_id: str
     ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
         # We use `current_state_events` here and not `local_current_membership`
         # as a) this gets called with remote users and b) this only gets called
@@ -542,7 +554,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     def _get_rooms_for_users_with_stream_ordering_txn(
-        self, txn, user_ids: Collection[str]
+        self, txn: LoggingTransaction, user_ids: Collection[str]
     ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
 
         clause, args = make_in_list_sql_clause(
@@ -575,7 +587,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         txn.execute(sql, [Membership.JOIN] + args)
 
-        result = {user_id: set() for user_id in user_ids}
+        result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = {
+            user_id: set() for user_id in user_ids
+        }
         for user_id, room_id, instance, stream_id in txn:
             result[user_id].add(
                 GetRoomsForUserWithStreamOrdering(
@@ -595,7 +609,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         if not user_ids:
             return set()
 
-        def _get_users_server_still_shares_room_with_txn(txn):
+        def _get_users_server_still_shares_room_with_txn(
+            txn: LoggingTransaction,
+        ) -> Set[str]:
             sql = """
                 SELECT state_key FROM current_state_events
                 WHERE
@@ -619,7 +635,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     async def get_rooms_for_user(
-        self, user_id: str, on_invalidate=None
+        self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
     ) -> FrozenSet[str]:
         """Returns a set of room_ids the user is currently joined to.
 
@@ -657,7 +673,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     async def get_joined_users_from_context(
         self, event: EventBase, context: EventContext
     ) -> Dict[str, ProfileInfo]:
-        state_group = context.state_group
+        state_group: Union[object, int] = context.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -666,14 +682,16 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             state_group = object()
 
         current_state_ids = await context.get_current_state_ids()
+        assert current_state_ids is not None
+        assert state_group is not None
         return await self._get_joined_users_from_context(
             event.room_id, state_group, current_state_ids, event=event, context=context
         )
 
     async def get_joined_users_from_state(
-        self, room_id, state_entry
+        self, room_id: str, state_entry: "_StateCacheEntry"
     ) -> Dict[str, ProfileInfo]:
-        state_group = state_entry.state_group
+        state_group: Union[object, int] = state_entry.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -681,6 +699,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
+        assert state_group is not None
         with Measure(self._clock, "get_joined_users_from_state"):
             return await self._get_joined_users_from_context(
                 room_id, state_group, state_entry.state, context=state_entry
@@ -689,12 +708,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
     async def _get_joined_users_from_context(
         self,
-        room_id,
-        state_group,
-        current_state_ids,
-        cache_context,
-        event=None,
-        context=None,
+        room_id: str,
+        state_group: Union[object, int],
+        current_state_ids: StateMap[str],
+        cache_context: _CacheContext,
+        event: Optional[EventBase] = None,
+        context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
     ) -> Dict[str, ProfileInfo]:
         # We don't use `state_group`, it's there so that we can cache based
         # on it. However, it's important that it's never None, since two current_states
@@ -765,14 +784,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return users_in_room
 
     @cached(max_entries=10000)
-    def _get_joined_profile_from_event_id(self, event_id):
+    def _get_joined_profile_from_event_id(
+        self, event_id: str
+    ) -> Optional[Tuple[str, ProfileInfo]]:
         raise NotImplementedError()
 
     @cachedList(
         cached_method_name="_get_joined_profile_from_event_id",
         list_name="event_ids",
     )
-    async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
+    async def _get_joined_profiles_from_event_ids(
+        self, event_ids: Iterable[str]
+    ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
         """For given set of member event_ids check if they point to a join
         event and if so return the associated user and profile info.
 
@@ -780,8 +803,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             event_ids: The member event IDs to lookup
 
         Returns:
-            dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
-            to `user_id` and ProfileInfo (or None if not join event).
+            Map from event ID to `user_id` and ProfileInfo (or None if not join event).
         """
 
         rows = await self.db_pool.simple_select_many_batch(
@@ -847,8 +869,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return True
 
-    async def get_joined_hosts(self, room_id: str, state_entry):
-        state_group = state_entry.state_group
+    async def get_joined_hosts(
+        self, room_id: str, state_entry: "_StateCacheEntry"
+    ) -> FrozenSet[str]:
+        state_group: Union[object, int] = state_entry.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -856,6 +880,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
+        assert state_group is not None
         with Measure(self._clock, "get_joined_hosts"):
             return await self._get_joined_hosts(
                 room_id, state_group, state_entry=state_entry
@@ -863,7 +888,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     @cached(num_args=2, max_entries=10000, iterable=True)
     async def _get_joined_hosts(
-        self, room_id: str, state_group: int, state_entry: "_StateCacheEntry"
+        self,
+        room_id: str,
+        state_group: Union[object, int],
+        state_entry: "_StateCacheEntry",
     ) -> FrozenSet[str]:
         # We don't use `state_group`, it's there so that we can cache based on
         # it. However, its important that its never None, since two
@@ -881,7 +909,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # `get_joined_hosts` is called with the "current" state group for the
         # room, and so consecutive calls will be for consecutive state groups
         # which point to the previous state group.
-        cache = await self._get_joined_hosts_cache(room_id)
+        cache = await self._get_joined_hosts_cache(room_id)  # type: ignore[misc]
 
         # If the state group in the cache matches, we already have the data we need.
         if state_entry.state_group == cache.state_group:
@@ -897,6 +925,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             elif state_entry.prev_group == cache.state_group:
                 # The cached work is for the previous state group, so we work out
                 # the delta.
+                assert state_entry.delta_ids is not None
                 for (typ, state_key), event_id in state_entry.delta_ids.items():
                     if typ != EventTypes.Member:
                         continue
@@ -942,7 +971,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         Returns False if they have since re-joined."""
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> int:
             sql = (
                 "SELECT"
                 "  COUNT(*)"
@@ -973,7 +1002,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             The forgotten rooms.
         """
 
-        def _get_forgotten_rooms_for_user_txn(txn):
+        def _get_forgotten_rooms_for_user_txn(txn: LoggingTransaction) -> Set[str]:
             # This is a slightly convoluted query that first looks up all rooms
             # that the user has forgotten in the past, then rechecks that list
             # to see if any have subsequently been updated. This is done so that
@@ -1076,7 +1105,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             clause,
         )
 
-        def _is_local_host_in_room_ignoring_users_txn(txn):
+        def _is_local_host_in_room_ignoring_users_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
             txn.execute(sql, (room_id, Membership.JOIN, *args))
 
             return bool(txn.fetchone())
@@ -1110,15 +1141,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
             where_clause="forgotten = 1",
         )
 
-    async def _background_add_membership_profile(self, progress, batch_size):
+    async def _background_add_membership_profile(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         target_min_stream_id = progress.get(
-            "target_min_stream_id_inclusive", self._min_stream_order_on_start
+            "target_min_stream_id_inclusive", self._min_stream_order_on_start  # type: ignore[attr-defined]
         )
         max_stream_id = progress.get(
-            "max_stream_id_exclusive", self._stream_order_on_start + 1
+            "max_stream_id_exclusive", self._stream_order_on_start + 1  # type: ignore[attr-defined]
         )
 
-        def add_membership_profile_txn(txn):
+        def add_membership_profile_txn(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT stream_ordering, event_id, events.room_id, event_json.json
                 FROM events
@@ -1182,13 +1215,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
 
         return result
 
-    async def _background_current_state_membership(self, progress, batch_size):
+    async def _background_current_state_membership(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Update the new membership column on current_state_events.
 
         This works by iterating over all rooms in alphebetical order.
         """
 
-        def _background_current_state_membership_txn(txn, last_processed_room):
+        def _background_current_state_membership_txn(
+            txn: LoggingTransaction, last_processed_room: str
+        ) -> Tuple[int, bool]:
             processed = 0
             while processed < batch_size:
                 txn.execute(
@@ -1242,7 +1279,11 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
         return row_count
 
 
-class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
+class RoomMemberStore(
+    RoomMemberWorkerStore,
+    RoomMemberBackgroundUpdateStore,
+    CacheInvalidationWorkerStore,
+):
     def __init__(
         self,
         database: DatabasePool,
@@ -1254,7 +1295,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
     async def forget(self, user_id: str, room_id: str) -> None:
         """Indicate that user_id wishes to discard history for room_id."""
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> None:
             sql = (
                 "UPDATE"
                 "  room_memberships"
@@ -1288,5 +1329,5 @@ class _JoinedHostsCache:
     # equal to anything else).
     state_group: Union[object, int] = attr.Factory(object)
 
-    def __len__(self):
+    def __len__(self) -> int:
         return sum(len(v) for v in self.hosts_to_joined_users.values())
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 67d3bb2b4b..7aa7126b69 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -14,7 +14,7 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
 
 import attr
 
@@ -27,7 +27,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.types import JsonDict
 
 if TYPE_CHECKING:
@@ -149,7 +149,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
             self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings
         )
 
-    async def _background_reindex_search(self, progress, batch_size):
+    async def _background_reindex_search(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         # we work through the events table from highest stream id to lowest
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
@@ -157,7 +159,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
 
-        def reindex_search_txn(txn):
+        def reindex_search_txn(txn: LoggingTransaction) -> int:
             sql = (
                 "SELECT stream_ordering, event_id, room_id, type, json, "
                 " origin_server_ts FROM events"
@@ -255,12 +257,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         return result
 
-    async def _background_reindex_gin_search(self, progress, batch_size):
+    async def _background_reindex_gin_search(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """This handles old synapses which used GIST indexes, if any;
         converting them back to be GIN as per the actual schema.
         """
 
-        def create_index(conn):
+        def create_index(conn: LoggingDatabaseConnection) -> None:
             conn.rollback()
 
             # we have to set autocommit, because postgres refuses to
@@ -299,7 +303,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
         )
         return 1
 
-    async def _background_reindex_search_order(self, progress, batch_size):
+    async def _background_reindex_search_order(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
@@ -307,7 +313,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         if not have_added_index:
 
-            def create_index(conn):
+            def create_index(conn: LoggingDatabaseConnection) -> None:
                 conn.rollback()
                 conn.set_session(autocommit=True)
                 c = conn.cursor()
@@ -336,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
                 pg,
             )
 
-        def reindex_search_txn(txn):
+        def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
             sql = (
                 "UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
                 " origin_server_ts = e.origin_server_ts"
@@ -644,7 +650,8 @@ class SearchStore(SearchBackgroundUpdateStore):
         else:
             raise Exception("Unrecognized database engine")
 
-        args.append(limit)
+        # mypy expects to append only a `str`, not an `int`
+        args.append(limit)  # type: ignore[arg-type]
 
         results = await self.db_pool.execute(
             "search_rooms", self.db_pool.cursor_to_dict, sql, *args
@@ -705,7 +712,7 @@ class SearchStore(SearchBackgroundUpdateStore):
             A set of strings.
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> Set[str]:
             highlight_words = set()
             for event in events:
                 # As a hack we simply join values of all possible keys. This is
@@ -759,11 +766,11 @@ class SearchStore(SearchBackgroundUpdateStore):
         return await self.db_pool.runInteraction("_find_highlights", f)
 
 
-def _to_postgres_options(options_dict):
+def _to_postgres_options(options_dict: JsonDict) -> str:
     return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
 
 
-def _parse_query(database_engine, search_term):
+def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
     """Takes a plain unicode string from the user and converts it into a form
     that can be passed to database.
     We use this so that we can add prefix matching, which isn't something
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 0373af86c8..0e3a23a140 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -788,30 +788,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return None
 
     async def get_current_room_stream_token_for_room_id(
-        self, room_id: Optional[str] = None
+        self, room_id: str
     ) -> RoomStreamToken:
-        """Returns the current position of the rooms stream.
-
-        By default, it returns a live token with the current global stream
-        token. Specifying a `room_id` causes it to return a historic token with
-        the room specific topological token.
-        """
+        """Returns the current position of the rooms stream (historic token)."""
         stream_ordering = self.get_room_max_stream_ordering()
-        if room_id is None:
-            return RoomStreamToken(None, stream_ordering)
-        else:
-            topo = await self.db_pool.runInteraction(
-                "_get_max_topological_txn", self._get_max_topological_txn, room_id
-            )
-            return RoomStreamToken(topo, stream_ordering)
+        topo = await self.db_pool.runInteraction(
+            "_get_max_topological_txn", self._get_max_topological_txn, room_id
+        )
+        return RoomStreamToken(topo, stream_ordering)
 
     def get_stream_id_for_event_txn(
         self,
         txn: LoggingTransaction,
         event_id: str,
-        allow_none=False,
-    ) -> int:
-        return self.db_pool.simple_select_one_onecol_txn(
+        allow_none: bool = False,
+    ) -> Optional[int]:
+        # Type ignore: we pass keyvalues a Dict[str, str]; the function wants
+        # Dict[str, Any]. I think mypy is unhappy because Dict is invariant?
+        return self.db_pool.simple_select_one_onecol_txn(  # type: ignore[call-overload]
             txn=txn,
             table="events",
             keyvalues={"event_id": event_id},
@@ -873,7 +867,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
         rows = txn.fetchall()
-        return rows[0][0] if rows else 0
+        # An aggregate function like MAX() will always return one row per group
+        # so we can safely rely on the lookup here. For example, when a we
+        # lookup a `room_id` which does not exist, `rows` will look like
+        # `[(None,)]`
+        return rows[0][0] if rows[0][0] is not None else 0
 
     @staticmethod
     def _set_before_and_after(
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 5de70f31d2..fa9eadaca7 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -195,6 +195,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
     STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
     STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
     STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
+    STATE_GROUP_EDGES_UNIQUE_INDEX_UPDATE_NAME = "state_group_edges_unique_idx"
 
     def __init__(
         self,
@@ -217,6 +218,21 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
             columns=["room_id"],
         )
 
+        # `state_group_edges` can cause severe performance issues if duplicate
+        # rows are introduced, which can accidentally be done by well-meaning
+        # server admins when trying to restore a database dump, etc.
+        # See https://github.com/matrix-org/synapse/issues/11779.
+        # Introduce a unique index to guard against that.
+        self.db_pool.updates.register_background_index_update(
+            self.STATE_GROUP_EDGES_UNIQUE_INDEX_UPDATE_NAME,
+            index_name="state_group_edges_unique_idx",
+            table="state_group_edges",
+            columns=["state_group", "prev_state_group"],
+            unique=True,
+            # The old index was on (state_group) and was not unique.
+            replaces_index="state_group_edges_idx",
+        )
+
     async def _background_deduplicate_state(
         self, progress: dict, batch_size: int
     ) -> int:
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 7614d76ac6..609a2b88bf 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -189,7 +189,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         group: int,
         state_filter: StateFilter,
     ) -> Tuple[MutableStateMap[str], bool]:
-        """Checks if group is in cache. See `_get_state_for_groups`
+        """Checks if group is in cache. See `get_state_for_groups`
 
         Args:
             cache: the state group cache to use
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index afb7d5054d..f51b3d228e 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -11,25 +11,21 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any, Mapping
 
 from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
 from .postgres import PostgresEngine
 from .sqlite import Sqlite3Engine
 
 
-def create_engine(database_config) -> BaseDatabaseEngine:
+def create_engine(database_config: Mapping[str, Any]) -> BaseDatabaseEngine:
     name = database_config["name"]
 
     if name == "sqlite3":
-        import sqlite3
-
-        return Sqlite3Engine(sqlite3, database_config)
+        return Sqlite3Engine(database_config)
 
     if name == "psycopg2":
-        # Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
-        import psycopg2
-
-        return PostgresEngine(psycopg2, database_config)
+        return PostgresEngine(database_config)
 
     raise RuntimeError("Unsupported database engine '%s'" % (name,))
 
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 143cd98ca2..971ff82693 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -13,9 +13,12 @@
 # limitations under the License.
 import abc
 from enum import IntEnum
-from typing import Generic, Optional, TypeVar
+from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, TypeVar
 
-from synapse.storage.types import Connection
+from synapse.storage.types import Connection, Cursor, DBAPI2Module
+
+if TYPE_CHECKING:
+    from synapse.storage.database import LoggingDatabaseConnection
 
 
 class IsolationLevel(IntEnum):
@@ -32,7 +35,7 @@ ConnectionType = TypeVar("ConnectionType", bound=Connection)
 
 
 class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
-    def __init__(self, module, database_config: dict):
+    def __init__(self, module: DBAPI2Module, config: Mapping[str, Any]):
         self.module = module
 
     @property
@@ -69,7 +72,7 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
         ...
 
     @abc.abstractmethod
-    def check_new_database(self, txn) -> None:
+    def check_new_database(self, txn: Cursor) -> None:
         """Gets called when setting up a brand new database. This allows us to
         apply stricter checks on new databases versus existing database.
         """
@@ -79,8 +82,11 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
     def convert_param_style(self, sql: str) -> str:
         ...
 
+    # This method would ideally take a plain ConnectionType, but it seems that
+    # the Sqlite engine expects to use LoggingDatabaseConnection.cursor
+    # instead of sqlite3.Connection.cursor: only the former takes a txn_name.
     @abc.abstractmethod
-    def on_new_connection(self, db_conn: ConnectionType) -> None:
+    def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
         ...
 
     @abc.abstractmethod
@@ -92,7 +98,7 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
         ...
 
     @abc.abstractmethod
-    def lock_table(self, txn, table: str) -> None:
+    def lock_table(self, txn: Cursor, table: str) -> None:
         ...
 
     @property
@@ -102,12 +108,12 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
         ...
 
     @abc.abstractmethod
-    def in_transaction(self, conn: Connection) -> bool:
+    def in_transaction(self, conn: ConnectionType) -> bool:
         """Whether the connection is currently in a transaction."""
         ...
 
     @abc.abstractmethod
-    def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
+    def attempt_to_set_autocommit(self, conn: ConnectionType, autocommit: bool) -> None:
         """Attempt to set the connections autocommit mode.
 
         When True queries are run outside of transactions.
@@ -119,8 +125,8 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
 
     @abc.abstractmethod
     def attempt_to_set_isolation_level(
-        self, conn: Connection, isolation_level: Optional[int]
-    ):
+        self, conn: ConnectionType, isolation_level: Optional[int]
+    ) -> None:
         """Attempt to set the connections isolation level.
 
         Note: This has no effect on SQLite3, as transactions are SERIALIZABLE by default.
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index e8d29e2870..391f8ed24a 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -13,39 +13,47 @@
 # limitations under the License.
 
 import logging
-from typing import Mapping, Optional
+from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast
 
 from synapse.storage.engines._base import (
     BaseDatabaseEngine,
     IncorrectDatabaseSetup,
     IsolationLevel,
 )
-from synapse.storage.types import Connection
+from synapse.storage.types import Cursor
+
+if TYPE_CHECKING:
+    import psycopg2  # noqa: F401
+
+    from synapse.storage.database import LoggingDatabaseConnection
+
 
 logger = logging.getLogger(__name__)
 
 
-class PostgresEngine(BaseDatabaseEngine):
-    def __init__(self, database_module, database_config):
-        super().__init__(database_module, database_config)
-        self.module.extensions.register_type(self.module.extensions.UNICODE)
+class PostgresEngine(BaseDatabaseEngine["psycopg2.connection"]):
+    def __init__(self, database_config: Mapping[str, Any]):
+        import psycopg2.extensions
+
+        super().__init__(psycopg2, database_config)
+        psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
 
         # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
         # actually want to use bytes than wrap it in `bytearray`.
-        def _disable_bytes_adapter(_):
+        def _disable_bytes_adapter(_: bytes) -> NoReturn:
             raise Exception("Passing bytes to DB is disabled.")
 
-        self.module.extensions.register_adapter(bytes, _disable_bytes_adapter)
-        self.synchronous_commit = database_config.get("synchronous_commit", True)
-        self._version = None  # unknown as yet
+        psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter)
+        self.synchronous_commit: bool = database_config.get("synchronous_commit", True)
+        self._version: Optional[int] = None  # unknown as yet
 
         self.isolation_level_map: Mapping[int, int] = {
-            IsolationLevel.READ_COMMITTED: self.module.extensions.ISOLATION_LEVEL_READ_COMMITTED,
-            IsolationLevel.REPEATABLE_READ: self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
-            IsolationLevel.SERIALIZABLE: self.module.extensions.ISOLATION_LEVEL_SERIALIZABLE,
+            IsolationLevel.READ_COMMITTED: psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED,
+            IsolationLevel.REPEATABLE_READ: psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
+            IsolationLevel.SERIALIZABLE: psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE,
         }
         self.default_isolation_level = (
-            self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
+            psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ
         )
         self.config = database_config
 
@@ -53,19 +61,21 @@ class PostgresEngine(BaseDatabaseEngine):
     def single_threaded(self) -> bool:
         return False
 
-    def get_db_locale(self, txn):
+    def get_db_locale(self, txn: Cursor) -> Tuple[str, str]:
         txn.execute(
             "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
         )
-        collation, ctype = txn.fetchone()
+        collation, ctype = cast(Tuple[str, str], txn.fetchone())
         return collation, ctype
 
-    def check_database(self, db_conn, allow_outdated_version: bool = False):
+    def check_database(
+        self, db_conn: "psycopg2.connection", allow_outdated_version: bool = False
+    ) -> None:
         # Get the version of PostgreSQL that we're using. As per the psycopg2
         # docs: The number is formed by converting the major, minor, and
         # revision numbers into two-decimal-digit numbers and appending them
         # together. For example, version 8.1.5 will be returned as 80105
-        self._version = db_conn.server_version
+        self._version = cast(int, db_conn.server_version)
         allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
 
         # Are we on a supported PostgreSQL version?
@@ -108,7 +118,7 @@ class PostgresEngine(BaseDatabaseEngine):
                         ctype,
                     )
 
-    def check_new_database(self, txn):
+    def check_new_database(self, txn: Cursor) -> None:
         """Gets called when setting up a brand new database. This allows us to
         apply stricter checks on new databases versus existing database.
         """
@@ -129,10 +139,10 @@ class PostgresEngine(BaseDatabaseEngine):
                 "See docs/postgres.md for more information." % ("\n".join(errors))
             )
 
-    def convert_param_style(self, sql):
+    def convert_param_style(self, sql: str) -> str:
         return sql.replace("?", "%s")
 
-    def on_new_connection(self, db_conn):
+    def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
         db_conn.set_isolation_level(self.default_isolation_level)
 
         # Set the bytea output to escape, vs the default of hex
@@ -149,14 +159,14 @@ class PostgresEngine(BaseDatabaseEngine):
         db_conn.commit()
 
     @property
-    def can_native_upsert(self):
+    def can_native_upsert(self) -> bool:
         """
         Can we use native UPSERTs?
         """
         return True
 
     @property
-    def supports_using_any_list(self):
+    def supports_using_any_list(self) -> bool:
         """Do we support using `a = ANY(?)` and passing a list"""
         return True
 
@@ -165,27 +175,25 @@ class PostgresEngine(BaseDatabaseEngine):
         """Do we support the `RETURNING` clause in insert/update/delete?"""
         return True
 
-    def is_deadlock(self, error):
-        if isinstance(error, self.module.DatabaseError):
+    def is_deadlock(self, error: Exception) -> bool:
+        import psycopg2.extensions
+
+        if isinstance(error, psycopg2.DatabaseError):
             # https://www.postgresql.org/docs/current/static/errcodes-appendix.html
             # "40001" serialization_failure
             # "40P01" deadlock_detected
             return error.pgcode in ["40001", "40P01"]
         return False
 
-    def is_connection_closed(self, conn):
+    def is_connection_closed(self, conn: "psycopg2.connection") -> bool:
         return bool(conn.closed)
 
-    def lock_table(self, txn, table):
+    def lock_table(self, txn: Cursor, table: str) -> None:
         txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
 
     @property
-    def server_version(self):
-        """Returns a string giving the server version. For example: '8.1.5'
-
-        Returns:
-            string
-        """
+    def server_version(self) -> str:
+        """Returns a string giving the server version. For example: '8.1.5'."""
         # note that this is a bit of a hack because it relies on check_database
         # having been called. Still, that should be a safe bet here.
         numver = self._version
@@ -197,17 +205,21 @@ class PostgresEngine(BaseDatabaseEngine):
         else:
             return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
 
-    def in_transaction(self, conn: Connection) -> bool:
-        return conn.status != self.module.extensions.STATUS_READY  # type: ignore
+    def in_transaction(self, conn: "psycopg2.connection") -> bool:
+        import psycopg2.extensions
+
+        return conn.status != psycopg2.extensions.STATUS_READY
 
-    def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
-        return conn.set_session(autocommit=autocommit)  # type: ignore
+    def attempt_to_set_autocommit(
+        self, conn: "psycopg2.connection", autocommit: bool
+    ) -> None:
+        return conn.set_session(autocommit=autocommit)
 
     def attempt_to_set_isolation_level(
-        self, conn: Connection, isolation_level: Optional[int]
-    ):
+        self, conn: "psycopg2.connection", isolation_level: Optional[int]
+    ) -> None:
         if isolation_level is None:
             isolation_level = self.default_isolation_level
         else:
             isolation_level = self.isolation_level_map[isolation_level]
-        return conn.set_isolation_level(isolation_level)  # type: ignore
+        return conn.set_isolation_level(isolation_level)
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 6c19e55999..621f2c5efe 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -12,21 +12,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import platform
+import sqlite3
 import struct
 import threading
-import typing
-from typing import Optional
+from typing import TYPE_CHECKING, Any, List, Mapping, Optional
 
 from synapse.storage.engines import BaseDatabaseEngine
-from synapse.storage.types import Connection
+from synapse.storage.types import Cursor
 
-if typing.TYPE_CHECKING:
-    import sqlite3  # noqa: F401
+if TYPE_CHECKING:
+    from synapse.storage.database import LoggingDatabaseConnection
 
 
-class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
-    def __init__(self, database_module, database_config):
-        super().__init__(database_module, database_config)
+class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
+    def __init__(self, database_config: Mapping[str, Any]):
+        super().__init__(sqlite3, database_config)
 
         database = database_config.get("args", {}).get("database")
         self._is_in_memory = database in (
@@ -37,7 +37,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
         if platform.python_implementation() == "PyPy":
             # pypy's sqlite3 module doesn't handle bytearrays, convert them
             # back to bytes.
-            database_module.register_adapter(bytearray, lambda array: bytes(array))
+            sqlite3.register_adapter(bytearray, lambda array: bytes(array))
 
         # The current max state_group, or None if we haven't looked
         # in the DB yet.
@@ -49,41 +49,43 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
         return True
 
     @property
-    def can_native_upsert(self):
+    def can_native_upsert(self) -> bool:
         """
         Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
         more work we haven't done yet to tell what was inserted vs updated.
         """
-        return self.module.sqlite_version_info >= (3, 24, 0)
+        return sqlite3.sqlite_version_info >= (3, 24, 0)
 
     @property
-    def supports_using_any_list(self):
+    def supports_using_any_list(self) -> bool:
         """Do we support using `a = ANY(?)` and passing a list"""
         return False
 
     @property
     def supports_returning(self) -> bool:
         """Do we support the `RETURNING` clause in insert/update/delete?"""
-        return self.module.sqlite_version_info >= (3, 35, 0)
+        return sqlite3.sqlite_version_info >= (3, 35, 0)
 
-    def check_database(self, db_conn, allow_outdated_version: bool = False):
+    def check_database(
+        self, db_conn: sqlite3.Connection, allow_outdated_version: bool = False
+    ) -> None:
         if not allow_outdated_version:
-            version = self.module.sqlite_version_info
+            version = sqlite3.sqlite_version_info
             # Synapse is untested against older SQLite versions, and we don't want
             # to let users upgrade to a version of Synapse with broken support for their
             # sqlite version, because it risks leaving them with a half-upgraded db.
             if version < (3, 22, 0):
                 raise RuntimeError("Synapse requires sqlite 3.22 or above.")
 
-    def check_new_database(self, txn):
+    def check_new_database(self, txn: Cursor) -> None:
         """Gets called when setting up a brand new database. This allows us to
         apply stricter checks on new databases versus existing database.
         """
 
-    def convert_param_style(self, sql):
+    def convert_param_style(self, sql: str) -> str:
         return sql
 
-    def on_new_connection(self, db_conn):
+    def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
         # We need to import here to avoid an import loop.
         from synapse.storage.prepare_database import prepare_database
 
@@ -97,48 +99,46 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
         db_conn.execute("PRAGMA foreign_keys = ON;")
         db_conn.commit()
 
-    def is_deadlock(self, error):
+    def is_deadlock(self, error: Exception) -> bool:
         return False
 
-    def is_connection_closed(self, conn):
+    def is_connection_closed(self, conn: sqlite3.Connection) -> bool:
         return False
 
-    def lock_table(self, txn, table):
+    def lock_table(self, txn: Cursor, table: str) -> None:
         return
 
     @property
-    def server_version(self):
-        """Gets a string giving the server version. For example: '3.22.0'
+    def server_version(self) -> str:
+        """Gets a string giving the server version. For example: '3.22.0'."""
+        return "%i.%i.%i" % sqlite3.sqlite_version_info
 
-        Returns:
-            string
-        """
-        return "%i.%i.%i" % self.module.sqlite_version_info
-
-    def in_transaction(self, conn: Connection) -> bool:
-        return conn.in_transaction  # type: ignore
+    def in_transaction(self, conn: sqlite3.Connection) -> bool:
+        return conn.in_transaction
 
-    def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
+    def attempt_to_set_autocommit(
+        self, conn: sqlite3.Connection, autocommit: bool
+    ) -> None:
         # Twisted doesn't let us set attributes on the connections, so we can't
         # set the connection to autocommit mode.
         pass
 
     def attempt_to_set_isolation_level(
-        self, conn: Connection, isolation_level: Optional[int]
-    ):
-        # All transactions are SERIALIZABLE by default in sqllite
+        self, conn: sqlite3.Connection, isolation_level: Optional[int]
+    ) -> None:
+        # All transactions are SERIALIZABLE by default in sqlite
         pass
 
 
 # Following functions taken from: https://github.com/coleifer/peewee
 
 
-def _parse_match_info(buf):
+def _parse_match_info(buf: bytes) -> List[int]:
     bufsize = len(buf)
     return [struct.unpack("@I", buf[i : i + 4])[0] for i in range(0, bufsize, 4)]
 
 
-def _rank(raw_match_info):
+def _rank(raw_match_info: bytes) -> float:
     """Handle match_info called w/default args 'pcx' - based on the example rank
     function http://sqlite.org/fts3.html#appendix_a
     """
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 97118045a1..0fc282866b 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -25,6 +25,7 @@ from typing import (
     Collection,
     Deque,
     Dict,
+    Generator,
     Generic,
     Iterable,
     List,
@@ -207,7 +208,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
 
         return res
 
-    def _handle_queue(self, room_id):
+    def _handle_queue(self, room_id: str) -> None:
         """Attempts to handle the queue for a room if not already being handled.
 
         The queue's callback will be invoked with for each item in the queue,
@@ -227,7 +228,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
 
         self._currently_persisting_rooms.add(room_id)
 
-        async def handle_queue_loop():
+        async def handle_queue_loop() -> None:
             try:
                 queue = self._get_drainining_queue(room_id)
                 for item in queue:
@@ -250,15 +251,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
                         with PreserveLoggingContext():
                             item.deferred.callback(ret)
             finally:
-                queue = self._event_persist_queues.pop(room_id, None)
-                if queue:
-                    self._event_persist_queues[room_id] = queue
+                remaining_queue = self._event_persist_queues.pop(room_id, None)
+                if remaining_queue:
+                    self._event_persist_queues[room_id] = remaining_queue
                 self._currently_persisting_rooms.discard(room_id)
 
         # set handle_queue_loop off in the background
         run_as_background_process("persist_events", handle_queue_loop)
 
-    def _get_drainining_queue(self, room_id):
+    def _get_drainining_queue(
+        self, room_id: str
+    ) -> Generator[_EventPersistQueueItem, None, None]:
         queue = self._event_persist_queues.setdefault(room_id, deque())
 
         try:
@@ -317,7 +320,9 @@ class EventsPersistenceStorage:
         for event, ctx in events_and_contexts:
             partitioned.setdefault(event.room_id, []).append((event, ctx))
 
-        async def enqueue(item):
+        async def enqueue(
+            item: Tuple[str, List[Tuple[EventBase, EventContext]]]
+        ) -> Dict[str, str]:
             room_id, evs_ctxs = item
             return await self._event_persist_queue.add_to_queue(
                 room_id, evs_ctxs, backfilled=backfilled
@@ -487,12 +492,6 @@ class EventsPersistenceStorage:
             # extremities in each room
             new_forward_extremities: Dict[str, Set[str]] = {}
 
-            # map room_id->(type,state_key)->event_id tracking the full
-            # state in each room after adding these events.
-            # This is simply used to prefill the get_current_state_ids
-            # cache
-            current_state_for_room: Dict[str, StateMap[str]] = {}
-
             # map room_id->(to_delete, to_insert) where to_delete is a list
             # of type/state keys to remove from current state, and to_insert
             # is a map (type,key)->event_id giving the state delta in each
@@ -628,14 +627,8 @@ class EventsPersistenceStorage:
 
                             state_delta_for_room[room_id] = delta
 
-                        # If we have the current_state then lets prefill
-                        # the cache with it.
-                        if current_state is not None:
-                            current_state_for_room[room_id] = current_state
-
             await self.persist_events_store._persist_events_and_state_updates(
                 chunk,
-                current_state_for_room=current_state_for_room,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremities=new_forward_extremities,
                 use_negative_stream_ordering=backfilled,
@@ -733,7 +726,8 @@ class EventsPersistenceStorage:
 
             The first state map is the full new current state and the second
             is the delta to the existing current state. If both are None then
-            there has been no change.
+            there has been no change. Either or neither can be None if there
+            has been a change.
 
             The function may prune some old entries from the set of new
             forward extremities if it's safe to do so.
@@ -743,9 +737,6 @@ class EventsPersistenceStorage:
             the new current state is only returned if we've already calculated
             it.
         """
-        # map from state_group to ((type, key) -> event_id) state map
-        state_groups_map = {}
-
         # Map from (prev state group, new state group) -> delta state dict
         state_group_deltas = {}
 
@@ -759,16 +750,6 @@ class EventsPersistenceStorage:
                     )
                 continue
 
-            if ctx.state_group in state_groups_map:
-                continue
-
-            # We're only interested in pulling out state that has already
-            # been cached in the context. We'll pull stuff out of the DB later
-            # if necessary.
-            current_state_ids = ctx.get_cached_current_state_ids()
-            if current_state_ids is not None:
-                state_groups_map[ctx.state_group] = current_state_ids
-
             if ctx.prev_group:
                 state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
 
@@ -826,18 +807,14 @@ class EventsPersistenceStorage:
             delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
             if delta_ids is not None:
                 # We have a delta from the existing to new current state,
-                # so lets just return that. If we happen to already have
-                # the current state in memory then lets also return that,
-                # but it doesn't matter if we don't.
-                new_state = state_groups_map.get(new_state_group)
-                return new_state, delta_ids, new_latest_event_ids
+                # so lets just return that.
+                return None, delta_ids, new_latest_event_ids
 
         # Now that we have calculated new_state_groups we need to get
         # their state IDs so we can resolve to a single state set.
-        missing_state = new_state_groups - set(state_groups_map)
-        if missing_state:
-            group_to_state = await self.state_store._get_state_for_groups(missing_state)
-            state_groups_map.update(group_to_state)
+        state_groups_map = await self.state_store._get_state_for_groups(
+            new_state_groups
+        )
 
         if len(new_state_groups) == 1:
             # If there is only one state group, then we know what the current
@@ -1130,7 +1107,7 @@ class EventsPersistenceStorage:
 
         return False
 
-    async def _handle_potentially_left_users(self, user_ids: Set[str]):
+    async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None:
         """Given a set of remote users check if the server still shares a room with
         them. If not then mark those users' device cache as stale.
         """
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 546d6bae6e..c33df42084 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -85,7 +85,7 @@ def prepare_database(
     database_engine: BaseDatabaseEngine,
     config: Optional[HomeServerConfig],
     databases: Collection[str] = ("main", "state"),
-):
+) -> None:
     """Prepares a physical database for usage. Will either create all necessary tables
     or upgrade from an older schema version.
 
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 871d4ace12..da98f05e03 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-SCHEMA_VERSION = 69  # remember to update the list below when updating
+SCHEMA_VERSION = 70  # remember to update the list below when updating
 """Represents the expectations made by the codebase about the database schema
 
 This should be incremented whenever the codebase changes its requirements on the
@@ -61,13 +61,19 @@ Changes in SCHEMA_VERSION = 68:
 
 Changes in SCHEMA_VERSION = 69:
     - We now write to `device_lists_changes_in_room` table.
-    - Use sequence to generate future `application_services_txns.txn_id`s
+    - We now use a PostgreSQL sequence to generate future txn_ids for
+      `application_services_txns`. `application_services_state.last_txn` is no longer
+      updated.
+
+Changes in SCHEMA_VERSION = 70:
+    - event_reference_hashes is no longer written to.
 """
 
 
 SCHEMA_COMPAT_VERSION = (
     # We now assume that `device_lists_changes_in_room` has been filled out for
     # recent device_list_updates.
+    # ... and that `application_services_state.last_txn` is not used.
     69
 )
 """Limit on how far the synapse codebase can be rolled back without breaking db compat
diff --git a/synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql b/synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql
new file mode 100644
index 0000000000..22ae3b8c00
--- /dev/null
+++ b/synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql
@@ -0,0 +1,18 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Background update to clear the inboxes of hidden and deleted devices.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (6902, 'cache_invalidation_index_by_instance', '{}');
diff --git a/synapse/storage/schema/state/delta/70/08_state_group_edges_unique.sql b/synapse/storage/schema/state/delta/70/08_state_group_edges_unique.sql
new file mode 100644
index 0000000000..b8c0ee0fa0
--- /dev/null
+++ b/synapse/storage/schema/state/delta/70/08_state_group_edges_unique.sql
@@ -0,0 +1,17 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+    (7008, 'state_group_edges_unique_idx', '{}');
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index d1d5859214..ab630953ac 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -1,4 +1,5 @@
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2022 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,6 +16,7 @@ import logging
 from typing import (
     TYPE_CHECKING,
     Awaitable,
+    Callable,
     Collection,
     Dict,
     Iterable,
@@ -62,7 +64,7 @@ class StateFilter:
     types: "frozendict[str, Optional[FrozenSet[str]]]"
     include_others: bool = False
 
-    def __attrs_post_init__(self):
+    def __attrs_post_init__(self) -> None:
         # If `include_others` is set we canonicalise the filter by removing
         # wildcards from the types dictionary
         if self.include_others:
@@ -138,7 +140,9 @@ class StateFilter:
         )
 
     @staticmethod
-    def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
+    def freeze(
+        types: Mapping[str, Optional[Collection[str]]], include_others: bool
+    ) -> "StateFilter":
         """
         Returns a (frozen) StateFilter with the same contents as the parameters
         specified here, which can be made of mutable types.
@@ -530,6 +534,44 @@ class StateFilter:
             new_all, new_excludes, new_wildcards, new_concrete_keys
         )
 
+    def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool:
+        """Check if we need to wait for full state to complete to calculate this state
+
+        If we have a state filter which is completely satisfied even with partial
+        state, then we don't need to await_full_state before we can return it.
+
+        Args:
+            is_mine_id: a callable which confirms if a given state_key matches a mxid
+               of a local user
+        """
+
+        # TODO(faster_joins): it's not entirely clear that this is safe. In particular,
+        #  there may be circumstances in which we return a piece of state that, once we
+        #  resync the state, we discover is invalid. For example: if it turns out that
+        #  the sender of a piece of state wasn't actually in the room, then clearly that
+        #  state shouldn't have been returned.
+        #  We should at least add some tests around this to see what happens.
+
+        # if we haven't requested membership events, then it depends on the value of
+        # 'include_others'
+        if EventTypes.Member not in self.types:
+            return self.include_others
+
+        # if we're looking for *all* membership events, then we have to wait
+        member_state_keys = self.types[EventTypes.Member]
+        if member_state_keys is None:
+            return True
+
+        # otherwise, consider whose membership we are looking for. If it's entirely
+        # local users, then we don't need to wait.
+        for state_key in member_state_keys:
+            if not is_mine_id(state_key):
+                # remote user
+                return True
+
+        # local users only
+        return False
+
 
 _ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
 _ALL_NON_MEMBER_STATE_FILTER = StateFilter(
@@ -542,6 +584,7 @@ class StateGroupStorage:
     """High level interface to fetching state for event."""
 
     def __init__(self, hs: "HomeServer", stores: "Databases"):
+        self._is_mine_id = hs.is_mine_id
         self.stores = stores
         self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
 
@@ -584,23 +627,26 @@ class StateGroupStorage:
         if not event_ids:
             return {}
 
-        event_to_groups = await self._get_state_group_for_events(event_ids)
+        event_to_groups = await self.get_state_group_for_events(event_ids)
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(groups)
 
         return group_to_state
 
-    async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
+    async def get_state_ids_for_group(
+        self, state_group: int, state_filter: Optional[StateFilter] = None
+    ) -> StateMap[str]:
         """Get the event IDs of all the state in the given state group
 
         Args:
             state_group: A state group for which we want to get the state IDs.
+            state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
 
         Returns:
             Resolves to a map of (type, state_key) -> event_id
         """
-        group_to_state = await self._get_state_for_groups((state_group,))
+        group_to_state = await self.get_state_for_groups((state_group,), state_filter)
 
         return group_to_state[state_group]
 
@@ -673,7 +719,13 @@ class StateGroupStorage:
             RuntimeError if we don't have a state group for one or more of the events
                (ie they are outliers or unknown)
         """
-        event_to_groups = await self._get_state_group_for_events(event_ids)
+        await_full_state = True
+        if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+            await_full_state = False
+
+        event_to_groups = await self.get_state_group_for_events(
+            event_ids, await_full_state=await_full_state
+        )
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(
@@ -697,7 +749,9 @@ class StateGroupStorage:
         return {event: event_to_state[event] for event in event_ids}
 
     async def get_state_ids_for_events(
-        self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
+        self,
+        event_ids: Collection[str],
+        state_filter: Optional[StateFilter] = None,
     ) -> Dict[str, StateMap[str]]:
         """
         Get the state dicts corresponding to a list of events, containing the event_ids
@@ -714,7 +768,13 @@ class StateGroupStorage:
             RuntimeError if we don't have a state group for one or more of the events
                 (ie they are outliers or unknown)
         """
-        event_to_groups = await self._get_state_group_for_events(event_ids)
+        await_full_state = True
+        if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+            await_full_state = False
+
+        event_to_groups = await self.get_state_group_for_events(
+            event_ids, await_full_state=await_full_state
+        )
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(
@@ -772,7 +832,7 @@ class StateGroupStorage:
         )
         return state_map[event_id]
 
-    def _get_state_for_groups(
+    def get_state_for_groups(
         self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
     ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
         """Gets the state at each of a list of state groups, optionally
@@ -790,7 +850,7 @@ class StateGroupStorage:
             groups, state_filter or StateFilter.all()
         )
 
-    async def _get_state_group_for_events(
+    async def get_state_group_for_events(
         self,
         event_ids: Collection[str],
         await_full_state: bool = True,
@@ -800,7 +860,7 @@ class StateGroupStorage:
         Args:
             event_ids: events to get state groups for
             await_full_state: if true, will block if we do not yet have complete
-               state at this event.
+               state at these events.
         """
         if await_full_state:
             await self._partial_state_events_tracker.await_full_state(event_ids)
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index d7d6f1d90e..0031df1e06 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -11,7 +11,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
+from types import TracebackType
+from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
 
 from typing_extensions import Protocol
 
@@ -86,5 +87,80 @@ class Connection(Protocol):
     def __enter__(self) -> "Connection":
         ...
 
-    def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> Optional[bool]:
+        ...
+
+
+class DBAPI2Module(Protocol):
+    """The module-level attributes that we use from PEP 249.
+
+    This is NOT a comprehensive stub for the entire DBAPI2."""
+
+    __name__: str
+
+    # Exceptions. See https://peps.python.org/pep-0249/#exceptions
+
+    # For our specific drivers:
+    # - Python's sqlite3 module doesn't contains the same descriptions as the
+    #   DBAPI2 spec, see https://docs.python.org/3/library/sqlite3.html#exceptions
+    # - Psycopg2 maps every Postgres error code onto a unique exception class which
+    #   extends from this hierarchy. See
+    #     https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#exceptions
+    #     https://www.postgresql.org/docs/current/errcodes-appendix.html#ERRCODES-TABLE
+    Warning: Type[Exception]
+    Error: Type[Exception]
+
+    # Errors are divided into `InterfaceError`s (something went wrong in the database
+    # driver) and `DatabaseError`s (something went wrong in the database). These are
+    # both subclasses of `Error`, but we can't currently express this in type
+    # annotations due to https://github.com/python/mypy/issues/8397
+    InterfaceError: Type[Exception]
+    DatabaseError: Type[Exception]
+
+    # Everything below is a subclass of `DatabaseError`.
+
+    # Roughly: the database rejected a nonsensical value. Examples:
+    # - An integer was too big for its data type.
+    # - An invalid date time was provided.
+    # - A string contained a null code point.
+    DataError: Type[Exception]
+
+    # Roughly: something went wrong in the database, but it's not within the application
+    # programmer's control. Examples:
+    # - We failed to establish a connection to the database.
+    # - The connection to the database was lost.
+    # - A deadlock was detected.
+    # - A serialisation failure occurred.
+    # - The database ran out of resources, such as storage, memory, connections, etc.
+    # - The database encountered an error from the operating system.
+    OperationalError: Type[Exception]
+
+    # Roughly: we've given the database data which breaks a rule we asked it to enforce.
+    # Examples:
+    # - Stop, criminal scum! You violated the foreign key constraint
+    # - Also check constraints, non-null constraints, etc.
+    IntegrityError: Type[Exception]
+
+    # Roughly: something went wrong within the database server itself.
+    InternalError: Type[Exception]
+
+    # Roughly: the application did something silly that needs to be fixed. Examples:
+    # - We don't have permissions to do something.
+    # - We tried to create a table with duplicate column names.
+    # - We tried to use a reserved name.
+    # - We referred to a column that doesn't exist.
+    ProgrammingError: Type[Exception]
+
+    # Roughly: we've tried to do something that this database doesn't support.
+    NotSupportedError: Type[Exception]
+
+    def connect(self, **parameters: object) -> Connection:
         ...
+
+
+__all__ = ["Cursor", "Connection", "DBAPI2Module"]
diff --git a/synapse/types.py b/synapse/types.py
index 9ac688b23b..6f7128ddd6 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -24,6 +24,7 @@ from typing import (
     Mapping,
     Match,
     MutableMapping,
+    NoReturn,
     Optional,
     Set,
     Tuple,
@@ -35,7 +36,8 @@ from typing import (
 import attr
 from frozendict import frozendict
 from signedjson.key import decode_verify_key_bytes
-from typing_extensions import TypedDict
+from signedjson.types import VerifyKey
+from typing_extensions import Final, TypedDict
 from unpaddedbase64 import decode_base64
 from zope.interface import Interface
 
@@ -55,6 +57,7 @@ from synapse.util.stringutils import parse_and_validate_server_name
 if TYPE_CHECKING:
     from synapse.appservice.api import ApplicationService
     from synapse.storage.databases.main import DataStore, PurgeEventsStore
+    from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
 
 # Define a state map type from type/state_key to T (usually an event ID or
 # event)
@@ -114,7 +117,7 @@ class Requester:
     app_service: Optional["ApplicationService"]
     authenticated_entity: str
 
-    def serialize(self):
+    def serialize(self) -> Dict[str, Any]:
         """Converts self to a type that can be serialized as JSON, and then
         deserialized by `deserialize`
 
@@ -132,7 +135,9 @@ class Requester:
         }
 
     @staticmethod
-    def deserialize(store, input):
+    def deserialize(
+        store: "ApplicationServiceWorkerStore", input: Dict[str, Any]
+    ) -> "Requester":
         """Converts a dict that was produced by `serialize` back into a
         Requester.
 
@@ -236,10 +241,10 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
     domain: str
 
     # Because this is a frozen class, it is deeply immutable.
-    def __copy__(self):
+    def __copy__(self: DS) -> DS:
         return self
 
-    def __deepcopy__(self, memo):
+    def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS:
         return self
 
     @classmethod
@@ -625,6 +630,22 @@ class RoomStreamToken:
             return "s%d" % (self.stream,)
 
 
+class StreamKeyType:
+    """Known stream types.
+
+    A stream is a list of entities ordered by an incrementing "stream token".
+    """
+
+    ROOM: Final = "room_key"
+    PRESENCE: Final = "presence_key"
+    TYPING: Final = "typing_key"
+    RECEIPT: Final = "receipt_key"
+    ACCOUNT_DATA: Final = "account_data_key"
+    PUSH_RULES: Final = "push_rules_key"
+    TO_DEVICE: Final = "to_device_key"
+    DEVICE_LIST: Final = "device_list_key"
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class StreamToken:
     """A collection of keys joined together by underscores in the following
@@ -729,16 +750,18 @@ class StreamToken:
         )
 
     @property
-    def room_stream_id(self):
+    def room_stream_id(self) -> int:
         return self.room_key.stream
 
-    def copy_and_advance(self, key, new_value) -> "StreamToken":
+    def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken":
         """Advance the given key in the token to a new value if and only if the
         new value is after the old value.
+
+        :raises TypeError: if `key` is not the one of the keys tracked by a StreamToken.
         """
-        if key == "room_key":
+        if key == StreamKeyType.ROOM:
             new_token = self.copy_and_replace(
-                "room_key", self.room_key.copy_and_advance(new_value)
+                StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value)
             )
             return new_token
 
@@ -751,7 +774,7 @@ class StreamToken:
         else:
             return self
 
-    def copy_and_replace(self, key, new_value) -> "StreamToken":
+    def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken":
         return attr.evolve(self, **{key: new_value})
 
 
@@ -793,14 +816,14 @@ class ThirdPartyInstanceID:
     # Deny iteration because it will bite you if you try to create a singleton
     # set by:
     #    users = set(user)
-    def __iter__(self):
+    def __iter__(self) -> NoReturn:
         raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
 
     # Because this class is a frozen class, it is deeply immutable.
-    def __copy__(self):
+    def __copy__(self) -> "ThirdPartyInstanceID":
         return self
 
-    def __deepcopy__(self, memo):
+    def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID":
         return self
 
     @classmethod
@@ -852,25 +875,28 @@ class DeviceListUpdates:
         return bool(self.changed or self.left)
 
 
-def get_verify_key_from_cross_signing_key(key_info):
+def get_verify_key_from_cross_signing_key(
+    key_info: Mapping[str, Any]
+) -> Tuple[str, VerifyKey]:
     """Get the key ID and signedjson verify key from a cross-signing key dict
 
     Args:
-        key_info (dict): a cross-signing key dict, which must have a "keys"
+        key_info: a cross-signing key dict, which must have a "keys"
             property that has exactly one item in it
 
     Returns:
-        (str, VerifyKey): the key ID and verify key for the cross-signing key
+        the key ID and verify key for the cross-signing key
     """
-    # make sure that exactly one key is provided
+    # make sure that a `keys` field is provided
     if "keys" not in key_info:
         raise ValueError("Invalid key")
     keys = key_info["keys"]
-    if len(keys) != 1:
-        raise ValueError("Invalid key")
-    # and return that one key
-    for key_id, key_data in keys.items():
+    # and that it contains exactly one key
+    if len(keys) == 1:
+        key_id, key_data = next(iter(keys.items()))
         return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))
+    else:
+        raise ValueError("Invalid key")
 
 
 @attr.s(auto_attribs=True, frozen=True, slots=True)
@@ -906,3 +932,9 @@ class UserProfile(TypedDict):
     user_id: str
     display_name: Optional[str]
     avatar_url: Optional[str]
+
+
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class RetentionPolicy:
+    min_lifetime: Optional[int] = None
+    max_lifetime: Optional[int] = None
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 45ff0de638..a3b60578e3 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import logging
+import math
 import threading
 import weakref
 from enum import Enum
@@ -40,6 +41,7 @@ from twisted.internet.interfaces import IReactorTime
 
 from synapse.config import cache as cache_config
 from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.metrics.jemalloc import get_jemalloc_stats
 from synapse.util import Clock, caches
 from synapse.util.caches import CacheMetric, EvictionReason, register_cache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
@@ -106,10 +108,16 @@ GLOBAL_ROOT = ListNode["_Node"].create_root_node()
 
 
 @wrap_as_background_process("LruCache._expire_old_entries")
-async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
+async def _expire_old_entries(
+    clock: Clock, expiry_seconds: int, autotune_config: Optional[dict]
+) -> None:
     """Walks the global cache list to find cache entries that haven't been
-    accessed in the given number of seconds.
+    accessed in the given number of seconds, or if a given memory threshold has been breached.
     """
+    if autotune_config:
+        max_cache_memory_usage = autotune_config["max_cache_memory_usage"]
+        target_cache_memory_usage = autotune_config["target_cache_memory_usage"]
+        min_cache_ttl = autotune_config["min_cache_ttl"] / 1000
 
     now = int(clock.time())
     node = GLOBAL_ROOT.prev_node
@@ -119,11 +127,36 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
 
     logger.debug("Searching for stale caches")
 
+    evicting_due_to_memory = False
+
+    # determine if we're evicting due to memory
+    jemalloc_interface = get_jemalloc_stats()
+    if jemalloc_interface and autotune_config:
+        try:
+            jemalloc_interface.refresh_stats()
+            mem_usage = jemalloc_interface.get_stat("allocated")
+            if mem_usage > max_cache_memory_usage:
+                logger.info("Begin memory-based cache eviction.")
+                evicting_due_to_memory = True
+        except Exception:
+            logger.warning(
+                "Unable to read allocated memory, skipping memory-based cache eviction."
+            )
+
     while node is not GLOBAL_ROOT:
         # Only the root node isn't a `_TimedListNode`.
         assert isinstance(node, _TimedListNode)
 
-        if node.last_access_ts_secs > now - expiry_seconds:
+        # if node has not aged past expiry_seconds and we are not evicting due to memory usage, there's
+        # nothing to do here
+        if (
+            node.last_access_ts_secs > now - expiry_seconds
+            and not evicting_due_to_memory
+        ):
+            break
+
+        # if entry is newer than min_cache_entry_ttl then do not evict and don't evict anything newer
+        if evicting_due_to_memory and now - node.last_access_ts_secs < min_cache_ttl:
             break
 
         cache_entry = node.get_cache_entry()
@@ -136,10 +169,29 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
         assert cache_entry is not None
         cache_entry.drop_from_cache()
 
+        # Check mem allocation periodically if we are evicting a bunch of caches
+        if jemalloc_interface and evicting_due_to_memory and (i + 1) % 100 == 0:
+            try:
+                jemalloc_interface.refresh_stats()
+                mem_usage = jemalloc_interface.get_stat("allocated")
+                if mem_usage < target_cache_memory_usage:
+                    evicting_due_to_memory = False
+                    logger.info("Stop memory-based cache eviction.")
+            except Exception:
+                logger.warning(
+                    "Unable to read allocated memory, this may affect memory-based cache eviction."
+                )
+                # If we've failed to read the current memory usage then we
+                # should stop trying to evict based on memory usage
+                evicting_due_to_memory = False
+
         # If we do lots of work at once we yield to allow other stuff to happen.
         if (i + 1) % 10000 == 0:
             logger.debug("Waiting during drop")
-            await clock.sleep(0)
+            if node.last_access_ts_secs > now - expiry_seconds:
+                await clock.sleep(0.5)
+            else:
+                await clock.sleep(0)
             logger.debug("Waking during drop")
 
         node = next_node
@@ -156,21 +208,28 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
 
 def setup_expire_lru_cache_entries(hs: "HomeServer") -> None:
     """Start a background job that expires all cache entries if they have not
-    been accessed for the given number of seconds.
+    been accessed for the given number of seconds, or if a given memory usage threshold has been
+    breached.
     """
-    if not hs.config.caches.expiry_time_msec:
+    if not hs.config.caches.expiry_time_msec and not hs.config.caches.cache_autotuning:
         return
 
-    logger.info(
-        "Expiring LRU caches after %d seconds", hs.config.caches.expiry_time_msec / 1000
-    )
+    if hs.config.caches.expiry_time_msec:
+        expiry_time = hs.config.caches.expiry_time_msec / 1000
+        logger.info("Expiring LRU caches after %d seconds", expiry_time)
+    else:
+        expiry_time = math.inf
 
     global USE_GLOBAL_LIST
     USE_GLOBAL_LIST = True
 
     clock = hs.get_clock()
     clock.looping_call(
-        _expire_old_entries, 30 * 1000, clock, hs.config.caches.expiry_time_msec / 1000
+        _expire_old_entries,
+        30 * 1000,
+        clock,
+        expiry_time,
+        hs.config.caches.cache_autotuning,
     )
 
 
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 81bfed268e..d0a69ff843 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -16,8 +16,8 @@ import random
 from types import TracebackType
 from typing import TYPE_CHECKING, Any, Optional, Type
 
-import synapse.logging.context
 from synapse.api.errors import CodeMessageException
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage import DataStore
 from synapse.util import Clock
 
@@ -265,4 +265,4 @@ class RetryDestinationLimiter:
                 logger.exception("Failed to store destination_retry_timings")
 
         # we deliberately do this in the background.
-        synapse.logging.context.run_in_background(store_retry_timings)
+        run_as_background_process("store_retry_timings", store_retry_timings)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index de6d2ffc52..da4af02796 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -22,7 +22,7 @@ from synapse.events import EventBase
 from synapse.events.utils import prune_event
 from synapse.storage import Storage
 from synapse.storage.state import StateFilter
-from synapse.types import StateMap, get_domain_from_id
+from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
 
 logger = logging.getLogger(__name__)
 
@@ -94,7 +94,7 @@ async def filter_events_for_client(
 
     if filter_send_to_client:
         room_ids = {e.room_id for e in events}
-        retention_policies = {}
+        retention_policies: Dict[str, RetentionPolicy] = {}
 
         for room_id in room_ids:
             retention_policies[
@@ -137,7 +137,7 @@ async def filter_events_for_client(
             # events.
             if not event.is_state():
                 retention_policy = retention_policies[event.room_id]
-                max_lifetime = retention_policy.get("max_lifetime")
+                max_lifetime = retention_policy.max_lifetime
 
                 if max_lifetime is not None:
                     oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime