summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/filtering.py5
-rw-r--r--synapse/app/_base.py44
-rw-r--r--synapse/app/homeserver.py36
-rw-r--r--synapse/config/_base.py81
-rw-r--r--synapse/config/_base.pyi15
-rw-r--r--synapse/config/cache.py82
-rw-r--r--synapse/config/room.py47
-rw-r--r--synapse/config/server.py2
-rw-r--r--synapse/events/__init__.py45
-rw-r--r--synapse/events/snapshot.py180
-rw-r--r--synapse/events/spamcheck.py81
-rw-r--r--synapse/federation/transport/server/_base.py13
-rw-r--r--synapse/groups/groups_server.py2
-rw-r--r--synapse/handlers/account_data.py10
-rw-r--r--synapse/handlers/appservice.py39
-rw-r--r--synapse/handlers/device.py7
-rw-r--r--synapse/handlers/devicemessage.py6
-rw-r--r--synapse/handlers/e2e_keys.py29
-rw-r--r--synapse/handlers/federation.py6
-rw-r--r--synapse/handlers/federation_event.py10
-rw-r--r--synapse/handlers/initial_sync.py19
-rw-r--r--synapse/handlers/message.py42
-rw-r--r--synapse/handlers/pagination.py8
-rw-r--r--synapse/handlers/presence.py6
-rw-r--r--synapse/handlers/receipts.py106
-rw-r--r--synapse/handlers/relations.py20
-rw-r--r--synapse/handlers/room.py150
-rw-r--r--synapse/handlers/search.py10
-rw-r--r--synapse/handlers/sync.py23
-rw-r--r--synapse/handlers/typing.py4
-rw-r--r--synapse/http/client.py18
-rw-r--r--synapse/http/connectproxyclient.py39
-rw-r--r--synapse/http/federation/matrix_federation_agent.py2
-rw-r--r--synapse/http/federation/srv_resolver.py4
-rw-r--r--synapse/http/federation/well_known_resolver.py6
-rw-r--r--synapse/http/matrixfederationclient.py31
-rw-r--r--synapse/http/proxyagent.py2
-rw-r--r--synapse/http/request_metrics.py10
-rw-r--r--synapse/http/server.py74
-rw-r--r--synapse/http/site.py25
-rw-r--r--synapse/logging/_remote.py20
-rw-r--r--synapse/logging/formatter.py14
-rw-r--r--synapse/logging/handlers.py4
-rw-r--r--synapse/logging/scopecontextmanager.py28
-rw-r--r--synapse/metrics/jemalloc.py114
-rw-r--r--synapse/notifier.py5
-rw-r--r--synapse/push/__init__.py74
-rw-r--r--synapse/push/action_generator.py44
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py45
-rw-r--r--synapse/push/push_rule_evaluator.py70
-rw-r--r--synapse/replication/http/_base.py21
-rw-r--r--synapse/replication/tcp/client.py18
-rw-r--r--synapse/rest/client/receipts.py13
-rw-r--r--synapse/rest/client/room.py6
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py30
-rw-r--r--synapse/server.py8
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py40
-rw-r--r--synapse/server_notices/server_notices_manager.py74
-rw-r--r--synapse/state/__init__.py9
-rw-r--r--synapse/storage/background_updates.py19
-rw-r--r--synapse/storage/database.py44
-rw-r--r--synapse/storage/databases/main/cache.py8
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py4
-rw-r--r--synapse/storage/databases/main/events.py236
-rw-r--r--synapse/storage/databases/main/events_worker.py35
-rw-r--r--synapse/storage/databases/main/metrics.py24
-rw-r--r--synapse/storage/databases/main/purge_events.py3
-rw-r--r--synapse/storage/databases/main/relations.py6
-rw-r--r--synapse/storage/databases/main/search.py33
-rw-r--r--synapse/storage/databases/main/stream.py34
-rw-r--r--synapse/storage/engines/__init__.py12
-rw-r--r--synapse/storage/engines/_base.py26
-rw-r--r--synapse/storage/engines/postgres.py92
-rw-r--r--synapse/storage/engines/sqlite.py72
-rw-r--r--synapse/storage/persist_events.py63
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/schema/__init__.py5
-rw-r--r--synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql18
-rw-r--r--synapse/storage/state.py6
-rw-r--r--synapse/storage/types.py80
-rw-r--r--synapse/types.py68
-rw-r--r--synapse/util/caches/lrucache.py79
82 files changed, 1864 insertions, 1081 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 4a808e33fe..b91ce06de7 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -19,6 +19,7 @@ from typing import (
     TYPE_CHECKING,
     Awaitable,
     Callable,
+    Collection,
     Dict,
     Iterable,
     List,
@@ -444,9 +445,9 @@ class Filter:
         return room_ids
 
     async def _check_event_relations(
-        self, events: Iterable[FilterEvent]
+        self, events: Collection[FilterEvent]
     ) -> List[FilterEvent]:
-        # The event IDs to check, mypy doesn't understand the ifinstance check.
+        # The event IDs to check, mypy doesn't understand the isinstance check.
         event_ids = [event.event_id for event in events if isinstance(event, EventBase)]  # type: ignore[attr-defined]
         event_ids_to_keep = set(
             await self._store.events_have_relations(
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 3623c1724d..a3446ac6e8 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -49,9 +49,12 @@ from twisted.logger import LoggingFile, LogLevel
 from twisted.protocols.tls import TLSMemoryBIOFactory
 from twisted.python.threadpool import ThreadPool
 
+import synapse.util.caches
 from synapse.api.constants import MAX_PDU_SIZE
 from synapse.app import check_bind_error
 from synapse.app.phone_stats_home import start_phone_stats_home
+from synapse.config import ConfigError
+from synapse.config._base import format_config_error
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import ManholeConfig
 from synapse.crypto import context_factory
@@ -432,6 +435,10 @@ async def start(hs: "HomeServer") -> None:
         signal.signal(signal.SIGHUP, run_sighup)
 
         register_sighup(refresh_certificate, hs)
+        register_sighup(reload_cache_config, hs.config)
+
+    # Apply the cache config.
+    hs.config.caches.resize_all_caches()
 
     # Load the certificate from disk.
     refresh_certificate(hs)
@@ -486,6 +493,43 @@ async def start(hs: "HomeServer") -> None:
         atexit.register(gc.freeze)
 
 
+def reload_cache_config(config: HomeServerConfig) -> None:
+    """Reload cache config from disk and immediately apply it.resize caches accordingly.
+
+    If the config is invalid, a `ConfigError` is logged and no changes are made.
+
+    Otherwise, this:
+        - replaces the `caches` section on the given `config` object,
+        - resizes all caches according to the new cache factors, and
+
+    Note that the following cache config keys are read, but not applied:
+        - event_cache_size: used to set a max_size and _original_max_size on
+              EventsWorkerStore._get_event_cache when it is created. We'd have to update
+              the _original_max_size (and maybe
+        - sync_response_cache_duration: would have to update the timeout_sec attribute on
+              HomeServer ->  SyncHandler -> ResponseCache.
+        - track_memory_usage. This affects synapse.util.caches.TRACK_MEMORY_USAGE which
+              influences Synapse's self-reported metrics.
+
+    Also, the HTTPConnectionPool in SimpleHTTPClient sets its maxPersistentPerHost
+    parameter based on the global_factor. This won't be applied on a config reload.
+    """
+    try:
+        previous_cache_config = config.reload_config_section("caches")
+    except ConfigError as e:
+        logger.warning("Failed to reload cache config")
+        for f in format_config_error(e):
+            logger.warning(f)
+    else:
+        logger.debug(
+            "New cache config. Was:\n %s\nNow:\n",
+            previous_cache_config.__dict__,
+            config.caches.__dict__,
+        )
+        synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage
+        config.caches.resize_all_caches()
+
+
 def setup_sentry(hs: "HomeServer") -> None:
     """Enable sentry integration, if enabled in configuration"""
 
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 0f75e7b9d4..4c6c0658ab 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -16,7 +16,7 @@
 import logging
 import os
 import sys
-from typing import Dict, Iterable, Iterator, List
+from typing import Dict, Iterable, List
 
 from matrix_common.versionstring import get_distribution_version_string
 
@@ -45,7 +45,7 @@ from synapse.app._base import (
     redirect_stdio_to_logs,
     register_start,
 )
-from synapse.config._base import ConfigError
+from synapse.config._base import ConfigError, format_config_error
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import ListenerConfig
@@ -399,38 +399,6 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
     return hs
 
 
-def format_config_error(e: ConfigError) -> Iterator[str]:
-    """
-    Formats a config error neatly
-
-    The idea is to format the immediate error, plus the "causes" of those errors,
-    hopefully in a way that makes sense to the user. For example:
-
-        Error in configuration at 'oidc_config.user_mapping_provider.config.display_name_template':
-          Failed to parse config for module 'JinjaOidcMappingProvider':
-            invalid jinja template:
-              unexpected end of template, expected 'end of print statement'.
-
-    Args:
-        e: the error to be formatted
-
-    Returns: An iterator which yields string fragments to be formatted
-    """
-    yield "Error in configuration"
-
-    if e.path:
-        yield " at '%s'" % (".".join(e.path),)
-
-    yield ":\n  %s" % (e.msg,)
-
-    parent_e = e.__cause__
-    indent = 1
-    while parent_e:
-        indent += 1
-        yield ":\n%s%s" % ("  " * indent, str(parent_e))
-        parent_e = parent_e.__cause__
-
-
 def run(hs: HomeServer) -> None:
     _base.start_reactor(
         "synapse-homeserver",
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 179aa7ff88..42364fc133 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -16,14 +16,18 @@
 
 import argparse
 import errno
+import logging
 import os
 from collections import OrderedDict
 from hashlib import sha256
 from textwrap import dedent
 from typing import (
     Any,
+    ClassVar,
+    Collection,
     Dict,
     Iterable,
+    Iterator,
     List,
     MutableMapping,
     Optional,
@@ -40,6 +44,8 @@ import yaml
 
 from synapse.util.templates import _create_mxc_to_http_filter, _format_ts_filter
 
+logger = logging.getLogger(__name__)
+
 
 class ConfigError(Exception):
     """Represents a problem parsing the configuration
@@ -55,6 +61,38 @@ class ConfigError(Exception):
         self.path = path
 
 
+def format_config_error(e: ConfigError) -> Iterator[str]:
+    """
+    Formats a config error neatly
+
+    The idea is to format the immediate error, plus the "causes" of those errors,
+    hopefully in a way that makes sense to the user. For example:
+
+        Error in configuration at 'oidc_config.user_mapping_provider.config.display_name_template':
+          Failed to parse config for module 'JinjaOidcMappingProvider':
+            invalid jinja template:
+              unexpected end of template, expected 'end of print statement'.
+
+    Args:
+        e: the error to be formatted
+
+    Returns: An iterator which yields string fragments to be formatted
+    """
+    yield "Error in configuration"
+
+    if e.path:
+        yield " at '%s'" % (".".join(e.path),)
+
+    yield ":\n  %s" % (e.msg,)
+
+    parent_e = e.__cause__
+    indent = 1
+    while parent_e:
+        indent += 1
+        yield ":\n%s%s" % ("  " * indent, str(parent_e))
+        parent_e = parent_e.__cause__
+
+
 # We split these messages out to allow packages to override with package
 # specific instructions.
 MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS = """\
@@ -119,7 +157,7 @@ class Config:
             defined in subclasses.
     """
 
-    section: str
+    section: ClassVar[str]
 
     def __init__(self, root_config: "RootConfig" = None):
         self.root = root_config
@@ -309,9 +347,12 @@ class RootConfig:
     class, lower-cased and with "Config" removed.
     """
 
-    config_classes = []
+    config_classes: List[Type[Config]] = []
+
+    def __init__(self, config_files: Collection[str] = ()):
+        # Capture absolute paths here, so we can reload config after we daemonize.
+        self.config_files = [os.path.abspath(path) for path in config_files]
 
-    def __init__(self):
         for config_class in self.config_classes:
             if config_class.section is None:
                 raise ValueError("%r requires a section name" % (config_class,))
@@ -512,12 +553,10 @@ class RootConfig:
             object from parser.parse_args(..)`
         """
 
-        obj = cls()
-
         config_args = parser.parse_args(argv)
 
         config_files = find_config_files(search_paths=config_args.config_path)
-
+        obj = cls(config_files)
         if not config_files:
             parser.error("Must supply a config file.")
 
@@ -627,7 +666,7 @@ class RootConfig:
 
         generate_missing_configs = config_args.generate_missing_configs
 
-        obj = cls()
+        obj = cls(config_files)
 
         if config_args.generate_config:
             if config_args.report_stats is None:
@@ -727,6 +766,34 @@ class RootConfig:
     ) -> None:
         self.invoke_all("generate_files", config_dict, config_dir_path)
 
+    def reload_config_section(self, section_name: str) -> Config:
+        """Reconstruct the given config section, leaving all others unchanged.
+
+        This works in three steps:
+
+        1. Create a new instance of the relevant `Config` subclass.
+        2. Call `read_config` on that instance to parse the new config.
+        3. Replace the existing config instance with the new one.
+
+        :raises ValueError: if the given `section` does not exist.
+        :raises ConfigError: for any other problems reloading config.
+
+        :returns: the previous config object, which no longer has a reference to this
+            RootConfig.
+        """
+        existing_config: Optional[Config] = getattr(self, section_name, None)
+        if existing_config is None:
+            raise ValueError(f"Unknown config section '{section_name}'")
+        logger.info("Reloading config section '%s'", section_name)
+
+        new_config_data = read_config_files(self.config_files)
+        new_config = type(existing_config)(self)
+        new_config.read_config(new_config_data)
+        setattr(self, section_name, new_config)
+
+        existing_config.root = None
+        return existing_config
+
 
 def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]:
     """Read the config files into a dict
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index bd092f956d..71d6655fda 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -1,15 +1,19 @@
 import argparse
 from typing import (
     Any,
+    Collection,
     Dict,
     Iterable,
+    Iterator,
     List,
+    Literal,
     MutableMapping,
     Optional,
     Tuple,
     Type,
     TypeVar,
     Union,
+    overload,
 )
 
 import jinja2
@@ -64,6 +68,8 @@ class ConfigError(Exception):
         self.msg = msg
         self.path = path
 
+def format_config_error(e: ConfigError) -> Iterator[str]: ...
+
 MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
 MISSING_REPORT_STATS_SPIEL: str
 MISSING_SERVER_NAME: str
@@ -117,7 +123,8 @@ class RootConfig:
     background_updates: background_updates.BackgroundUpdateConfig
 
     config_classes: List[Type["Config"]] = ...
-    def __init__(self) -> None: ...
+    config_files: List[str]
+    def __init__(self, config_files: Collection[str] = ...) -> None: ...
     def invoke_all(
         self, func_name: str, *args: Any, **kwargs: Any
     ) -> MutableMapping[str, Any]: ...
@@ -157,6 +164,12 @@ class RootConfig:
     def generate_missing_files(
         self, config_dict: dict, config_dir_path: str
     ) -> None: ...
+    @overload
+    def reload_config_section(
+        self, section_name: Literal["caches"]
+    ) -> cache.CacheConfig: ...
+    @overload
+    def reload_config_section(self, section_name: str) -> Config: ...
 
 class Config:
     root: RootConfig
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 94d852f413..d2f55534d7 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -69,11 +69,11 @@ def _canonicalise_cache_name(cache_name: str) -> str:
 def add_resizable_cache(
     cache_name: str, cache_resize_callback: Callable[[float], None]
 ) -> None:
-    """Register a cache that's size can dynamically change
+    """Register a cache whose size can dynamically change
 
     Args:
         cache_name: A reference to the cache
-        cache_resize_callback: A callback function that will be ran whenever
+        cache_resize_callback: A callback function that will run whenever
             the cache needs to be resized
     """
     # Some caches have '*' in them which we strip out.
@@ -96,6 +96,13 @@ class CacheConfig(Config):
     section = "caches"
     _environ = os.environ
 
+    event_cache_size: int
+    cache_factors: Dict[str, float]
+    global_factor: float
+    track_memory_usage: bool
+    expiry_time_msec: Optional[int]
+    sync_response_cache_duration: int
+
     @staticmethod
     def reset() -> None:
         """Resets the caches to their defaults. Used for tests."""
@@ -115,6 +122,12 @@ class CacheConfig(Config):
         # A cache 'factor' is a multiplier that can be applied to each of
         # Synapse's caches in order to increase or decrease the maximum
         # number of entries that can be stored.
+        #
+        # The configuration for cache factors (caches.global_factor and
+        # caches.per_cache_factors) can be reloaded while the application is running,
+        # by sending a SIGHUP signal to the Synapse process. Changes to other parts of
+        # the caching config will NOT be applied after a SIGHUP is received; a restart
+        # is necessary.
 
         # The number of events to cache in memory. Not affected by
         # caches.global_factor.
@@ -163,6 +176,24 @@ class CacheConfig(Config):
           #
           #cache_entry_ttl: 30m
 
+          # This flag enables cache autotuning, and is further specified by the sub-options `max_cache_memory_usage`,
+          # `target_cache_memory_usage`, `min_cache_ttl`. These flags work in conjunction with each other to maintain
+          # a balance between cache memory usage and cache entry availability. You must be using jemalloc to utilize
+          # this option, and all three of the options must be specified for this feature to work.
+          #cache_autotuning:
+            # This flag sets a ceiling on much memory the cache can use before caches begin to be continuously evicted.
+            # They will continue to be evicted until the memory usage drops below the `target_memory_usage`, set in
+            # the flag below, or until the `min_cache_ttl` is hit.
+            #max_cache_memory_usage: 1024M
+
+            # This flag sets a rough target for the desired memory usage of the caches.
+            #target_cache_memory_usage: 758M
+
+            # 'min_cache_ttl` sets a limit under which newer cache entries are not evicted and is only applied when
+            # caches are actively being evicted/`max_cache_memory_usage` has been exceeded. This is to protect hot caches
+            # from being emptied while Synapse is evicting due to memory.
+            #min_cache_ttl: 5m
+
           # Controls how long the results of a /sync request are cached for after
           # a successful response is returned. A higher duration can help clients with
           # intermittent connections, at the cost of higher memory usage.
@@ -174,21 +205,21 @@ class CacheConfig(Config):
         """
 
     def read_config(self, config: JsonDict, **kwargs: Any) -> None:
+        """Populate this config object with values from `config`.
+
+        This method does NOT resize existing or future caches: use `resize_all_caches`.
+        We use two separate methods so that we can reject bad config before applying it.
+        """
         self.event_cache_size = self.parse_size(
             config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
         )
-        self.cache_factors: Dict[str, float] = {}
+        self.cache_factors = {}
 
         cache_config = config.get("caches") or {}
-        self.global_factor = cache_config.get(
-            "global_factor", properties.default_factor_size
-        )
+        self.global_factor = cache_config.get("global_factor", _DEFAULT_FACTOR_SIZE)
         if not isinstance(self.global_factor, (int, float)):
             raise ConfigError("caches.global_factor must be a number.")
 
-        # Set the global one so that it's reflected in new caches
-        properties.default_factor_size = self.global_factor
-
         # Load cache factors from the config
         individual_factors = cache_config.get("per_cache_factors") or {}
         if not isinstance(individual_factors, dict):
@@ -230,7 +261,7 @@ class CacheConfig(Config):
         cache_entry_ttl = cache_config.get("cache_entry_ttl", "30m")
 
         if expire_caches:
-            self.expiry_time_msec: Optional[int] = self.parse_duration(cache_entry_ttl)
+            self.expiry_time_msec = self.parse_duration(cache_entry_ttl)
         else:
             self.expiry_time_msec = None
 
@@ -250,23 +281,38 @@ class CacheConfig(Config):
             )
             self.expiry_time_msec = self.parse_duration(expiry_time)
 
+        self.cache_autotuning = cache_config.get("cache_autotuning")
+        if self.cache_autotuning:
+            max_memory_usage = self.cache_autotuning.get("max_cache_memory_usage")
+            self.cache_autotuning["max_cache_memory_usage"] = self.parse_size(
+                max_memory_usage
+            )
+
+            target_mem_size = self.cache_autotuning.get("target_cache_memory_usage")
+            self.cache_autotuning["target_cache_memory_usage"] = self.parse_size(
+                target_mem_size
+            )
+
+            min_cache_ttl = self.cache_autotuning.get("min_cache_ttl")
+            self.cache_autotuning["min_cache_ttl"] = self.parse_duration(min_cache_ttl)
+
         self.sync_response_cache_duration = self.parse_duration(
             cache_config.get("sync_response_cache_duration", 0)
         )
 
-        # Resize all caches (if necessary) with the new factors we've loaded
-        self.resize_all_caches()
-
-        # Store this function so that it can be called from other classes without
-        # needing an instance of Config
-        properties.resize_all_caches_func = self.resize_all_caches
-
     def resize_all_caches(self) -> None:
-        """Ensure all cache sizes are up to date
+        """Ensure all cache sizes are up-to-date.
 
         For each cache, run the mapped callback function with either
         a specific cache factor or the default, global one.
         """
+        # Set the global factor size, so that new caches are appropriately sized.
+        properties.default_factor_size = self.global_factor
+
+        # Store this function so that it can be called from other classes without
+        # needing an instance of CacheConfig
+        properties.resize_all_caches_func = self.resize_all_caches
+
         # block other threads from modifying _CACHES while we iterate it.
         with _CACHES_LOCK:
             for cache_name, callback in _CACHES.items():
diff --git a/synapse/config/room.py b/synapse/config/room.py
index e18a87ea37..462d85ac1d 100644
--- a/synapse/config/room.py
+++ b/synapse/config/room.py
@@ -63,6 +63,19 @@ class RoomConfig(Config):
                 "Invalid value for encryption_enabled_by_default_for_room_type"
             )
 
+        self.default_power_level_content_override = config.get(
+            "default_power_level_content_override",
+            None,
+        )
+        if self.default_power_level_content_override is not None:
+            for preset in self.default_power_level_content_override:
+                if preset not in vars(RoomCreationPreset).values():
+                    raise ConfigError(
+                        "Unrecognised room preset %s in default_power_level_content_override"
+                        % preset
+                    )
+                # We validate the actual overrides when we try to apply them.
+
     def generate_config_section(self, **kwargs: Any) -> str:
         return """\
         ## Rooms ##
@@ -83,4 +96,38 @@ class RoomConfig(Config):
         # will also not affect rooms created by other servers.
         #
         #encryption_enabled_by_default_for_room_type: invite
+
+        # Override the default power levels for rooms created on this server, per
+        # room creation preset.
+        #
+        # The appropriate dictionary for the room preset will be applied on top
+        # of the existing power levels content.
+        #
+        # Useful if you know that your users need special permissions in rooms
+        # that they create (e.g. to send particular types of state events without
+        # needing an elevated power level).  This takes the same shape as the
+        # `power_level_content_override` parameter in the /createRoom API, but
+        # is applied before that parameter.
+        #
+        # Valid keys are some or all of `private_chat`, `trusted_private_chat`
+        # and `public_chat`. Inside each of those should be any of the
+        # properties allowed in `power_level_content_override` in the
+        # /createRoom API. If any property is missing, its default value will
+        # continue to be used. If any property is present, it will overwrite
+        # the existing default completely (so if the `events` property exists,
+        # the default event power levels will be ignored).
+        #
+        #default_power_level_content_override:
+        #    private_chat:
+        #        "events":
+        #            "com.example.myeventtype" : 0
+        #            "m.room.avatar": 50
+        #            "m.room.canonical_alias": 50
+        #            "m.room.encryption": 100
+        #            "m.room.history_visibility": 100
+        #            "m.room.name": 50
+        #            "m.room.power_levels": 100
+        #            "m.room.server_acl": 100
+        #            "m.room.tombstone": 100
+        #        "events_default": 1
         """
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 005a3ee48c..f73d5e1f66 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -996,7 +996,7 @@ class ServerConfig(Config):
         #   federation: the server-server API (/_matrix/federation). Also implies
         #       'media', 'keys', 'openid'
         #
-        #   keys: the key discovery API (/_matrix/keys).
+        #   keys: the key discovery API (/_matrix/key).
         #
         #   media: the media API (/_matrix/media).
         #
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index c238376caf..39ad2793d9 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import abc
+import collections.abc
 import os
 from typing import (
     TYPE_CHECKING,
@@ -32,9 +33,11 @@ from typing import (
     overload,
 )
 
+import attr
 from typing_extensions import Literal
 from unpaddedbase64 import encode_base64
 
+from synapse.api.constants import RelationTypes
 from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
 from synapse.types import JsonDict, RoomStreamToken
 from synapse.util.caches import intern_dict
@@ -615,3 +618,45 @@ def make_event_from_dict(
     return event_type(
         event_dict, room_version, internal_metadata_dict or {}, rejected_reason
     )
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventRelation:
+    # The target event of the relation.
+    parent_id: str
+    # The relation type.
+    rel_type: str
+    # The aggregation key. Will be None if the rel_type is not m.annotation or is
+    # not a string.
+    aggregation_key: Optional[str]
+
+
+def relation_from_event(event: EventBase) -> Optional[_EventRelation]:
+    """
+    Attempt to parse relation information an event.
+
+    Returns:
+        The event relation information, if it is valid. None, otherwise.
+    """
+    relation = event.content.get("m.relates_to")
+    if not relation or not isinstance(relation, collections.abc.Mapping):
+        # No relation information.
+        return None
+
+    # Relations must have a type and parent event ID.
+    rel_type = relation.get("rel_type")
+    if not isinstance(rel_type, str):
+        return None
+
+    parent_id = relation.get("event_id")
+    if not isinstance(parent_id, str):
+        return None
+
+    # Annotations have a key field.
+    aggregation_key = None
+    if rel_type == RelationTypes.ANNOTATION:
+        aggregation_key = relation.get("key")
+        if not isinstance(aggregation_key, str):
+            aggregation_key = None
+
+    return _EventRelation(parent_id, rel_type, aggregation_key)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 46042b2bf7..9ccd24b298 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -15,12 +15,10 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
 
 import attr
 from frozendict import frozendict
-
-from twisted.internet.defer import Deferred
+from typing_extensions import Literal
 
 from synapse.appservice import ApplicationService
 from synapse.events import EventBase
-from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.types import JsonDict, StateMap
 
 if TYPE_CHECKING:
@@ -60,6 +58,9 @@ class EventContext:
             If ``state_group`` is None (ie, the event is an outlier),
             ``state_group_before_event`` will always also be ``None``.
 
+        state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None
+            then this is the delta of the state between the two groups.
+
         prev_group: If it is known, ``state_group``'s prev_group. Note that this being
             None does not necessarily mean that ``state_group`` does not have
             a prev_group!
@@ -78,73 +79,47 @@ class EventContext:
         app_service: If this event is being sent by a (local) application service, that
             app service.
 
-        _current_state_ids: The room state map, including this event - ie, the state
-            in ``state_group``.
-
-            (type, state_key) -> event_id
-
-            For an outlier, this is {}
-
-            Note that this is a private attribute: it should be accessed via
-            ``get_current_state_ids``. _AsyncEventContext impl calculates this
-            on-demand: it will be None until that happens.
-
-        _prev_state_ids: The room state map, excluding this event - ie, the state
-            in ``state_group_before_event``. For a non-state
-            event, this will be the same as _current_state_events.
-
-            Note that it is a completely different thing to prev_group!
-
-            (type, state_key) -> event_id
-
-            For an outlier, this is {}
-
-            As with _current_state_ids, this is a private attribute. It should be
-            accessed via get_prev_state_ids.
-
         partial_state: if True, we may be storing this event with a temporary,
             incomplete state.
     """
 
-    rejected: Union[bool, str] = False
+    _storage: "Storage"
+    rejected: Union[Literal[False], str] = False
     _state_group: Optional[int] = None
     state_group_before_event: Optional[int] = None
+    _state_delta_due_to_event: Optional[StateMap[str]] = None
     prev_group: Optional[int] = None
     delta_ids: Optional[StateMap[str]] = None
     app_service: Optional[ApplicationService] = None
 
-    _current_state_ids: Optional[StateMap[str]] = None
-    _prev_state_ids: Optional[StateMap[str]] = None
-
     partial_state: bool = False
 
     @staticmethod
     def with_state(
+        storage: "Storage",
         state_group: Optional[int],
         state_group_before_event: Optional[int],
-        current_state_ids: Optional[StateMap[str]],
-        prev_state_ids: Optional[StateMap[str]],
+        state_delta_due_to_event: Optional[StateMap[str]],
         partial_state: bool,
         prev_group: Optional[int] = None,
         delta_ids: Optional[StateMap[str]] = None,
     ) -> "EventContext":
         return EventContext(
-            current_state_ids=current_state_ids,
-            prev_state_ids=prev_state_ids,
+            storage=storage,
             state_group=state_group,
             state_group_before_event=state_group_before_event,
+            state_delta_due_to_event=state_delta_due_to_event,
             prev_group=prev_group,
             delta_ids=delta_ids,
             partial_state=partial_state,
         )
 
     @staticmethod
-    def for_outlier() -> "EventContext":
+    def for_outlier(
+        storage: "Storage",
+    ) -> "EventContext":
         """Return an EventContext instance suitable for persisting an outlier event"""
-        return EventContext(
-            current_state_ids={},
-            prev_state_ids={},
-        )
+        return EventContext(storage=storage)
 
     async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
         """Converts self to a type that can be serialized as JSON, and then
@@ -157,24 +132,14 @@ class EventContext:
             The serialized event.
         """
 
-        # We don't serialize the full state dicts, instead they get pulled out
-        # of the DB on the other side. However, the other side can't figure out
-        # the prev_state_ids, so if we're a state event we include the event
-        # id that we replaced in the state.
-        if event.is_state():
-            prev_state_ids = await self.get_prev_state_ids()
-            prev_state_id = prev_state_ids.get((event.type, event.state_key))
-        else:
-            prev_state_id = None
-
         return {
-            "prev_state_id": prev_state_id,
-            "event_type": event.type,
-            "event_state_key": event.get_state_key(),
             "state_group": self._state_group,
             "state_group_before_event": self.state_group_before_event,
             "rejected": self.rejected,
             "prev_group": self.prev_group,
+            "state_delta_due_to_event": _encode_state_dict(
+                self._state_delta_due_to_event
+            ),
             "delta_ids": _encode_state_dict(self.delta_ids),
             "app_service_id": self.app_service.id if self.app_service else None,
             "partial_state": self.partial_state,
@@ -192,16 +157,16 @@ class EventContext:
         Returns:
             The event context.
         """
-        context = _AsyncEventContextImpl(
+        context = EventContext(
             # We use the state_group and prev_state_id stuff to pull the
             # current_state_ids out of the DB and construct prev_state_ids.
             storage=storage,
-            prev_state_id=input["prev_state_id"],
-            event_type=input["event_type"],
-            event_state_key=input["event_state_key"],
             state_group=input["state_group"],
             state_group_before_event=input["state_group_before_event"],
             prev_group=input["prev_group"],
+            state_delta_due_to_event=_decode_state_dict(
+                input["state_delta_due_to_event"]
+            ),
             delta_ids=_decode_state_dict(input["delta_ids"]),
             rejected=input["rejected"],
             partial_state=input.get("partial_state", False),
@@ -249,8 +214,15 @@ class EventContext:
         if self.rejected:
             raise RuntimeError("Attempt to access state_ids of rejected event")
 
-        await self._ensure_fetched()
-        return self._current_state_ids
+        assert self._state_delta_due_to_event is not None
+
+        prev_state_ids = await self.get_prev_state_ids()
+
+        if self._state_delta_due_to_event:
+            prev_state_ids = dict(prev_state_ids)
+            prev_state_ids.update(self._state_delta_due_to_event)
+
+        return prev_state_ids
 
     async def get_prev_state_ids(self) -> StateMap[str]:
         """
@@ -265,94 +237,10 @@ class EventContext:
             Maps a (type, state_key) to the event ID of the state event matching
             this tuple.
         """
-        await self._ensure_fetched()
-        # There *should* be previous state IDs now.
-        assert self._prev_state_ids is not None
-        return self._prev_state_ids
-
-    def get_cached_current_state_ids(self) -> Optional[StateMap[str]]:
-        """Gets the current state IDs if we have them already cached.
-
-        It is an error to access this for a rejected event, since rejected state should
-        not make it into the room state. This method will raise an exception if
-        ``rejected`` is set.
-
-        Returns:
-            Returns None if we haven't cached the state or if state_group is None
-            (which happens when the associated event is an outlier).
-
-            Otherwise, returns the the current state IDs.
-        """
-        if self.rejected:
-            raise RuntimeError("Attempt to access state_ids of rejected event")
-
-        return self._current_state_ids
-
-    async def _ensure_fetched(self) -> None:
-        return None
-
-
-@attr.s(slots=True)
-class _AsyncEventContextImpl(EventContext):
-    """
-    An implementation of EventContext which fetches _current_state_ids and
-    _prev_state_ids from the database on demand.
-
-    Attributes:
-
-        _storage
-
-        _fetching_state_deferred: Resolves when *_state_ids have been calculated.
-            None if we haven't started calculating yet
-
-        _event_type: The type of the event the context is associated with.
-
-        _event_state_key: The state_key of the event the context is associated with.
-
-        _prev_state_id: If the event associated with the context is a state event,
-            then `_prev_state_id` is the event_id of the state that was replaced.
-    """
-
-    # This needs to have a default as we're inheriting
-    _storage: "Storage" = attr.ib(default=None)
-    _prev_state_id: Optional[str] = attr.ib(default=None)
-    _event_type: str = attr.ib(default=None)
-    _event_state_key: Optional[str] = attr.ib(default=None)
-    _fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None)
-
-    async def _ensure_fetched(self) -> None:
-        if not self._fetching_state_deferred:
-            self._fetching_state_deferred = run_in_background(self._fill_out_state)
-
-        await make_deferred_yieldable(self._fetching_state_deferred)
-
-    async def _fill_out_state(self) -> None:
-        """Called to populate the _current_state_ids and _prev_state_ids
-        attributes by loading from the database.
-        """
-        if self.state_group is None:
-            # No state group means the event is an outlier. Usually the state_ids dicts are also
-            # pre-set to empty dicts, but they get reset when the context is serialized, so set
-            # them to empty dicts again here.
-            self._current_state_ids = {}
-            self._prev_state_ids = {}
-            return
-
-        current_state_ids = await self._storage.state.get_state_ids_for_group(
-            self.state_group
+        assert self.state_group_before_event is not None
+        return await self._storage.state.get_state_ids_for_group(
+            self.state_group_before_event
         )
-        # Set this separately so mypy knows current_state_ids is not None.
-        self._current_state_ids = current_state_ids
-        if self._event_state_key is not None:
-            self._prev_state_ids = dict(current_state_ids)
-
-            key = (self._event_type, self._event_state_key)
-            if self._prev_state_id:
-                self._prev_state_ids[key] = self._prev_state_id
-            else:
-                self._prev_state_ids.pop(key, None)
-        else:
-            self._prev_state_ids = current_state_ids
 
 
 def _encode_state_dict(
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 3b6795d40f..f30207376a 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -32,6 +32,7 @@ from synapse.rest.media.v1.media_storage import ReadableFileWrapper
 from synapse.spam_checker_api import RegistrationBehaviour
 from synapse.types import RoomAlias, UserProfile
 from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
+from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
     import synapse.events
@@ -162,7 +163,10 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
 
 
 class SpamChecker:
-    def __init__(self) -> None:
+    def __init__(self, hs: "synapse.server.HomeServer") -> None:
+        self.hs = hs
+        self.clock = hs.get_clock()
+
         self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
         self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
         self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
@@ -255,7 +259,10 @@ class SpamChecker:
             will be used as the error message returned to the user.
         """
         for callback in self._check_event_for_spam_callbacks:
-            res: Union[bool, str] = await delay_cancellation(callback(event))
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                res: Union[bool, str] = await delay_cancellation(callback(event))
             if res:
                 return res
 
@@ -276,9 +283,12 @@ class SpamChecker:
             Whether the user may join the room
         """
         for callback in self._user_may_join_room_callbacks:
-            may_join_room = await delay_cancellation(
-                callback(user_id, room_id, is_invited)
-            )
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                may_join_room = await delay_cancellation(
+                    callback(user_id, room_id, is_invited)
+                )
             if may_join_room is False:
                 return False
 
@@ -300,9 +310,12 @@ class SpamChecker:
             True if the user may send an invite, otherwise False
         """
         for callback in self._user_may_invite_callbacks:
-            may_invite = await delay_cancellation(
-                callback(inviter_userid, invitee_userid, room_id)
-            )
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                may_invite = await delay_cancellation(
+                    callback(inviter_userid, invitee_userid, room_id)
+                )
             if may_invite is False:
                 return False
 
@@ -328,9 +341,12 @@ class SpamChecker:
             True if the user may send the invite, otherwise False
         """
         for callback in self._user_may_send_3pid_invite_callbacks:
-            may_send_3pid_invite = await delay_cancellation(
-                callback(inviter_userid, medium, address, room_id)
-            )
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                may_send_3pid_invite = await delay_cancellation(
+                    callback(inviter_userid, medium, address, room_id)
+                )
             if may_send_3pid_invite is False:
                 return False
 
@@ -348,7 +364,10 @@ class SpamChecker:
             True if the user may create a room, otherwise False
         """
         for callback in self._user_may_create_room_callbacks:
-            may_create_room = await delay_cancellation(callback(userid))
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                may_create_room = await delay_cancellation(callback(userid))
             if may_create_room is False:
                 return False
 
@@ -369,9 +388,12 @@ class SpamChecker:
             True if the user may create a room alias, otherwise False
         """
         for callback in self._user_may_create_room_alias_callbacks:
-            may_create_room_alias = await delay_cancellation(
-                callback(userid, room_alias)
-            )
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                may_create_room_alias = await delay_cancellation(
+                    callback(userid, room_alias)
+                )
             if may_create_room_alias is False:
                 return False
 
@@ -390,7 +412,10 @@ class SpamChecker:
             True if the user may publish the room, otherwise False
         """
         for callback in self._user_may_publish_room_callbacks:
-            may_publish_room = await delay_cancellation(callback(userid, room_id))
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                may_publish_room = await delay_cancellation(callback(userid, room_id))
             if may_publish_room is False:
                 return False
 
@@ -412,9 +437,13 @@ class SpamChecker:
             True if the user is spammy.
         """
         for callback in self._check_username_for_spam_callbacks:
-            # Make a copy of the user profile object to ensure the spam checker cannot
-            # modify it.
-            if await delay_cancellation(callback(user_profile.copy())):
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                # Make a copy of the user profile object to ensure the spam checker cannot
+                # modify it.
+                res = await delay_cancellation(callback(user_profile.copy()))
+            if res:
                 return True
 
         return False
@@ -442,9 +471,12 @@ class SpamChecker:
         """
 
         for callback in self._check_registration_for_spam_callbacks:
-            behaviour = await delay_cancellation(
-                callback(email_threepid, username, request_info, auth_provider_id)
-            )
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                behaviour = await delay_cancellation(
+                    callback(email_threepid, username, request_info, auth_provider_id)
+                )
             assert isinstance(behaviour, RegistrationBehaviour)
             if behaviour != RegistrationBehaviour.ALLOW:
                 return behaviour
@@ -486,7 +518,10 @@ class SpamChecker:
         """
 
         for callback in self._check_media_file_for_spam_callbacks:
-            spam = await delay_cancellation(callback(file_wrapper, file_info))
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                spam = await delay_cancellation(callback(file_wrapper, file_info))
             if spam:
                 return True
 
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index d629a3ecb5..103861644a 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tupl
 
 from synapse.api.errors import Codes, FederationDeniedError, SynapseError
 from synapse.api.urls import FEDERATION_V1_PREFIX
-from synapse.http.server import HttpServer, ServletCallback
+from synapse.http.server import HttpServer, ServletCallback, is_method_cancellable
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import run_in_background
@@ -373,6 +373,17 @@ class BaseFederationServlet:
             if code is None:
                 continue
 
+            if is_method_cancellable(code):
+                # The wrapper added by `self._wrap` will inherit the cancellable flag,
+                # but the wrapper itself does not support cancellation yet.
+                # Once resolved, the cancellation tests in
+                # `tests/federation/transport/server/test__base.py` can be re-enabled.
+                raise Exception(
+                    f"{self.__class__.__name__}.on_{method} has been marked as "
+                    "cancellable, but federation servlets do not support cancellation "
+                    "yet."
+                )
+
             server.register_paths(
                 method,
                 (pattern,),
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 4c3a5a6e24..dfd24af695 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -934,7 +934,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         # Before deleting the group lets kick everyone out of it
         users = await self.store.get_users_in_group(group_id, include_private=True)
 
-        async def _kick_user_from_group(user_id):
+        async def _kick_user_from_group(user_id: str) -> None:
             if self.hs.is_mine_id(user_id):
                 groups_local = self.hs.get_groups_local_handler()
                 assert isinstance(
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 4af9fbc5d1..0478448b47 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -23,7 +23,7 @@ from synapse.replication.http.account_data import (
     ReplicationUserAccountDataRestServlet,
 )
 from synapse.streams import EventSource
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, StreamKeyType, UserID
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -105,7 +105,7 @@ class AccountDataHandler:
             )
 
             self._notifier.on_new_event(
-                "account_data_key", max_stream_id, users=[user_id]
+                StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
             )
 
             await self._notify_modules(user_id, room_id, account_data_type, content)
@@ -141,7 +141,7 @@ class AccountDataHandler:
             )
 
             self._notifier.on_new_event(
-                "account_data_key", max_stream_id, users=[user_id]
+                StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
             )
 
             await self._notify_modules(user_id, None, account_data_type, content)
@@ -176,7 +176,7 @@ class AccountDataHandler:
             )
 
             self._notifier.on_new_event(
-                "account_data_key", max_stream_id, users=[user_id]
+                StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
             )
             return max_stream_id
         else:
@@ -201,7 +201,7 @@ class AccountDataHandler:
             )
 
             self._notifier.on_new_event(
-                "account_data_key", max_stream_id, users=[user_id]
+                StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
             )
             return max_stream_id
         else:
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 85bd5e4768..1da7bcc85b 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -38,6 +38,7 @@ from synapse.types import (
     JsonDict,
     RoomAlias,
     RoomStreamToken,
+    StreamKeyType,
     UserID,
 )
 from synapse.util.async_helpers import Linearizer
@@ -213,8 +214,8 @@ class ApplicationServicesHandler:
         Args:
             stream_key: The stream the event came from.
 
-                `stream_key` can be "typing_key", "receipt_key", "presence_key",
-                "to_device_key" or "device_list_key". Any other value for `stream_key`
+                `stream_key` can be StreamKeyType.TYPING, StreamKeyType.RECEIPT, StreamKeyType.PRESENCE,
+                StreamKeyType.TO_DEVICE or StreamKeyType.DEVICE_LIST. Any other value for `stream_key`
                 will cause this function to return early.
 
                 Ephemeral events will only be pushed to appservices that have opted into
@@ -235,11 +236,11 @@ class ApplicationServicesHandler:
         # Only the following streams are currently supported.
         # FIXME: We should use constants for these values.
         if stream_key not in (
-            "typing_key",
-            "receipt_key",
-            "presence_key",
-            "to_device_key",
-            "device_list_key",
+            StreamKeyType.TYPING,
+            StreamKeyType.RECEIPT,
+            StreamKeyType.PRESENCE,
+            StreamKeyType.TO_DEVICE,
+            StreamKeyType.DEVICE_LIST,
         ):
             return
 
@@ -258,14 +259,14 @@ class ApplicationServicesHandler:
 
         # Ignore to-device messages if the feature flag is not enabled
         if (
-            stream_key == "to_device_key"
+            stream_key == StreamKeyType.TO_DEVICE
             and not self._msc2409_to_device_messages_enabled
         ):
             return
 
         # Ignore device lists if the feature flag is not enabled
         if (
-            stream_key == "device_list_key"
+            stream_key == StreamKeyType.DEVICE_LIST
             and not self._msc3202_transaction_extensions_enabled
         ):
             return
@@ -283,15 +284,15 @@ class ApplicationServicesHandler:
             if (
                 stream_key
                 in (
-                    "typing_key",
-                    "receipt_key",
-                    "presence_key",
-                    "to_device_key",
+                    StreamKeyType.TYPING,
+                    StreamKeyType.RECEIPT,
+                    StreamKeyType.PRESENCE,
+                    StreamKeyType.TO_DEVICE,
                 )
                 and service.supports_ephemeral
             )
             or (
-                stream_key == "device_list_key"
+                stream_key == StreamKeyType.DEVICE_LIST
                 and service.msc3202_transaction_extensions
             )
         ]
@@ -317,7 +318,7 @@ class ApplicationServicesHandler:
         logger.debug("Checking interested services for %s", stream_key)
         with Measure(self.clock, "notify_interested_services_ephemeral"):
             for service in services:
-                if stream_key == "typing_key":
+                if stream_key == StreamKeyType.TYPING:
                     # Note that we don't persist the token (via set_appservice_stream_type_pos)
                     # for typing_key due to performance reasons and due to their highly
                     # ephemeral nature.
@@ -333,7 +334,7 @@ class ApplicationServicesHandler:
                 async with self._ephemeral_events_linearizer.queue(
                     (service.id, stream_key)
                 ):
-                    if stream_key == "receipt_key":
+                    if stream_key == StreamKeyType.RECEIPT:
                         events = await self._handle_receipts(service, new_token)
                         self.scheduler.enqueue_for_appservice(service, ephemeral=events)
 
@@ -342,7 +343,7 @@ class ApplicationServicesHandler:
                             service, "read_receipt", new_token
                         )
 
-                    elif stream_key == "presence_key":
+                    elif stream_key == StreamKeyType.PRESENCE:
                         events = await self._handle_presence(service, users, new_token)
                         self.scheduler.enqueue_for_appservice(service, ephemeral=events)
 
@@ -351,7 +352,7 @@ class ApplicationServicesHandler:
                             service, "presence", new_token
                         )
 
-                    elif stream_key == "to_device_key":
+                    elif stream_key == StreamKeyType.TO_DEVICE:
                         # Retrieve a list of to-device message events, as well as the
                         # maximum stream token of the messages we were able to retrieve.
                         to_device_messages = await self._get_to_device_messages(
@@ -366,7 +367,7 @@ class ApplicationServicesHandler:
                             service, "to_device", new_token
                         )
 
-                    elif stream_key == "device_list_key":
+                    elif stream_key == StreamKeyType.DEVICE_LIST:
                         device_list_summary = await self._get_device_list_summary(
                             service, new_token
                         )
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index a91b1ee4d5..1d6d1f8a92 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -43,6 +43,7 @@ from synapse.metrics.background_process_metrics import (
 )
 from synapse.types import (
     JsonDict,
+    StreamKeyType,
     StreamToken,
     UserID,
     get_domain_from_id,
@@ -502,7 +503,7 @@ class DeviceHandler(DeviceWorkerHandler):
         # specify the user ID too since the user should always get their own device list
         # updates, even if they aren't in any rooms.
         self.notifier.on_new_event(
-            "device_list_key", position, users={user_id}, rooms=room_ids
+            StreamKeyType.DEVICE_LIST, position, users={user_id}, rooms=room_ids
         )
 
         # We may need to do some processing asynchronously for local user IDs.
@@ -523,7 +524,9 @@ class DeviceHandler(DeviceWorkerHandler):
             from_user_id, user_ids
         )
 
-        self.notifier.on_new_event("device_list_key", position, users=[from_user_id])
+        self.notifier.on_new_event(
+            StreamKeyType.DEVICE_LIST, position, users=[from_user_id]
+        )
 
     async def user_left_room(self, user: UserID, room_id: str) -> None:
         user_id = user.to_string()
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 4cb725d027..53668cce3b 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -26,7 +26,7 @@ from synapse.logging.opentracing import (
     set_tag,
 )
 from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
-from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.stringutils import random_string
 
@@ -151,7 +151,7 @@ class DeviceMessageHandler:
         # Notify listeners that there are new to-device messages to process,
         # handing them the latest stream id.
         self.notifier.on_new_event(
-            "to_device_key", last_stream_id, users=local_messages.keys()
+            StreamKeyType.TO_DEVICE, last_stream_id, users=local_messages.keys()
         )
 
     async def _check_for_unknown_devices(
@@ -285,7 +285,7 @@ class DeviceMessageHandler:
         # Notify listeners that there are new to-device messages to process,
         # handing them the latest stream id.
         self.notifier.on_new_event(
-            "to_device_key", last_stream_id, users=local_messages.keys()
+            StreamKeyType.TO_DEVICE, last_stream_id, users=local_messages.keys()
         )
 
         if self.federation_sender:
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d6714228ef..e6c2cfb8c8 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -1105,22 +1105,19 @@ class E2eKeysHandler:
             # can request over federation
             raise NotFoundError("No %s key found for %s" % (key_type, user_id))
 
-        (
-            key,
-            key_id,
-            verify_key,
-        ) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
-
-        if key is None:
+        cross_signing_keys = await self._retrieve_cross_signing_keys_for_remote_user(
+            user, key_type
+        )
+        if cross_signing_keys is None:
             raise NotFoundError("No %s key found for %s" % (key_type, user_id))
 
-        return key, key_id, verify_key
+        return cross_signing_keys
 
     async def _retrieve_cross_signing_keys_for_remote_user(
         self,
         user: UserID,
         desired_key_type: str,
-    ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
+    ) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]:
         """Queries cross-signing keys for a remote user and saves them to the database
 
         Only the key specified by `key_type` will be returned, while all retrieved keys
@@ -1146,12 +1143,10 @@ class E2eKeysHandler:
                 type(e),
                 e,
             )
-            return None, None, None
+            return None
 
         # Process each of the retrieved cross-signing keys
-        desired_key = None
-        desired_key_id = None
-        desired_verify_key = None
+        desired_key_data = None
         retrieved_device_ids = []
         for key_type in ["master", "self_signing"]:
             key_content = remote_result.get(key_type + "_key")
@@ -1196,9 +1191,7 @@ class E2eKeysHandler:
 
             # If this is the desired key type, save it and its ID/VerifyKey
             if key_type == desired_key_type:
-                desired_key = key_content
-                desired_verify_key = verify_key
-                desired_key_id = key_id
+                desired_key_data = key_content, key_id, verify_key
 
             # At the same time, store this key in the db for subsequent queries
             await self.store.set_e2e_cross_signing_key(
@@ -1212,7 +1205,7 @@ class E2eKeysHandler:
                 user.to_string(), retrieved_device_ids
             )
 
-        return desired_key, desired_key_id, desired_verify_key
+        return desired_key_data
 
 
 def _check_cross_signing_key(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 38dc5b1f6e..be5099b507 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -659,7 +659,7 @@ class FederationHandler:
         # in the invitee's sync stream. It is stripped out for all other local users.
         event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
 
-        context = EventContext.for_outlier()
+        context = EventContext.for_outlier(self.storage)
         stream_id = await self._federation_event_handler.persist_events_and_notify(
             event.room_id, [(event, context)]
         )
@@ -848,7 +848,7 @@ class FederationHandler:
             )
         )
 
-        context = EventContext.for_outlier()
+        context = EventContext.for_outlier(self.storage)
         await self._federation_event_handler.persist_events_and_notify(
             event.room_id, [(event, context)]
         )
@@ -877,7 +877,7 @@ class FederationHandler:
 
         await self.federation_client.send_leave(host_list, event)
 
-        context = EventContext.for_outlier()
+        context = EventContext.for_outlier(self.storage)
         stream_id = await self._federation_event_handler.persist_events_and_notify(
             event.room_id, [(event, context)]
         )
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 6cf927e4ff..761caa04b7 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -103,7 +103,7 @@ class FederationEventHandler:
         self._event_creation_handler = hs.get_event_creation_handler()
         self._event_auth_handler = hs.get_event_auth_handler()
         self._message_handler = hs.get_message_handler()
-        self._action_generator = hs.get_action_generator()
+        self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
         self._state_resolution_handler = hs.get_state_resolution_handler()
         # avoid a circular dependency by deferring execution here
         self._get_room_member_handler = hs.get_room_member_handler
@@ -1423,7 +1423,7 @@ class FederationEventHandler:
                 # we're not bothering about room state, so flag the event as an outlier.
                 event.internal_metadata.outlier = True
 
-                context = EventContext.for_outlier()
+                context = EventContext.for_outlier(self._storage)
                 try:
                     validate_event_for_room_version(room_version_obj, event)
                     check_auth_rules_for_event(room_version_obj, event, auth)
@@ -1874,10 +1874,10 @@ class FederationEventHandler:
         )
 
         return EventContext.with_state(
+            storage=self._storage,
             state_group=state_group,
             state_group_before_event=context.state_group_before_event,
-            current_state_ids=current_state_ids,
-            prev_state_ids=prev_state_ids,
+            state_delta_due_to_event=state_updates,
             prev_group=prev_group,
             delta_ids=state_updates,
             partial_state=context.partial_state,
@@ -1913,7 +1913,7 @@ class FederationEventHandler:
                     min_depth,
                 )
             else:
-                await self._action_generator.handle_push_actions_for_event(
+                await self._bulk_push_rule_evaluator.action_for_event_by_user(
                     event, context
                 )
 
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 7b94770f97..d79248ad90 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -30,6 +30,7 @@ from synapse.types import (
     Requester,
     RoomStreamToken,
     StateMap,
+    StreamKeyType,
     StreamToken,
     UserID,
 )
@@ -143,7 +144,7 @@ class InitialSyncHandler:
             to_key=int(now_token.receipt_key),
         )
         if self.hs.config.experimental.msc2285_enabled:
-            receipt = ReceiptEventSource.filter_out_private(receipt, user_id)
+            receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
 
         tags_by_room = await self.store.get_tags_for_user(user_id)
 
@@ -220,8 +221,10 @@ class InitialSyncHandler:
                     self.storage, user_id, messages
                 )
 
-                start_token = now_token.copy_and_replace("room_key", token)
-                end_token = now_token.copy_and_replace("room_key", room_end_token)
+                start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
+                end_token = now_token.copy_and_replace(
+                    StreamKeyType.ROOM, room_end_token
+                )
                 time_now = self.clock.time_msec()
 
                 d["messages"] = {
@@ -369,8 +372,8 @@ class InitialSyncHandler:
             self.storage, user_id, messages, is_peeking=is_peeking
         )
 
-        start_token = StreamToken.START.copy_and_replace("room_key", token)
-        end_token = StreamToken.START.copy_and_replace("room_key", stream_token)
+        start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token)
+        end_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, stream_token)
 
         time_now = self.clock.time_msec()
 
@@ -449,7 +452,9 @@ class InitialSyncHandler:
             if not receipts:
                 return []
             if self.hs.config.experimental.msc2285_enabled:
-                receipts = ReceiptEventSource.filter_out_private(receipts, user_id)
+                receipts = ReceiptEventSource.filter_out_private_receipts(
+                    receipts, user_id
+                )
             return receipts
 
         presence, receipts, (messages, token) = await make_deferred_yieldable(
@@ -472,7 +477,7 @@ class InitialSyncHandler:
             self.storage, user_id, messages, is_peeking=is_peeking
         )
 
-        start_token = now_token.copy_and_replace("room_key", token)
+        start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
         end_token = now_token
 
         time_now = self.clock.time_msec()
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index c28b792e6f..0951b9c71f 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -44,7 +44,7 @@ from synapse.api.errors import (
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
 from synapse.api.urls import ConsentURIBuilder
 from synapse.event_auth import validate_event_for_room_version
-from synapse.events import EventBase
+from synapse.events import EventBase, relation_from_event
 from synapse.events.builder import EventBuilder
 from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
@@ -426,7 +426,7 @@ class EventCreationHandler:
         # This is to stop us from diverging history *too* much.
         self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
 
-        self.action_generator = hs.get_action_generator()
+        self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
 
         self.spam_checker = hs.get_spam_checker()
         self.third_party_event_rules: "ThirdPartyEventRules" = (
@@ -757,6 +757,10 @@ class EventCreationHandler:
             The previous version of the event is returned, if it is found in the
             event context. Otherwise, None is returned.
         """
+        if event.internal_metadata.is_outlier():
+            # This can happen due to out of band memberships
+            return None
+
         prev_state_ids = await context.get_prev_state_ids()
         prev_event_id = prev_state_ids.get((event.type, event.state_key))
         if not prev_event_id:
@@ -1001,7 +1005,7 @@ class EventCreationHandler:
         # after it is created
         if builder.internal_metadata.outlier:
             event.internal_metadata.outlier = True
-            context = EventContext.for_outlier()
+            context = EventContext.for_outlier(self.storage)
         elif (
             event.type == EventTypes.MSC2716_INSERTION
             and state_event_ids
@@ -1056,20 +1060,11 @@ class EventCreationHandler:
             SynapseError if the event is invalid.
         """
 
-        relation = event.content.get("m.relates_to")
+        relation = relation_from_event(event)
         if not relation:
             return
 
-        relation_type = relation.get("rel_type")
-        if not relation_type:
-            return
-
-        # Ensure the parent is real.
-        relates_to = relation.get("event_id")
-        if not relates_to:
-            return
-
-        parent_event = await self.store.get_event(relates_to, allow_none=True)
+        parent_event = await self.store.get_event(relation.parent_id, allow_none=True)
         if parent_event:
             # And in the same room.
             if parent_event.room_id != event.room_id:
@@ -1078,28 +1073,31 @@ class EventCreationHandler:
         else:
             # There must be some reason that the client knows the event exists,
             # see if there are existing relations. If so, assume everything is fine.
-            if not await self.store.event_is_target_of_relation(relates_to):
+            if not await self.store.event_is_target_of_relation(relation.parent_id):
                 # Otherwise, the client can't know about the parent event!
                 raise SynapseError(400, "Can't send relation to unknown event")
 
         # If this event is an annotation then we check that that the sender
         # can't annotate the same way twice (e.g. stops users from liking an
         # event multiple times).
-        if relation_type == RelationTypes.ANNOTATION:
-            aggregation_key = relation["key"]
+        if relation.rel_type == RelationTypes.ANNOTATION:
+            aggregation_key = relation.aggregation_key
+
+            if aggregation_key is None:
+                raise SynapseError(400, "Missing aggregation key")
 
             if len(aggregation_key) > 500:
                 raise SynapseError(400, "Aggregation key is too long")
 
             already_exists = await self.store.has_user_annotated_event(
-                relates_to, event.type, aggregation_key, event.sender
+                relation.parent_id, event.type, aggregation_key, event.sender
             )
             if already_exists:
                 raise SynapseError(400, "Can't send same reaction twice")
 
         # Don't attempt to start a thread if the parent event is a relation.
-        elif relation_type == RelationTypes.THREAD:
-            if await self.store.event_includes_relation(relates_to):
+        elif relation.rel_type == RelationTypes.THREAD:
+            if await self.store.event_includes_relation(relation.parent_id):
                 raise SynapseError(
                     400, "Cannot start threads from an event with a relation"
                 )
@@ -1245,7 +1243,9 @@ class EventCreationHandler:
         # and `state_groups` because they have `prev_events` that aren't persisted yet
         # (historical messages persisted in reverse-chronological order).
         if not event.internal_metadata.is_historical():
-            await self.action_generator.handle_push_actions_for_event(event, context)
+            await self._bulk_push_rule_evaluator.action_for_event_by_user(
+                event, context
+            )
 
         try:
             # If we're a worker we need to hit out to the master.
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 7ee3340373..6ae88add95 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -27,7 +27,7 @@ from synapse.handlers.room import ShutdownRoomResponse
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, Requester
+from synapse.types import JsonDict, Requester, StreamKeyType
 from synapse.util.async_helpers import ReadWriteLock
 from synapse.util.stringutils import random_string
 from synapse.visibility import filter_events_for_client
@@ -448,7 +448,7 @@ class PaginationHandler:
             )
             # We expect `/messages` to use historic pagination tokens by default but
             # `/messages` should still works with live tokens when manually provided.
-            assert from_token.room_key.topological
+            assert from_token.room_key.topological is not None
 
         if pagin_config.limit is None:
             # This shouldn't happen as we've set a default limit before this
@@ -491,7 +491,7 @@ class PaginationHandler:
 
                     if leave_token.topological < curr_topo:
                         from_token = from_token.copy_and_replace(
-                            "room_key", leave_token
+                            StreamKeyType.ROOM, leave_token
                         )
 
                 await self.hs.get_federation_handler().maybe_backfill(
@@ -513,7 +513,7 @@ class PaginationHandler:
                 event_filter=event_filter,
             )
 
-            next_token = from_token.copy_and_replace("room_key", next_key)
+            next_token = from_token.copy_and_replace(StreamKeyType.ROOM, next_key)
 
         if events:
             if event_filter:
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 268481ec19..dd84e6c88b 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -66,7 +66,7 @@ from synapse.replication.tcp.commands import ClearUserSyncsCommand
 from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
 from synapse.storage.databases.main import DataStore
 from synapse.streams import EventSource
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.descriptors import _CacheContext, cached
 from synapse.util.metrics import Measure
@@ -522,7 +522,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
         room_ids_to_states, users_to_states = parties
 
         self.notifier.on_new_event(
-            "presence_key",
+            StreamKeyType.PRESENCE,
             stream_id,
             rooms=room_ids_to_states.keys(),
             users=users_to_states.keys(),
@@ -1145,7 +1145,7 @@ class PresenceHandler(BasePresenceHandler):
         room_ids_to_states, users_to_states = parties
 
         self.notifier.on_new_event(
-            "presence_key",
+            StreamKeyType.PRESENCE,
             stream_id,
             rooms=room_ids_to_states.keys(),
             users=[UserID.from_string(u) for u in users_to_states],
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 43d615357b..e6a35f1d09 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -17,7 +17,13 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 from synapse.api.constants import ReceiptTypes
 from synapse.appservice import ApplicationService
 from synapse.streams import EventSource
-from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
+from synapse.types import (
+    JsonDict,
+    ReadReceipt,
+    StreamKeyType,
+    UserID,
+    get_domain_from_id,
+)
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -129,7 +135,9 @@ class ReceiptsHandler:
 
         affected_room_ids = list({r.room_id for r in receipts})
 
-        self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
+        self.notifier.on_new_event(
+            StreamKeyType.RECEIPT, max_batch_id, rooms=affected_room_ids
+        )
         # Note that the min here shouldn't be relied upon to be accurate.
         await self.hs.get_pusherpool().on_new_receipts(
             min_batch_id, max_batch_id, affected_room_ids
@@ -165,43 +173,69 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
         self.config = hs.config
 
     @staticmethod
-    def filter_out_private(events: List[JsonDict], user_id: str) -> List[JsonDict]:
-        """
-        This method takes in what is returned by
-        get_linearized_receipts_for_rooms() and goes through read receipts
-        filtering out m.read.private receipts if they were not sent by the
-        current user.
+    def filter_out_private_receipts(
+        rooms: List[JsonDict], user_id: str
+    ) -> List[JsonDict]:
         """
+        Filters a list of serialized receipts (as returned by /sync and /initialSync)
+        and removes private read receipts of other users.
 
-        visible_events = []
-
-        # filter out private receipts the user shouldn't see
-        for event in events:
-            content = event.get("content", {})
-            new_event = event.copy()
-            new_event["content"] = {}
-
-            for event_id, event_content in content.items():
-                receipt_event = {}
-                for receipt_type, receipt_content in event_content.items():
-                    if receipt_type == ReceiptTypes.READ_PRIVATE:
-                        user_rr = receipt_content.get(user_id, None)
-                        if user_rr:
-                            receipt_event[ReceiptTypes.READ_PRIVATE] = {
-                                user_id: user_rr.copy()
-                            }
-                    else:
-                        receipt_event[receipt_type] = receipt_content.copy()
+        This operates on the return value of get_linearized_receipts_for_rooms(),
+        which is wrapped in a cache. Care must be taken to ensure that the input
+        values are not modified.
 
-                # Only include the receipt event if it is non-empty.
-                if receipt_event:
-                    new_event["content"][event_id] = receipt_event
+        Args:
+            rooms: A list of mappings, each mapping has a `content` field, which
+                is a map of event ID -> receipt type -> user ID -> receipt information.
 
-            # Append new_event to visible_events unless empty
-            if len(new_event["content"].keys()) > 0:
-                visible_events.append(new_event)
+        Returns:
+            The same as rooms, but filtered.
+        """
 
-        return visible_events
+        result = []
+
+        # Iterate through each room's receipt content.
+        for room in rooms:
+            # The receipt content with other user's private read receipts removed.
+            content = {}
+
+            # Iterate over each event ID / receipts for that event.
+            for event_id, orig_event_content in room.get("content", {}).items():
+                event_content = orig_event_content
+                # If there are private read receipts, additional logic is necessary.
+                if ReceiptTypes.READ_PRIVATE in event_content:
+                    # Make a copy without private read receipts to avoid leaking
+                    # other user's private read receipts..
+                    event_content = {
+                        receipt_type: receipt_value
+                        for receipt_type, receipt_value in event_content.items()
+                        if receipt_type != ReceiptTypes.READ_PRIVATE
+                    }
+
+                    # Copy the current user's private read receipt from the
+                    # original content, if it exists.
+                    user_private_read_receipt = orig_event_content[
+                        ReceiptTypes.READ_PRIVATE
+                    ].get(user_id, None)
+                    if user_private_read_receipt:
+                        event_content[ReceiptTypes.READ_PRIVATE] = {
+                            user_id: user_private_read_receipt
+                        }
+
+                # Include the event if there is at least one non-private read
+                # receipt or the current user has a private read receipt.
+                if event_content:
+                    content[event_id] = event_content
+
+            # Include the event if there is at least one non-private read receipt
+            # or the current user has a private read receipt.
+            if content:
+                # Build a new event to avoid mutating the cache.
+                new_room = {k: v for k, v in room.items() if k != "content"}
+                new_room["content"] = content
+                result.append(new_room)
+
+        return result
 
     async def get_new_events(
         self,
@@ -223,7 +257,9 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
         )
 
         if self.config.experimental.msc2285_enabled:
-            events = ReceiptEventSource.filter_out_private(events, user.to_string())
+            events = ReceiptEventSource.filter_out_private_receipts(
+                events, user.to_string()
+            )
 
         return events, to_key
 
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index c2754ec918..ab7e54857d 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -11,7 +11,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import collections.abc
 import logging
 from typing import (
     TYPE_CHECKING,
@@ -28,7 +27,7 @@ import attr
 
 from synapse.api.constants import RelationTypes
 from synapse.api.errors import SynapseError
-from synapse.events import EventBase
+from synapse.events import EventBase, relation_from_event
 from synapse.storage.databases.main.relations import _RelatedEvent
 from synapse.types import JsonDict, Requester, StreamToken, UserID
 from synapse.visibility import filter_events_for_client
@@ -373,20 +372,21 @@ class RelationsHandler:
             if event.is_state():
                 continue
 
-            relates_to = event.content.get("m.relates_to")
-            relation_type = None
-            if isinstance(relates_to, collections.abc.Mapping):
-                relation_type = relates_to.get("rel_type")
+            relates_to = relation_from_event(event)
+            if relates_to:
                 # An event which is a replacement (ie edit) or annotation (ie,
                 # reaction) may not have any other event related to it.
-                if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+                if relates_to.rel_type in (
+                    RelationTypes.ANNOTATION,
+                    RelationTypes.REPLACE,
+                ):
                     continue
 
+                # Track the event's relation information for later.
+                relations_by_id[event.event_id] = relates_to.rel_type
+
             # The event should get bundled aggregations.
             events_by_id[event.event_id] = event
-            # Track the event's relation information for later.
-            if isinstance(relation_type, str):
-                relations_by_id[event.event_id] = relation_type
 
         # event ID -> bundled aggregation in non-serialized form.
         results: Dict[str, BundledAggregations] = {}
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 604eb6ec15..a2973109ad 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -33,6 +33,7 @@ from typing import (
 import attr
 from typing_extensions import TypedDict
 
+import synapse.events.snapshot
 from synapse.api.constants import (
     EventContentFields,
     EventTypes,
@@ -72,12 +73,12 @@ from synapse.types import (
     RoomID,
     RoomStreamToken,
     StateMap,
+    StreamKeyType,
     StreamToken,
     UserID,
     create_requester,
 )
 from synapse.util import stringutils
-from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.stringutils import parse_and_validate_server_name
 from synapse.visibility import filter_events_for_client
@@ -149,10 +150,11 @@ class RoomCreationHandler:
             )
             preset_config["encrypted"] = encrypted
 
-        self._replication = hs.get_replication_data_handler()
+        self._default_power_level_content_override = (
+            self.config.room.default_power_level_content_override
+        )
 
-        # linearizer to stop two upgrades happening at once
-        self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")
+        self._replication = hs.get_replication_data_handler()
 
         # If a user tries to update the same room multiple times in quick
         # succession, only process the first attempt and return its result to
@@ -196,6 +198,39 @@ class RoomCreationHandler:
                     400, "An upgrade for this room is currently in progress"
                 )
 
+        # Check whether the room exists and 404 if it doesn't.
+        # We could go straight for the auth check, but that will raise a 403 instead.
+        old_room = await self.store.get_room(old_room_id)
+        if old_room is None:
+            raise NotFoundError("Unknown room id %s" % (old_room_id,))
+
+        new_room_id = self._generate_room_id()
+
+        # Check whether the user has the power level to carry out the upgrade.
+        # `check_auth_rules_from_context` will check that they are in the room and have
+        # the required power level to send the tombstone event.
+        (
+            tombstone_event,
+            tombstone_context,
+        ) = await self.event_creation_handler.create_event(
+            requester,
+            {
+                "type": EventTypes.Tombstone,
+                "state_key": "",
+                "room_id": old_room_id,
+                "sender": user_id,
+                "content": {
+                    "body": "This room has been replaced",
+                    "replacement_room": new_room_id,
+                },
+            },
+        )
+        old_room_version = await self.store.get_room_version(old_room_id)
+        validate_event_for_room_version(old_room_version, tombstone_event)
+        await self._event_auth_handler.check_auth_rules_from_context(
+            old_room_version, tombstone_event, tombstone_context
+        )
+
         # Upgrade the room
         #
         # If this user has sent multiple upgrade requests for the same room
@@ -206,19 +241,35 @@ class RoomCreationHandler:
             self._upgrade_room,
             requester,
             old_room_id,
-            new_version,  # args for _upgrade_room
+            old_room,  # args for _upgrade_room
+            new_room_id,
+            new_version,
+            tombstone_event,
+            tombstone_context,
         )
 
         return ret
 
     async def _upgrade_room(
-        self, requester: Requester, old_room_id: str, new_version: RoomVersion
+        self,
+        requester: Requester,
+        old_room_id: str,
+        old_room: Dict[str, Any],
+        new_room_id: str,
+        new_version: RoomVersion,
+        tombstone_event: EventBase,
+        tombstone_context: synapse.events.snapshot.EventContext,
     ) -> str:
         """
         Args:
             requester: the user requesting the upgrade
             old_room_id: the id of the room to be replaced
-            new_versions: the version to upgrade the room to
+            old_room: a dict containing room information for the room to be replaced,
+                as returned by `RoomWorkerStore.get_room`.
+            new_room_id: the id of the replacement room
+            new_version: the version to upgrade the room to
+            tombstone_event: the tombstone event to send to the old room
+            tombstone_context: the context for the tombstone event
 
         Raises:
             ShadowBanError if the requester is shadow-banned.
@@ -226,40 +277,15 @@ class RoomCreationHandler:
         user_id = requester.user.to_string()
         assert self.hs.is_mine_id(user_id), "User must be our own: %s" % (user_id,)
 
-        # start by allocating a new room id
-        r = await self.store.get_room(old_room_id)
-        if r is None:
-            raise NotFoundError("Unknown room id %s" % (old_room_id,))
-        new_room_id = await self._generate_room_id(
-            creator_id=user_id,
-            is_public=r["is_public"],
-            room_version=new_version,
-        )
-
         logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
 
-        # we create and auth the tombstone event before properly creating the new
-        # room, to check our user has perms in the old room.
-        (
-            tombstone_event,
-            tombstone_context,
-        ) = await self.event_creation_handler.create_event(
-            requester,
-            {
-                "type": EventTypes.Tombstone,
-                "state_key": "",
-                "room_id": old_room_id,
-                "sender": user_id,
-                "content": {
-                    "body": "This room has been replaced",
-                    "replacement_room": new_room_id,
-                },
-            },
-        )
-        old_room_version = await self.store.get_room_version(old_room_id)
-        validate_event_for_room_version(old_room_version, tombstone_event)
-        await self._event_auth_handler.check_auth_rules_from_context(
-            old_room_version, tombstone_event, tombstone_context
+        # create the new room. may raise a `StoreError` in the exceedingly unlikely
+        # event of a room ID collision.
+        await self.store.store_room(
+            room_id=new_room_id,
+            room_creator_user_id=user_id,
+            is_public=old_room["is_public"],
+            room_version=new_version,
         )
 
         await self.clone_existing_room(
@@ -778,7 +804,7 @@ class RoomCreationHandler:
         visibility = config.get("visibility", "private")
         is_public = visibility == "public"
 
-        room_id = await self._generate_room_id(
+        room_id = await self._generate_and_create_room_id(
             creator_id=user_id,
             is_public=is_public,
             room_version=room_version,
@@ -1042,9 +1068,19 @@ class RoomCreationHandler:
                 for invitee in invite_list:
                     power_level_content["users"][invitee] = 100
 
-            # Power levels overrides are defined per chat preset
+            # If the user supplied a preset name e.g. "private_chat",
+            # we apply that preset
             power_level_content.update(config["power_level_content_override"])
 
+            # If the server config contains default_power_level_content_override,
+            # and that contains information for this room preset, apply it.
+            if self._default_power_level_content_override:
+                override = self._default_power_level_content_override.get(preset_config)
+                if override is not None:
+                    power_level_content.update(override)
+
+            # Finally, if the user supplied specific permissions for this room,
+            # apply those.
             if power_level_content_override:
                 power_level_content.update(power_level_content_override)
 
@@ -1090,7 +1126,26 @@ class RoomCreationHandler:
 
         return last_sent_stream_id
 
-    async def _generate_room_id(
+    def _generate_room_id(self) -> str:
+        """Generates a random room ID.
+
+        Room IDs look like "!opaque_id:domain" and are case-sensitive as per the spec
+        at https://spec.matrix.org/v1.2/appendices/#room-ids-and-event-ids.
+
+        Does not check for collisions with existing rooms or prevent future calls from
+        returning the same room ID. To ensure the uniqueness of a new room ID, use
+        `_generate_and_create_room_id` instead.
+
+        Synapse's room IDs are 18 [a-zA-Z] characters long, which comes out to around
+        102 bits.
+
+        Returns:
+            A random room ID of the form "!opaque_id:domain".
+        """
+        random_string = stringutils.random_string(18)
+        return RoomID(random_string, self.hs.hostname).to_string()
+
+    async def _generate_and_create_room_id(
         self,
         creator_id: str,
         is_public: bool,
@@ -1101,8 +1156,7 @@ class RoomCreationHandler:
         attempts = 0
         while attempts < 5:
             try:
-                random_string = stringutils.random_string(18)
-                gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
+                gen_room_id = self._generate_room_id()
                 await self.store.store_room(
                     room_id=gen_room_id,
                     room_creator_user_id=creator_id,
@@ -1239,10 +1293,10 @@ class RoomContextHandler:
             events_after=events_after,
             state=await filter_evts(state_events),
             aggregations=aggregations,
-            start=await token.copy_and_replace("room_key", results.start).to_string(
-                self.store
-            ),
-            end=await token.copy_and_replace("room_key", results.end).to_string(
+            start=await token.copy_and_replace(
+                StreamKeyType.ROOM, results.start
+            ).to_string(self.store),
+            end=await token.copy_and_replace(StreamKeyType.ROOM, results.end).to_string(
                 self.store
             ),
         )
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 5619f8f50e..cd1c47dae8 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -24,7 +24,7 @@ from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.filtering import Filter
 from synapse.events import EventBase
 from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, StreamKeyType, UserID
 from synapse.visibility import filter_events_for_client
 
 if TYPE_CHECKING:
@@ -655,11 +655,11 @@ class SearchHandler:
                 "events_before": events_before,
                 "events_after": events_after,
                 "start": await now_token.copy_and_replace(
-                    "room_key", res.start
+                    StreamKeyType.ROOM, res.start
+                ).to_string(self.store),
+                "end": await now_token.copy_and_replace(
+                    StreamKeyType.ROOM, res.end
                 ).to_string(self.store),
-                "end": await now_token.copy_and_replace("room_key", res.end).to_string(
-                    self.store
-                ),
             }
 
             if include_profile:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 2c555a66d0..4be08fe7cb 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -37,6 +37,7 @@ from synapse.types import (
     Requester,
     RoomStreamToken,
     StateMap,
+    StreamKeyType,
     StreamToken,
     UserID,
 )
@@ -449,7 +450,7 @@ class SyncHandler:
                 room_ids=room_ids,
                 is_guest=sync_config.is_guest,
             )
-            now_token = now_token.copy_and_replace("typing_key", typing_key)
+            now_token = now_token.copy_and_replace(StreamKeyType.TYPING, typing_key)
 
             ephemeral_by_room: JsonDict = {}
 
@@ -471,7 +472,7 @@ class SyncHandler:
                 room_ids=room_ids,
                 is_guest=sync_config.is_guest,
             )
-            now_token = now_token.copy_and_replace("receipt_key", receipt_key)
+            now_token = now_token.copy_and_replace(StreamKeyType.RECEIPT, receipt_key)
 
             for event in receipts:
                 room_id = event["room_id"]
@@ -537,7 +538,9 @@ class SyncHandler:
                 prev_batch_token = now_token
                 if recents:
                     room_key = recents[0].internal_metadata.before
-                    prev_batch_token = now_token.copy_and_replace("room_key", room_key)
+                    prev_batch_token = now_token.copy_and_replace(
+                        StreamKeyType.ROOM, room_key
+                    )
 
                 return TimelineBatch(
                     events=recents, prev_batch=prev_batch_token, limited=False
@@ -611,7 +614,7 @@ class SyncHandler:
                 recents = recents[-timeline_limit:]
                 room_key = recents[0].internal_metadata.before
 
-            prev_batch_token = now_token.copy_and_replace("room_key", room_key)
+            prev_batch_token = now_token.copy_and_replace(StreamKeyType.ROOM, room_key)
 
         # Don't bother to bundle aggregations if the timeline is unlimited,
         # as clients will have all the necessary information.
@@ -1398,7 +1401,7 @@ class SyncHandler:
                 now_token.to_device_key,
             )
             sync_result_builder.now_token = now_token.copy_and_replace(
-                "to_device_key", stream_id
+                StreamKeyType.TO_DEVICE, stream_id
             )
             sync_result_builder.to_device = messages
         else:
@@ -1503,7 +1506,7 @@ class SyncHandler:
         )
         assert presence_key
         sync_result_builder.now_token = now_token.copy_and_replace(
-            "presence_key", presence_key
+            StreamKeyType.PRESENCE, presence_key
         )
 
         extra_users_ids = set(newly_joined_or_invited_users)
@@ -1826,7 +1829,7 @@ class SyncHandler:
                 # stream token as it'll only be used in the context of this
                 # room. (c.f. the docstring of `to_room_stream_token`).
                 leave_token = since_token.copy_and_replace(
-                    "room_key", leave_position.to_room_stream_token()
+                    StreamKeyType.ROOM, leave_position.to_room_stream_token()
                 )
 
                 # If this is an out of band message, like a remote invite
@@ -1875,7 +1878,9 @@ class SyncHandler:
             if room_entry:
                 events, start_key = room_entry
 
-                prev_batch_token = now_token.copy_and_replace("room_key", start_key)
+                prev_batch_token = now_token.copy_and_replace(
+                    StreamKeyType.ROOM, start_key
+                )
 
                 entry = RoomSyncResultBuilder(
                     room_id=room_id,
@@ -1972,7 +1977,7 @@ class SyncHandler:
                             continue
 
                 leave_token = now_token.copy_and_replace(
-                    "room_key", RoomStreamToken(None, event.stream_ordering)
+                    StreamKeyType.ROOM, RoomStreamToken(None, event.stream_ordering)
                 )
                 room_entries.append(
                     RoomSyncResultBuilder(
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 6854428b7c..bb00750bfd 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -25,7 +25,7 @@ from synapse.metrics.background_process_metrics import (
 )
 from synapse.replication.tcp.streams import TypingStream
 from synapse.streams import EventSource
-from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
@@ -382,7 +382,7 @@ class TypingWriterHandler(FollowerTypingHandler):
         )
 
         self.notifier.on_new_event(
-            "typing_key", self._latest_room_serial, rooms=[member.room_id]
+            StreamKeyType.TYPING, self._latest_room_serial, rooms=[member.room_id]
         )
 
     async def get_all_typing_updates(
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 8310fb466a..084d0a5b84 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -43,8 +43,10 @@ from twisted.internet import defer, error as twisted_error, protocol, ssl
 from twisted.internet.address import IPv4Address, IPv6Address
 from twisted.internet.interfaces import (
     IAddress,
+    IDelayedCall,
     IHostResolution,
     IReactorPluggableNameResolver,
+    IReactorTime,
     IResolutionReceiver,
     ITCPTransport,
 )
@@ -121,13 +123,15 @@ def check_against_blacklist(
 _EPSILON = 0.00000001
 
 
-def _make_scheduler(reactor):
+def _make_scheduler(
+    reactor: IReactorTime,
+) -> Callable[[Callable[[], object]], IDelayedCall]:
     """Makes a schedular suitable for a Cooperator using the given reactor.
 
     (This is effectively just a copy from `twisted.internet.task`)
     """
 
-    def _scheduler(x):
+    def _scheduler(x: Callable[[], object]) -> IDelayedCall:
         return reactor.callLater(_EPSILON, x)
 
     return _scheduler
@@ -348,7 +352,7 @@ class SimpleHttpClient:
         # XXX: The justification for using the cache factor here is that larger instances
         # will need both more cache and more connections.
         # Still, this should probably be a separate dial
-        pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5))
+        pool.maxPersistentPerHost = max(int(100 * hs.config.caches.global_factor), 5)
         pool.cachedConnectionTimeout = 2 * 60
 
         self.agent: IAgent = ProxyAgent(
@@ -775,7 +779,7 @@ class SimpleHttpClient:
         )
 
 
-def _timeout_to_request_timed_out_error(f: Failure):
+def _timeout_to_request_timed_out_error(f: Failure) -> Failure:
     if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError):
         # The TCP connection has its own timeout (set by the 'connectTimeout' param
         # on the Agent), which raises twisted_error.TimeoutError exception.
@@ -809,7 +813,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
     def __init__(self, deferred: defer.Deferred):
         self.deferred = deferred
 
-    def _maybe_fail(self):
+    def _maybe_fail(self) -> None:
         """
         Report a max size exceed error and disconnect the first time this is called.
         """
@@ -933,12 +937,12 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
     Do not use this since it allows an attacker to intercept your communications.
     """
 
-    def __init__(self):
+    def __init__(self) -> None:
         self._context = SSL.Context(SSL.SSLv23_METHOD)
         self._context.set_verify(VERIFY_NONE, lambda *_: False)
 
     def getContext(self, hostname=None, port=None):
         return self._context
 
-    def creatorForNetloc(self, hostname, port):
+    def creatorForNetloc(self, hostname: bytes, port: int):
         return self
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index 203e995bb7..23a60af171 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -14,15 +14,22 @@
 
 import base64
 import logging
-from typing import Optional
+from typing import Optional, Union
 
 import attr
 from zope.interface import implementer
 
 from twisted.internet import defer, protocol
 from twisted.internet.error import ConnectError
-from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
+from twisted.internet.interfaces import (
+    IAddress,
+    IConnector,
+    IProtocol,
+    IReactorCore,
+    IStreamClientEndpoint,
+)
 from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
+from twisted.python.failure import Failure
 from twisted.web import http
 
 logger = logging.getLogger(__name__)
@@ -81,14 +88,14 @@ class HTTPConnectProxyEndpoint:
         self._port = port
         self._proxy_creds = proxy_creds
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
 
     # Mypy encounters a false positive here: it complains that ClientFactory
     # is incompatible with IProtocolFactory. But ClientFactory inherits from
     # Factory, which implements IProtocolFactory. So I think this is a bug
     # in mypy-zope.
-    def connect(self, protocolFactory: ClientFactory):  # type: ignore[override]
+    def connect(self, protocolFactory: ClientFactory) -> "defer.Deferred[IProtocol]":  # type: ignore[override]
         f = HTTPProxiedClientFactory(
             self._host, self._port, protocolFactory, self._proxy_creds
         )
@@ -125,10 +132,10 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
         self.proxy_creds = proxy_creds
         self.on_connection: "defer.Deferred[None]" = defer.Deferred()
 
-    def startedConnecting(self, connector):
+    def startedConnecting(self, connector: IConnector) -> None:
         return self.wrapped_factory.startedConnecting(connector)
 
-    def buildProtocol(self, addr):
+    def buildProtocol(self, addr: IAddress) -> "HTTPConnectProtocol":
         wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
         if wrapped_protocol is None:
             raise TypeError("buildProtocol produced None instead of a Protocol")
@@ -141,13 +148,13 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
             self.proxy_creds,
         )
 
-    def clientConnectionFailed(self, connector, reason):
+    def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
         logger.debug("Connection to proxy failed: %s", reason)
         if not self.on_connection.called:
             self.on_connection.errback(reason)
         return self.wrapped_factory.clientConnectionFailed(connector, reason)
 
-    def clientConnectionLost(self, connector, reason):
+    def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
         logger.debug("Connection to proxy lost: %s", reason)
         if not self.on_connection.called:
             self.on_connection.errback(reason)
@@ -191,10 +198,10 @@ class HTTPConnectProtocol(protocol.Protocol):
         )
         self.http_setup_client.on_connected.addCallback(self.proxyConnected)
 
-    def connectionMade(self):
+    def connectionMade(self) -> None:
         self.http_setup_client.makeConnection(self.transport)
 
-    def connectionLost(self, reason=connectionDone):
+    def connectionLost(self, reason: Failure = connectionDone) -> None:
         if self.wrapped_protocol.connected:
             self.wrapped_protocol.connectionLost(reason)
 
@@ -203,7 +210,7 @@ class HTTPConnectProtocol(protocol.Protocol):
         if not self.connected_deferred.called:
             self.connected_deferred.errback(reason)
 
-    def proxyConnected(self, _):
+    def proxyConnected(self, _: Union[None, "defer.Deferred[None]"]) -> None:
         self.wrapped_protocol.makeConnection(self.transport)
 
         self.connected_deferred.callback(self.wrapped_protocol)
@@ -213,7 +220,7 @@ class HTTPConnectProtocol(protocol.Protocol):
         if buf:
             self.wrapped_protocol.dataReceived(buf)
 
-    def dataReceived(self, data: bytes):
+    def dataReceived(self, data: bytes) -> None:
         # if we've set up the HTTP protocol, we can send the data there
         if self.wrapped_protocol.connected:
             return self.wrapped_protocol.dataReceived(data)
@@ -243,7 +250,7 @@ class HTTPConnectSetupClient(http.HTTPClient):
         self.proxy_creds = proxy_creds
         self.on_connected: "defer.Deferred[None]" = defer.Deferred()
 
-    def connectionMade(self):
+    def connectionMade(self) -> None:
         logger.debug("Connected to proxy, sending CONNECT")
         self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
 
@@ -257,14 +264,14 @@ class HTTPConnectSetupClient(http.HTTPClient):
 
         self.endHeaders()
 
-    def handleStatus(self, version: bytes, status: bytes, message: bytes):
+    def handleStatus(self, version: bytes, status: bytes, message: bytes) -> None:
         logger.debug("Got Status: %s %s %s", status, message, version)
         if status != b"200":
             raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}")
 
-    def handleEndHeaders(self):
+    def handleEndHeaders(self) -> None:
         logger.debug("End Headers")
         self.on_connected.callback(None)
 
-    def handleResponse(self, body):
+    def handleResponse(self, body: bytes) -> None:
         pass
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index a8a520f809..2f0177f1e2 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -239,7 +239,7 @@ class MatrixHostnameEndpointFactory:
 
         self._srv_resolver = srv_resolver
 
-    def endpointForURI(self, parsed_uri: URI):
+    def endpointForURI(self, parsed_uri: URI) -> "MatrixHostnameEndpoint":
         return MatrixHostnameEndpoint(
             self._reactor,
             self._proxy_reactor,
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index f68646fd0d..de0e882b33 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -16,7 +16,7 @@
 import logging
 import random
 import time
-from typing import Callable, Dict, List
+from typing import Any, Callable, Dict, List
 
 import attr
 
@@ -109,7 +109,7 @@ class SrvResolver:
 
     def __init__(
         self,
-        dns_client=client,
+        dns_client: Any = client,
         cache: Dict[bytes, List[Server]] = SERVER_CACHE,
         get_time: Callable[[], float] = time.time,
     ):
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 43f2140429..71b685fade 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -74,9 +74,9 @@ _well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
 _had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class WellKnownLookupResult:
-    delegated_server = attr.ib()
+    delegated_server: Optional[bytes]
 
 
 class WellKnownResolver:
@@ -336,4 +336,4 @@ def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
 class _FetchWellKnownFailure(Exception):
     # True if we didn't get a non-5xx HTTP response, i.e. this may or may not be
     # a temporary failure.
-    temporary = attr.ib()
+    temporary: bool = attr.ib()
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c2ec3caa0e..725b5c33b8 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -23,6 +23,8 @@ from http import HTTPStatus
 from io import BytesIO, StringIO
 from typing import (
     TYPE_CHECKING,
+    Any,
+    BinaryIO,
     Callable,
     Dict,
     Generic,
@@ -44,7 +46,7 @@ from typing_extensions import Literal
 from twisted.internet import defer
 from twisted.internet.error import DNSLookupError
 from twisted.internet.interfaces import IReactorTime
-from twisted.internet.task import _EPSILON, Cooperator
+from twisted.internet.task import Cooperator
 from twisted.web.client import ResponseFailed
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IBodyProducer, IResponse
@@ -58,11 +60,13 @@ from synapse.api.errors import (
     RequestSendFailed,
     SynapseError,
 )
+from synapse.crypto.context_factory import FederationPolicyForHTTPS
 from synapse.http import QuieterFileBodyProducer
 from synapse.http.client import (
     BlacklistingAgentWrapper,
     BodyExceededMaxSize,
     ByteWriteable,
+    _make_scheduler,
     encode_query_args,
     read_body_with_max_size,
 )
@@ -181,7 +185,7 @@ class JsonParser(ByteParser[Union[JsonDict, list]]):
 
     CONTENT_TYPE = "application/json"
 
-    def __init__(self):
+    def __init__(self) -> None:
         self._buffer = StringIO()
         self._binary_wrapper = BinaryIOWrapper(self._buffer)
 
@@ -299,7 +303,9 @@ async def _handle_response(
 class BinaryIOWrapper:
     """A wrapper for a TextIO which converts from bytes on the fly."""
 
-    def __init__(self, file: typing.TextIO, encoding="utf-8", errors="strict"):
+    def __init__(
+        self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
+    ):
         self.decoder = codecs.getincrementaldecoder(encoding)(errors)
         self.file = file
 
@@ -317,7 +323,11 @@ class MatrixFederationHttpClient:
             requests.
     """
 
-    def __init__(self, hs: "HomeServer", tls_client_options_factory):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+    ):
         self.hs = hs
         self.signing_key = hs.signing_key
         self.server_name = hs.hostname
@@ -348,10 +358,7 @@ class MatrixFederationHttpClient:
         self.version_string_bytes = hs.version_string.encode("ascii")
         self.default_timeout = 60
 
-        def schedule(x):
-            self.reactor.callLater(_EPSILON, x)
-
-        self._cooperator = Cooperator(scheduler=schedule)
+        self._cooperator = Cooperator(scheduler=_make_scheduler(self.reactor))
 
         self._sleeper = AwakenableSleeper(self.reactor)
 
@@ -364,7 +371,7 @@ class MatrixFederationHttpClient:
         self,
         request: MatrixFederationRequest,
         try_trailing_slash_on_400: bool = False,
-        **send_request_args,
+        **send_request_args: Any,
     ) -> IResponse:
         """Wrapper for _send_request which can optionally retry the request
         upon receiving a combination of a 400 HTTP response code and a
@@ -1159,7 +1166,7 @@ class MatrixFederationHttpClient:
         self,
         destination: str,
         path: str,
-        output_stream,
+        output_stream: BinaryIO,
         args: Optional[QueryParams] = None,
         retry_on_dns_fail: bool = True,
         max_size: Optional[int] = None,
@@ -1250,10 +1257,10 @@ class MatrixFederationHttpClient:
         return length, headers
 
 
-def _flatten_response_never_received(e):
+def _flatten_response_never_received(e: BaseException) -> str:
     if hasattr(e, "reasons"):
         reasons = ", ".join(
-            _flatten_response_never_received(f.value) for f in e.reasons
+            _flatten_response_never_received(f.value) for f in e.reasons  # type: ignore[attr-defined]
         )
 
         return "%s:[%s]" % (type(e).__name__, reasons)
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index a16dde2380..b2a50c9105 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -245,7 +245,7 @@ def http_proxy_endpoint(
     proxy: Optional[bytes],
     reactor: IReactorCore,
     tls_options_factory: Optional[IPolicyForHTTPS],
-    **kwargs,
+    **kwargs: object,
 ) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]:
     """Parses an http proxy setting and returns an endpoint for the proxy
 
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 4886626d50..2b6d113544 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -162,7 +162,7 @@ class RequestMetrics:
         with _in_flight_requests_lock:
             _in_flight_requests.add(self)
 
-    def stop(self, time_sec, response_code, sent_bytes):
+    def stop(self, time_sec: float, response_code: int, sent_bytes: int) -> None:
         with _in_flight_requests_lock:
             _in_flight_requests.discard(self)
 
@@ -186,13 +186,13 @@ class RequestMetrics:
             )
             return
 
-        response_code = str(response_code)
+        response_code_str = str(response_code)
 
-        outgoing_responses_counter.labels(self.method, response_code).inc()
+        outgoing_responses_counter.labels(self.method, response_code_str).inc()
 
         response_count.labels(self.method, self.name, tag).inc()
 
-        response_timer.labels(self.method, self.name, tag, response_code).observe(
+        response_timer.labels(self.method, self.name, tag, response_code_str).observe(
             time_sec - self.start_ts
         )
 
@@ -221,7 +221,7 @@ class RequestMetrics:
         # flight.
         self.update_metrics()
 
-    def update_metrics(self):
+    def update_metrics(self) -> None:
         """Updates the in flight metrics with values from this request."""
         if not self.start_context:
             logger.error(
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 657bffcddd..e3dcc3f3dd 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -33,6 +33,7 @@ from typing import (
     Optional,
     Pattern,
     Tuple,
+    TypeVar,
     Union,
 )
 
@@ -92,6 +93,68 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
 HTTP_STATUS_REQUEST_CANCELLED = 499
 
 
+F = TypeVar("F", bound=Callable[..., Any])
+
+
+_cancellable_method_names = frozenset(
+    {
+        # `RestServlet`, `BaseFederationServlet` and `BaseFederationServerServlet`
+        # methods
+        "on_GET",
+        "on_PUT",
+        "on_POST",
+        "on_DELETE",
+        # `_AsyncResource`, `DirectServeHtmlResource` and `DirectServeJsonResource`
+        # methods
+        "_async_render_GET",
+        "_async_render_PUT",
+        "_async_render_POST",
+        "_async_render_DELETE",
+        "_async_render_OPTIONS",
+        # `ReplicationEndpoint` methods
+        "_handle_request",
+    }
+)
+
+
+def cancellable(method: F) -> F:
+    """Marks a servlet method as cancellable.
+
+    Methods with this decorator will be cancelled if the client disconnects before we
+    finish processing the request.
+
+    During cancellation, `Deferred.cancel()` will be invoked on the `Deferred` wrapping
+    the method. The `cancel()` call will propagate down to the `Deferred` that is
+    currently being waited on. That `Deferred` will raise a `CancelledError`, which will
+    propagate up, as per normal exception handling.
+
+    Before applying this decorator to a new endpoint, you MUST recursively check
+    that all `await`s in the function are on `async` functions or `Deferred`s that
+    handle cancellation cleanly, otherwise a variety of bugs may occur, ranging from
+    premature logging context closure, to stuck requests, to database corruption.
+
+    Usage:
+        class SomeServlet(RestServlet):
+            @cancellable
+            async def on_GET(self, request: SynapseRequest) -> ...:
+                ...
+    """
+    if method.__name__ not in _cancellable_method_names and not any(
+        method.__name__.startswith(prefix) for prefix in _cancellable_method_names
+    ):
+        raise ValueError(
+            "@cancellable decorator can only be applied to servlet methods."
+        )
+
+    method.cancellable = True  # type: ignore[attr-defined]
+    return method
+
+
+def is_method_cancellable(method: Callable[..., Any]) -> bool:
+    """Checks whether a servlet method has the `@cancellable` flag."""
+    return getattr(method, "cancellable", False)
+
+
 def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
     """Sends a JSON error response to clients."""
 
@@ -253,6 +316,9 @@ class HttpServer(Protocol):
         If the regex contains groups these gets passed to the callback via
         an unpacked tuple.
 
+        The callback may be marked with the `@cancellable` decorator, which will
+        cause request processing to be cancelled when clients disconnect early.
+
         Args:
             method: The HTTP method to listen to.
             path_patterns: The regex used to match requests.
@@ -283,7 +349,9 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
 
     def render(self, request: SynapseRequest) -> int:
         """This gets called by twisted every time someone sends us a request."""
-        defer.ensureDeferred(self._async_render_wrapper(request))
+        request.render_deferred = defer.ensureDeferred(
+            self._async_render_wrapper(request)
+        )
         return NOT_DONE_YET
 
     @wrap_async_request_handler
@@ -319,6 +387,8 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
 
         method_handler = getattr(self, "_async_render_%s" % (request_method,), None)
         if method_handler:
+            request.is_render_cancellable = is_method_cancellable(method_handler)
+
             raw_callback_return = method_handler(request)
 
             # Is it synchronous? We'll allow this for now.
@@ -479,6 +549,8 @@ class JsonResource(DirectServeJsonResource):
     async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
         callback, servlet_classname, group_dict = self._get_handler_for_request(request)
 
+        request.is_render_cancellable = is_method_cancellable(callback)
+
         # Make sure we have an appropriate name for this handler in prometheus
         # (rather than the default of JsonResource).
         request.request_metrics.name = servlet_classname
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 0b85a57d77..eeec74b78a 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union
 import attr
 from zope.interface import implementer
 
+from twisted.internet.defer import Deferred
 from twisted.internet.interfaces import IAddress, IReactorTime
 from twisted.python.failure import Failure
 from twisted.web.http import HTTPChannel
@@ -91,6 +92,14 @@ class SynapseRequest(Request):
         # we can't yet create the logcontext, as we don't know the method.
         self.logcontext: Optional[LoggingContext] = None
 
+        # The `Deferred` to cancel if the client disconnects early and
+        # `is_render_cancellable` is set. Expected to be set by `Resource.render`.
+        self.render_deferred: Optional["Deferred[None]"] = None
+        # A boolean indicating whether `render_deferred` should be cancelled if the
+        # client disconnects early. Expected to be set by the coroutine started by
+        # `Resource.render`, if rendering is asynchronous.
+        self.is_render_cancellable = False
+
         global _next_request_seq
         self.request_seq = _next_request_seq
         _next_request_seq += 1
@@ -357,7 +366,21 @@ class SynapseRequest(Request):
                     {"event": "client connection lost", "reason": str(reason.value)}
                 )
 
-            if not self._is_processing:
+            if self._is_processing:
+                if self.is_render_cancellable:
+                    if self.render_deferred is not None:
+                        # Throw a cancellation into the request processing, in the hope
+                        # that it will finish up sooner than it normally would.
+                        # The `self.processing()` context manager will call
+                        # `_finished_processing()` when done.
+                        with PreserveLoggingContext():
+                            self.render_deferred.cancel()
+                    else:
+                        logger.error(
+                            "Connection from client lost, but have no Deferred to "
+                            "cancel even though the request is marked as cancellable."
+                        )
+            else:
                 self._finished_processing()
 
     def _started_processing(self, servlet_name: str) -> None:
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index 475756f1db..5a61b21eaf 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -31,7 +31,11 @@ from twisted.internet.endpoints import (
     TCP4ClientEndpoint,
     TCP6ClientEndpoint,
 )
-from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
+from twisted.internet.interfaces import (
+    IPushProducer,
+    IReactorTCP,
+    IStreamClientEndpoint,
+)
 from twisted.internet.protocol import Factory, Protocol
 from twisted.internet.tcp import Connection
 from twisted.python.failure import Failure
@@ -59,14 +63,14 @@ class LogProducer:
     _buffer: Deque[logging.LogRecord]
     _paused: bool = attr.ib(default=False, init=False)
 
-    def pauseProducing(self):
+    def pauseProducing(self) -> None:
         self._paused = True
 
-    def stopProducing(self):
+    def stopProducing(self) -> None:
         self._paused = True
         self._buffer = deque()
 
-    def resumeProducing(self):
+    def resumeProducing(self) -> None:
         # If we're already producing, nothing to do.
         self._paused = False
 
@@ -102,8 +106,8 @@ class RemoteHandler(logging.Handler):
         host: str,
         port: int,
         maximum_buffer: int = 1000,
-        level=logging.NOTSET,
-        _reactor=None,
+        level: int = logging.NOTSET,
+        _reactor: Optional[IReactorTCP] = None,
     ):
         super().__init__(level=level)
         self.host = host
@@ -118,7 +122,7 @@ class RemoteHandler(logging.Handler):
         if _reactor is None:
             from twisted.internet import reactor
 
-            _reactor = reactor
+            _reactor = reactor  # type: ignore[assignment]
 
         try:
             ip = ip_address(self.host)
@@ -139,7 +143,7 @@ class RemoteHandler(logging.Handler):
         self._stopping = False
         self._connect()
 
-    def close(self):
+    def close(self) -> None:
         self._stopping = True
         self._service.stopService()
 
diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py
index c0f12ecd15..c88b8ae545 100644
--- a/synapse/logging/formatter.py
+++ b/synapse/logging/formatter.py
@@ -16,6 +16,8 @@
 import logging
 import traceback
 from io import StringIO
+from types import TracebackType
+from typing import Optional, Tuple, Type
 
 
 class LogFormatter(logging.Formatter):
@@ -28,10 +30,14 @@ class LogFormatter(logging.Formatter):
     where it was caught are logged).
     """
 
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-
-    def formatException(self, ei):
+    def formatException(
+        self,
+        ei: Tuple[
+            Optional[Type[BaseException]],
+            Optional[BaseException],
+            Optional[TracebackType],
+        ],
+    ) -> str:
         sio = StringIO()
         (typ, val, tb) = ei
 
diff --git a/synapse/logging/handlers.py b/synapse/logging/handlers.py
index 478b527494..dec2a2c3dd 100644
--- a/synapse/logging/handlers.py
+++ b/synapse/logging/handlers.py
@@ -49,7 +49,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
         )
         self._flushing_thread.start()
 
-        def on_reactor_running():
+        def on_reactor_running() -> None:
             self._reactor_started = True
 
         reactor_to_use: IReactorCore
@@ -74,7 +74,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
         else:
             return True
 
-    def _flush_periodically(self):
+    def _flush_periodically(self) -> None:
         """
         Whilst this handler is active, flush the handler periodically.
         """
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index d57e7c5324..a26a1a58e7 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -13,6 +13,8 @@
 # limitations under the License.import logging
 
 import logging
+from types import TracebackType
+from typing import Optional, Type
 
 from opentracing import Scope, ScopeManager
 
@@ -107,19 +109,26 @@ class _LogContextScope(Scope):
         and - if enter_logcontext was set - the logcontext is finished too.
     """
 
-    def __init__(self, manager, span, logcontext, enter_logcontext, finish_on_close):
+    def __init__(
+        self,
+        manager: LogContextScopeManager,
+        span,
+        logcontext,
+        enter_logcontext: bool,
+        finish_on_close: bool,
+    ):
         """
         Args:
-            manager (LogContextScopeManager):
+            manager:
                 the manager that is responsible for this scope.
             span (Span):
                 the opentracing span which this scope represents the local
                 lifetime for.
             logcontext (LogContext):
                 the logcontext to which this scope is attached.
-            enter_logcontext (Boolean):
+            enter_logcontext:
                 if True the logcontext will be exited when the scope is finished
-            finish_on_close (Boolean):
+            finish_on_close:
                 if True finish the span when the scope is closed
         """
         super().__init__(manager, span)
@@ -127,16 +136,21 @@ class _LogContextScope(Scope):
         self._finish_on_close = finish_on_close
         self._enter_logcontext = enter_logcontext
 
-    def __exit__(self, exc_type, value, traceback):
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
         if exc_type == twisted.internet.defer._DefGen_Return:
             # filter out defer.returnValue() calls
             exc_type = value = traceback = None
         super().__exit__(exc_type, value, traceback)
 
-    def __str__(self):
+    def __str__(self) -> str:
         return f"Scope<{self.span}>"
 
-    def close(self):
+    def close(self) -> None:
         active_scope = self.manager.active
         if active_scope is not self:
             logger.error(
diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py
index 6bc329f04a..1fc8a0e888 100644
--- a/synapse/metrics/jemalloc.py
+++ b/synapse/metrics/jemalloc.py
@@ -18,6 +18,7 @@ import os
 import re
 from typing import Iterable, Optional, overload
 
+import attr
 from prometheus_client import REGISTRY, Metric
 from typing_extensions import Literal
 
@@ -27,52 +28,24 @@ from synapse.metrics._types import Collector
 logger = logging.getLogger(__name__)
 
 
-def _setup_jemalloc_stats() -> None:
-    """Checks to see if jemalloc is loaded, and hooks up a collector to record
-    statistics exposed by jemalloc.
-    """
-
-    # Try to find the loaded jemalloc shared library, if any. We need to
-    # introspect into what is loaded, rather than loading whatever is on the
-    # path, as if we load a *different* jemalloc version things will seg fault.
-
-    # We look in `/proc/self/maps`, which only exists on linux.
-    if not os.path.exists("/proc/self/maps"):
-        logger.debug("Not looking for jemalloc as no /proc/self/maps exist")
-        return
-
-    # We're looking for a path at the end of the line that includes
-    # "libjemalloc".
-    regex = re.compile(r"/\S+/libjemalloc.*$")
-
-    jemalloc_path = None
-    with open("/proc/self/maps") as f:
-        for line in f:
-            match = regex.search(line.strip())
-            if match:
-                jemalloc_path = match.group()
-
-    if not jemalloc_path:
-        # No loaded jemalloc was found.
-        logger.debug("jemalloc not found")
-        return
-
-    logger.debug("Found jemalloc at %s", jemalloc_path)
-
-    jemalloc = ctypes.CDLL(jemalloc_path)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class JemallocStats:
+    jemalloc: ctypes.CDLL
 
     @overload
     def _mallctl(
-        name: str, read: Literal[True] = True, write: Optional[int] = None
+        self, name: str, read: Literal[True] = True, write: Optional[int] = None
     ) -> int:
         ...
 
     @overload
-    def _mallctl(name: str, read: Literal[False], write: Optional[int] = None) -> None:
+    def _mallctl(
+        self, name: str, read: Literal[False], write: Optional[int] = None
+    ) -> None:
         ...
 
     def _mallctl(
-        name: str, read: bool = True, write: Optional[int] = None
+        self, name: str, read: bool = True, write: Optional[int] = None
     ) -> Optional[int]:
         """Wrapper around `mallctl` for reading and writing integers to
         jemalloc.
@@ -120,7 +93,7 @@ def _setup_jemalloc_stats() -> None:
         # Where oldp/oldlenp is a buffer where the old value will be written to
         # (if not null), and newp/newlen is the buffer with the new value to set
         # (if not null). Note that they're all references *except* newlen.
-        result = jemalloc.mallctl(
+        result = self.jemalloc.mallctl(
             name.encode("ascii"),
             input_var_ref,
             input_len_ref,
@@ -136,21 +109,80 @@ def _setup_jemalloc_stats() -> None:
 
         return input_var.value
 
-    def _jemalloc_refresh_stats() -> None:
+    def refresh_stats(self) -> None:
         """Request that jemalloc updates its internal statistics. This needs to
         be called before querying for stats, otherwise it will return stale
         values.
         """
         try:
-            _mallctl("epoch", read=False, write=1)
+            self._mallctl("epoch", read=False, write=1)
         except Exception as e:
             logger.warning("Failed to reload jemalloc stats: %s", e)
 
+    def get_stat(self, name: str) -> int:
+        """Request the stat of the given name at the time of the last
+        `refresh_stats` call. This may throw if we fail to read
+        the stat.
+        """
+        return self._mallctl(f"stats.{name}")
+
+
+_JEMALLOC_STATS: Optional[JemallocStats] = None
+
+
+def get_jemalloc_stats() -> Optional[JemallocStats]:
+    """Returns an interface to jemalloc, if it is being used.
+
+    Note that this will always return None until `setup_jemalloc_stats` has been
+    called.
+    """
+    return _JEMALLOC_STATS
+
+
+def _setup_jemalloc_stats() -> None:
+    """Checks to see if jemalloc is loaded, and hooks up a collector to record
+    statistics exposed by jemalloc.
+    """
+
+    global _JEMALLOC_STATS
+
+    # Try to find the loaded jemalloc shared library, if any. We need to
+    # introspect into what is loaded, rather than loading whatever is on the
+    # path, as if we load a *different* jemalloc version things will seg fault.
+
+    # We look in `/proc/self/maps`, which only exists on linux.
+    if not os.path.exists("/proc/self/maps"):
+        logger.debug("Not looking for jemalloc as no /proc/self/maps exist")
+        return
+
+    # We're looking for a path at the end of the line that includes
+    # "libjemalloc".
+    regex = re.compile(r"/\S+/libjemalloc.*$")
+
+    jemalloc_path = None
+    with open("/proc/self/maps") as f:
+        for line in f:
+            match = regex.search(line.strip())
+            if match:
+                jemalloc_path = match.group()
+
+    if not jemalloc_path:
+        # No loaded jemalloc was found.
+        logger.debug("jemalloc not found")
+        return
+
+    logger.debug("Found jemalloc at %s", jemalloc_path)
+
+    jemalloc_dll = ctypes.CDLL(jemalloc_path)
+
+    stats = JemallocStats(jemalloc_dll)
+    _JEMALLOC_STATS = stats
+
     class JemallocCollector(Collector):
         """Metrics for internal jemalloc stats."""
 
         def collect(self) -> Iterable[Metric]:
-            _jemalloc_refresh_stats()
+            stats.refresh_stats()
 
             g = GaugeMetricFamily(
                 "jemalloc_stats_app_memory_bytes",
@@ -184,7 +216,7 @@ def _setup_jemalloc_stats() -> None:
                 "metadata",
             ):
                 try:
-                    value = _mallctl(f"stats.{t}")
+                    value = stats.get_stat(t)
                 except Exception as e:
                     # There was an error fetching the value, skip.
                     logger.warning("Failed to read jemalloc stats.%s: %s", t, e)
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 01a50b9d62..ba23257f54 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -46,6 +46,7 @@ from synapse.types import (
     JsonDict,
     PersistedEventPosition,
     RoomStreamToken,
+    StreamKeyType,
     StreamToken,
     UserID,
 )
@@ -370,7 +371,7 @@ class Notifier:
 
         if users or rooms:
             self.on_new_event(
-                "room_key",
+                StreamKeyType.ROOM,
                 max_room_stream_token,
                 users=users,
                 rooms=rooms,
@@ -440,7 +441,7 @@ class Notifier:
             for room in rooms:
                 user_streams |= self.room_to_user_streams.get(room, set())
 
-            if stream_key == "to_device_key":
+            if stream_key == StreamKeyType.TO_DEVICE:
                 issue9533_logger.debug(
                     "to-device messages stream id %s, awaking streams for %s",
                     new_token,
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index a1b7711098..57c4d70466 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -12,6 +12,80 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+"""
+This module implements the push rules & notifications portion of the Matrix
+specification.
+
+There's a few related features:
+
+* Push notifications (i.e. email or outgoing requests to a Push Gateway).
+* Calculation of unread notifications (for /sync and /notifications).
+
+When Synapse receives a new event (locally, via the Client-Server API, or via
+federation), the following occurs:
+
+1. The push rules get evaluated to generate a set of per-user actions.
+2. The event is persisted into the database.
+3. (In the background) The notifier is notified about the new event.
+
+The per-user actions are initially stored in the event_push_actions_staging table,
+before getting moved into the event_push_actions table when the event is persisted.
+The event_push_actions table is periodically summarised into the event_push_summary
+and event_push_summary_stream_ordering tables.
+
+Since push actions block an event from being persisted the generation of push
+actions is performance sensitive.
+
+The general interaction of the classes are:
+
+        +---------------------------------------------+
+        | FederationEventHandler/EventCreationHandler |
+        +---------------------------------------------+
+                |
+                v
+        +-----------------------+     +---------------------------+
+        | BulkPushRuleEvaluator |---->| PushRuleEvaluatorForEvent |
+        +-----------------------+     +---------------------------+
+                |
+                v
+        +-----------------------------+
+        | EventPushActionsWorkerStore |
+        +-----------------------------+
+
+The notifier notifies the pusher pool of the new event, which checks for affected
+users. Each user-configured pusher of the affected users then performs the
+previously calculated action.
+
+The general interaction of the classes are:
+
+        +----------+
+        | Notifier |
+        +----------+
+                |
+                v
+        +------------+     +--------------+
+        | PusherPool |---->| PusherConfig |
+        +------------+     +--------------+
+                |
+                |     +---------------+
+                +<--->| PusherFactory |
+                |     +---------------+
+                v
+        +------------------------+     +-----------------------------------------------+
+        | EmailPusher/HttpPusher |---->| EventPushActionsWorkerStore/PusherWorkerStore |
+        +------------------------+     +-----------------------------------------------+
+                |
+                v
+        +-------------------------+
+        | Mailer/SimpleHttpClient |
+        +-------------------------+
+
+The Pusher instance also calls out to various utilities for generating payloads
+(or email templates), but those interactions are not detailed in this diagram
+(and are specific to the type of pusher).
+
+"""
+
 import abc
 from typing import TYPE_CHECKING, Any, Dict, Optional
 
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
deleted file mode 100644
index 60758df016..0000000000
--- a/synapse/push/action_generator.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# Copyright 2015 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from typing import TYPE_CHECKING
-
-from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
-from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
-from synapse.util.metrics import Measure
-
-if TYPE_CHECKING:
-    from synapse.server import HomeServer
-
-logger = logging.getLogger(__name__)
-
-
-class ActionGenerator:
-    def __init__(self, hs: "HomeServer"):
-        self.clock = hs.get_clock()
-        self.bulk_evaluator = BulkPushRuleEvaluator(hs)
-        # really we want to get all user ids and all profile tags too,
-        # since we want the actions for each profile tag for every user and
-        # also actions for a client with no profile tag for each user.
-        # Currently the event stream doesn't support profile tags on an
-        # event stream, so we just run the rules for a client with no profile
-        # tag (ie. we just need all the users).
-
-    async def handle_push_actions_for_event(
-        self, event: EventBase, context: EventContext
-    ) -> None:
-        with Measure(self.clock, "action_for_event_by_user"):
-            await self.bulk_evaluator.action_for_event_by_user(event, context)
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index b07cf2eee7..4ac2c546bf 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -21,7 +21,7 @@ from prometheus_client import Counter
 
 from synapse.api.constants import EventTypes, Membership, RelationTypes
 from synapse.event_auth import get_user_power_level
-from synapse.events import EventBase
+from synapse.events import EventBase, relation_from_event
 from synapse.events.snapshot import EventContext
 from synapse.state import POWER_KEY
 from synapse.storage.databases.main.roommember import EventIdMembership
@@ -29,6 +29,7 @@ from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import CacheMetric, register_cache
 from synapse.util.caches.descriptors import lru_cache
 from synapse.util.caches.lrucache import LruCache
+from synapse.util.metrics import measure_func
 
 from .push_rule_evaluator import PushRuleEvaluatorForEvent
 
@@ -77,8 +78,8 @@ def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
         return False
 
     # Exclude edits.
-    relates_to = event.content.get("m.relates_to", {})
-    if relates_to.get("rel_type") == RelationTypes.REPLACE:
+    relates_to = relation_from_event(event)
+    if relates_to and relates_to.rel_type == RelationTypes.REPLACE:
         return False
 
     # Mark events that have a non-empty string body as unread.
@@ -105,6 +106,7 @@ class BulkPushRuleEvaluator:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastores().main
+        self.clock = hs.get_clock()
         self._event_auth_handler = hs.get_event_auth_handler()
 
         # Used by `RulesForRoom` to ensure only one thing mutates the cache at a
@@ -185,6 +187,7 @@ class BulkPushRuleEvaluator:
 
         return pl_event.content if pl_event else {}, sender_level
 
+    @measure_func("action_for_event_by_user")
     async def action_for_event_by_user(
         self, event: EventBase, context: EventContext
     ) -> None:
@@ -192,6 +195,10 @@ class BulkPushRuleEvaluator:
         should increment the unread count, and insert the results into the
         event_push_actions_staging table.
         """
+        if event.internal_metadata.is_outlier():
+            # This can happen due to out of band memberships
+            return
+
         count_as_unread = _should_count_as_unread(event, context)
 
         rules_by_user = await self._get_rules_for_event(event, context)
@@ -208,8 +215,6 @@ class BulkPushRuleEvaluator:
             event, len(room_members), sender_power_level, power_levels
         )
 
-        condition_cache: Dict[str, bool] = {}
-
         # If the event is not a state event check if any users ignore the sender.
         if not event.is_state():
             ignorers = await self.store.ignored_by(event.sender)
@@ -247,8 +252,8 @@ class BulkPushRuleEvaluator:
                 if "enabled" in rule and not rule["enabled"]:
                     continue
 
-                matches = _condition_checker(
-                    evaluator, rule["conditions"], uid, display_name, condition_cache
+                matches = evaluator.check_conditions(
+                    rule["conditions"], uid, display_name
                 )
                 if matches:
                     actions = [x for x in rule["actions"] if x != "dont_notify"]
@@ -267,32 +272,6 @@ class BulkPushRuleEvaluator:
         )
 
 
-def _condition_checker(
-    evaluator: PushRuleEvaluatorForEvent,
-    conditions: List[dict],
-    uid: str,
-    display_name: Optional[str],
-    cache: Dict[str, bool],
-) -> bool:
-    for cond in conditions:
-        _cache_key = cond.get("_cache_key", None)
-        if _cache_key:
-            res = cache.get(_cache_key, None)
-            if res is False:
-                return False
-            elif res is True:
-                continue
-
-        res = evaluator.matches(cond, uid, display_name)
-        if _cache_key:
-            cache[_cache_key] = bool(res)
-
-        if not res:
-            return False
-
-    return True
-
-
 MemberMap = Dict[str, Optional[EventIdMembership]]
 Rule = Dict[str, dict]
 RulesByUser = Dict[str, List[Rule]]
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index f617c759e6..54db6b5612 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -129,9 +129,55 @@ class PushRuleEvaluatorForEvent:
         # Maps strings of e.g. 'content.body' -> event["content"]["body"]
         self._value_cache = _flatten_dict(event)
 
+        # Maps cache keys to final values.
+        self._condition_cache: Dict[str, bool] = {}
+
+    def check_conditions(
+        self, conditions: List[dict], uid: str, display_name: Optional[str]
+    ) -> bool:
+        """
+        Returns true if a user's conditions/user ID/display name match the event.
+
+        Args:
+            conditions: The user's conditions to match.
+            uid: The user's MXID.
+            display_name: The display name.
+
+        Returns:
+             True if all conditions match the event, False otherwise.
+        """
+        for cond in conditions:
+            _cache_key = cond.get("_cache_key", None)
+            if _cache_key:
+                res = self._condition_cache.get(_cache_key, None)
+                if res is False:
+                    return False
+                elif res is True:
+                    continue
+
+            res = self.matches(cond, uid, display_name)
+            if _cache_key:
+                self._condition_cache[_cache_key] = bool(res)
+
+            if not res:
+                return False
+
+        return True
+
     def matches(
         self, condition: Dict[str, Any], user_id: str, display_name: Optional[str]
     ) -> bool:
+        """
+        Returns true if a user's condition/user ID/display name match the event.
+
+        Args:
+            condition: The user's condition to match.
+            uid: The user's MXID.
+            display_name: The display name, or None if there is not one.
+
+        Returns:
+             True if the condition matches the event, False otherwise.
+        """
         if condition["kind"] == "event_match":
             return self._event_match(condition, user_id)
         elif condition["kind"] == "contains_display_name":
@@ -146,6 +192,16 @@ class PushRuleEvaluatorForEvent:
             return True
 
     def _event_match(self, condition: dict, user_id: str) -> bool:
+        """
+        Check an "event_match" push rule condition.
+
+        Args:
+            condition: The "event_match" push rule condition to match.
+            user_id: The user's MXID.
+
+        Returns:
+             True if the condition matches the event, False otherwise.
+        """
         pattern = condition.get("pattern", None)
 
         if not pattern:
@@ -167,13 +223,22 @@ class PushRuleEvaluatorForEvent:
 
             return _glob_matches(pattern, body, word_boundary=True)
         else:
-            haystack = self._get_value(condition["key"])
+            haystack = self._value_cache.get(condition["key"], None)
             if haystack is None:
                 return False
 
             return _glob_matches(pattern, haystack)
 
     def _contains_display_name(self, display_name: Optional[str]) -> bool:
+        """
+        Check an "event_match" push rule condition.
+
+        Args:
+            display_name: The display name, or None if there is not one.
+
+        Returns:
+             True if the display name is found in the event body, False otherwise.
+        """
         if not display_name:
             return False
 
@@ -191,9 +256,6 @@ class PushRuleEvaluatorForEvent:
 
         return bool(r.search(body))
 
-    def _get_value(self, dotted_key: str) -> Optional[str]:
-        return self._value_cache.get(dotted_key, None)
-
 
 # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
 regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache(
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 2bd244ed79..a4ae4040c3 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -26,7 +26,8 @@ from twisted.web.server import Request
 
 from synapse.api.errors import HttpResponseException, SynapseError
 from synapse.http import RequestTimedOutError
-from synapse.http.server import HttpServer
+from synapse.http.server import HttpServer, is_method_cancellable
+from synapse.http.site import SynapseRequest
 from synapse.logging import opentracing
 from synapse.logging.opentracing import trace
 from synapse.types import JsonDict
@@ -310,6 +311,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         url_args = list(self.PATH_ARGS)
         method = self.METHOD
 
+        if self.CACHE and is_method_cancellable(self._handle_request):
+            raise Exception(
+                f"{self.__class__.__name__} has been marked as cancellable, but CACHE "
+                "is set. The cancellable flag would have no effect."
+            )
+
         if self.CACHE:
             url_args.append("txn_id")
 
@@ -324,7 +331,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         )
 
     async def _check_auth_and_handle(
-        self, request: Request, **kwargs: Any
+        self, request: SynapseRequest, **kwargs: Any
     ) -> Tuple[int, JsonDict]:
         """Called on new incoming requests when caching is enabled. Checks
         if there is a cached response for the request and returns that,
@@ -340,8 +347,18 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         if self.CACHE:
             txn_id = kwargs.pop("txn_id")
 
+            # We ignore the `@cancellable` flag, since cancellation wouldn't interupt
+            # `_handle_request` and `ResponseCache` does not handle cancellation
+            # correctly yet. In particular, there may be issues to do with logging
+            # context lifetimes.
+
             return await self.response_cache.wrap(
                 txn_id, self._handle_request, request, **kwargs
             )
 
+        # The `@cancellable` decorator may be applied to `_handle_request`. But we
+        # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
+        # so we have to set up the cancellable flag ourselves.
+        request.is_render_cancellable = is_method_cancellable(self._handle_request)
+
         return await self._handle_request(request, **kwargs)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 350762f494..a52e25c1af 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -43,7 +43,7 @@ from synapse.replication.tcp.streams.events import (
     EventsStreamEventRow,
     EventsStreamRow,
 )
-from synapse.types import PersistedEventPosition, ReadReceipt, UserID
+from synapse.types import PersistedEventPosition, ReadReceipt, StreamKeyType, UserID
 from synapse.util.async_helpers import Linearizer, timeout_deferred
 from synapse.util.metrics import Measure
 
@@ -153,19 +153,19 @@ class ReplicationDataHandler:
         if stream_name == TypingStream.NAME:
             self._typing_handler.process_replication_rows(token, rows)
             self.notifier.on_new_event(
-                "typing_key", token, rooms=[row.room_id for row in rows]
+                StreamKeyType.TYPING, token, rooms=[row.room_id for row in rows]
             )
         elif stream_name == PushRulesStream.NAME:
             self.notifier.on_new_event(
-                "push_rules_key", token, users=[row.user_id for row in rows]
+                StreamKeyType.PUSH_RULES, token, users=[row.user_id for row in rows]
             )
         elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
             self.notifier.on_new_event(
-                "account_data_key", token, users=[row.user_id for row in rows]
+                StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows]
             )
         elif stream_name == ReceiptsStream.NAME:
             self.notifier.on_new_event(
-                "receipt_key", token, rooms=[row.room_id for row in rows]
+                StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows]
             )
             await self._pusher_pool.on_new_receipts(
                 token, token, {row.room_id for row in rows}
@@ -173,14 +173,18 @@ class ReplicationDataHandler:
         elif stream_name == ToDeviceStream.NAME:
             entities = [row.entity for row in rows if row.entity.startswith("@")]
             if entities:
-                self.notifier.on_new_event("to_device_key", token, users=entities)
+                self.notifier.on_new_event(
+                    StreamKeyType.TO_DEVICE, token, users=entities
+                )
         elif stream_name == DeviceListsStream.NAME:
             all_room_ids: Set[str] = set()
             for row in rows:
                 if row.entity.startswith("@"):
                     room_ids = await self.store.get_rooms_for_user(row.entity)
                     all_room_ids.update(room_ids)
-            self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
+            self.notifier.on_new_event(
+                StreamKeyType.DEVICE_LIST, token, rooms=all_room_ids
+            )
         elif stream_name == GroupServerStream.NAME:
             self.notifier.on_new_event(
                 "groups_key", token, users=[row.user_id for row in rows]
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index f9caab6635..4b03eb876b 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -13,12 +13,10 @@
 # limitations under the License.
 
 import logging
-import re
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.constants import ReceiptTypes
 from synapse.api.errors import SynapseError
-from synapse.http import get_request_user_agent
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.http.site import SynapseRequest
@@ -26,8 +24,6 @@ from synapse.types import JsonDict
 
 from ._base import client_patterns
 
-pattern = re.compile(r"(?:Element|SchildiChat)/1\.[012]\.")
-
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
@@ -69,14 +65,7 @@ class ReceiptRestServlet(RestServlet):
         ):
             raise SynapseError(400, "Receipt type must be 'm.read'")
 
-        # Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body.
-        user_agent = get_request_user_agent(request)
-        allow_empty_body = False
-        if "Android" in user_agent:
-            if pattern.match(user_agent) or "Riot" in user_agent:
-                allow_empty_body = True
-        # This call makes sure possible empty body is handled correctly
-        parse_json_object_from_request(request, allow_empty_body)
+        parse_json_object_from_request(request, allow_empty_body=False)
 
         await self.presence_handler.bump_presence_active_time(requester.user)
 
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 906fe09e97..4b8bfbffcb 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -34,7 +34,7 @@ from synapse.api.errors import (
 )
 from synapse.api.filtering import Filter
 from synapse.events.utils import format_event_for_client_v2
-from synapse.http.server import HttpServer
+from synapse.http.server import HttpServer, cancellable
 from synapse.http.servlet import (
     ResolveRoomIdMixin,
     RestServlet,
@@ -143,6 +143,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
             self.__class__.__name__,
         )
 
+    @cancellable
     def on_GET_no_state_key(
         self, request: SynapseRequest, room_id: str, event_type: str
     ) -> Awaitable[Tuple[int, JsonDict]]:
@@ -153,6 +154,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
     ) -> Awaitable[Tuple[int, JsonDict]]:
         return self.on_PUT(request, room_id, event_type, "")
 
+    @cancellable
     async def on_GET(
         self, request: SynapseRequest, room_id: str, event_type: str, state_key: str
     ) -> Tuple[int, JsonDict]:
@@ -481,6 +483,7 @@ class RoomMemberListRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
 
+    @cancellable
     async def on_GET(
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
@@ -602,6 +605,7 @@ class RoomStateRestServlet(RestServlet):
         self.message_handler = hs.get_message_handler()
         self.auth = hs.get_auth()
 
+    @cancellable
     async def on_GET(
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, List[JsonDict]]:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 50383bdbd1..2b2db63bf7 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -668,7 +668,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         logger.debug("Running url preview cache expiry")
 
         if not (await self.store.db_pool.updates.has_completed_background_updates()):
-            logger.info("Still running DB updates; skipping expiry")
+            logger.debug("Still running DB updates; skipping url preview cache expiry")
             return
 
         def try_remove_parent_dirs(dirs: Iterable[str]) -> None:
@@ -688,7 +688,9 @@ class PreviewUrlResource(DirectServeJsonResource):
                     # Failed, skip deleting the rest of the parent dirs
                     if e.errno != errno.ENOTEMPTY:
                         logger.warning(
-                            "Failed to remove media directory: %r: %s", dir, e
+                            "Failed to remove media directory while clearing url preview cache: %r: %s",
+                            dir,
+                            e,
                         )
                     break
 
@@ -703,7 +705,11 @@ class PreviewUrlResource(DirectServeJsonResource):
             except FileNotFoundError:
                 pass  # If the path doesn't exist, meh
             except OSError as e:
-                logger.warning("Failed to remove media: %r: %s", media_id, e)
+                logger.warning(
+                    "Failed to remove media while clearing url preview cache: %r: %s",
+                    media_id,
+                    e,
+                )
                 continue
 
             removed_media.append(media_id)
@@ -714,9 +720,11 @@ class PreviewUrlResource(DirectServeJsonResource):
         await self.store.delete_url_cache(removed_media)
 
         if removed_media:
-            logger.info("Deleted %d entries from url cache", len(removed_media))
+            logger.debug(
+                "Deleted %d entries from url preview cache", len(removed_media)
+            )
         else:
-            logger.debug("No entries removed from url cache")
+            logger.debug("No entries removed from url preview cache")
 
         # Now we delete old images associated with the url cache.
         # These may be cached for a bit on the client (i.e., they
@@ -733,7 +741,9 @@ class PreviewUrlResource(DirectServeJsonResource):
             except FileNotFoundError:
                 pass  # If the path doesn't exist, meh
             except OSError as e:
-                logger.warning("Failed to remove media: %r: %s", media_id, e)
+                logger.warning(
+                    "Failed to remove media from url preview cache: %r: %s", media_id, e
+                )
                 continue
 
             dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
@@ -745,7 +755,9 @@ class PreviewUrlResource(DirectServeJsonResource):
             except FileNotFoundError:
                 pass  # If the path doesn't exist, meh
             except OSError as e:
-                logger.warning("Failed to remove media: %r: %s", media_id, e)
+                logger.warning(
+                    "Failed to remove media from url preview cache: %r: %s", media_id, e
+                )
                 continue
 
             removed_media.append(media_id)
@@ -758,9 +770,9 @@ class PreviewUrlResource(DirectServeJsonResource):
         await self.store.delete_url_cache_media(removed_media)
 
         if removed_media:
-            logger.info("Deleted %d media from url cache", len(removed_media))
+            logger.debug("Deleted %d media from url preview cache", len(removed_media))
         else:
-            logger.debug("No media removed from url cache")
+            logger.debug("No media removed from url preview cache")
 
 
 def _is_media(content_type: str) -> bool:
diff --git a/synapse/server.py b/synapse/server.py
index d49c76518a..ee60cce8eb 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -119,7 +119,7 @@ from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpC
 from synapse.http.matrixfederationclient import MatrixFederationHttpClient
 from synapse.module_api import ModuleApi
 from synapse.notifier import Notifier
-from synapse.push.action_generator import ActionGenerator
+from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
 from synapse.push.pusherpool import PusherPool
 from synapse.replication.tcp.client import ReplicationDataHandler
 from synapse.replication.tcp.external_cache import ExternalCache
@@ -644,8 +644,8 @@ class HomeServer(metaclass=abc.ABCMeta):
         return ReplicationCommandHandler(self)
 
     @cache_in_self
-    def get_action_generator(self) -> ActionGenerator:
-        return ActionGenerator(self)
+    def get_bulk_push_rule_evaluator(self) -> BulkPushRuleEvaluator:
+        return BulkPushRuleEvaluator(self)
 
     @cache_in_self
     def get_user_directory_handler(self) -> UserDirectoryHandler:
@@ -681,7 +681,7 @@ class HomeServer(metaclass=abc.ABCMeta):
 
     @cache_in_self
     def get_spam_checker(self) -> SpamChecker:
-        return SpamChecker()
+        return SpamChecker(self)
 
     @cache_in_self
     def get_third_party_event_rules(self) -> ThirdPartyEventRules:
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index 015dd08f05..b5f3a0c74e 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -21,7 +21,6 @@ from synapse.api.constants import (
     ServerNoticeMsgType,
 )
 from synapse.api.errors import AuthError, ResourceLimitError, SynapseError
-from synapse.server_notices.server_notices_manager import SERVER_NOTICE_ROOM_TAG
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -71,18 +70,19 @@ class ResourceLimitsServerNotices:
             # In practice, not sure we can ever get here
             return
 
-        room_id = await self._server_notices_manager.get_or_create_notice_room_for_user(
+        # Check if there's a server notice room for this user.
+        room_id = await self._server_notices_manager.maybe_get_notice_room_for_user(
             user_id
         )
 
-        if not room_id:
-            logger.warning("Failed to get server notices room")
-            return
-
-        await self._check_and_set_tags(user_id, room_id)
-
-        # Determine current state of room
-        currently_blocked, ref_events = await self._is_room_currently_blocked(room_id)
+        if room_id is not None:
+            # Determine current state of room
+            currently_blocked, ref_events = await self._is_room_currently_blocked(
+                room_id
+            )
+        else:
+            currently_blocked = False
+            ref_events = []
 
         limit_msg = None
         limit_type = None
@@ -161,26 +161,6 @@ class ResourceLimitsServerNotices:
             user_id, content, EventTypes.Pinned, ""
         )
 
-    async def _check_and_set_tags(self, user_id: str, room_id: str) -> None:
-        """
-        Since server notices rooms were originally not with tags,
-        important to check that tags have been set correctly
-        Args:
-            user_id(str): the user in question
-            room_id(str): the server notices room for that user
-        """
-        tags = await self._store.get_tags_for_room(user_id, room_id)
-        need_to_set_tag = True
-        if tags:
-            if SERVER_NOTICE_ROOM_TAG in tags:
-                # tag already present, nothing to do here
-                need_to_set_tag = False
-        if need_to_set_tag:
-            max_id = await self._account_data_handler.add_tag_to_room(
-                user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
-            )
-            self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
-
     async def _is_room_currently_blocked(self, room_id: str) -> Tuple[bool, List[str]]:
         """
         Determines if the room is currently blocked
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 48eae5fa06..8ecab86ec7 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Optional
 
 from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
 from synapse.events import EventBase
-from synapse.types import Requester, UserID, create_requester
+from synapse.types import Requester, StreamKeyType, UserID, create_requester
 from synapse.util.caches.descriptors import cached
 
 if TYPE_CHECKING:
@@ -91,6 +91,35 @@ class ServerNoticesManager:
         return event
 
     @cached()
+    async def maybe_get_notice_room_for_user(self, user_id: str) -> Optional[str]:
+        """Try to look up the server notice room for this user if it exists.
+
+        Does not create one if none can be found.
+
+        Args:
+            user_id: the user we want a server notice room for.
+
+        Returns:
+            The room's ID, or None if no room could be found.
+        """
+        rooms = await self._store.get_rooms_for_local_user_where_membership_is(
+            user_id, [Membership.INVITE, Membership.JOIN]
+        )
+        for room in rooms:
+            # it's worth noting that there is an asymmetry here in that we
+            # expect the user to be invited or joined, but the system user must
+            # be joined. This is kinda deliberate, in that if somebody somehow
+            # manages to invite the system user to a room, that doesn't make it
+            # the server notices room.
+            user_ids = await self._store.get_users_in_room(room.room_id)
+            if len(user_ids) <= 2 and self.server_notices_mxid in user_ids:
+                # we found a room which our user shares with the system notice
+                # user
+                return room.room_id
+
+        return None
+
+    @cached()
     async def get_or_create_notice_room_for_user(self, user_id: str) -> str:
         """Get the room for notices for a given user
 
@@ -112,31 +141,20 @@ class ServerNoticesManager:
             self.server_notices_mxid, authenticated_entity=self._server_name
         )
 
-        rooms = await self._store.get_rooms_for_local_user_where_membership_is(
-            user_id, [Membership.INVITE, Membership.JOIN]
-        )
-        for room in rooms:
-            # it's worth noting that there is an asymmetry here in that we
-            # expect the user to be invited or joined, but the system user must
-            # be joined. This is kinda deliberate, in that if somebody somehow
-            # manages to invite the system user to a room, that doesn't make it
-            # the server notices room.
-            user_ids = await self._store.get_users_in_room(room.room_id)
-            if len(user_ids) <= 2 and self.server_notices_mxid in user_ids:
-                # we found a room which our user shares with the system notice
-                # user
-                logger.info(
-                    "Using existing server notices room %s for user %s",
-                    room.room_id,
-                    user_id,
-                )
-                await self._update_notice_user_profile_if_changed(
-                    requester,
-                    room.room_id,
-                    self._config.servernotices.server_notices_mxid_display_name,
-                    self._config.servernotices.server_notices_mxid_avatar_url,
-                )
-                return room.room_id
+        room_id = await self.maybe_get_notice_room_for_user(user_id)
+        if room_id is not None:
+            logger.info(
+                "Using existing server notices room %s for user %s",
+                room_id,
+                user_id,
+            )
+            await self._update_notice_user_profile_if_changed(
+                requester,
+                room_id,
+                self._config.servernotices.server_notices_mxid_display_name,
+                self._config.servernotices.server_notices_mxid_avatar_url,
+            )
+            return room_id
 
         # apparently no existing notice room: create a new one
         logger.info("Creating server notices room for %s", user_id)
@@ -166,10 +184,12 @@ class ServerNoticesManager:
         )
         room_id = info["room_id"]
 
+        self.maybe_get_notice_room_for_user.invalidate((user_id,))
+
         max_id = await self._account_data_handler.add_tag_to_room(
             user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
         )
-        self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        self._notifier.on_new_event(StreamKeyType.ACCOUNT_DATA, max_id, users=[user_id])
 
         logger.info("Created server notices room %s for %s", room_id, user_id)
         return room_id
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index cad3b42640..54e41d5375 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -130,6 +130,7 @@ class StateHandler:
         self.state_store = hs.get_storage().state
         self.hs = hs
         self._state_resolution_handler = hs.get_state_resolution_handler()
+        self._storage = hs.get_storage()
 
     @overload
     async def get_current_state(
@@ -361,10 +362,10 @@ class StateHandler:
 
         if not event.is_state():
             return EventContext.with_state(
+                storage=self._storage,
                 state_group_before_event=state_group_before_event,
                 state_group=state_group_before_event,
-                current_state_ids=state_ids_before_event,
-                prev_state_ids=state_ids_before_event,
+                state_delta_due_to_event={},
                 prev_group=state_group_before_event_prev_group,
                 delta_ids=deltas_to_state_group_before_event,
                 partial_state=partial_state,
@@ -393,10 +394,10 @@ class StateHandler:
         )
 
         return EventContext.with_state(
+            storage=self._storage,
             state_group=state_group_after_event,
             state_group_before_event=state_group_before_event,
-            current_state_ids=state_ids_after_event,
-            prev_state_ids=state_ids_before_event,
+            state_delta_due_to_event=delta_ids,
             prev_group=state_group_before_event,
             delta_ids=delta_ids,
             partial_state=partial_state,
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 08c6eabc6d..c2bbbb574e 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,20 +12,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from types import TracebackType
 from typing import (
     TYPE_CHECKING,
+    Any,
     AsyncContextManager,
     Awaitable,
     Callable,
     Dict,
     Iterable,
+    List,
     Optional,
+    Type,
 )
 
 import attr
 
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.types import Connection
+from synapse.storage.types import Connection, Cursor
 from synapse.types import JsonDict
 from synapse.util import Clock, json_encoder
 
@@ -74,7 +78,12 @@ class _BackgroundUpdateContextManager:
 
         return self._update_duration_ms
 
-    async def __aexit__(self, *exc) -> None:
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> None:
         pass
 
 
@@ -352,7 +361,7 @@ class BackgroundUpdater:
             True if we have finished running all the background updates, otherwise False
         """
 
-        def get_background_updates_txn(txn):
+        def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]:
             txn.execute(
                 """
                 SELECT update_name, depends_on FROM background_updates
@@ -469,7 +478,7 @@ class BackgroundUpdater:
         self,
         update_name: str,
         update_handler: Callable[[JsonDict, int], Awaitable[int]],
-    ):
+    ) -> None:
         """Register a handler for doing a background update.
 
         The handler should take two arguments:
@@ -603,7 +612,7 @@ class BackgroundUpdater:
         else:
             runner = create_index_sqlite
 
-        async def updater(progress, batch_size):
+        async def updater(progress: JsonDict, batch_size: int) -> int:
             if runner is not None:
                 logger.info("Adding index %s to %s", index_name, table)
                 await self.db_pool.runWithConnection(runner)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 41f566b648..5ddb58a8a2 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -31,6 +31,7 @@ from typing import (
     List,
     Optional,
     Tuple,
+    Type,
     TypeVar,
     cast,
     overload,
@@ -41,6 +42,7 @@ from prometheus_client import Histogram
 from typing_extensions import Concatenate, Literal, ParamSpec
 
 from twisted.enterprise import adbapi
+from twisted.internet.interfaces import IReactorCore
 
 from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
@@ -92,7 +94,9 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
 
 
 def make_pool(
-    reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+    reactor: IReactorCore,
+    db_config: DatabaseConnectionConfig,
+    engine: BaseDatabaseEngine,
 ) -> adbapi.ConnectionPool:
     """Get the connection pool for the database."""
 
@@ -101,7 +105,7 @@ def make_pool(
     db_args = dict(db_config.config.get("args", {}))
     db_args.setdefault("cp_reconnect", True)
 
-    def _on_new_connection(conn):
+    def _on_new_connection(conn: Connection) -> None:
         # Ensure we have a logging context so we can correctly track queries,
         # etc.
         with LoggingContext("db.on_new_connection"):
@@ -157,7 +161,11 @@ class LoggingDatabaseConnection:
     default_txn_name: str
 
     def cursor(
-        self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
+        self,
+        *,
+        txn_name: Optional[str] = None,
+        after_callbacks: Optional[List["_CallbackListEntry"]] = None,
+        exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
     ) -> "LoggingTransaction":
         if not txn_name:
             txn_name = self.default_txn_name
@@ -183,11 +191,16 @@ class LoggingDatabaseConnection:
         self.conn.__enter__()
         return self
 
-    def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_value: Optional[BaseException],
+        traceback: Optional[types.TracebackType],
+    ) -> Optional[bool]:
         return self.conn.__exit__(exc_type, exc_value, traceback)
 
     # Proxy through any unknown lookups to the DB conn class.
-    def __getattr__(self, name):
+    def __getattr__(self, name: str) -> Any:
         return getattr(self.conn, name)
 
 
@@ -391,17 +404,22 @@ class LoggingTransaction:
     def __enter__(self) -> "LoggingTransaction":
         return self
 
-    def __exit__(self, exc_type, exc_value, traceback):
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_value: Optional[BaseException],
+        traceback: Optional[types.TracebackType],
+    ) -> None:
         self.close()
 
 
 class PerformanceCounters:
-    def __init__(self):
-        self.current_counters = {}
-        self.previous_counters = {}
+    def __init__(self) -> None:
+        self.current_counters: Dict[str, Tuple[int, float]] = {}
+        self.previous_counters: Dict[str, Tuple[int, float]] = {}
 
     def update(self, key: str, duration_secs: float) -> None:
-        count, cum_time = self.current_counters.get(key, (0, 0))
+        count, cum_time = self.current_counters.get(key, (0, 0.0))
         count += 1
         cum_time += duration_secs
         self.current_counters[key] = (count, cum_time)
@@ -527,7 +545,7 @@ class DatabasePool:
     def start_profiling(self) -> None:
         self._previous_loop_ts = monotonic_time()
 
-        def loop():
+        def loop() -> None:
             curr = self._current_txn_total_time
             prev = self._previous_txn_total_time
             self._previous_txn_total_time = curr
@@ -1186,7 +1204,7 @@ class DatabasePool:
         if lock:
             self.engine.lock_table(txn, table)
 
-        def _getwhere(key):
+        def _getwhere(key: str) -> str:
             # If the value we're passing in is None (aka NULL), we need to use
             # IS, not =, as NULL = NULL equals NULL (False).
             if keyvalues[key] is None:
@@ -2258,7 +2276,7 @@ class DatabasePool:
         term: Optional[str],
         col: str,
         retcols: Collection[str],
-        desc="simple_search_list",
+        desc: str = "simple_search_list",
     ) -> Optional[List[Dict[str, Any]]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index dd4e83a2ad..1653a6a9b6 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -57,6 +57,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         self._instance_name = hs.get_instance_name()
 
+        self.db_pool.updates.register_background_index_update(
+            update_name="cache_invalidation_index_by_instance",
+            index_name="cache_invalidation_stream_by_instance_instance_index",
+            table="cache_invalidation_stream_by_instance",
+            columns=("instance_name", "stream_id"),
+            psql_only=True,  # The table is only on postgres DBs.
+        )
+
     async def get_all_updated_caches(
         self, instance_name: str, last_id: int, current_id: int, limit: int
     ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index b789a588a5..af59be6b48 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -21,7 +21,7 @@ from synapse.api.errors import StoreError
 from synapse.logging.opentracing import log_kv, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import LoggingTransaction
-from synapse.types import JsonDict, JsonSerializable
+from synapse.types import JsonDict, JsonSerializable, StreamKeyType
 from synapse.util import json_encoder
 
 
@@ -126,7 +126,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                     "message": "Set room key",
                     "room_id": room_id,
                     "session_id": session_id,
-                    "room_key": room_key,
+                    StreamKeyType.ROOM: room_key,
                 }
             )
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index ed29a0a5e2..42d484dc98 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -36,9 +36,8 @@ from prometheus_client import Counter
 import synapse.metrics
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.room_versions import RoomVersions
-from synapse.crypto.event_signing import compute_event_reference_hash
-from synapse.events import EventBase  # noqa: F401
-from synapse.events.snapshot import EventContext  # noqa: F401
+from synapse.events import EventBase, relation_from_event
+from synapse.events.snapshot import EventContext
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -50,7 +49,7 @@ from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.engines.postgres import PostgresEngine
 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
-from synapse.types import StateMap, get_domain_from_id
+from synapse.types import JsonDict, StateMap, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.iterutils import batch_iter, sorted_topologically
 
@@ -129,7 +128,6 @@ class PersistEventsStore:
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         *,
-        current_state_for_room: Dict[str, StateMap[str]],
         state_delta_for_room: Dict[str, DeltaState],
         new_forward_extremities: Dict[str, Set[str]],
         use_negative_stream_ordering: bool = False,
@@ -140,8 +138,6 @@ class PersistEventsStore:
 
         Args:
             events_and_contexts:
-            current_state_for_room: Map from room_id to the current state of
-                the room based on forward extremities
             state_delta_for_room: Map from room_id to the delta to apply to
                 room state
             new_forward_extremities: Map from room_id to set of event IDs
@@ -216,9 +212,6 @@ class PersistEventsStore:
 
                 event_counter.labels(event.type, origin_type, origin_entity).inc()
 
-            for room_id, new_state in current_state_for_room.items():
-                self.store.get_current_state_ids.prefill((room_id,), new_state)
-
             for room_id, latest_event_ids in new_forward_extremities.items():
                 self.store.get_latest_event_ids_in_room.prefill(
                     (room_id,), list(latest_event_ids)
@@ -236,7 +229,9 @@ class PersistEventsStore:
         """
         results: List[str] = []
 
-        def _get_events_which_are_prevs_txn(txn, batch):
+        def _get_events_which_are_prevs_txn(
+            txn: LoggingTransaction, batch: Collection[str]
+        ) -> None:
             sql = """
             SELECT prev_event_id, internal_metadata
             FROM event_edges
@@ -286,7 +281,9 @@ class PersistEventsStore:
         # and their prev events.
         existing_prevs = set()
 
-        def _get_prevs_before_rejected_txn(txn, batch):
+        def _get_prevs_before_rejected_txn(
+            txn: LoggingTransaction, batch: Collection[str]
+        ) -> None:
             to_recursively_check = batch
 
             while to_recursively_check:
@@ -516,7 +513,7 @@ class PersistEventsStore:
     @classmethod
     def _add_chain_cover_index(
         cls,
-        txn,
+        txn: LoggingTransaction,
         db_pool: DatabasePool,
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
@@ -810,7 +807,7 @@ class PersistEventsStore:
 
     @staticmethod
     def _allocate_chain_ids(
-        txn,
+        txn: LoggingTransaction,
         db_pool: DatabasePool,
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
@@ -944,7 +941,7 @@ class PersistEventsStore:
         self,
         txn: LoggingTransaction,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-    ):
+    ) -> None:
         """Persist the mapping from transaction IDs to event IDs (if defined)."""
 
         to_insert = []
@@ -998,7 +995,7 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         state_delta_by_room: Dict[str, DeltaState],
         stream_id: int,
-    ):
+    ) -> None:
         for room_id, delta_state in state_delta_by_room.items():
             to_delete = delta_state.to_delete
             to_insert = delta_state.to_insert
@@ -1156,7 +1153,7 @@ class PersistEventsStore:
                 txn, room_id, members_changed
             )
 
-    def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
+    def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
         """Update the room version in the database based off current state
         events.
 
@@ -1190,7 +1187,7 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         new_forward_extremities: Dict[str, Set[str]],
         max_stream_order: int,
-    ):
+    ) -> None:
         for room_id in new_forward_extremities.keys():
             self.db_pool.simple_delete_txn(
                 txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
@@ -1255,9 +1252,9 @@ class PersistEventsStore:
 
     def _update_room_depths_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
-    ):
+    ) -> None:
         """Update min_depth for each room
 
         Args:
@@ -1386,7 +1383,7 @@ class PersistEventsStore:
             # nothing to do here
             return
 
-        def event_dict(event):
+        def event_dict(event: EventBase) -> JsonDict:
             d = event.get_dict()
             d.pop("redacted", None)
             d.pop("redacted_because", None)
@@ -1477,18 +1474,20 @@ class PersistEventsStore:
             ),
         )
 
-    def _store_rejected_events_txn(self, txn, events_and_contexts):
+    def _store_rejected_events_txn(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> List[Tuple[EventBase, EventContext]]:
         """Add rows to the 'rejections' table for received events which were
         rejected
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
+            txn: db connection
+            events_and_contexts: events we are persisting
 
         Returns:
-            list[(EventBase, EventContext)] new list, without the rejected
-                events.
+            new list, without the rejected events.
         """
         # Remove the rejected events from the list now that we've added them
         # to the events table and the events_json table.
@@ -1509,7 +1508,7 @@ class PersistEventsStore:
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         all_events_and_contexts: List[Tuple[EventBase, EventContext]],
         inhibit_local_membership_updates: bool = False,
-    ):
+    ) -> None:
         """Update all the miscellaneous tables for new events
 
         Args:
@@ -1600,15 +1599,14 @@ class PersistEventsStore:
             inhibit_local_membership_updates=inhibit_local_membership_updates,
         )
 
-        # Insert event_reference_hashes table.
-        self._store_event_reference_hashes_txn(
-            txn, [event for event, _ in events_and_contexts]
-        )
-
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
 
-    def _add_to_cache(self, txn, events_and_contexts):
+    def _add_to_cache(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> None:
         to_prefill = []
 
         rows = []
@@ -1639,7 +1637,7 @@ class PersistEventsStore:
             if not row["rejects"] and not row["redacts"]:
                 to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
 
-        def prefill():
+        def prefill() -> None:
             for cache_entry in to_prefill:
                 self.store._get_event_cache.set(
                     (cache_entry.event.event_id,), cache_entry
@@ -1669,19 +1667,24 @@ class PersistEventsStore:
         )
 
     def insert_labels_for_event_txn(
-        self, txn, event_id, labels, room_id, topological_ordering
-    ):
+        self,
+        txn: LoggingTransaction,
+        event_id: str,
+        labels: List[str],
+        room_id: str,
+        topological_ordering: int,
+    ) -> None:
         """Store the mapping between an event's ID and its labels, with one row per
         (event_id, label) tuple.
 
         Args:
-            txn (LoggingTransaction): The transaction to execute.
-            event_id (str): The event's ID.
-            labels (list[str]): A list of text labels.
-            room_id (str): The ID of the room the event was sent to.
-            topological_ordering (int): The position of the event in the room's topology.
+            txn: The transaction to execute.
+            event_id: The event's ID.
+            labels: A list of text labels.
+            room_id: The ID of the room the event was sent to.
+            topological_ordering: The position of the event in the room's topology.
         """
-        return self.db_pool.simple_insert_many_txn(
+        self.db_pool.simple_insert_many_txn(
             txn=txn,
             table="event_labels",
             keys=("event_id", "label", "room_id", "topological_ordering"),
@@ -1690,44 +1693,32 @@ class PersistEventsStore:
             ],
         )
 
-    def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+    def _insert_event_expiry_txn(
+        self, txn: LoggingTransaction, event_id: str, expiry_ts: int
+    ) -> None:
         """Save the expiry timestamp associated with a given event ID.
 
         Args:
-            txn (LoggingTransaction): The database transaction to use.
-            event_id (str): The event ID the expiry timestamp is associated with.
-            expiry_ts (int): The timestamp at which to expire (delete) the event.
+            txn: The database transaction to use.
+            event_id: The event ID the expiry timestamp is associated with.
+            expiry_ts: The timestamp at which to expire (delete) the event.
         """
-        return self.db_pool.simple_insert_txn(
+        self.db_pool.simple_insert_txn(
             txn=txn,
             table="event_expiry",
             values={"event_id": event_id, "expiry_ts": expiry_ts},
         )
 
-    def _store_event_reference_hashes_txn(self, txn, events):
-        """Store a hash for a PDU
-        Args:
-            txn (cursor):
-            events (list): list of Events.
-        """
-
-        vals = []
-        for event in events:
-            ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
-            vals.append((event.event_id, ref_alg, memoryview(ref_hash_bytes)))
-
-        self.db_pool.simple_insert_many_txn(
-            txn,
-            table="event_reference_hashes",
-            keys=("event_id", "algorithm", "hash"),
-            values=vals,
-        )
-
     def _store_room_members_txn(
-        self, txn, events, *, inhibit_local_membership_updates: bool = False
-    ):
+        self,
+        txn: LoggingTransaction,
+        events: List[EventBase],
+        *,
+        inhibit_local_membership_updates: bool = False,
+    ) -> None:
         """
         Store a room member in the database.
+
         Args:
             txn: The transaction to use.
             events: List of events to store.
@@ -1767,6 +1758,7 @@ class PersistEventsStore:
         )
 
         for event in events:
+            assert event.internal_metadata.stream_ordering is not None
             txn.call_after(
                 self.store._membership_stream_cache.entity_has_changed,
                 event.state_key,
@@ -1815,55 +1807,50 @@ class PersistEventsStore:
             txn: The current database transaction.
             event: The event which might have relations.
         """
-        relation = event.content.get("m.relates_to")
+        relation = relation_from_event(event)
         if not relation:
-            # No relations
-            return
-
-        # Relations must have a type and parent event ID.
-        rel_type = relation.get("rel_type")
-        if not isinstance(rel_type, str):
+            # No relation, nothing to do.
             return
 
-        parent_id = relation.get("event_id")
-        if not isinstance(parent_id, str):
-            return
-
-        # Annotations have a key field.
-        aggregation_key = None
-        if rel_type == RelationTypes.ANNOTATION:
-            aggregation_key = relation.get("key")
-
         self.db_pool.simple_insert_txn(
             txn,
             table="event_relations",
             values={
                 "event_id": event.event_id,
-                "relates_to_id": parent_id,
-                "relation_type": rel_type,
-                "aggregation_key": aggregation_key,
+                "relates_to_id": relation.parent_id,
+                "relation_type": relation.rel_type,
+                "aggregation_key": relation.aggregation_key,
             },
         )
 
-        txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,))
         txn.call_after(
-            self.store.get_aggregation_groups_for_event.invalidate, (parent_id,)
+            self.store.get_relations_for_event.invalidate, (relation.parent_id,)
+        )
+        txn.call_after(
+            self.store.get_aggregation_groups_for_event.invalidate,
+            (relation.parent_id,),
         )
 
-        if rel_type == RelationTypes.REPLACE:
-            txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+        if relation.rel_type == RelationTypes.REPLACE:
+            txn.call_after(
+                self.store.get_applicable_edit.invalidate, (relation.parent_id,)
+            )
 
-        if rel_type == RelationTypes.THREAD:
-            txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+        if relation.rel_type == RelationTypes.THREAD:
+            txn.call_after(
+                self.store.get_thread_summary.invalidate, (relation.parent_id,)
+            )
             # It should be safe to only invalidate the cache if the user has not
             # previously participated in the thread, but that's difficult (and
             # potentially error-prone) so it is always invalidated.
             txn.call_after(
                 self.store.get_thread_participated.invalidate,
-                (parent_id, event.sender),
+                (relation.parent_id, event.sender),
             )
 
-    def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
+    def _handle_insertion_event(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         """Handles keeping track of insertion events and edges/connections.
         Part of MSC2716.
 
@@ -1924,7 +1911,7 @@ class PersistEventsStore:
                 },
             )
 
-    def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase):
+    def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None:
         """Handles inserting the batch edges/connections between the batch event
         and an insertion event. Part of MSC2716.
 
@@ -2024,25 +2011,29 @@ class PersistEventsStore:
             txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
         )
 
-    def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
         if isinstance(event.content.get("topic"), str):
             self.store_event_search_txn(
                 txn, event, "content.topic", event.content["topic"]
             )
 
-    def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
         if isinstance(event.content.get("name"), str):
             self.store_event_search_txn(
                 txn, event, "content.name", event.content["name"]
             )
 
-    def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase):
+    def _store_room_message_txn(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         if isinstance(event.content.get("body"), str):
             self.store_event_search_txn(
                 txn, event, "content.body", event.content["body"]
             )
 
-    def _store_retention_policy_for_room_txn(self, txn, event):
+    def _store_retention_policy_for_room_txn(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
         if not event.is_state():
             logger.debug("Ignoring non-state m.room.retention event")
             return
@@ -2102,8 +2093,11 @@ class PersistEventsStore:
         )
 
     def _set_push_actions_for_event_and_users_txn(
-        self, txn, events_and_contexts, all_events_and_contexts
-    ):
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        all_events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ) -> None:
         """Handles moving push actions from staging table to main
         event_push_actions table for all events in `events_and_contexts`.
 
@@ -2111,12 +2105,10 @@ class PersistEventsStore:
         from the push action staging area.
 
         Args:
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
-            all_events_and_contexts (list[(EventBase, EventContext)]): all
-                events that we were going to persist. This includes events
-                we've already persisted, etc, that wouldn't appear in
-                events_and_context.
+            events_and_contexts: events we are persisting
+            all_events_and_contexts: all events that we were going to persist.
+                This includes events we've already persisted, etc, that wouldn't
+                appear in events_and_context.
         """
 
         # Only non outlier events will have push actions associated with them,
@@ -2185,7 +2177,9 @@ class PersistEventsStore:
             ),
         )
 
-    def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
+    def _remove_push_actions_for_event_id_txn(
+        self, txn: LoggingTransaction, room_id: str, event_id: str
+    ) -> None:
         # Sad that we have to blow away the cache for the whole room here
         txn.call_after(
             self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
@@ -2196,7 +2190,9 @@ class PersistEventsStore:
             (room_id, event_id),
         )
 
-    def _store_rejections_txn(self, txn, event_id, reason):
+    def _store_rejections_txn(
+        self, txn: LoggingTransaction, event_id: str, reason: str
+    ) -> None:
         self.db_pool.simple_insert_txn(
             txn,
             table="rejections",
@@ -2208,8 +2204,10 @@ class PersistEventsStore:
         )
 
     def _store_event_state_mappings_txn(
-        self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
-    ):
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: Collection[Tuple[EventBase, EventContext]],
+    ) -> None:
         state_groups = {}
         for event, context in events_and_contexts:
             if event.internal_metadata.is_outlier():
@@ -2266,7 +2264,9 @@ class PersistEventsStore:
                 state_group_id,
             )
 
-    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+    def _update_min_depth_for_room_txn(
+        self, txn: LoggingTransaction, room_id: str, depth: int
+    ) -> None:
         min_depth = self.store._get_min_depth_interaction(txn, room_id)
 
         if min_depth is not None and depth >= min_depth:
@@ -2279,7 +2279,9 @@ class PersistEventsStore:
             values={"min_depth": depth},
         )
 
-    def _handle_mult_prev_events(self, txn, events):
+    def _handle_mult_prev_events(
+        self, txn: LoggingTransaction, events: List[EventBase]
+    ) -> None:
         """
         For the given event, update the event edges table and forward and
         backward extremities tables.
@@ -2297,7 +2299,9 @@ class PersistEventsStore:
 
         self._update_backward_extremeties(txn, events)
 
-    def _update_backward_extremeties(self, txn, events):
+    def _update_backward_extremeties(
+        self, txn: LoggingTransaction, events: List[EventBase]
+    ) -> None:
         """Updates the event_backward_extremities tables based on the new/updated
         events being persisted.
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index a4a604a499..5b22d6b452 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -14,6 +14,7 @@
 
 import logging
 import threading
+import weakref
 from enum import Enum, auto
 from typing import (
     TYPE_CHECKING,
@@ -23,6 +24,7 @@ from typing import (
     Dict,
     Iterable,
     List,
+    MutableMapping,
     Optional,
     Set,
     Tuple,
@@ -248,6 +250,12 @@ class EventsWorkerStore(SQLBaseStore):
             str, ObservableDeferred[Dict[str, EventCacheEntry]]
         ] = {}
 
+        # We keep track of the events we have currently loaded in memory so that
+        # we can reuse them even if they've been evicted from the cache. We only
+        # track events that don't need redacting in here (as then we don't need
+        # to track redaction status).
+        self._event_ref: MutableMapping[str, EventBase] = weakref.WeakValueDictionary()
+
         self._event_fetch_lock = threading.Condition()
         self._event_fetch_list: List[
             Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
@@ -723,6 +731,8 @@ class EventsWorkerStore(SQLBaseStore):
 
     def _invalidate_get_event_cache(self, event_id: str) -> None:
         self._get_event_cache.invalidate((event_id,))
+        self._event_ref.pop(event_id, None)
+        self._current_event_fetches.pop(event_id, None)
 
     def _get_events_from_cache(
         self, events: Iterable[str], update_metrics: bool = True
@@ -738,13 +748,30 @@ class EventsWorkerStore(SQLBaseStore):
         event_map = {}
 
         for event_id in events:
+            # First check if it's in the event cache
             ret = self._get_event_cache.get(
                 (event_id,), None, update_metrics=update_metrics
             )
-            if not ret:
+            if ret:
+                event_map[event_id] = ret
                 continue
 
-            event_map[event_id] = ret
+            # Otherwise check if we still have the event in memory.
+            event = self._event_ref.get(event_id)
+            if event:
+                # Reconstruct an event cache entry
+
+                cache_entry = EventCacheEntry(
+                    event=event,
+                    # We don't cache weakrefs to redacted events, so we know
+                    # this is None.
+                    redacted_event=None,
+                )
+                event_map[event_id] = cache_entry
+
+                # We add the entry back into the cache as we want to keep
+                # recently queried events in the cache.
+                self._get_event_cache.set((event_id,), cache_entry)
 
         return event_map
 
@@ -1124,6 +1151,10 @@ class EventsWorkerStore(SQLBaseStore):
             self._get_event_cache.set((event_id,), cache_entry)
             result_map[event_id] = cache_entry
 
+            if not redacted_event:
+                # We only cache references to unredacted events.
+                self._event_ref[event_id] = original_ev
+
         return result_map
 
     async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 1480a0f048..d03555a585 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -23,6 +23,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
 )
+from synapse.storage.types import Cursor
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -71,7 +72,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         self._last_user_visit_update = self._get_start_of_day()
 
     @wrap_as_background_process("read_forward_extremities")
-    async def _read_forward_extremities(self):
+    async def _read_forward_extremities(self) -> None:
         def fetch(txn):
             txn.execute(
                 """
@@ -95,7 +96,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             (x[0] - 1) * x[1] for x in res if x[1]
         )
 
-    async def count_daily_e2ee_messages(self):
+    async def count_daily_e2ee_messages(self) -> int:
         """
         Returns an estimate of the number of messages sent in the last day.
 
@@ -115,7 +116,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
         return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
 
-    async def count_daily_sent_e2ee_messages(self):
+    async def count_daily_sent_e2ee_messages(self) -> int:
         def _count_messages(txn):
             # This is good enough as if you have silly characters in your own
             # hostname then that's your own fault.
@@ -136,7 +137,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_daily_sent_e2ee_messages", _count_messages
         )
 
-    async def count_daily_active_e2ee_rooms(self):
+    async def count_daily_active_e2ee_rooms(self) -> int:
         def _count(txn):
             sql = """
                 SELECT COUNT(DISTINCT room_id) FROM events
@@ -151,7 +152,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_daily_active_e2ee_rooms", _count
         )
 
-    async def count_daily_messages(self):
+    async def count_daily_messages(self) -> int:
         """
         Returns an estimate of the number of messages sent in the last day.
 
@@ -171,7 +172,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
         return await self.db_pool.runInteraction("count_messages", _count_messages)
 
-    async def count_daily_sent_messages(self):
+    async def count_daily_sent_messages(self) -> int:
         def _count_messages(txn):
             # This is good enough as if you have silly characters in your own
             # hostname then that's your own fault.
@@ -192,7 +193,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_daily_sent_messages", _count_messages
         )
 
-    async def count_daily_active_rooms(self):
+    async def count_daily_active_rooms(self) -> int:
         def _count(txn):
             sql = """
                 SELECT COUNT(DISTINCT room_id) FROM events
@@ -226,7 +227,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_monthly_users", self._count_users, thirty_days_ago
         )
 
-    def _count_users(self, txn, time_from):
+    def _count_users(self, txn: Cursor, time_from: int) -> int:
         """
         Returns number of users seen in the past time_from period
         """
@@ -238,7 +239,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             ) u
         """
         txn.execute(sql, (time_from,))
-        (count,) = txn.fetchone()
+        # Mypy knows that fetchone() might return None if there are no rows.
+        # We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
+        # returns exactly one row.
+        (count,) = txn.fetchone()  # type: ignore[misc]
         return count
 
     async def count_r30_users(self) -> Dict[str, int]:
@@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_r30v2_users", _count_r30v2_users
         )
 
-    def _get_start_of_day(self):
+    def _get_start_of_day(self) -> int:
         """
         Returns millisecond unixtime for start of UTC day.
         """
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index bfc85b3add..38ba91af4c 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -69,7 +69,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         #     event_forward_extremities
         #     event_json
         #     event_push_actions
-        #     event_reference_hashes
         #     event_relations
         #     event_search
         #     event_to_state_groups
@@ -220,7 +219,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "event_auth",
             "event_edges",
             "event_forward_extremities",
-            "event_reference_hashes",
             "event_relations",
             "event_search",
             "rejections",
@@ -369,7 +367,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "event_edges",
             "event_json",
             "event_push_actions_staging",
-            "event_reference_hashes",
             "event_relations",
             "event_to_state_groups",
             "event_auth_chains",
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 484976ca6b..fe8fded88b 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -34,7 +34,7 @@ from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
 logger = logging.getLogger(__name__)
@@ -161,7 +161,9 @@ class RelationsWorkerStore(SQLBaseStore):
             if len(events) > limit and last_topo_id and last_stream_id:
                 next_key = RoomStreamToken(last_topo_id, last_stream_id)
                 if from_token:
-                    next_token = from_token.copy_and_replace("room_key", next_key)
+                    next_token = from_token.copy_and_replace(
+                        StreamKeyType.ROOM, next_key
+                    )
                 else:
                     next_token = StreamToken(
                         room_key=next_key,
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 3c49e7ec98..78e0773b2a 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -14,7 +14,7 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
 
 import attr
 
@@ -27,7 +27,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.types import JsonDict
 
 if TYPE_CHECKING:
@@ -149,7 +149,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
             self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings
         )
 
-    async def _background_reindex_search(self, progress, batch_size):
+    async def _background_reindex_search(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         # we work through the events table from highest stream id to lowest
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
@@ -157,7 +159,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
 
-        def reindex_search_txn(txn):
+        def reindex_search_txn(txn: LoggingTransaction) -> int:
             sql = (
                 "SELECT stream_ordering, event_id, room_id, type, json, "
                 " origin_server_ts FROM events"
@@ -255,12 +257,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         return result
 
-    async def _background_reindex_gin_search(self, progress, batch_size):
+    async def _background_reindex_gin_search(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """This handles old synapses which used GIST indexes, if any;
         converting them back to be GIN as per the actual schema.
         """
 
-        def create_index(conn):
+        def create_index(conn: LoggingDatabaseConnection) -> None:
             conn.rollback()
 
             # we have to set autocommit, because postgres refuses to
@@ -299,7 +303,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
         )
         return 1
 
-    async def _background_reindex_search_order(self, progress, batch_size):
+    async def _background_reindex_search_order(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
@@ -307,7 +313,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         if not have_added_index:
 
-            def create_index(conn):
+            def create_index(conn: LoggingDatabaseConnection) -> None:
                 conn.rollback()
                 conn.set_session(autocommit=True)
                 c = conn.cursor()
@@ -336,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
                 pg,
             )
 
-        def reindex_search_txn(txn):
+        def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
             sql = (
                 "UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
                 " origin_server_ts = e.origin_server_ts"
@@ -644,7 +650,8 @@ class SearchStore(SearchBackgroundUpdateStore):
         else:
             raise Exception("Unrecognized database engine")
 
-        args.append(limit)
+        # mypy expects to append only a `str`, not an `int`
+        args.append(limit)  # type: ignore[arg-type]
 
         results = await self.db_pool.execute(
             "search_rooms", self.db_pool.cursor_to_dict, sql, *args
@@ -705,7 +712,7 @@ class SearchStore(SearchBackgroundUpdateStore):
             A set of strings.
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> Set[str]:
             highlight_words = set()
             for event in events:
                 # As a hack we simply join values of all possible keys. This is
@@ -759,11 +766,11 @@ class SearchStore(SearchBackgroundUpdateStore):
         return await self.db_pool.runInteraction("_find_highlights", f)
 
 
-def _to_postgres_options(options_dict):
+def _to_postgres_options(options_dict: JsonDict) -> str:
     return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
 
 
-def _parse_query(database_engine, search_term):
+def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
     """Takes a plain unicode string from the user and converts it into a form
     that can be passed to database.
     We use this so that we can add prefix matching, which isn't something
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 0373af86c8..0e3a23a140 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -788,30 +788,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return None
 
     async def get_current_room_stream_token_for_room_id(
-        self, room_id: Optional[str] = None
+        self, room_id: str
     ) -> RoomStreamToken:
-        """Returns the current position of the rooms stream.
-
-        By default, it returns a live token with the current global stream
-        token. Specifying a `room_id` causes it to return a historic token with
-        the room specific topological token.
-        """
+        """Returns the current position of the rooms stream (historic token)."""
         stream_ordering = self.get_room_max_stream_ordering()
-        if room_id is None:
-            return RoomStreamToken(None, stream_ordering)
-        else:
-            topo = await self.db_pool.runInteraction(
-                "_get_max_topological_txn", self._get_max_topological_txn, room_id
-            )
-            return RoomStreamToken(topo, stream_ordering)
+        topo = await self.db_pool.runInteraction(
+            "_get_max_topological_txn", self._get_max_topological_txn, room_id
+        )
+        return RoomStreamToken(topo, stream_ordering)
 
     def get_stream_id_for_event_txn(
         self,
         txn: LoggingTransaction,
         event_id: str,
-        allow_none=False,
-    ) -> int:
-        return self.db_pool.simple_select_one_onecol_txn(
+        allow_none: bool = False,
+    ) -> Optional[int]:
+        # Type ignore: we pass keyvalues a Dict[str, str]; the function wants
+        # Dict[str, Any]. I think mypy is unhappy because Dict is invariant?
+        return self.db_pool.simple_select_one_onecol_txn(  # type: ignore[call-overload]
             txn=txn,
             table="events",
             keyvalues={"event_id": event_id},
@@ -873,7 +867,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
         rows = txn.fetchall()
-        return rows[0][0] if rows else 0
+        # An aggregate function like MAX() will always return one row per group
+        # so we can safely rely on the lookup here. For example, when a we
+        # lookup a `room_id` which does not exist, `rows` will look like
+        # `[(None,)]`
+        return rows[0][0] if rows[0][0] is not None else 0
 
     @staticmethod
     def _set_before_and_after(
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index afb7d5054d..f51b3d228e 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -11,25 +11,21 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any, Mapping
 
 from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
 from .postgres import PostgresEngine
 from .sqlite import Sqlite3Engine
 
 
-def create_engine(database_config) -> BaseDatabaseEngine:
+def create_engine(database_config: Mapping[str, Any]) -> BaseDatabaseEngine:
     name = database_config["name"]
 
     if name == "sqlite3":
-        import sqlite3
-
-        return Sqlite3Engine(sqlite3, database_config)
+        return Sqlite3Engine(database_config)
 
     if name == "psycopg2":
-        # Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
-        import psycopg2
-
-        return PostgresEngine(psycopg2, database_config)
+        return PostgresEngine(database_config)
 
     raise RuntimeError("Unsupported database engine '%s'" % (name,))
 
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 143cd98ca2..971ff82693 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -13,9 +13,12 @@
 # limitations under the License.
 import abc
 from enum import IntEnum
-from typing import Generic, Optional, TypeVar
+from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, TypeVar
 
-from synapse.storage.types import Connection
+from synapse.storage.types import Connection, Cursor, DBAPI2Module
+
+if TYPE_CHECKING:
+    from synapse.storage.database import LoggingDatabaseConnection
 
 
 class IsolationLevel(IntEnum):
@@ -32,7 +35,7 @@ ConnectionType = TypeVar("ConnectionType", bound=Connection)
 
 
 class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
-    def __init__(self, module, database_config: dict):
+    def __init__(self, module: DBAPI2Module, config: Mapping[str, Any]):
         self.module = module
 
     @property
@@ -69,7 +72,7 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
         ...
 
     @abc.abstractmethod
-    def check_new_database(self, txn) -> None:
+    def check_new_database(self, txn: Cursor) -> None:
         """Gets called when setting up a brand new database. This allows us to
         apply stricter checks on new databases versus existing database.
         """
@@ -79,8 +82,11 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
     def convert_param_style(self, sql: str) -> str:
         ...
 
+    # This method would ideally take a plain ConnectionType, but it seems that
+    # the Sqlite engine expects to use LoggingDatabaseConnection.cursor
+    # instead of sqlite3.Connection.cursor: only the former takes a txn_name.
     @abc.abstractmethod
-    def on_new_connection(self, db_conn: ConnectionType) -> None:
+    def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
         ...
 
     @abc.abstractmethod
@@ -92,7 +98,7 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
         ...
 
     @abc.abstractmethod
-    def lock_table(self, txn, table: str) -> None:
+    def lock_table(self, txn: Cursor, table: str) -> None:
         ...
 
     @property
@@ -102,12 +108,12 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
         ...
 
     @abc.abstractmethod
-    def in_transaction(self, conn: Connection) -> bool:
+    def in_transaction(self, conn: ConnectionType) -> bool:
         """Whether the connection is currently in a transaction."""
         ...
 
     @abc.abstractmethod
-    def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
+    def attempt_to_set_autocommit(self, conn: ConnectionType, autocommit: bool) -> None:
         """Attempt to set the connections autocommit mode.
 
         When True queries are run outside of transactions.
@@ -119,8 +125,8 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
 
     @abc.abstractmethod
     def attempt_to_set_isolation_level(
-        self, conn: Connection, isolation_level: Optional[int]
-    ):
+        self, conn: ConnectionType, isolation_level: Optional[int]
+    ) -> None:
         """Attempt to set the connections isolation level.
 
         Note: This has no effect on SQLite3, as transactions are SERIALIZABLE by default.
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index e8d29e2870..391f8ed24a 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -13,39 +13,47 @@
 # limitations under the License.
 
 import logging
-from typing import Mapping, Optional
+from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast
 
 from synapse.storage.engines._base import (
     BaseDatabaseEngine,
     IncorrectDatabaseSetup,
     IsolationLevel,
 )
-from synapse.storage.types import Connection
+from synapse.storage.types import Cursor
+
+if TYPE_CHECKING:
+    import psycopg2  # noqa: F401
+
+    from synapse.storage.database import LoggingDatabaseConnection
+
 
 logger = logging.getLogger(__name__)
 
 
-class PostgresEngine(BaseDatabaseEngine):
-    def __init__(self, database_module, database_config):
-        super().__init__(database_module, database_config)
-        self.module.extensions.register_type(self.module.extensions.UNICODE)
+class PostgresEngine(BaseDatabaseEngine["psycopg2.connection"]):
+    def __init__(self, database_config: Mapping[str, Any]):
+        import psycopg2.extensions
+
+        super().__init__(psycopg2, database_config)
+        psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
 
         # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
         # actually want to use bytes than wrap it in `bytearray`.
-        def _disable_bytes_adapter(_):
+        def _disable_bytes_adapter(_: bytes) -> NoReturn:
             raise Exception("Passing bytes to DB is disabled.")
 
-        self.module.extensions.register_adapter(bytes, _disable_bytes_adapter)
-        self.synchronous_commit = database_config.get("synchronous_commit", True)
-        self._version = None  # unknown as yet
+        psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter)
+        self.synchronous_commit: bool = database_config.get("synchronous_commit", True)
+        self._version: Optional[int] = None  # unknown as yet
 
         self.isolation_level_map: Mapping[int, int] = {
-            IsolationLevel.READ_COMMITTED: self.module.extensions.ISOLATION_LEVEL_READ_COMMITTED,
-            IsolationLevel.REPEATABLE_READ: self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
-            IsolationLevel.SERIALIZABLE: self.module.extensions.ISOLATION_LEVEL_SERIALIZABLE,
+            IsolationLevel.READ_COMMITTED: psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED,
+            IsolationLevel.REPEATABLE_READ: psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ,
+            IsolationLevel.SERIALIZABLE: psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE,
         }
         self.default_isolation_level = (
-            self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
+            psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ
         )
         self.config = database_config
 
@@ -53,19 +61,21 @@ class PostgresEngine(BaseDatabaseEngine):
     def single_threaded(self) -> bool:
         return False
 
-    def get_db_locale(self, txn):
+    def get_db_locale(self, txn: Cursor) -> Tuple[str, str]:
         txn.execute(
             "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
         )
-        collation, ctype = txn.fetchone()
+        collation, ctype = cast(Tuple[str, str], txn.fetchone())
         return collation, ctype
 
-    def check_database(self, db_conn, allow_outdated_version: bool = False):
+    def check_database(
+        self, db_conn: "psycopg2.connection", allow_outdated_version: bool = False
+    ) -> None:
         # Get the version of PostgreSQL that we're using. As per the psycopg2
         # docs: The number is formed by converting the major, minor, and
         # revision numbers into two-decimal-digit numbers and appending them
         # together. For example, version 8.1.5 will be returned as 80105
-        self._version = db_conn.server_version
+        self._version = cast(int, db_conn.server_version)
         allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
 
         # Are we on a supported PostgreSQL version?
@@ -108,7 +118,7 @@ class PostgresEngine(BaseDatabaseEngine):
                         ctype,
                     )
 
-    def check_new_database(self, txn):
+    def check_new_database(self, txn: Cursor) -> None:
         """Gets called when setting up a brand new database. This allows us to
         apply stricter checks on new databases versus existing database.
         """
@@ -129,10 +139,10 @@ class PostgresEngine(BaseDatabaseEngine):
                 "See docs/postgres.md for more information." % ("\n".join(errors))
             )
 
-    def convert_param_style(self, sql):
+    def convert_param_style(self, sql: str) -> str:
         return sql.replace("?", "%s")
 
-    def on_new_connection(self, db_conn):
+    def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
         db_conn.set_isolation_level(self.default_isolation_level)
 
         # Set the bytea output to escape, vs the default of hex
@@ -149,14 +159,14 @@ class PostgresEngine(BaseDatabaseEngine):
         db_conn.commit()
 
     @property
-    def can_native_upsert(self):
+    def can_native_upsert(self) -> bool:
         """
         Can we use native UPSERTs?
         """
         return True
 
     @property
-    def supports_using_any_list(self):
+    def supports_using_any_list(self) -> bool:
         """Do we support using `a = ANY(?)` and passing a list"""
         return True
 
@@ -165,27 +175,25 @@ class PostgresEngine(BaseDatabaseEngine):
         """Do we support the `RETURNING` clause in insert/update/delete?"""
         return True
 
-    def is_deadlock(self, error):
-        if isinstance(error, self.module.DatabaseError):
+    def is_deadlock(self, error: Exception) -> bool:
+        import psycopg2.extensions
+
+        if isinstance(error, psycopg2.DatabaseError):
             # https://www.postgresql.org/docs/current/static/errcodes-appendix.html
             # "40001" serialization_failure
             # "40P01" deadlock_detected
             return error.pgcode in ["40001", "40P01"]
         return False
 
-    def is_connection_closed(self, conn):
+    def is_connection_closed(self, conn: "psycopg2.connection") -> bool:
         return bool(conn.closed)
 
-    def lock_table(self, txn, table):
+    def lock_table(self, txn: Cursor, table: str) -> None:
         txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
 
     @property
-    def server_version(self):
-        """Returns a string giving the server version. For example: '8.1.5'
-
-        Returns:
-            string
-        """
+    def server_version(self) -> str:
+        """Returns a string giving the server version. For example: '8.1.5'."""
         # note that this is a bit of a hack because it relies on check_database
         # having been called. Still, that should be a safe bet here.
         numver = self._version
@@ -197,17 +205,21 @@ class PostgresEngine(BaseDatabaseEngine):
         else:
             return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
 
-    def in_transaction(self, conn: Connection) -> bool:
-        return conn.status != self.module.extensions.STATUS_READY  # type: ignore
+    def in_transaction(self, conn: "psycopg2.connection") -> bool:
+        import psycopg2.extensions
+
+        return conn.status != psycopg2.extensions.STATUS_READY
 
-    def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
-        return conn.set_session(autocommit=autocommit)  # type: ignore
+    def attempt_to_set_autocommit(
+        self, conn: "psycopg2.connection", autocommit: bool
+    ) -> None:
+        return conn.set_session(autocommit=autocommit)
 
     def attempt_to_set_isolation_level(
-        self, conn: Connection, isolation_level: Optional[int]
-    ):
+        self, conn: "psycopg2.connection", isolation_level: Optional[int]
+    ) -> None:
         if isolation_level is None:
             isolation_level = self.default_isolation_level
         else:
             isolation_level = self.isolation_level_map[isolation_level]
-        return conn.set_isolation_level(isolation_level)  # type: ignore
+        return conn.set_isolation_level(isolation_level)
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 6c19e55999..621f2c5efe 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -12,21 +12,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import platform
+import sqlite3
 import struct
 import threading
-import typing
-from typing import Optional
+from typing import TYPE_CHECKING, Any, List, Mapping, Optional
 
 from synapse.storage.engines import BaseDatabaseEngine
-from synapse.storage.types import Connection
+from synapse.storage.types import Cursor
 
-if typing.TYPE_CHECKING:
-    import sqlite3  # noqa: F401
+if TYPE_CHECKING:
+    from synapse.storage.database import LoggingDatabaseConnection
 
 
-class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
-    def __init__(self, database_module, database_config):
-        super().__init__(database_module, database_config)
+class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
+    def __init__(self, database_config: Mapping[str, Any]):
+        super().__init__(sqlite3, database_config)
 
         database = database_config.get("args", {}).get("database")
         self._is_in_memory = database in (
@@ -37,7 +37,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
         if platform.python_implementation() == "PyPy":
             # pypy's sqlite3 module doesn't handle bytearrays, convert them
             # back to bytes.
-            database_module.register_adapter(bytearray, lambda array: bytes(array))
+            sqlite3.register_adapter(bytearray, lambda array: bytes(array))
 
         # The current max state_group, or None if we haven't looked
         # in the DB yet.
@@ -49,41 +49,43 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
         return True
 
     @property
-    def can_native_upsert(self):
+    def can_native_upsert(self) -> bool:
         """
         Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
         more work we haven't done yet to tell what was inserted vs updated.
         """
-        return self.module.sqlite_version_info >= (3, 24, 0)
+        return sqlite3.sqlite_version_info >= (3, 24, 0)
 
     @property
-    def supports_using_any_list(self):
+    def supports_using_any_list(self) -> bool:
         """Do we support using `a = ANY(?)` and passing a list"""
         return False
 
     @property
     def supports_returning(self) -> bool:
         """Do we support the `RETURNING` clause in insert/update/delete?"""
-        return self.module.sqlite_version_info >= (3, 35, 0)
+        return sqlite3.sqlite_version_info >= (3, 35, 0)
 
-    def check_database(self, db_conn, allow_outdated_version: bool = False):
+    def check_database(
+        self, db_conn: sqlite3.Connection, allow_outdated_version: bool = False
+    ) -> None:
         if not allow_outdated_version:
-            version = self.module.sqlite_version_info
+            version = sqlite3.sqlite_version_info
             # Synapse is untested against older SQLite versions, and we don't want
             # to let users upgrade to a version of Synapse with broken support for their
             # sqlite version, because it risks leaving them with a half-upgraded db.
             if version < (3, 22, 0):
                 raise RuntimeError("Synapse requires sqlite 3.22 or above.")
 
-    def check_new_database(self, txn):
+    def check_new_database(self, txn: Cursor) -> None:
         """Gets called when setting up a brand new database. This allows us to
         apply stricter checks on new databases versus existing database.
         """
 
-    def convert_param_style(self, sql):
+    def convert_param_style(self, sql: str) -> str:
         return sql
 
-    def on_new_connection(self, db_conn):
+    def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None:
         # We need to import here to avoid an import loop.
         from synapse.storage.prepare_database import prepare_database
 
@@ -97,48 +99,46 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
         db_conn.execute("PRAGMA foreign_keys = ON;")
         db_conn.commit()
 
-    def is_deadlock(self, error):
+    def is_deadlock(self, error: Exception) -> bool:
         return False
 
-    def is_connection_closed(self, conn):
+    def is_connection_closed(self, conn: sqlite3.Connection) -> bool:
         return False
 
-    def lock_table(self, txn, table):
+    def lock_table(self, txn: Cursor, table: str) -> None:
         return
 
     @property
-    def server_version(self):
-        """Gets a string giving the server version. For example: '3.22.0'
+    def server_version(self) -> str:
+        """Gets a string giving the server version. For example: '3.22.0'."""
+        return "%i.%i.%i" % sqlite3.sqlite_version_info
 
-        Returns:
-            string
-        """
-        return "%i.%i.%i" % self.module.sqlite_version_info
-
-    def in_transaction(self, conn: Connection) -> bool:
-        return conn.in_transaction  # type: ignore
+    def in_transaction(self, conn: sqlite3.Connection) -> bool:
+        return conn.in_transaction
 
-    def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
+    def attempt_to_set_autocommit(
+        self, conn: sqlite3.Connection, autocommit: bool
+    ) -> None:
         # Twisted doesn't let us set attributes on the connections, so we can't
         # set the connection to autocommit mode.
         pass
 
     def attempt_to_set_isolation_level(
-        self, conn: Connection, isolation_level: Optional[int]
-    ):
-        # All transactions are SERIALIZABLE by default in sqllite
+        self, conn: sqlite3.Connection, isolation_level: Optional[int]
+    ) -> None:
+        # All transactions are SERIALIZABLE by default in sqlite
         pass
 
 
 # Following functions taken from: https://github.com/coleifer/peewee
 
 
-def _parse_match_info(buf):
+def _parse_match_info(buf: bytes) -> List[int]:
     bufsize = len(buf)
     return [struct.unpack("@I", buf[i : i + 4])[0] for i in range(0, bufsize, 4)]
 
 
-def _rank(raw_match_info):
+def _rank(raw_match_info: bytes) -> float:
     """Handle match_info called w/default args 'pcx' - based on the example rank
     function http://sqlite.org/fts3.html#appendix_a
     """
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 97118045a1..0fc282866b 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -25,6 +25,7 @@ from typing import (
     Collection,
     Deque,
     Dict,
+    Generator,
     Generic,
     Iterable,
     List,
@@ -207,7 +208,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
 
         return res
 
-    def _handle_queue(self, room_id):
+    def _handle_queue(self, room_id: str) -> None:
         """Attempts to handle the queue for a room if not already being handled.
 
         The queue's callback will be invoked with for each item in the queue,
@@ -227,7 +228,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
 
         self._currently_persisting_rooms.add(room_id)
 
-        async def handle_queue_loop():
+        async def handle_queue_loop() -> None:
             try:
                 queue = self._get_drainining_queue(room_id)
                 for item in queue:
@@ -250,15 +251,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
                         with PreserveLoggingContext():
                             item.deferred.callback(ret)
             finally:
-                queue = self._event_persist_queues.pop(room_id, None)
-                if queue:
-                    self._event_persist_queues[room_id] = queue
+                remaining_queue = self._event_persist_queues.pop(room_id, None)
+                if remaining_queue:
+                    self._event_persist_queues[room_id] = remaining_queue
                 self._currently_persisting_rooms.discard(room_id)
 
         # set handle_queue_loop off in the background
         run_as_background_process("persist_events", handle_queue_loop)
 
-    def _get_drainining_queue(self, room_id):
+    def _get_drainining_queue(
+        self, room_id: str
+    ) -> Generator[_EventPersistQueueItem, None, None]:
         queue = self._event_persist_queues.setdefault(room_id, deque())
 
         try:
@@ -317,7 +320,9 @@ class EventsPersistenceStorage:
         for event, ctx in events_and_contexts:
             partitioned.setdefault(event.room_id, []).append((event, ctx))
 
-        async def enqueue(item):
+        async def enqueue(
+            item: Tuple[str, List[Tuple[EventBase, EventContext]]]
+        ) -> Dict[str, str]:
             room_id, evs_ctxs = item
             return await self._event_persist_queue.add_to_queue(
                 room_id, evs_ctxs, backfilled=backfilled
@@ -487,12 +492,6 @@ class EventsPersistenceStorage:
             # extremities in each room
             new_forward_extremities: Dict[str, Set[str]] = {}
 
-            # map room_id->(type,state_key)->event_id tracking the full
-            # state in each room after adding these events.
-            # This is simply used to prefill the get_current_state_ids
-            # cache
-            current_state_for_room: Dict[str, StateMap[str]] = {}
-
             # map room_id->(to_delete, to_insert) where to_delete is a list
             # of type/state keys to remove from current state, and to_insert
             # is a map (type,key)->event_id giving the state delta in each
@@ -628,14 +627,8 @@ class EventsPersistenceStorage:
 
                             state_delta_for_room[room_id] = delta
 
-                        # If we have the current_state then lets prefill
-                        # the cache with it.
-                        if current_state is not None:
-                            current_state_for_room[room_id] = current_state
-
             await self.persist_events_store._persist_events_and_state_updates(
                 chunk,
-                current_state_for_room=current_state_for_room,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremities=new_forward_extremities,
                 use_negative_stream_ordering=backfilled,
@@ -733,7 +726,8 @@ class EventsPersistenceStorage:
 
             The first state map is the full new current state and the second
             is the delta to the existing current state. If both are None then
-            there has been no change.
+            there has been no change. Either or neither can be None if there
+            has been a change.
 
             The function may prune some old entries from the set of new
             forward extremities if it's safe to do so.
@@ -743,9 +737,6 @@ class EventsPersistenceStorage:
             the new current state is only returned if we've already calculated
             it.
         """
-        # map from state_group to ((type, key) -> event_id) state map
-        state_groups_map = {}
-
         # Map from (prev state group, new state group) -> delta state dict
         state_group_deltas = {}
 
@@ -759,16 +750,6 @@ class EventsPersistenceStorage:
                     )
                 continue
 
-            if ctx.state_group in state_groups_map:
-                continue
-
-            # We're only interested in pulling out state that has already
-            # been cached in the context. We'll pull stuff out of the DB later
-            # if necessary.
-            current_state_ids = ctx.get_cached_current_state_ids()
-            if current_state_ids is not None:
-                state_groups_map[ctx.state_group] = current_state_ids
-
             if ctx.prev_group:
                 state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
 
@@ -826,18 +807,14 @@ class EventsPersistenceStorage:
             delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
             if delta_ids is not None:
                 # We have a delta from the existing to new current state,
-                # so lets just return that. If we happen to already have
-                # the current state in memory then lets also return that,
-                # but it doesn't matter if we don't.
-                new_state = state_groups_map.get(new_state_group)
-                return new_state, delta_ids, new_latest_event_ids
+                # so lets just return that.
+                return None, delta_ids, new_latest_event_ids
 
         # Now that we have calculated new_state_groups we need to get
         # their state IDs so we can resolve to a single state set.
-        missing_state = new_state_groups - set(state_groups_map)
-        if missing_state:
-            group_to_state = await self.state_store._get_state_for_groups(missing_state)
-            state_groups_map.update(group_to_state)
+        state_groups_map = await self.state_store._get_state_for_groups(
+            new_state_groups
+        )
 
         if len(new_state_groups) == 1:
             # If there is only one state group, then we know what the current
@@ -1130,7 +1107,7 @@ class EventsPersistenceStorage:
 
         return False
 
-    async def _handle_potentially_left_users(self, user_ids: Set[str]):
+    async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None:
         """Given a set of remote users check if the server still shares a room with
         them. If not then mark those users' device cache as stale.
         """
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 546d6bae6e..c33df42084 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -85,7 +85,7 @@ def prepare_database(
     database_engine: BaseDatabaseEngine,
     config: Optional[HomeServerConfig],
     databases: Collection[str] = ("main", "state"),
-):
+) -> None:
     """Prepares a physical database for usage. Will either create all necessary tables
     or upgrade from an older schema version.
 
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 871d4ace12..20c344faea 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-SCHEMA_VERSION = 69  # remember to update the list below when updating
+SCHEMA_VERSION = 70  # remember to update the list below when updating
 """Represents the expectations made by the codebase about the database schema
 
 This should be incremented whenever the codebase changes its requirements on the
@@ -62,6 +62,9 @@ Changes in SCHEMA_VERSION = 68:
 Changes in SCHEMA_VERSION = 69:
     - We now write to `device_lists_changes_in_room` table.
     - Use sequence to generate future `application_services_txns.txn_id`s
+
+Changes in SCHEMA_VERSION = 70:
+    - event_reference_hashes is no longer written to.
 """
 
 
diff --git a/synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql b/synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql
new file mode 100644
index 0000000000..22ae3b8c00
--- /dev/null
+++ b/synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql
@@ -0,0 +1,18 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Background update to clear the inboxes of hidden and deleted devices.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (6902, 'cache_invalidation_index_by_instance', '{}');
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index d1d5859214..d4a1bd4f9d 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -62,7 +62,7 @@ class StateFilter:
     types: "frozendict[str, Optional[FrozenSet[str]]]"
     include_others: bool = False
 
-    def __attrs_post_init__(self):
+    def __attrs_post_init__(self) -> None:
         # If `include_others` is set we canonicalise the filter by removing
         # wildcards from the types dictionary
         if self.include_others:
@@ -138,7 +138,9 @@ class StateFilter:
         )
 
     @staticmethod
-    def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
+    def freeze(
+        types: Mapping[str, Optional[Collection[str]]], include_others: bool
+    ) -> "StateFilter":
         """
         Returns a (frozen) StateFilter with the same contents as the parameters
         specified here, which can be made of mutable types.
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index d7d6f1d90e..0031df1e06 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -11,7 +11,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
+from types import TracebackType
+from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
 
 from typing_extensions import Protocol
 
@@ -86,5 +87,80 @@ class Connection(Protocol):
     def __enter__(self) -> "Connection":
         ...
 
-    def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> Optional[bool]:
+        ...
+
+
+class DBAPI2Module(Protocol):
+    """The module-level attributes that we use from PEP 249.
+
+    This is NOT a comprehensive stub for the entire DBAPI2."""
+
+    __name__: str
+
+    # Exceptions. See https://peps.python.org/pep-0249/#exceptions
+
+    # For our specific drivers:
+    # - Python's sqlite3 module doesn't contains the same descriptions as the
+    #   DBAPI2 spec, see https://docs.python.org/3/library/sqlite3.html#exceptions
+    # - Psycopg2 maps every Postgres error code onto a unique exception class which
+    #   extends from this hierarchy. See
+    #     https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#exceptions
+    #     https://www.postgresql.org/docs/current/errcodes-appendix.html#ERRCODES-TABLE
+    Warning: Type[Exception]
+    Error: Type[Exception]
+
+    # Errors are divided into `InterfaceError`s (something went wrong in the database
+    # driver) and `DatabaseError`s (something went wrong in the database). These are
+    # both subclasses of `Error`, but we can't currently express this in type
+    # annotations due to https://github.com/python/mypy/issues/8397
+    InterfaceError: Type[Exception]
+    DatabaseError: Type[Exception]
+
+    # Everything below is a subclass of `DatabaseError`.
+
+    # Roughly: the database rejected a nonsensical value. Examples:
+    # - An integer was too big for its data type.
+    # - An invalid date time was provided.
+    # - A string contained a null code point.
+    DataError: Type[Exception]
+
+    # Roughly: something went wrong in the database, but it's not within the application
+    # programmer's control. Examples:
+    # - We failed to establish a connection to the database.
+    # - The connection to the database was lost.
+    # - A deadlock was detected.
+    # - A serialisation failure occurred.
+    # - The database ran out of resources, such as storage, memory, connections, etc.
+    # - The database encountered an error from the operating system.
+    OperationalError: Type[Exception]
+
+    # Roughly: we've given the database data which breaks a rule we asked it to enforce.
+    # Examples:
+    # - Stop, criminal scum! You violated the foreign key constraint
+    # - Also check constraints, non-null constraints, etc.
+    IntegrityError: Type[Exception]
+
+    # Roughly: something went wrong within the database server itself.
+    InternalError: Type[Exception]
+
+    # Roughly: the application did something silly that needs to be fixed. Examples:
+    # - We don't have permissions to do something.
+    # - We tried to create a table with duplicate column names.
+    # - We tried to use a reserved name.
+    # - We referred to a column that doesn't exist.
+    ProgrammingError: Type[Exception]
+
+    # Roughly: we've tried to do something that this database doesn't support.
+    NotSupportedError: Type[Exception]
+
+    def connect(self, **parameters: object) -> Connection:
         ...
+
+
+__all__ = ["Cursor", "Connection", "DBAPI2Module"]
diff --git a/synapse/types.py b/synapse/types.py
index 9ac688b23b..bd8071d51d 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -24,6 +24,7 @@ from typing import (
     Mapping,
     Match,
     MutableMapping,
+    NoReturn,
     Optional,
     Set,
     Tuple,
@@ -35,7 +36,8 @@ from typing import (
 import attr
 from frozendict import frozendict
 from signedjson.key import decode_verify_key_bytes
-from typing_extensions import TypedDict
+from signedjson.types import VerifyKey
+from typing_extensions import Final, TypedDict
 from unpaddedbase64 import decode_base64
 from zope.interface import Interface
 
@@ -55,6 +57,7 @@ from synapse.util.stringutils import parse_and_validate_server_name
 if TYPE_CHECKING:
     from synapse.appservice.api import ApplicationService
     from synapse.storage.databases.main import DataStore, PurgeEventsStore
+    from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
 
 # Define a state map type from type/state_key to T (usually an event ID or
 # event)
@@ -114,7 +117,7 @@ class Requester:
     app_service: Optional["ApplicationService"]
     authenticated_entity: str
 
-    def serialize(self):
+    def serialize(self) -> Dict[str, Any]:
         """Converts self to a type that can be serialized as JSON, and then
         deserialized by `deserialize`
 
@@ -132,7 +135,9 @@ class Requester:
         }
 
     @staticmethod
-    def deserialize(store, input):
+    def deserialize(
+        store: "ApplicationServiceWorkerStore", input: Dict[str, Any]
+    ) -> "Requester":
         """Converts a dict that was produced by `serialize` back into a
         Requester.
 
@@ -236,10 +241,10 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
     domain: str
 
     # Because this is a frozen class, it is deeply immutable.
-    def __copy__(self):
+    def __copy__(self: DS) -> DS:
         return self
 
-    def __deepcopy__(self, memo):
+    def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS:
         return self
 
     @classmethod
@@ -625,6 +630,22 @@ class RoomStreamToken:
             return "s%d" % (self.stream,)
 
 
+class StreamKeyType:
+    """Known stream types.
+
+    A stream is a list of entities ordered by an incrementing "stream token".
+    """
+
+    ROOM: Final = "room_key"
+    PRESENCE: Final = "presence_key"
+    TYPING: Final = "typing_key"
+    RECEIPT: Final = "receipt_key"
+    ACCOUNT_DATA: Final = "account_data_key"
+    PUSH_RULES: Final = "push_rules_key"
+    TO_DEVICE: Final = "to_device_key"
+    DEVICE_LIST: Final = "device_list_key"
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class StreamToken:
     """A collection of keys joined together by underscores in the following
@@ -729,16 +750,18 @@ class StreamToken:
         )
 
     @property
-    def room_stream_id(self):
+    def room_stream_id(self) -> int:
         return self.room_key.stream
 
-    def copy_and_advance(self, key, new_value) -> "StreamToken":
+    def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken":
         """Advance the given key in the token to a new value if and only if the
         new value is after the old value.
+
+        :raises TypeError: if `key` is not the one of the keys tracked by a StreamToken.
         """
-        if key == "room_key":
+        if key == StreamKeyType.ROOM:
             new_token = self.copy_and_replace(
-                "room_key", self.room_key.copy_and_advance(new_value)
+                StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value)
             )
             return new_token
 
@@ -751,7 +774,7 @@ class StreamToken:
         else:
             return self
 
-    def copy_and_replace(self, key, new_value) -> "StreamToken":
+    def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken":
         return attr.evolve(self, **{key: new_value})
 
 
@@ -793,14 +816,14 @@ class ThirdPartyInstanceID:
     # Deny iteration because it will bite you if you try to create a singleton
     # set by:
     #    users = set(user)
-    def __iter__(self):
+    def __iter__(self) -> NoReturn:
         raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
 
     # Because this class is a frozen class, it is deeply immutable.
-    def __copy__(self):
+    def __copy__(self) -> "ThirdPartyInstanceID":
         return self
 
-    def __deepcopy__(self, memo):
+    def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID":
         return self
 
     @classmethod
@@ -852,25 +875,28 @@ class DeviceListUpdates:
         return bool(self.changed or self.left)
 
 
-def get_verify_key_from_cross_signing_key(key_info):
+def get_verify_key_from_cross_signing_key(
+    key_info: Mapping[str, Any]
+) -> Tuple[str, VerifyKey]:
     """Get the key ID and signedjson verify key from a cross-signing key dict
 
     Args:
-        key_info (dict): a cross-signing key dict, which must have a "keys"
+        key_info: a cross-signing key dict, which must have a "keys"
             property that has exactly one item in it
 
     Returns:
-        (str, VerifyKey): the key ID and verify key for the cross-signing key
+        the key ID and verify key for the cross-signing key
     """
-    # make sure that exactly one key is provided
+    # make sure that a `keys` field is provided
     if "keys" not in key_info:
         raise ValueError("Invalid key")
     keys = key_info["keys"]
-    if len(keys) != 1:
-        raise ValueError("Invalid key")
-    # and return that one key
-    for key_id, key_data in keys.items():
+    # and that it contains exactly one key
+    if len(keys) == 1:
+        key_id, key_data = next(iter(keys.items()))
         return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))
+    else:
+        raise ValueError("Invalid key")
 
 
 @attr.s(auto_attribs=True, frozen=True, slots=True)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 45ff0de638..a3b60578e3 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import logging
+import math
 import threading
 import weakref
 from enum import Enum
@@ -40,6 +41,7 @@ from twisted.internet.interfaces import IReactorTime
 
 from synapse.config import cache as cache_config
 from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.metrics.jemalloc import get_jemalloc_stats
 from synapse.util import Clock, caches
 from synapse.util.caches import CacheMetric, EvictionReason, register_cache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
@@ -106,10 +108,16 @@ GLOBAL_ROOT = ListNode["_Node"].create_root_node()
 
 
 @wrap_as_background_process("LruCache._expire_old_entries")
-async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
+async def _expire_old_entries(
+    clock: Clock, expiry_seconds: int, autotune_config: Optional[dict]
+) -> None:
     """Walks the global cache list to find cache entries that haven't been
-    accessed in the given number of seconds.
+    accessed in the given number of seconds, or if a given memory threshold has been breached.
     """
+    if autotune_config:
+        max_cache_memory_usage = autotune_config["max_cache_memory_usage"]
+        target_cache_memory_usage = autotune_config["target_cache_memory_usage"]
+        min_cache_ttl = autotune_config["min_cache_ttl"] / 1000
 
     now = int(clock.time())
     node = GLOBAL_ROOT.prev_node
@@ -119,11 +127,36 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
 
     logger.debug("Searching for stale caches")
 
+    evicting_due_to_memory = False
+
+    # determine if we're evicting due to memory
+    jemalloc_interface = get_jemalloc_stats()
+    if jemalloc_interface and autotune_config:
+        try:
+            jemalloc_interface.refresh_stats()
+            mem_usage = jemalloc_interface.get_stat("allocated")
+            if mem_usage > max_cache_memory_usage:
+                logger.info("Begin memory-based cache eviction.")
+                evicting_due_to_memory = True
+        except Exception:
+            logger.warning(
+                "Unable to read allocated memory, skipping memory-based cache eviction."
+            )
+
     while node is not GLOBAL_ROOT:
         # Only the root node isn't a `_TimedListNode`.
         assert isinstance(node, _TimedListNode)
 
-        if node.last_access_ts_secs > now - expiry_seconds:
+        # if node has not aged past expiry_seconds and we are not evicting due to memory usage, there's
+        # nothing to do here
+        if (
+            node.last_access_ts_secs > now - expiry_seconds
+            and not evicting_due_to_memory
+        ):
+            break
+
+        # if entry is newer than min_cache_entry_ttl then do not evict and don't evict anything newer
+        if evicting_due_to_memory and now - node.last_access_ts_secs < min_cache_ttl:
             break
 
         cache_entry = node.get_cache_entry()
@@ -136,10 +169,29 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
         assert cache_entry is not None
         cache_entry.drop_from_cache()
 
+        # Check mem allocation periodically if we are evicting a bunch of caches
+        if jemalloc_interface and evicting_due_to_memory and (i + 1) % 100 == 0:
+            try:
+                jemalloc_interface.refresh_stats()
+                mem_usage = jemalloc_interface.get_stat("allocated")
+                if mem_usage < target_cache_memory_usage:
+                    evicting_due_to_memory = False
+                    logger.info("Stop memory-based cache eviction.")
+            except Exception:
+                logger.warning(
+                    "Unable to read allocated memory, this may affect memory-based cache eviction."
+                )
+                # If we've failed to read the current memory usage then we
+                # should stop trying to evict based on memory usage
+                evicting_due_to_memory = False
+
         # If we do lots of work at once we yield to allow other stuff to happen.
         if (i + 1) % 10000 == 0:
             logger.debug("Waiting during drop")
-            await clock.sleep(0)
+            if node.last_access_ts_secs > now - expiry_seconds:
+                await clock.sleep(0.5)
+            else:
+                await clock.sleep(0)
             logger.debug("Waking during drop")
 
         node = next_node
@@ -156,21 +208,28 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
 
 def setup_expire_lru_cache_entries(hs: "HomeServer") -> None:
     """Start a background job that expires all cache entries if they have not
-    been accessed for the given number of seconds.
+    been accessed for the given number of seconds, or if a given memory usage threshold has been
+    breached.
     """
-    if not hs.config.caches.expiry_time_msec:
+    if not hs.config.caches.expiry_time_msec and not hs.config.caches.cache_autotuning:
         return
 
-    logger.info(
-        "Expiring LRU caches after %d seconds", hs.config.caches.expiry_time_msec / 1000
-    )
+    if hs.config.caches.expiry_time_msec:
+        expiry_time = hs.config.caches.expiry_time_msec / 1000
+        logger.info("Expiring LRU caches after %d seconds", expiry_time)
+    else:
+        expiry_time = math.inf
 
     global USE_GLOBAL_LIST
     USE_GLOBAL_LIST = True
 
     clock = hs.get_clock()
     clock.looping_call(
-        _expire_old_entries, 30 * 1000, clock, hs.config.caches.expiry_time_msec / 1000
+        _expire_old_entries,
+        30 * 1000,
+        clock,
+        expiry_time,
+        hs.config.caches.cache_autotuning,
     )