summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/_scripts/register_new_matrix_user.py102
-rw-r--r--synapse/app/_base.py39
-rw-r--r--synapse/app/generic_worker.py6
-rw-r--r--synapse/app/homeserver.py6
-rw-r--r--synapse/config/_base.py59
-rw-r--r--synapse/config/metrics.py29
-rw-r--r--synapse/config/registration.py33
-rw-r--r--synapse/crypto/event_signing.py2
-rw-r--r--synapse/events/spamcheck.py2
-rw-r--r--synapse/federation/federation_base.py22
-rw-r--r--synapse/federation/federation_client.py23
-rw-r--r--synapse/federation/federation_server.py11
-rw-r--r--synapse/handlers/device.py8
-rw-r--r--synapse/handlers/directory.py12
-rw-r--r--synapse/handlers/event_auth.py9
-rw-r--r--synapse/handlers/events.py9
-rw-r--r--synapse/handlers/federation.py53
-rw-r--r--synapse/handlers/federation_event.py8
-rw-r--r--synapse/handlers/message.py6
-rw-r--r--synapse/handlers/pagination.py6
-rw-r--r--synapse/handlers/presence.py3
-rw-r--r--synapse/handlers/room.py21
-rw-r--r--synapse/handlers/room_member.py13
-rw-r--r--synapse/handlers/sync.py4
-rw-r--r--synapse/handlers/typing.py7
-rw-r--r--synapse/metrics/__init__.py4
-rw-r--r--synapse/metrics/_legacy_exposition.py (renamed from synapse/metrics/_exposition.py)34
-rw-r--r--synapse/replication/slave/storage/push_rule.py1
-rw-r--r--synapse/rest/media/v1/_base.py40
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py4
-rw-r--r--synapse/server_notices/server_notices_manager.py10
-rw-r--r--synapse/state/__init__.py13
-rw-r--r--synapse/state/v2.py69
-rw-r--r--synapse/storage/controllers/state.py3
-rw-r--r--synapse/storage/databases/main/account_data.py3
-rw-r--r--synapse/storage/databases/main/event_push_actions.py256
-rw-r--r--synapse/storage/databases/main/push_rule.py6
-rw-r--r--synapse/storage/databases/main/roommember.py187
-rw-r--r--synapse/storage/util/id_generators.py13
-rw-r--r--synapse/util/caches/__init__.py16
-rw-r--r--synapse/util/caches/deferred_cache.py346
-rw-r--r--synapse/util/caches/descriptors.py115
-rw-r--r--synapse/util/caches/treecache.py3
43 files changed, 1022 insertions, 594 deletions
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index 092601f530..0c4504d5d8 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -1,6 +1,6 @@
 # Copyright 2015, 2016 OpenMarket Ltd
 # Copyright 2018 New Vector
-# Copyright 2021 The Matrix.org Foundation C.I.C.
+# Copyright 2021-22 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -20,11 +20,22 @@ import hashlib
 import hmac
 import logging
 import sys
-from typing import Callable, Optional
+from typing import Any, Callable, Dict, Optional
 
 import requests
 import yaml
 
+_CONFLICTING_SHARED_SECRET_OPTS_ERROR = """\
+Conflicting options 'registration_shared_secret' and 'registration_shared_secret_path'
+are both defined in config file.
+"""
+
+_NO_SHARED_SECRET_OPTS_ERROR = """\
+No 'registration_shared_secret' or 'registration_shared_secret_path' defined in config.
+"""
+
+_DEFAULT_SERVER_URL = "http://localhost:8008"
+
 
 def request_registration(
     user: str,
@@ -203,31 +214,104 @@ def main() -> None:
 
     parser.add_argument(
         "server_url",
-        default="https://localhost:8448",
         nargs="?",
-        help="URL to use to talk to the homeserver. Defaults to "
-        " 'https://localhost:8448'.",
+        help="URL to use to talk to the homeserver. By default, tries to find a "
+        "suitable URL from the configuration file. Otherwise, defaults to "
+        f"'{_DEFAULT_SERVER_URL}'.",
     )
 
     args = parser.parse_args()
 
     if "config" in args and args.config:
         config = yaml.safe_load(args.config)
-        secret = config.get("registration_shared_secret", None)
+
+    if args.shared_secret:
+        secret = args.shared_secret
+    else:
+        # argparse should check that we have either config or shared secret
+        assert config
+
+        secret = config.get("registration_shared_secret")
+        secret_file = config.get("registration_shared_secret_path")
+        if secret_file:
+            if secret:
+                print(_CONFLICTING_SHARED_SECRET_OPTS_ERROR, file=sys.stderr)
+                sys.exit(1)
+            secret = _read_file(secret_file, "registration_shared_secret_path").strip()
         if not secret:
-            print("No 'registration_shared_secret' defined in config.")
+            print(_NO_SHARED_SECRET_OPTS_ERROR, file=sys.stderr)
             sys.exit(1)
+
+    if args.server_url:
+        server_url = args.server_url
+    elif config:
+        server_url = _find_client_listener(config)
+        if not server_url:
+            server_url = _DEFAULT_SERVER_URL
+            print(
+                "Unable to find a suitable HTTP listener in the configuration file. "
+                f"Trying {server_url} as a last resort.",
+                file=sys.stderr,
+            )
     else:
-        secret = args.shared_secret
+        server_url = _DEFAULT_SERVER_URL
+        print(
+            f"No server url or configuration file given. Defaulting to {server_url}.",
+            file=sys.stderr,
+        )
 
     admin = None
     if args.admin or args.no_admin:
         admin = args.admin
 
     register_new_user(
-        args.user, args.password, args.server_url, secret, admin, args.user_type
+        args.user, args.password, server_url, secret, admin, args.user_type
     )
 
 
+def _read_file(file_path: Any, config_path: str) -> str:
+    """Check the given file exists, and read it into a string
+
+    If it does not, exit with an error indicating the problem
+
+    Args:
+        file_path: the file to be read
+        config_path: where in the configuration file_path came from, so that a useful
+           error can be emitted if it does not exist.
+    Returns:
+        content of the file.
+    """
+    if not isinstance(file_path, str):
+        print(f"{config_path} setting is not a string", file=sys.stderr)
+        sys.exit(1)
+
+    try:
+        with open(file_path) as file_stream:
+            return file_stream.read()
+    except OSError as e:
+        print(f"Error accessing file {file_path}: {e}", file=sys.stderr)
+        sys.exit(1)
+
+
+def _find_client_listener(config: Dict[str, Any]) -> Optional[str]:
+    # try to find a listener in the config. Returns a host:port pair
+    for listener in config.get("listeners", []):
+        if listener.get("type") != "http" or listener.get("tls", False):
+            continue
+
+        if not any(
+            name == "client"
+            for resource in listener.get("resources", [])
+            for name in resource.get("names", [])
+        ):
+            continue
+
+        # TODO: consider bind_addresses
+        return f"http://localhost:{listener['port']}"
+
+    # no suitable listeners?
+    return None
+
+
 if __name__ == "__main__":
     main()
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 923891ae0d..4742435d3b 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -266,15 +266,48 @@ def register_start(
     reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
 
 
-def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
+def listen_metrics(
+    bind_addresses: Iterable[str], port: int, enable_legacy_metric_names: bool
+) -> None:
     """
     Start Prometheus metrics server.
     """
-    from synapse.metrics import RegistryProxy, start_http_server
+    from prometheus_client import start_http_server as start_http_server_prometheus
+
+    from synapse.metrics import (
+        RegistryProxy,
+        start_http_server as start_http_server_legacy,
+    )
 
     for host in bind_addresses:
         logger.info("Starting metrics listener on %s:%d", host, port)
-        start_http_server(port, addr=host, registry=RegistryProxy)
+        if enable_legacy_metric_names:
+            start_http_server_legacy(port, addr=host, registry=RegistryProxy)
+        else:
+            _set_prometheus_client_use_created_metrics(False)
+            start_http_server_prometheus(port, addr=host, registry=RegistryProxy)
+
+
+def _set_prometheus_client_use_created_metrics(new_value: bool) -> None:
+    """
+    Sets whether prometheus_client should expose `_created`-suffixed metrics for
+    all gauges, histograms and summaries.
+    There is no programmatic way to disable this without poking at internals;
+    the proper way is to use an environment variable which prometheus_client
+    loads at import time.
+
+    The motivation for disabling these `_created` metrics is that they're
+    a waste of space as they're not useful but they take up space in Prometheus.
+    """
+
+    import prometheus_client.metrics
+
+    if hasattr(prometheus_client.metrics, "_use_created"):
+        prometheus_client.metrics._use_created = new_value
+    else:
+        logger.error(
+            "Can't disable `_created` metrics in prometheus_client (brittle hack broken?)"
+        )
 
 
 def listen_manhole(
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 30e21d9707..5e3825fca6 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -412,7 +412,11 @@ class GenericWorkerServer(HomeServer):
                         "enable_metrics is not True!"
                     )
                 else:
-                    _base.listen_metrics(listener.bind_addresses, listener.port)
+                    _base.listen_metrics(
+                        listener.bind_addresses,
+                        listener.port,
+                        enable_legacy_metric_names=self.config.metrics.enable_legacy_metrics,
+                    )
             else:
                 logger.warning("Unsupported listener type: %s", listener.type)
 
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 68993d91a9..e57a926032 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -307,7 +307,11 @@ class SynapseHomeServer(HomeServer):
                         "enable_metrics is not True!"
                     )
                 else:
-                    _base.listen_metrics(listener.bind_addresses, listener.port)
+                    _base.listen_metrics(
+                        listener.bind_addresses,
+                        listener.port,
+                        enable_legacy_metric_names=self.config.metrics.enable_legacy_metrics,
+                    )
             else:
                 # this shouldn't happen, as the listener type should have been checked
                 # during parsing
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 7c9cf403ef..1f6362aedd 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -20,6 +20,7 @@ import logging
 import os
 import re
 from collections import OrderedDict
+from enum import Enum, auto
 from hashlib import sha256
 from textwrap import dedent
 from typing import (
@@ -603,18 +604,44 @@ class RootConfig:
             " may specify directories containing *.yaml files.",
         )
 
-        generate_group = parser.add_argument_group("Config generation")
-        generate_group.add_argument(
+        # we nest the mutually-exclusive group inside another group so that the help
+        # text shows them in their own group.
+        generate_mode_group = parser.add_argument_group(
+            "Config generation mode",
+        )
+        generate_mode_exclusive = generate_mode_group.add_mutually_exclusive_group()
+        generate_mode_exclusive.add_argument(
+            # hidden option to make the type and default work
+            "--generate-mode",
+            help=argparse.SUPPRESS,
+            type=_ConfigGenerateMode,
+            default=_ConfigGenerateMode.GENERATE_MISSING_AND_RUN,
+        )
+        generate_mode_exclusive.add_argument(
             "--generate-config",
-            action="store_true",
             help="Generate a config file, then exit.",
+            action="store_const",
+            const=_ConfigGenerateMode.GENERATE_EVERYTHING_AND_EXIT,
+            dest="generate_mode",
         )
-        generate_group.add_argument(
+        generate_mode_exclusive.add_argument(
             "--generate-missing-configs",
             "--generate-keys",
-            action="store_true",
             help="Generate any missing additional config files, then exit.",
+            action="store_const",
+            const=_ConfigGenerateMode.GENERATE_MISSING_AND_EXIT,
+            dest="generate_mode",
         )
+        generate_mode_exclusive.add_argument(
+            "--generate-missing-and-run",
+            help="Generate any missing additional config files, then run. This is the "
+            "default behaviour.",
+            action="store_const",
+            const=_ConfigGenerateMode.GENERATE_MISSING_AND_RUN,
+            dest="generate_mode",
+        )
+
+        generate_group = parser.add_argument_group("Details for --generate-config")
         generate_group.add_argument(
             "-H", "--server-name", help="The server name to generate a config file for."
         )
@@ -670,11 +697,12 @@ class RootConfig:
         config_dir_path = os.path.abspath(config_dir_path)
         data_dir_path = os.getcwd()
 
-        generate_missing_configs = config_args.generate_missing_configs
-
         obj = cls(config_files)
 
-        if config_args.generate_config:
+        if (
+            config_args.generate_mode
+            == _ConfigGenerateMode.GENERATE_EVERYTHING_AND_EXIT
+        ):
             if config_args.report_stats is None:
                 parser.error(
                     "Please specify either --report-stats=yes or --report-stats=no\n\n"
@@ -732,11 +760,14 @@ class RootConfig:
                     )
                     % (config_path,)
                 )
-                generate_missing_configs = True
 
         config_dict = read_config_files(config_files)
-        if generate_missing_configs:
-            obj.generate_missing_files(config_dict, config_dir_path)
+        obj.generate_missing_files(config_dict, config_dir_path)
+
+        if config_args.generate_mode in (
+            _ConfigGenerateMode.GENERATE_EVERYTHING_AND_EXIT,
+            _ConfigGenerateMode.GENERATE_MISSING_AND_EXIT,
+        ):
             return None
 
         obj.parse_config_dict(
@@ -965,6 +996,12 @@ def read_file(file_path: Any, config_path: Iterable[str]) -> str:
         raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e
 
 
+class _ConfigGenerateMode(Enum):
+    GENERATE_MISSING_AND_RUN = auto()
+    GENERATE_MISSING_AND_EXIT = auto()
+    GENERATE_EVERYTHING_AND_EXIT = auto()
+
+
 __all__ = [
     "Config",
     "RootConfig",
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index 3b42be5b5b..f3134834e5 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -42,6 +42,35 @@ class MetricsConfig(Config):
 
     def read_config(self, config: JsonDict, **kwargs: Any) -> None:
         self.enable_metrics = config.get("enable_metrics", False)
+
+        """
+        ### `enable_legacy_metrics` (experimental)
+
+        **Experimental: this option may be removed or have its behaviour
+        changed at any time, with no notice.**
+
+        Set to `true` to publish both legacy and non-legacy Prometheus metric names,
+        or to `false` to only publish non-legacy Prometheus metric names.
+        Defaults to `true`. Has no effect if `enable_metrics` is `false`.
+
+        Legacy metric names include:
+        - metrics containing colons in the name, such as `synapse_util_caches_response_cache:hits`, because colons are supposed to be reserved for user-defined recording rules;
+        - counters that don't end with the `_total` suffix, such as `synapse_federation_client_sent_edus`, therefore not adhering to the OpenMetrics standard.
+
+        These legacy metric names are unconventional and not compliant with OpenMetrics standards.
+        They are included for backwards compatibility.
+
+        Example configuration:
+        ```yaml
+        enable_legacy_metrics: false
+        ```
+
+        See https://github.com/matrix-org/synapse/issues/11106 for context.
+
+        *Since v1.67.0.*
+        """
+        self.enable_legacy_metrics = config.get("enable_legacy_metrics", True)
+
         self.report_stats = config.get("report_stats", None)
         self.report_stats_endpoint = config.get(
             "report_stats_endpoint", "https://matrix.org/report-usage-stats/push"
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index a888d976f2..df1d83dfaa 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -13,10 +13,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import argparse
-from typing import Any, Optional
+from typing import Any, Dict, Optional
 
 from synapse.api.constants import RoomCreationPreset
-from synapse.config._base import Config, ConfigError
+from synapse.config._base import Config, ConfigError, read_file
 from synapse.types import JsonDict, RoomAlias, UserID
 from synapse.util.stringutils import random_string_with_symbols, strtobool
 
@@ -27,6 +27,11 @@ password resets, configure Synapse with an SMTP server via the `email` setting,
 remove `account_threepid_delegates.email`.
 """
 
+CONFLICTING_SHARED_SECRET_OPTS_ERROR = """\
+You have configured both `registration_shared_secret` and
+`registration_shared_secret_path`. These are mutually incompatible.
+"""
+
 
 class RegistrationConfig(Config):
     section = "registration"
@@ -53,7 +58,16 @@ class RegistrationConfig(Config):
         self.enable_registration_token_3pid_bypass = config.get(
             "enable_registration_token_3pid_bypass", False
         )
+
+        # read the shared secret, either inline or from an external file
         self.registration_shared_secret = config.get("registration_shared_secret")
+        registration_shared_secret_path = config.get("registration_shared_secret_path")
+        if registration_shared_secret_path:
+            if self.registration_shared_secret:
+                raise ConfigError(CONFLICTING_SHARED_SECRET_OPTS_ERROR)
+            self.registration_shared_secret = read_file(
+                registration_shared_secret_path, ("registration_shared_secret_path",)
+            ).strip()
 
         self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
 
@@ -218,6 +232,21 @@ class RegistrationConfig(Config):
         else:
             return ""
 
+    def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None:
+        # if 'registration_shared_secret_path' is specified, and the target file
+        # does not exist, generate it.
+        registration_shared_secret_path = config.get("registration_shared_secret_path")
+        if registration_shared_secret_path and not self.path_exists(
+            registration_shared_secret_path
+        ):
+            print(
+                "Generating registration shared secret file "
+                + registration_shared_secret_path
+            )
+            secret = random_string_with_symbols(50)
+            with open(registration_shared_secret_path, "w") as f:
+                f.write(f"{secret}\n")
+
     @staticmethod
     def add_arguments(parser: argparse.ArgumentParser) -> None:
         reg_group = parser.add_argument_group("registration")
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 7520647d1e..23b799ac32 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -28,6 +28,7 @@ from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
 from synapse.events.utils import prune_event, prune_event_dict
+from synapse.logging.opentracing import trace
 from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
@@ -35,6 +36,7 @@ logger = logging.getLogger(__name__)
 Hasher = Callable[[bytes], "hashlib._Hash"]
 
 
+@trace
 def check_event_content_hash(
     event: EventBase, hash_algorithm: Hasher = hashlib.sha256
 ) -> bool:
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 4a3bfb38f1..623a2c71ea 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -32,6 +32,7 @@ from typing_extensions import Literal
 
 import synapse
 from synapse.api.errors import Codes
+from synapse.logging.opentracing import trace
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.media_storage import ReadableFileWrapper
 from synapse.spam_checker_api import RegistrationBehaviour
@@ -378,6 +379,7 @@ class SpamChecker:
         if check_media_file_for_spam is not None:
             self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)
 
+    @trace
     async def check_event_for_spam(
         self, event: "synapse.events.EventBase"
     ) -> Union[Tuple[Codes, JsonDict], str]:
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 2522bf78fc..4269a98db2 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -23,6 +23,7 @@ from synapse.crypto.keyring import Keyring
 from synapse.events import EventBase, make_event_from_dict
 from synapse.events.utils import prune_event, validate_canonicaljson
 from synapse.http.servlet import assert_params_in_dict
+from synapse.logging.opentracing import log_kv, trace
 from synapse.types import JsonDict, get_domain_from_id
 
 if TYPE_CHECKING:
@@ -55,6 +56,7 @@ class FederationBase:
         self._clock = hs.get_clock()
         self._storage_controllers = hs.get_storage_controllers()
 
+    @trace
     async def _check_sigs_and_hash(
         self, room_version: RoomVersion, pdu: EventBase
     ) -> EventBase:
@@ -97,17 +99,36 @@ class FederationBase:
                     "Event %s seems to have been redacted; using our redacted copy",
                     pdu.event_id,
                 )
+                log_kv(
+                    {
+                        "message": "Event seems to have been redacted; using our redacted copy",
+                        "event_id": pdu.event_id,
+                    }
+                )
             else:
                 logger.warning(
                     "Event %s content has been tampered, redacting",
                     pdu.event_id,
                 )
+                log_kv(
+                    {
+                        "message": "Event content has been tampered, redacting",
+                        "event_id": pdu.event_id,
+                    }
+                )
             return redacted_event
 
         spam_check = await self.spam_checker.check_event_for_spam(pdu)
 
         if spam_check != self.spam_checker.NOT_SPAM:
             logger.warning("Event contains spam, soft-failing %s", pdu.event_id)
+            log_kv(
+                {
+                    "message": "Event contains spam, redacting (to save disk space) "
+                    "as well as soft-failing (to stop using the event in prev_events)",
+                    "event_id": pdu.event_id,
+                }
+            )
             # we redact (to save disk space) as well as soft-failing (to stop
             # using the event in prev_events).
             redacted_event = prune_event(pdu)
@@ -117,6 +138,7 @@ class FederationBase:
         return pdu
 
 
+@trace
 async def _check_sigs_on_pdu(
     keyring: Keyring, room_version: RoomVersion, pdu: EventBase
 ) -> None:
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 987f6dad46..7ee2974bb1 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -61,7 +61,7 @@ from synapse.federation.federation_base import (
 )
 from synapse.federation.transport.client import SendJoinResponse
 from synapse.http.types import QueryParams
-from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace
+from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace
 from synapse.types import JsonDict, UserID, get_domain_from_id
 from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -587,11 +587,15 @@ class FederationClient(FederationBase):
         Returns:
             A list of PDUs that have valid signatures and hashes.
         """
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "pdus.length",
+            str(len(pdus)),
+        )
 
         # We limit how many PDUs we check at once, as if we try to do hundreds
         # of thousands of PDUs at once we see large memory spikes.
 
-        valid_pdus = []
+        valid_pdus: List[EventBase] = []
 
         async def _execute(pdu: EventBase) -> None:
             valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
@@ -607,6 +611,8 @@ class FederationClient(FederationBase):
 
         return valid_pdus
 
+    @trace
+    @tag_args
     async def _check_sigs_and_hash_and_fetch_one(
         self,
         pdu: EventBase,
@@ -639,16 +645,27 @@ class FederationClient(FederationBase):
         except InvalidEventSignatureError as e:
             logger.warning(
                 "Signature on retrieved event %s was invalid (%s). "
-                "Checking local store/orgin server",
+                "Checking local store/origin server",
                 pdu.event_id,
                 e,
             )
+            log_kv(
+                {
+                    "message": "Signature on retrieved event was invalid. "
+                    "Checking local store/origin server",
+                    "event_id": pdu.event_id,
+                    "InvalidEventSignatureError": e,
+                }
+            )
 
         # Check local db.
         res = await self.store.get_event(
             pdu.event_id, allow_rejected=True, allow_none=True
         )
 
+        # If the PDU fails its signature check and we don't have it in our
+        # database, we then request it from sender's server (if that is not the
+        # same as `origin`).
         pdu_origin = get_domain_from_id(pdu.sender)
         if not res and pdu_origin != origin:
             try:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 75fbc6073d..3bf84cf625 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -763,6 +763,17 @@ class FederationServer(FederationBase):
             The partial knock event.
         """
         origin_host, _ = parse_server_name(origin)
+
+        if await self.store.is_partial_state_room(room_id):
+            # Before we do anything: check if the room is partial-stated.
+            # Note that at the time this check was added, `on_make_knock_request` would
+            # block due to https://github.com/matrix-org/synapse/issues/12997.
+            raise SynapseError(
+                404,
+                "Unable to handle /make_knock right now; this server is not fully joined.",
+                errcode=Codes.NOT_FOUND,
+            )
+
         await self.check_server_matches_acl(origin_host, room_id)
 
         room_version = await self.store.get_room_version(room_id)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index f5c586f657..9c2c3a0e68 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -310,6 +310,7 @@ class DeviceHandler(DeviceWorkerHandler):
         super().__init__(hs)
 
         self.federation_sender = hs.get_federation_sender()
+        self._storage_controllers = hs.get_storage_controllers()
 
         self.device_list_updater = DeviceListUpdater(hs, self)
 
@@ -694,8 +695,11 @@ class DeviceHandler(DeviceWorkerHandler):
 
                     # Ignore any users that aren't ours
                     if self.hs.is_mine_id(user_id):
-                        joined_user_ids = await self.store.get_users_in_room(room_id)
-                        hosts = {get_domain_from_id(u) for u in joined_user_ids}
+                        hosts = set(
+                            await self._storage_controllers.state.get_current_hosts_in_room(
+                                room_id
+                            )
+                        )
                         hosts.discard(self.server_name)
 
                     # Check if we've already sent this update to some hosts
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 948f66a94d..7127d5aefc 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -30,7 +30,7 @@ from synapse.api.errors import (
 from synapse.appservice import ApplicationService
 from synapse.module_api import NOT_SPAM
 from synapse.storage.databases.main.directory import RoomAliasMapping
-from synapse.types import JsonDict, Requester, RoomAlias, get_domain_from_id
+from synapse.types import JsonDict, Requester, RoomAlias
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -83,8 +83,9 @@ class DirectoryHandler:
         # TODO(erikj): Add transactions.
         # TODO(erikj): Check if there is a current association.
         if not servers:
-            users = await self.store.get_users_in_room(room_id)
-            servers = {get_domain_from_id(u) for u in users}
+            servers = await self._storage_controllers.state.get_current_hosts_in_room(
+                room_id
+            )
 
         if not servers:
             raise SynapseError(400, "Failed to get server list")
@@ -287,8 +288,9 @@ class DirectoryHandler:
                 Codes.NOT_FOUND,
             )
 
-        users = await self.store.get_users_in_room(room_id)
-        extra_servers = {get_domain_from_id(u) for u in users}
+        extra_servers = await self._storage_controllers.state.get_current_hosts_in_room(
+            room_id
+        )
         servers_set = set(extra_servers) | set(servers)
 
         # If this server is in the list of servers, return it first.
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index a2dd9c7efa..c3ddc5d182 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -129,12 +129,9 @@ class EventAuthHandler:
         else:
             users = {}
 
-        # Find the user with the highest power level.
-        users_in_room = await self._store.get_users_in_room(room_id)
-        # Only interested in local users.
-        local_users_in_room = [
-            u for u in users_in_room if get_domain_from_id(u) == self._server_name
-        ]
+        # Find the user with the highest power level (only interested in local
+        # users).
+        local_users_in_room = await self._store.get_local_users_in_room(room_id)
         chosen_user = max(
             local_users_in_room,
             key=lambda user: users.get(user, users_default_level),
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index ac13340d3a..949b69cb41 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -151,7 +151,7 @@ class EventHandler:
         """Retrieve a single specified event.
 
         Args:
-            user: The user requesting the event
+            user: The local user requesting the event
             room_id: The expected room id. We'll return None if the
                 event's room does not match.
             event_id: The event ID to obtain.
@@ -173,8 +173,11 @@ class EventHandler:
         if not event:
             return None
 
-        users = await self.store.get_users_in_room(event.room_id)
-        is_peeking = user.to_string() not in users
+        is_user_in_room = await self.store.check_local_user_in_room(
+            user_id=user.to_string(), room_id=event.room_id
+        )
+        # The user is peeking if they aren't in the room already
+        is_peeking = not is_user_in_room
 
         filtered = await filter_events_for_client(
             self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index e151962055..dd4b9f66d1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -70,7 +70,7 @@ from synapse.replication.http.federation import (
 from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, StateMap, get_domain_from_id
+from synapse.types import JsonDict, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.retryutils import NotRetryingDestination
 from synapse.visibility import filter_events_for_server
@@ -104,37 +104,6 @@ backfill_processing_before_timer = Histogram(
 )
 
 
-def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
-    """Get joined domains from state
-
-    Args:
-        state: State map from type/state key to event.
-
-    Returns:
-        Returns a list of servers with the lowest depth of their joins.
-            Sorted by lowest depth first.
-    """
-    joined_users = [
-        (state_key, int(event.depth))
-        for (e_type, state_key), event in state.items()
-        if e_type == EventTypes.Member and event.membership == Membership.JOIN
-    ]
-
-    joined_domains: Dict[str, int] = {}
-    for u, d in joined_users:
-        try:
-            dom = get_domain_from_id(u)
-            old_d = joined_domains.get(dom)
-            if old_d:
-                joined_domains[dom] = min(d, old_d)
-            else:
-                joined_domains[dom] = d
-        except Exception:
-            pass
-
-    return sorted(joined_domains.items(), key=lambda d: d[1])
-
-
 class _BackfillPointType(Enum):
     # a regular backwards extremity (ie, an event which we don't yet have, but which
     # is referred to by other events in the DAG)
@@ -432,21 +401,19 @@ class FederationHandler:
         )
 
         # Now we need to decide which hosts to hit first.
-
-        # First we try hosts that are already in the room
+        # First we try hosts that are already in the room.
         # TODO: HEURISTIC ALERT.
+        likely_domains = (
+            await self._storage_controllers.state.get_current_hosts_in_room(room_id)
+        )
 
-        curr_state = await self._storage_controllers.state.get_current_state(room_id)
-
-        curr_domains = get_domains_from_state(curr_state)
-
-        likely_domains = [
-            domain for domain, depth in curr_domains if domain != self.server_name
-        ]
-
-        async def try_backfill(domains: List[str]) -> bool:
+        async def try_backfill(domains: Collection[str]) -> bool:
             # TODO: Should we try multiple of these at a time?
             for dom in domains:
+                # We don't want to ask our own server for information we don't have
+                if dom == self.server_name:
+                    continue
+
                 try:
                     await self._federation_event_handler.backfill(
                         dom, room_id, limit=100, extremities=extremities_to_request
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 048c4111f6..ace7adcffb 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -1041,6 +1041,14 @@ class FederationEventHandler:
             InvalidResponseError: if the remote homeserver's response contains fields
                 of the wrong type.
         """
+
+        # It would be better if we could query the difference from our known
+        # state to the given `event_id` so the sending server doesn't have to
+        # send as much and we don't have to process as many events. For example
+        # in a room like #matrix:matrix.org, we get 200k events (77k state_events, 122k
+        # auth_events) from this call.
+        #
+        # Tracked by https://github.com/matrix-org/synapse/issues/13618
         (
             state_event_ids,
             auth_event_ids,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index acd3de06f6..72157d5a36 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -761,8 +761,10 @@ class EventCreationHandler:
     async def _is_server_notices_room(self, room_id: str) -> bool:
         if self.config.servernotices.server_notices_mxid is None:
             return False
-        user_ids = await self.store.get_users_in_room(room_id)
-        return self.config.servernotices.server_notices_mxid in user_ids
+        is_server_notices_room = await self.store.check_local_user_in_room(
+            user_id=self.config.servernotices.server_notices_mxid, room_id=room_id
+        )
+        return is_server_notices_room
 
     async def assert_accepted_privacy_policy(self, requester: Requester) -> None:
         """Check if a user has accepted the privacy policy
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 74e944bce7..a0c39778ab 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -159,11 +159,9 @@ class PaginationHandler:
         self._retention_allowed_lifetime_max = (
             hs.config.retention.retention_allowed_lifetime_max
         )
+        self._is_master = hs.config.worker.worker_app is None
 
-        if (
-            hs.config.worker.run_background_tasks
-            and hs.config.retention.retention_enabled
-        ):
+        if hs.config.retention.retention_enabled and self._is_master:
             # Run the purge jobs described in the configuration file.
             for job in hs.config.retention.retention_purge_jobs:
                 logger.info("Setting up purge job with config: %s", job)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 741504ba9f..4e575ffbaa 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -2051,8 +2051,7 @@ async def get_interested_remotes(
     )
 
     for room_id, states in room_ids_to_states.items():
-        user_ids = await store.get_users_in_room(room_id)
-        hosts = {get_domain_from_id(user_id) for user_id in user_ids}
+        hosts = await store.get_current_hosts_in_room(room_id)
         for host in hosts:
             hosts_and_states.setdefault(host, set()).update(states)
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 2bf0ebd025..f64a8690a5 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -60,7 +60,6 @@ from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase
 from synapse.events.utils import copy_and_fixup_power_levels_contents
 from synapse.federation.federation_client import InvalidResponseError
-from synapse.handlers.federation import get_domains_from_state
 from synapse.handlers.relations import BundledAggregations
 from synapse.module_api import NOT_SPAM
 from synapse.rest.admin._base import assert_user_is_admin
@@ -1284,8 +1283,11 @@ class RoomContextHandler:
         before_limit = math.floor(limit / 2.0)
         after_limit = limit - before_limit
 
-        users = await self.store.get_users_in_room(room_id)
-        is_peeking = user.to_string() not in users
+        is_user_in_room = await self.store.check_local_user_in_room(
+            user_id=user.to_string(), room_id=room_id
+        )
+        # The user is peeking if they aren't in the room already
+        is_peeking = not is_user_in_room
 
         async def filter_evts(events: List[EventBase]) -> List[EventBase]:
             if use_admin_priviledge:
@@ -1459,17 +1461,16 @@ class TimestampLookupHandler:
                 timestamp,
             )
 
-            # Find other homeservers from the given state in the room
-            curr_state = await self._storage_controllers.state.get_current_state(
-                room_id
+            likely_domains = (
+                await self._storage_controllers.state.get_current_hosts_in_room(room_id)
             )
-            curr_domains = get_domains_from_state(curr_state)
-            likely_domains = [
-                domain for domain, depth in curr_domains if domain != self.server_name
-            ]
 
             # Loop through each homeserver candidate until we get a succesful response
             for domain in likely_domains:
+                # We don't want to ask our own server for information we don't have
+                if domain == self.server_name:
+                    continue
+
                 try:
                     remote_response = await self.federation_client.timestamp_to_event(
                         domain, room_id, timestamp, direction
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 65b9a655d4..e726997d83 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1620,8 +1620,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
     async def _is_server_notice_room(self, room_id: str) -> bool:
         if self._server_notices_mxid is None:
             return False
-        user_ids = await self.store.get_users_in_room(room_id)
-        return self._server_notices_mxid in user_ids
+        is_server_notices_room = await self.store.check_local_user_in_room(
+            user_id=self._server_notices_mxid, room_id=room_id
+        )
+        return is_server_notices_room
 
 
 class RoomMemberMasterHandler(RoomMemberHandler):
@@ -1923,8 +1925,11 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         ]:
             raise SynapseError(400, "User %s in room %s" % (user_id, room_id))
 
-        if membership:
-            await self.store.forget(user_id, room_id)
+        # In normal case this call is only required if `membership` is not `None`.
+        # But: After the last member had left the room, the background update
+        # `_background_remove_left_rooms` is deleting rows related to this room from
+        # the table `current_state_events` and `get_current_state_events` is `None`.
+        await self.store.forget(user_id, room_id)
 
 
 def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[str]:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b4d3f3958c..2d95b1fa24 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -2421,10 +2421,10 @@ class SyncHandler:
                     joined_room.room_id, joined_room.event_pos.stream
                 )
             )
-            users_in_room = await self.state.get_current_users_in_room(
+            user_ids_in_room = await self.state.get_current_user_ids_in_room(
                 joined_room.room_id, extrems
             )
-            if user_id in users_in_room:
+            if user_id in user_ids_in_room:
                 joined_room_ids.add(joined_room.room_id)
 
         return frozenset(joined_room_ids)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index bcac3372a2..a4cd8b8f0c 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -26,7 +26,7 @@ from synapse.metrics.background_process_metrics import (
 )
 from synapse.replication.tcp.streams import TypingStream
 from synapse.streams import EventSource
-from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, StreamKeyType, UserID
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
@@ -362,8 +362,9 @@ class TypingWriterHandler(FollowerTypingHandler):
             )
             return
 
-        users = await self.store.get_users_in_room(room_id)
-        domains = {get_domain_from_id(u) for u in users}
+        domains = await self._storage_controllers.state.get_current_hosts_in_room(
+            room_id
+        )
 
         if self.server_name in domains:
             logger.info("Got typing update from %s: %r", user_id, content)
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 496fce2ecc..c3d3daf877 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -46,12 +46,12 @@ from twisted.python.threadpool import ThreadPool
 
 # This module is imported for its side effects; flake8 needn't warn that it's unused.
 import synapse.metrics._reactor_metrics  # noqa: F401
-from synapse.metrics._exposition import (
+from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
+from synapse.metrics._legacy_exposition import (
     MetricsResource,
     generate_latest,
     start_http_server,
 )
-from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
 from synapse.metrics._types import Collector
 from synapse.util import SYNAPSE_VERSION
 
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_legacy_exposition.py
index 353d0a63b6..ff640a49af 100644
--- a/synapse/metrics/_exposition.py
+++ b/synapse/metrics/_legacy_exposition.py
@@ -80,7 +80,27 @@ def sample_line(line: Sample, name: str) -> str:
     return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp)
 
 
+# Mapping from new metric names to legacy metric names.
+# We translate these back to their old names when exposing them through our
+# legacy vendored exporter.
+# Only this legacy exposition module applies these name changes.
+LEGACY_METRIC_NAMES = {
+    "synapse_util_caches_cache_hits": "synapse_util_caches_cache:hits",
+    "synapse_util_caches_cache_size": "synapse_util_caches_cache:size",
+    "synapse_util_caches_cache_evicted_size": "synapse_util_caches_cache:evicted_size",
+    "synapse_util_caches_cache_total": "synapse_util_caches_cache:total",
+    "synapse_util_caches_response_cache_size": "synapse_util_caches_response_cache:size",
+    "synapse_util_caches_response_cache_hits": "synapse_util_caches_response_cache:hits",
+    "synapse_util_caches_response_cache_evicted_size": "synapse_util_caches_response_cache:evicted_size",
+    "synapse_util_caches_response_cache_total": "synapse_util_caches_response_cache:total",
+}
+
+
 def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes:
+    """
+    Generate metrics in legacy format. Modern metrics are generated directly
+    by prometheus-client.
+    """
 
     # Trigger the cache metrics to be rescraped, which updates the common
     # metrics but do not produce metrics themselves
@@ -94,7 +114,8 @@ def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> byt
             # No samples, don't bother.
             continue
 
-        mname = metric.name
+        # Translate to legacy metric name if it has one.
+        mname = LEGACY_METRIC_NAMES.get(metric.name, metric.name)
         mnewname = metric.name
         mtype = metric.type
 
@@ -124,7 +145,7 @@ def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> byt
         om_samples: Dict[str, List[str]] = {}
         for s in metric.samples:
             for suffix in ["_created", "_gsum", "_gcount"]:
-                if s.name == metric.name + suffix:
+                if s.name == mname + suffix:
                     # OpenMetrics specific sample, put in a gauge at the end.
                     # (these come from gaugehistograms which don't get renamed,
                     # so no need to faff with mnewname)
@@ -140,12 +161,12 @@ def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> byt
             if emit_help:
                 output.append(
                     "# HELP {}{} {}\n".format(
-                        metric.name,
+                        mname,
                         suffix,
                         metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
                     )
                 )
-            output.append(f"# TYPE {metric.name}{suffix} gauge\n")
+            output.append(f"# TYPE {mname}{suffix} gauge\n")
             output.extend(lines)
 
         # Get rid of the weird colon things while we're at it
@@ -170,11 +191,12 @@ def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> byt
             # Get rid of the OpenMetrics specific samples (we should already have
             # dealt with them above anyway.)
             for suffix in ["_created", "_gsum", "_gcount"]:
-                if s.name == metric.name + suffix:
+                if s.name == mname + suffix:
                     break
             else:
+                sample_name = LEGACY_METRIC_NAMES.get(s.name, s.name)
                 output.append(
-                    sample_line(s, s.name.replace(":total", "").replace(":", "_"))
+                    sample_line(s, sample_name.replace(":total", "").replace(":", "_"))
                 )
 
     return "".join(output).encode("utf-8")
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 52ee3f7e58..5e65eaf1e0 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -31,6 +31,5 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
             self._push_rules_stream_id_gen.advance(instance_name, token)
             for row in rows:
                 self.get_push_rules_for_user.invalidate((row.user_id,))
-                self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
                 self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
         return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index c35d42fab8..d30878f704 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -254,30 +254,32 @@ async def respond_with_responder(
         file_size: Size in bytes of the media. If not known it should be None
         upload_name: The name of the requested file, if any.
     """
-    if request._disconnected:
-        logger.warning(
-            "Not sending response to request %s, already disconnected.", request
-        )
-        return
-
     if not responder:
         respond_404(request)
         return
 
-    logger.debug("Responding to media request with responder %s", responder)
-    add_file_headers(request, media_type, file_size, upload_name)
-    try:
-        with responder:
+    # If we have a responder we *must* use it as a context manager.
+    with responder:
+        if request._disconnected:
+            logger.warning(
+                "Not sending response to request %s, already disconnected.", request
+            )
+            return
+
+        logger.debug("Responding to media request with responder %s", responder)
+        add_file_headers(request, media_type, file_size, upload_name)
+        try:
+
             await responder.write_to_consumer(request)
-    except Exception as e:
-        # The majority of the time this will be due to the client having gone
-        # away. Unfortunately, Twisted simply throws a generic exception at us
-        # in that case.
-        logger.warning("Failed to write to consumer: %s %s", type(e), e)
-
-        # Unregister the producer, if it has one, so Twisted doesn't complain
-        if request.producer:
-            request.unregisterProducer()
+        except Exception as e:
+            # The majority of the time this will be due to the client having gone
+            # away. Unfortunately, Twisted simply throws a generic exception at us
+            # in that case.
+            logger.warning("Failed to write to consumer: %s %s", type(e), e)
+
+            # Unregister the producer, if it has one, so Twisted doesn't complain
+            if request.producer:
+                request.unregisterProducer()
 
     finish_request(request)
 
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index b36c98a08e..a8f6fd6b35 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -732,10 +732,6 @@ class PreviewUrlResource(DirectServeJsonResource):
 
         logger.debug("Running url preview cache expiry")
 
-        if not (await self.store.db_pool.updates.has_completed_background_updates()):
-            logger.debug("Still running DB updates; skipping url preview cache expiry")
-            return
-
         def try_remove_parent_dirs(dirs: Iterable[str]) -> None:
             """Attempt to remove the given chain of parent directories
 
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 70d054a8f4..564e3705c2 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -102,6 +102,10 @@ class ServerNoticesManager:
         Returns:
             The room's ID, or None if no room could be found.
         """
+        # If there is no server notices MXID, then there is no server notices room
+        if self.server_notices_mxid is None:
+            return None
+
         rooms = await self._store.get_rooms_for_local_user_where_membership_is(
             user_id, [Membership.INVITE, Membership.JOIN]
         )
@@ -111,8 +115,10 @@ class ServerNoticesManager:
             # be joined. This is kinda deliberate, in that if somebody somehow
             # manages to invite the system user to a room, that doesn't make it
             # the server notices room.
-            user_ids = await self._store.get_users_in_room(room.room_id)
-            if len(user_ids) <= 2 and self.server_notices_mxid in user_ids:
+            is_server_notices_room = await self._store.check_local_user_in_room(
+                user_id=self.server_notices_mxid, room_id=room.room_id
+            )
+            if is_server_notices_room:
                 # we found a room which our user shares with the system notice
                 # user
                 return room.room_id
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index c355e4f98a..3047e1b1ad 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -44,7 +44,6 @@ from synapse.logging.context import ContextResourceUsage
 from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
 from synapse.state import v1, v2
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.roommember import ProfileInfo
 from synapse.storage.state import StateFilter
 from synapse.types import StateMap
 from synapse.util.async_helpers import Linearizer
@@ -210,11 +209,11 @@ class StateHandler:
         ret = await self.resolve_state_groups_for_events(room_id, event_ids)
         return await ret.get_state(self._state_storage_controller, state_filter)
 
-    async def get_current_users_in_room(
+    async def get_current_user_ids_in_room(
         self, room_id: str, latest_event_ids: List[str]
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Set[str]:
         """
-        Get the users who are currently in a room.
+        Get the users IDs who are currently in a room.
 
         Note: This is much slower than using the equivalent method
         `DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`,
@@ -225,15 +224,15 @@ class StateHandler:
             room_id: The ID of the room.
             latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
         Returns:
-            Dictionary of user IDs to their profileinfo.
+            Set of user IDs in the room.
         """
 
         assert latest_event_ids is not None
 
-        logger.debug("calling resolve_state_groups from get_current_users_in_room")
+        logger.debug("calling resolve_state_groups from get_current_user_ids_in_room")
         entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
         state = await entry.get_state(self._state_storage_controller, StateFilter.all())
-        return await self.store.get_joined_users_from_state(room_id, state, entry)
+        return await self.store.get_joined_user_ids_from_state(room_id, state, entry)
 
     async def get_hosts_in_room_at_events(
         self, room_id: str, event_ids: Collection[str]
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index cf3045f82e..af03851c71 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -271,40 +271,41 @@ async def _get_power_level_for_sender(
 async def _get_auth_chain_difference(
     room_id: str,
     state_sets: Sequence[Mapping[Any, str]],
-    event_map: Dict[str, EventBase],
+    unpersisted_events: Dict[str, EventBase],
     state_res_store: StateResolutionStore,
 ) -> Set[str]:
     """Compare the auth chains of each state set and return the set of events
-    that only appear in some but not all of the auth chains.
+    that only appear in some, but not all of the auth chains.
 
     Args:
-        state_sets
-        event_map
-        state_res_store
+        state_sets: The input state sets we are trying to resolve across.
+        unpersisted_events: A map from event ID to EventBase containing all unpersisted
+            events involved in this resolution.
+        state_res_store:
 
     Returns:
-        Set of event IDs
+        The auth difference of the given state sets, as a set of event IDs.
     """
 
     # The `StateResolutionStore.get_auth_chain_difference` function assumes that
     # all events passed to it (and their auth chains) have been persisted
-    # previously. This is not the case for any events in the `event_map`, and so
-    # we need to manually handle those events.
+    # previously. We need to manually handle any other events that are yet to be
+    # persisted.
     #
-    # We do this by:
-    #   1. calculating the auth chain difference for the state sets based on the
-    #      events in `event_map` alone
-    #   2. replacing any events in the state_sets that are also in `event_map`
-    #      with their auth events (recursively), and then calling
-    #      `store.get_auth_chain_difference` as normal
-    #   3. adding the results of 1 and 2 together.
-
-    # Map from event ID in `event_map` to their auth event IDs, and their auth
-    # event IDs if they appear in the `event_map`. This is the intersection of
-    # the event's auth chain with the events in the `event_map` *plus* their
+    # We do this in three steps:
+    #   1. Compute the set of unpersisted events belonging to the auth difference.
+    #   2. Replacing any unpersisted events in the state_sets with their auth events,
+    #      recursively, until the state_sets contain only persisted events.
+    #      Then we call `store.get_auth_chain_difference` as normal, which computes
+    #      the set of persisted events belonging to the auth difference.
+    #   3. Adding the results of 1 and 2 together.
+
+    # Map from event ID in `unpersisted_events` to their auth event IDs, and their auth
+    # event IDs if they appear in the `unpersisted_events`. This is the intersection of
+    # the event's auth chain with the events in `unpersisted_events` *plus* their
     # auth event IDs.
     events_to_auth_chain: Dict[str, Set[str]] = {}
-    for event in event_map.values():
+    for event in unpersisted_events.values():
         chain = {event.event_id}
         events_to_auth_chain[event.event_id] = chain
 
@@ -312,16 +313,16 @@ async def _get_auth_chain_difference(
         while to_search:
             for auth_id in to_search.pop().auth_event_ids():
                 chain.add(auth_id)
-                auth_event = event_map.get(auth_id)
+                auth_event = unpersisted_events.get(auth_id)
                 if auth_event:
                     to_search.append(auth_event)
 
-    # We now a) calculate the auth chain difference for the unpersisted events
-    # and b) work out the state sets to pass to the store.
+    # We now 1) calculate the auth chain difference for the unpersisted events
+    # and 2) work out the state sets to pass to the store.
     #
-    # Note: If the `event_map` is empty (which is the common case), we can do a
+    # Note: If there are no `unpersisted_events` (which is the common case), we can do a
     # much simpler calculation.
-    if event_map:
+    if unpersisted_events:
         # The list of state sets to pass to the store, where each state set is a set
         # of the event ids making up the state. This is similar to `state_sets`,
         # except that (a) we only have event ids, not the complete
@@ -344,14 +345,18 @@ async def _get_auth_chain_difference(
             for event_id in state_set.values():
                 event_chain = events_to_auth_chain.get(event_id)
                 if event_chain is not None:
-                    # We have an event in `event_map`. We add all the auth
-                    # events that it references (that aren't also in `event_map`).
-                    set_ids.update(e for e in event_chain if e not in event_map)
+                    # We have an unpersisted event. We add all the auth
+                    # events that it references which are also unpersisted.
+                    set_ids.update(
+                        e for e in event_chain if e not in unpersisted_events
+                    )
 
                     # We also add the full chain of unpersisted event IDs
                     # referenced by this state set, so that we can work out the
                     # auth chain difference of the unpersisted events.
-                    unpersisted_ids.update(e for e in event_chain if e in event_map)
+                    unpersisted_ids.update(
+                        e for e in event_chain if e in unpersisted_events
+                    )
                 else:
                     set_ids.add(event_id)
 
@@ -361,15 +366,15 @@ async def _get_auth_chain_difference(
         union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
         intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
 
-        difference_from_event_map: Collection[str] = union - intersection
+        auth_difference_unpersisted_part: Collection[str] = union - intersection
     else:
-        difference_from_event_map = ()
+        auth_difference_unpersisted_part = ()
         state_sets_ids = [set(state_set.values()) for state_set in state_sets]
 
     difference = await state_res_store.get_auth_chain_difference(
         room_id, state_sets_ids
     )
-    difference.update(difference_from_event_map)
+    difference.update(auth_difference_unpersisted_part)
 
     return difference
 
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index f9ffd0e29e..ba5380ce3e 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -23,7 +23,6 @@ from typing import (
     List,
     Mapping,
     Optional,
-    Set,
     Tuple,
 )
 
@@ -520,7 +519,7 @@ class StateStorageController:
         )
         return state_map.get(key)
 
-    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+    async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
         """Get current hosts in room based on current state."""
 
         await self._partial_state_room_tracker.await_full_state(room_id)
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 9af9f4f18e..c38b8a9e5a 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -650,9 +650,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             txn, self.get_account_data_for_room, (user_id,)
         )
         self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,))
-        self._invalidate_cache_and_stream(
-            txn, self.get_push_rules_enabled_for_user, (user_id,)
-        )
         # This user might be contained in the ignored_by cache for other users,
         # so we have to invalidate it all.
         self._invalidate_all_cache_and_stream(txn, self.ignored_by)
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index eabf9c9739..9f410d69de 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -459,6 +459,32 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
 
+    def _get_receipts_by_room_txn(
+        self, txn: LoggingTransaction, user_id: str
+    ) -> List[Tuple[str, int]]:
+        receipt_types_clause, args = make_in_list_sql_clause(
+            self.database_engine,
+            "receipt_type",
+            (
+                ReceiptTypes.READ,
+                ReceiptTypes.READ_PRIVATE,
+                ReceiptTypes.UNSTABLE_READ_PRIVATE,
+            ),
+        )
+
+        sql = f"""
+            SELECT room_id, MAX(stream_ordering)
+            FROM receipts_linearized
+            INNER JOIN events USING (room_id, event_id)
+            WHERE {receipt_types_clause}
+            AND user_id = ?
+            GROUP BY room_id
+        """
+
+        args.extend((user_id,))
+        txn.execute(sql, args)
+        return cast(List[Tuple[str, int]], txn.fetchall())
+
     async def get_unread_push_actions_for_user_in_range_for_http(
         self,
         user_id: str,
@@ -482,106 +508,45 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will have between 0~limit entries.
         """
 
-        # find rooms that have a read receipt in them and return the next
-        # push actions
-        def get_after_receipt(
-            txn: LoggingTransaction,
-        ) -> List[Tuple[str, str, int, str, bool]]:
-            # find rooms that have a read receipt in them and return the next
-            # push actions
-
-            receipt_types_clause, args = make_in_list_sql_clause(
-                self.database_engine,
-                "receipt_type",
-                (
-                    ReceiptTypes.READ,
-                    ReceiptTypes.READ_PRIVATE,
-                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
-                ),
-            )
-
-            sql = f"""
-                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
-                    ep.highlight
-                FROM (
-                    SELECT room_id,
-                        MAX(stream_ordering) as stream_ordering
-                    FROM events
-                    INNER JOIN receipts_linearized USING (room_id, event_id)
-                    WHERE {receipt_types_clause} AND user_id = ?
-                    GROUP BY room_id
-                ) AS rl,
-                event_push_actions AS ep
-                WHERE
-                    ep.room_id = rl.room_id
-                    AND ep.stream_ordering > rl.stream_ordering
-                    AND ep.user_id = ?
-                    AND ep.stream_ordering > ?
-                    AND ep.stream_ordering <= ?
-                    AND ep.notif = 1
-                ORDER BY ep.stream_ordering ASC LIMIT ?
-            """
-            args.extend(
-                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
-            )
-            txn.execute(sql, args)
-            return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
-
-        after_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
+        receipts_by_room = dict(
+            await self.db_pool.runInteraction(
+                "get_unread_push_actions_for_user_in_range_http_receipts",
+                self._get_receipts_by_room_txn,
+                user_id=user_id,
+            ),
         )
 
-        # There are rooms with push actions in them but you don't have a read receipt in
-        # them e.g. rooms you've been invited to, so get push actions for rooms which do
-        # not have read receipts in them too.
-        def get_no_receipt(
+        def get_push_actions_txn(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool]]:
-            receipt_types_clause, args = make_in_list_sql_clause(
-                self.database_engine,
-                "receipt_type",
-                (
-                    ReceiptTypes.READ,
-                    ReceiptTypes.READ_PRIVATE,
-                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
-                ),
-            )
-
-            sql = f"""
-                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
-                    ep.highlight
+            sql = """
+                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight
                 FROM event_push_actions AS ep
-                INNER JOIN events AS e USING (room_id, event_id)
                 WHERE
-                    ep.room_id NOT IN (
-                        SELECT room_id FROM receipts_linearized
-                        WHERE {receipt_types_clause} AND user_id = ?
-                        GROUP BY room_id
-                    )
-                    AND ep.user_id = ?
+                    ep.user_id = ?
                     AND ep.stream_ordering > ?
                     AND ep.stream_ordering <= ?
                     AND ep.notif = 1
                 ORDER BY ep.stream_ordering ASC LIMIT ?
             """
-            args.extend(
-                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
-            )
-            txn.execute(sql, args)
+            txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
             return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
 
-        no_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
+        push_actions = await self.db_pool.runInteraction(
+            "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
         )
 
         notifs = [
             HttpPushAction(
-                event_id=row[0],
-                room_id=row[1],
-                stream_ordering=row[2],
-                actions=_deserialize_action(row[3], row[4]),
+                event_id=event_id,
+                room_id=room_id,
+                stream_ordering=stream_ordering,
+                actions=_deserialize_action(actions, highlight),
             )
-            for row in after_read_receipt + no_read_receipt
+            for event_id, room_id, stream_ordering, actions, highlight in push_actions
+            # Only include push actions with a stream ordering after any receipt, or without any
+            # receipt present (invited to but never read rooms).
+            if stream_ordering > receipts_by_room.get(room_id, 0)
         ]
 
         # Now sort it so it's ordered correctly, since currently it will
@@ -617,106 +582,49 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will have between 0~limit entries.
         """
 
-        # find rooms that have a read receipt in them and return the most recent
-        # push actions
-        def get_after_receipt(
-            txn: LoggingTransaction,
-        ) -> List[Tuple[str, str, int, str, bool, int]]:
-            receipt_types_clause, args = make_in_list_sql_clause(
-                self.database_engine,
-                "receipt_type",
-                (
-                    ReceiptTypes.READ,
-                    ReceiptTypes.READ_PRIVATE,
-                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
-                ),
-            )
-
-            sql = f"""
-                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
-                    ep.highlight, e.received_ts
-                FROM (
-                    SELECT room_id,
-                        MAX(stream_ordering) as stream_ordering
-                    FROM events
-                    INNER JOIN receipts_linearized USING (room_id, event_id)
-                    WHERE {receipt_types_clause} AND user_id = ?
-                    GROUP BY room_id
-                ) AS rl,
-                event_push_actions AS ep
-                INNER JOIN events AS e USING (room_id, event_id)
-                WHERE
-                    ep.room_id = rl.room_id
-                    AND ep.stream_ordering > rl.stream_ordering
-                    AND ep.user_id = ?
-                    AND ep.stream_ordering > ?
-                    AND ep.stream_ordering <= ?
-                    AND ep.notif = 1
-                ORDER BY ep.stream_ordering DESC LIMIT ?
-            """
-            args.extend(
-                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
-            )
-            txn.execute(sql, args)
-            return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
-
-        after_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
+        receipts_by_room = dict(
+            await self.db_pool.runInteraction(
+                "get_unread_push_actions_for_user_in_range_email_receipts",
+                self._get_receipts_by_room_txn,
+                user_id=user_id,
+            ),
         )
 
-        # There are rooms with push actions in them but you don't have a read receipt in
-        # them e.g. rooms you've been invited to, so get push actions for rooms which do
-        # not have read receipts in them too.
-        def get_no_receipt(
+        def get_push_actions_txn(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool, int]]:
-            receipt_types_clause, args = make_in_list_sql_clause(
-                self.database_engine,
-                "receipt_type",
-                (
-                    ReceiptTypes.READ,
-                    ReceiptTypes.READ_PRIVATE,
-                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
-                ),
-            )
-
-            sql = f"""
+            sql = """
                 SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
                     ep.highlight, e.received_ts
                 FROM event_push_actions AS ep
                 INNER JOIN events AS e USING (room_id, event_id)
                 WHERE
-                    ep.room_id NOT IN (
-                        SELECT room_id FROM receipts_linearized
-                        WHERE {receipt_types_clause} AND user_id = ?
-                        GROUP BY room_id
-                    )
-                    AND ep.user_id = ?
+                    ep.user_id = ?
                     AND ep.stream_ordering > ?
                     AND ep.stream_ordering <= ?
                     AND ep.notif = 1
                 ORDER BY ep.stream_ordering DESC LIMIT ?
             """
-            args.extend(
-                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
-            )
-            txn.execute(sql, args)
+            txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
             return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
 
-        no_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
+        push_actions = await self.db_pool.runInteraction(
+            "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
         )
 
         # Make a list of dicts from the two sets of results.
         notifs = [
             EmailPushAction(
-                event_id=row[0],
-                room_id=row[1],
-                stream_ordering=row[2],
-                actions=_deserialize_action(row[3], row[4]),
-                received_ts=row[5],
+                event_id=event_id,
+                room_id=room_id,
+                stream_ordering=stream_ordering,
+                actions=_deserialize_action(actions, highlight),
+                received_ts=received_ts,
             )
-            for row in after_read_receipt + no_read_receipt
+            for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions
+            # Only include push actions with a stream ordering after any receipt, or without any
+            # receipt present (invited to but never read rooms).
+            if stream_ordering > receipts_by_room.get(room_id, 0)
         ]
 
         # Now sort it so it's ordered correctly, since currently it will
@@ -792,26 +700,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 int(count_as_unread),  # unread column
             )
 
-        def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
-            # We don't use simple_insert_many here to avoid the overhead
-            # of generating lists of dicts.
-
-            sql = """
-                INSERT INTO event_push_actions_staging
-                    (event_id, user_id, actions, notif, highlight, unread)
-                VALUES (?, ?, ?, ?, ?, ?)
-            """
-
-            txn.execute_batch(
-                sql,
-                (
-                    _gen_entry(user_id, actions)
-                    for user_id, actions in user_id_actions.items()
-                ),
-            )
-
-        return await self.db_pool.runInteraction(
-            "add_push_actions_to_staging", _add_push_actions_to_staging_txn
+        await self.db_pool.simple_insert_many(
+            "event_push_actions_staging",
+            keys=("event_id", "user_id", "actions", "notif", "highlight", "unread"),
+            values=[
+                _gen_entry(user_id, actions)
+                for user_id, actions in user_id_actions.items()
+            ],
+            desc="add_push_actions_to_staging",
         )
 
     async def remove_push_actions_from_staging(self, event_id: str) -> None:
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 255620f996..5079edd1e0 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -165,7 +165,6 @@ class PushRulesWorkerStore(
 
         return _load_rules(rows, enabled_map, self.hs.config.experimental)
 
-    @cached(max_entries=5000)
     async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
         results = await self.db_pool.simple_select_list(
             table="push_rules_enable",
@@ -229,9 +228,6 @@ class PushRulesWorkerStore(
 
         return results
 
-    @cachedList(
-        cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids"
-    )
     async def bulk_get_push_rules_enabled(
         self, user_ids: Collection[str]
     ) -> Dict[str, Dict[str, bool]]:
@@ -246,6 +242,7 @@ class PushRulesWorkerStore(
             iterable=user_ids,
             retcols=("user_name", "rule_id", "enabled"),
             desc="bulk_get_push_rules_enabled",
+            batch_size=1000,
         )
         for row in rows:
             enabled = bool(row["enabled"])
@@ -792,7 +789,6 @@ class PushRuleStore(PushRulesWorkerStore):
         self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values)
 
         txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
-        txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
         txn.call_after(
             self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
         )
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 827c1f1efd..06500457bd 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -187,27 +187,48 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     @cached(max_entries=100000, iterable=True)
     async def get_users_in_room(self, room_id: str) -> List[str]:
+        """
+        Returns a list of users in the room sorted by longest in the room first
+        (aka. with the lowest depth). This is done to match the sort in
+        `get_current_hosts_in_room()` and so we can re-use the cache but it's
+        not horrible to have here either.
+        """
+
         return await self.db_pool.runInteraction(
             "get_users_in_room", self.get_users_in_room_txn, room_id
         )
 
     def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
+        """
+        Returns a list of users in the room sorted by longest in the room first
+        (aka. with the lowest depth). This is done to match the sort in
+        `get_current_hosts_in_room()` and so we can re-use the cache but it's
+        not horrible to have here either.
+        """
         # If we can assume current_state_events.membership is up to date
         # then we can avoid a join, which is a Very Good Thing given how
         # frequently this function gets called.
         if self._current_state_events_membership_up_to_date:
             sql = """
-                SELECT state_key FROM current_state_events
-                WHERE type = 'm.room.member' AND room_id = ? AND membership = ?
+                SELECT c.state_key FROM current_state_events as c
+                /* Get the depth of the event from the events table */
+                INNER JOIN events AS e USING (event_id)
+                WHERE c.type = 'm.room.member' AND c.room_id = ? AND membership = ?
+                /* Sorted by lowest depth first */
+                ORDER BY e.depth ASC;
             """
         else:
             sql = """
-                SELECT state_key FROM room_memberships as m
+                SELECT c.state_key FROM room_memberships as m
+                /* Get the depth of the event from the events table */
+                INNER JOIN events AS e USING (event_id)
                 INNER JOIN current_state_events as c
                 ON m.event_id = c.event_id
                 AND m.room_id = c.room_id
                 AND m.user_id = c.state_key
                 WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
+                /* Sorted by lowest depth first */
+                ORDER BY e.depth ASC;
             """
 
         txn.execute(sql, (room_id, Membership.JOIN))
@@ -534,6 +555,32 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             desc="get_local_users_in_room",
         )
 
+    async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool:
+        """
+        Check whether a given local user is currently joined to the given room.
+
+        Returns:
+            A boolean indicating whether the user is currently joined to the room
+
+        Raises:
+            Exeption when called with a non-local user to this homeserver
+        """
+        if not self.hs.is_mine_id(user_id):
+            raise Exception(
+                "Cannot call 'check_local_user_in_room' on "
+                "non-local user %s" % (user_id,),
+            )
+
+        (
+            membership,
+            member_event_id,
+        ) = await self.get_local_current_membership_for_user_in_room(
+            user_id=user_id,
+            room_id=room_id,
+        )
+
+        return membership == Membership.JOIN
+
     async def get_local_current_membership_for_user_in_room(
         self, user_id: str, room_id: str
     ) -> Tuple[Optional[str], Optional[str]]:
@@ -835,9 +882,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return shared_room_ids or frozenset()
 
-    async def get_joined_users_from_state(
+    async def get_joined_user_ids_from_state(
         self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Set[str]:
         state_group: Union[object, int] = state_entry.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
@@ -848,25 +895,25 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         assert state_group is not None
         with Measure(self._clock, "get_joined_users_from_state"):
-            return await self._get_joined_users_from_context(
+            return await self._get_joined_user_ids_from_context(
                 room_id, state_group, state, context=state_entry
             )
 
     @cached(num_args=2, iterable=True, max_entries=100000)
-    async def _get_joined_users_from_context(
+    async def _get_joined_user_ids_from_context(
         self,
         room_id: str,
         state_group: Union[object, int],
         current_state_ids: StateMap[str],
         event: Optional[EventBase] = None,
         context: Optional["_StateCacheEntry"] = None,
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Set[str]:
         # We don't use `state_group`, it's there so that we can cache based
         # on it. However, it's important that it's never None, since two current_states
         # with a state_group of None are likely to be different.
         assert state_group is not None
 
-        users_in_room = {}
+        users_in_room = set()
         member_event_ids = [
             e_id
             for key, e_id in current_state_ids.items()
@@ -879,11 +926,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # If we do then we can reuse that result and simply update it with
             # any membership changes in `delta_ids`
             if context.prev_group and context.delta_ids:
-                prev_res = self._get_joined_users_from_context.cache.get_immediate(
+                prev_res = self._get_joined_user_ids_from_context.cache.get_immediate(
                     (room_id, context.prev_group), None
                 )
-                if prev_res and isinstance(prev_res, dict):
-                    users_in_room = dict(prev_res)
+                if prev_res and isinstance(prev_res, set):
+                    users_in_room = prev_res
                     member_event_ids = [
                         e_id
                         for key, e_id in context.delta_ids.items()
@@ -891,7 +938,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                     ]
                     for etype, state_key in context.delta_ids:
                         if etype == EventTypes.Member:
-                            users_in_room.pop(state_key, None)
+                            users_in_room.discard(state_key)
 
         # We check if we have any of the member event ids in the event cache
         # before we ask the DB
@@ -908,71 +955,64 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             ev_entry = event_map.get(event_id)
             if ev_entry and not ev_entry.event.rejected_reason:
                 if ev_entry.event.membership == Membership.JOIN:
-                    users_in_room[ev_entry.event.state_key] = ProfileInfo(
-                        display_name=ev_entry.event.content.get("displayname", None),
-                        avatar_url=ev_entry.event.content.get("avatar_url", None),
-                    )
+                    users_in_room.add(ev_entry.event.state_key)
             else:
                 missing_member_event_ids.append(event_id)
 
         if missing_member_event_ids:
-            event_to_memberships = await self._get_joined_profiles_from_event_ids(
+            event_to_memberships = await self._get_user_ids_from_membership_event_ids(
                 missing_member_event_ids
             )
-            users_in_room.update(row for row in event_to_memberships.values() if row)
+            users_in_room.update(
+                user_id for user_id in event_to_memberships.values() if user_id
+            )
 
         if event is not None and event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
                 if event.event_id in member_event_ids:
-                    users_in_room[event.state_key] = ProfileInfo(
-                        display_name=event.content.get("displayname", None),
-                        avatar_url=event.content.get("avatar_url", None),
-                    )
+                    users_in_room.add(event.state_key)
 
         return users_in_room
 
-    @cached(max_entries=10000)
-    def _get_joined_profile_from_event_id(
+    @cached(
+        max_entries=10000,
+        # This name matches the old function that has been replaced - the cache name
+        # is kept here to maintain backwards compatibility.
+        name="_get_joined_profile_from_event_id",
+    )
+    def _get_user_id_from_membership_event_id(
         self, event_id: str
     ) -> Optional[Tuple[str, ProfileInfo]]:
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="_get_joined_profile_from_event_id",
+        cached_method_name="_get_user_id_from_membership_event_id",
         list_name="event_ids",
     )
-    async def _get_joined_profiles_from_event_ids(
+    async def _get_user_ids_from_membership_event_ids(
         self, event_ids: Iterable[str]
-    ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
+    ) -> Dict[str, Optional[str]]:
         """For given set of member event_ids check if they point to a join
-        event and if so return the associated user and profile info.
+        event.
 
         Args:
             event_ids: The member event IDs to lookup
 
         Returns:
-            Map from event ID to `user_id` and ProfileInfo (or None if not join event).
+            Map from event ID to `user_id`, or None if event is not a join.
         """
 
         rows = await self.db_pool.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=event_ids,
-            retcols=("user_id", "display_name", "avatar_url", "event_id"),
+            retcols=("user_id", "event_id"),
             keyvalues={"membership": Membership.JOIN},
             batch_size=1000,
-            desc="_get_joined_profiles_from_event_ids",
+            desc="_get_user_ids_from_membership_event_ids",
         )
 
-        return {
-            row["event_id"]: (
-                row["user_id"],
-                ProfileInfo(
-                    avatar_url=row["avatar_url"], display_name=row["display_name"]
-                ),
-            )
-            for row in rows
-        }
+        return {row["event_id"]: row["user_id"] for row in rows}
 
     @cached(max_entries=10000)
     async def is_host_joined(self, room_id: str, host: str) -> bool:
@@ -1018,37 +1058,70 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return True
 
     @cached(iterable=True, max_entries=10000)
-    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
-        """Get current hosts in room based on current state."""
+    async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
+        """
+        Get current hosts in room based on current state.
+
+        The heuristic of sorting by servers who have been in the room the
+        longest is good because they're most likely to have anything we ask
+        about.
+
+        Returns:
+            Returns a list of servers sorted by longest in the room first. (aka.
+            sorted by join with the lowest depth first).
+        """
 
         # First we check if we already have `get_users_in_room` in the cache, as
         # we can just calculate result from that
         users = self.get_users_in_room.cache.get_immediate(
             (room_id,), None, update_metrics=False
         )
-        if users is not None:
-            return {get_domain_from_id(u) for u in users}
-
-        if isinstance(self.database_engine, Sqlite3Engine):
+        if users is None and isinstance(self.database_engine, Sqlite3Engine):
             # If we're using SQLite then let's just always use
             # `get_users_in_room` rather than funky SQL.
             users = await self.get_users_in_room(room_id)
-            return {get_domain_from_id(u) for u in users}
+
+        if users is not None:
+            # Because `users` is sorted from lowest -> highest depth, the list
+            # of domains will also be sorted that way.
+            domains: List[str] = []
+            # We use a `Set` just for fast lookups
+            domain_set: Set[str] = set()
+            for u in users:
+                domain = get_domain_from_id(u)
+                if domain not in domain_set:
+                    domain_set.add(domain)
+                    domains.append(domain)
+            return domains
 
         # For PostgreSQL we can use a regex to pull out the domains from the
         # joined users in `current_state_events` via regex.
 
-        def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
+        def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> List[str]:
+            # Returns a list of servers currently joined in the room sorted by
+            # longest in the room first (aka. with the lowest depth). The
+            # heuristic of sorting by servers who have been in the room the
+            # longest is good because they're most likely to have anything we
+            # ask about.
             sql = """
-                SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
-                FROM current_state_events
+                SELECT
+                    /* Match the domain part of the MXID */
+                    substring(c.state_key FROM '@[^:]*:(.*)$') as server_domain
+                FROM current_state_events c
+                /* Get the depth of the event from the events table */
+                INNER JOIN events AS e USING (event_id)
                 WHERE
-                    type = 'm.room.member'
-                    AND membership = 'join'
-                    AND room_id = ?
+                    /* Find any join state events in the room */
+                    c.type = 'm.room.member'
+                    AND c.membership = 'join'
+                    AND c.room_id = ?
+                /* Group all state events from the same domain into their own buckets (groups) */
+                GROUP BY server_domain
+                /* Sorted by lowest depth first */
+                ORDER BY min(e.depth) ASC;
             """
             txn.execute(sql, (room_id,))
-            return {d for d, in txn}
+            return [d for d, in txn]
 
         return await self.db_pool.runInteraction(
             "get_current_hosts_in_room", get_current_hosts_in_room_txn
@@ -1131,12 +1204,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             else:
                 # The cache doesn't match the state group or prev state group,
                 # so we calculate the result from first principles.
-                joined_users = await self.get_joined_users_from_state(
+                joined_user_ids = await self.get_joined_user_ids_from_state(
                     room_id, state, state_entry
                 )
 
                 cache.hosts_to_joined_users = {}
-                for user_id in joined_users:
+                for user_id in joined_user_ids:
                     host = intern_string(get_domain_from_id(user_id))
                     cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
 
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 3c13859faa..2dfe4c0b66 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -460,8 +460,17 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
                 # Cast safety: this corresponds to the types returned by the query above.
                 rows.extend(cast(Iterable[Tuple[str, int]], cur))
 
-            # Sort so that we handle rows in order for each instance.
-            rows.sort()
+            # Sort by stream_id (ascending, lowest -> highest) so that we handle
+            # rows in order for each instance because we don't want to overwrite
+            # the current_position of an instance to a lower stream ID than
+            # we're actually at.
+            def sort_by_stream_id_key_func(row: Tuple[str, int]) -> int:
+                (instance, stream_id) = row
+                # If `stream_id` is ever `None`, we will see a `TypeError: '<'
+                # not supported between instances of 'NoneType' and 'X'` error.
+                return stream_id
+
+            rows.sort(key=sort_by_stream_id_key_func)
 
             with self._lock:
                 for (
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 42f6abb5e1..bdf9b0dc8c 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -34,10 +34,10 @@ TRACK_MEMORY_USAGE = False
 caches_by_name: Dict[str, Sized] = {}
 collectors_by_name: Dict[str, "CacheMetric"] = {}
 
-cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
-cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
-cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name", "reason"])
-cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"])
+cache_size = Gauge("synapse_util_caches_cache_size", "", ["name"])
+cache_hits = Gauge("synapse_util_caches_cache_hits", "", ["name"])
+cache_evicted = Gauge("synapse_util_caches_cache_evicted_size", "", ["name", "reason"])
+cache_total = Gauge("synapse_util_caches_cache_total", "", ["name"])
 cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"])
 cache_memory_usage = Gauge(
     "synapse_util_caches_cache_size_bytes",
@@ -45,12 +45,12 @@ cache_memory_usage = Gauge(
     ["name"],
 )
 
-response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"])
-response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"])
+response_cache_size = Gauge("synapse_util_caches_response_cache_size", "", ["name"])
+response_cache_hits = Gauge("synapse_util_caches_response_cache_hits", "", ["name"])
 response_cache_evicted = Gauge(
-    "synapse_util_caches_response_cache:evicted_size", "", ["name", "reason"]
+    "synapse_util_caches_response_cache_evicted_size", "", ["name", "reason"]
 )
-response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
+response_cache_total = Gauge("synapse_util_caches_response_cache_total", "", ["name"])
 
 
 class EvictionReason(Enum):
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 1d6ec22191..6425f851ea 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -14,15 +14,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
 import enum
 import threading
 from typing import (
     Callable,
+    Collection,
+    Dict,
     Generic,
-    Iterable,
     MutableMapping,
     Optional,
+    Set,
     Sized,
+    Tuple,
     TypeVar,
     Union,
     cast,
@@ -31,7 +35,6 @@ from typing import (
 from prometheus_client import Gauge
 
 from twisted.internet import defer
-from twisted.python import failure
 from twisted.python.failure import Failure
 
 from synapse.util.async_helpers import ObservableDeferred
@@ -94,7 +97,7 @@ class DeferredCache(Generic[KT, VT]):
 
         # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
         self._pending_deferred_cache: Union[
-            TreeCache, "MutableMapping[KT, CacheEntry]"
+            TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]"
         ] = cache_type()
 
         def metrics_cb() -> None:
@@ -159,15 +162,16 @@ class DeferredCache(Generic[KT, VT]):
         Raises:
             KeyError if the key is not found in the cache
         """
-        callbacks = [callback] if callback else []
         val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
         if val is not _Sentinel.sentinel:
-            val.callbacks.update(callbacks)
+            val.add_invalidation_callback(key, callback)
             if update_metrics:
                 m = self.cache.metrics
                 assert m  # we always have a name, so should always have metrics
                 m.inc_hits()
-            return val.deferred.observe()
+            return val.deferred(key)
+
+        callbacks = (callback,) if callback else ()
 
         val2 = self.cache.get(
             key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
@@ -177,6 +181,73 @@ class DeferredCache(Generic[KT, VT]):
         else:
             return defer.succeed(val2)
 
+    def get_bulk(
+        self,
+        keys: Collection[KT],
+        callback: Optional[Callable[[], None]] = None,
+    ) -> Tuple[Dict[KT, VT], Optional["defer.Deferred[Dict[KT, VT]]"], Collection[KT]]:
+        """Bulk lookup of items in the cache.
+
+        Returns:
+            A 3-tuple of:
+                1. a dict of key/value of items already cached;
+                2. a deferred that resolves to a dict of key/value of items
+                   we're already fetching; and
+                3. a collection of keys that don't appear in the previous two.
+        """
+
+        # The cached results
+        cached = {}
+
+        # List of pending deferreds
+        pending = []
+
+        # Dict that gets filled out when the pending deferreds complete
+        pending_results = {}
+
+        # List of keys that aren't in either cache
+        missing = []
+
+        callbacks = (callback,) if callback else ()
+
+        for key in keys:
+            # Check if its in the main cache.
+            immediate_value = self.cache.get(
+                key,
+                _Sentinel.sentinel,
+                callbacks=callbacks,
+            )
+            if immediate_value is not _Sentinel.sentinel:
+                cached[key] = immediate_value
+                continue
+
+            # Check if its in the pending cache
+            pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
+            if pending_value is not _Sentinel.sentinel:
+                pending_value.add_invalidation_callback(key, callback)
+
+                def completed_cb(value: VT, key: KT) -> VT:
+                    pending_results[key] = value
+                    return value
+
+                # Add a callback to fill out `pending_results` when that completes
+                d = pending_value.deferred(key).addCallback(completed_cb, key)
+                pending.append(d)
+                continue
+
+            # Not in either cache
+            missing.append(key)
+
+        # If we've got pending deferreds, squash them into a single one that
+        # returns `pending_results`.
+        pending_deferred = None
+        if pending:
+            pending_deferred = defer.gatherResults(
+                pending, consumeErrors=True
+            ).addCallback(lambda _: pending_results)
+
+        return (cached, pending_deferred, missing)
+
     def get_immediate(
         self, key: KT, default: T, update_metrics: bool = True
     ) -> Union[VT, T]:
@@ -218,84 +289,89 @@ class DeferredCache(Generic[KT, VT]):
             value: a deferred which will complete with a result to add to the cache
             callback: An optional callback to be called when the entry is invalidated
         """
-        if not isinstance(value, defer.Deferred):
-            raise TypeError("not a Deferred")
-
-        callbacks = [callback] if callback else []
         self.check_thread()
 
-        existing_entry = self._pending_deferred_cache.pop(key, None)
-        if existing_entry:
-            existing_entry.invalidate()
+        self._pending_deferred_cache.pop(key, None)
 
         # XXX: why don't we invalidate the entry in `self.cache` yet?
 
-        # we can save a whole load of effort if the deferred is ready.
-        if value.called:
-            result = value.result
-            if not isinstance(result, failure.Failure):
-                self.cache.set(key, cast(VT, result), callbacks)
-            return value
-
         # otherwise, we'll add an entry to the _pending_deferred_cache for now,
         # and add callbacks to add it to the cache properly later.
+        entry = CacheEntrySingle[KT, VT](value)
+        entry.add_invalidation_callback(key, callback)
+        self._pending_deferred_cache[key] = entry
+        deferred = entry.deferred(key).addCallbacks(
+            self._completed_callback,
+            self._error_callback,
+            callbackArgs=(entry, key),
+            errbackArgs=(entry, key),
+        )
 
-        observable = ObservableDeferred(value, consumeErrors=True)
-        observer = observable.observe()
-        entry = CacheEntry(deferred=observable, callbacks=callbacks)
+        # we return a new Deferred which will be called before any subsequent observers.
+        return deferred
 
-        self._pending_deferred_cache[key] = entry
+    def start_bulk_input(
+        self,
+        keys: Collection[KT],
+        callback: Optional[Callable[[], None]] = None,
+    ) -> "CacheMultipleEntries[KT, VT]":
+        """Bulk set API for use when fetching multiple keys at once from the DB.
 
-        def compare_and_pop() -> bool:
-            """Check if our entry is still the one in _pending_deferred_cache, and
-            if so, pop it.
-
-            Returns true if the entries matched.
-            """
-            existing_entry = self._pending_deferred_cache.pop(key, None)
-            if existing_entry is entry:
-                return True
-
-            # oops, the _pending_deferred_cache has been updated since
-            # we started our query, so we are out of date.
-            #
-            # Better put back whatever we took out. (We do it this way
-            # round, rather than peeking into the _pending_deferred_cache
-            # and then removing on a match, to make the common case faster)
-            if existing_entry is not None:
-                self._pending_deferred_cache[key] = existing_entry
-
-            return False
-
-        def cb(result: VT) -> None:
-            if compare_and_pop():
-                self.cache.set(key, result, entry.callbacks)
-            else:
-                # we're not going to put this entry into the cache, so need
-                # to make sure that the invalidation callbacks are called.
-                # That was probably done when _pending_deferred_cache was
-                # updated, but it's possible that `set` was called without
-                # `invalidate` being previously called, in which case it may
-                # not have been. Either way, let's double-check now.
-                entry.invalidate()
-
-        def eb(_fail: Failure) -> None:
-            compare_and_pop()
-            entry.invalidate()
-
-        # once the deferred completes, we can move the entry from the
-        # _pending_deferred_cache to the real cache.
-        #
-        observer.addCallbacks(cb, eb)
+        Called *before* starting the fetch from the DB, and the caller *must*
+        call either `complete_bulk(..)` or `error_bulk(..)` on the return value.
+        """
 
-        # we return a new Deferred which will be called before any subsequent observers.
-        return observable.observe()
+        entry = CacheMultipleEntries[KT, VT]()
+        entry.add_global_invalidation_callback(callback)
+
+        for key in keys:
+            self._pending_deferred_cache[key] = entry
+
+        return entry
+
+    def _completed_callback(
+        self, value: VT, entry: "CacheEntry[KT, VT]", key: KT
+    ) -> VT:
+        """Called when a deferred is completed."""
+        # We check if the current entry matches the entry associated with the
+        # deferred. If they don't match then it got invalidated.
+        current_entry = self._pending_deferred_cache.pop(key, None)
+        if current_entry is not entry:
+            if current_entry:
+                self._pending_deferred_cache[key] = current_entry
+            return value
+
+        self.cache.set(key, value, entry.get_invalidation_callbacks(key))
+
+        return value
+
+    def _error_callback(
+        self,
+        failure: Failure,
+        entry: "CacheEntry[KT, VT]",
+        key: KT,
+    ) -> Failure:
+        """Called when a deferred errors."""
+
+        # We check if the current entry matches the entry associated with the
+        # deferred. If they don't match then it got invalidated.
+        current_entry = self._pending_deferred_cache.pop(key, None)
+        if current_entry is not entry:
+            if current_entry:
+                self._pending_deferred_cache[key] = current_entry
+            return failure
+
+        for cb in entry.get_invalidation_callbacks(key):
+            cb()
+
+        return failure
 
     def prefill(
         self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
     ) -> None:
-        callbacks = [callback] if callback else []
+        callbacks = (callback,) if callback else ()
         self.cache.set(key, value, callbacks=callbacks)
+        self._pending_deferred_cache.pop(key, None)
 
     def invalidate(self, key: KT) -> None:
         """Delete a key, or tree of entries
@@ -311,41 +387,129 @@ class DeferredCache(Generic[KT, VT]):
         self.cache.del_multi(key)
 
         # if we have a pending lookup for this key, remove it from the
-        # _pending_deferred_cache, which will (a) stop it being returned
-        # for future queries and (b) stop it being persisted as a proper entry
+        # _pending_deferred_cache, which will (a) stop it being returned for
+        # future queries and (b) stop it being persisted as a proper entry
         # in self.cache.
         entry = self._pending_deferred_cache.pop(key, None)
-
-        # run the invalidation callbacks now, rather than waiting for the
-        # deferred to resolve.
         if entry:
             # _pending_deferred_cache.pop should either return a CacheEntry, or, in the
             # case of a TreeCache, a dict of keys to cache entries. Either way calling
             # iterate_tree_cache_entry on it will do the right thing.
             for entry in iterate_tree_cache_entry(entry):
-                entry.invalidate()
+                for cb in entry.get_invalidation_callbacks(key):
+                    cb()
 
     def invalidate_all(self) -> None:
         self.check_thread()
         self.cache.clear()
-        for entry in self._pending_deferred_cache.values():
-            entry.invalidate()
+        for key, entry in self._pending_deferred_cache.items():
+            for cb in entry.get_invalidation_callbacks(key):
+                cb()
+
         self._pending_deferred_cache.clear()
 
 
-class CacheEntry:
-    __slots__ = ["deferred", "callbacks", "invalidated"]
+class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta):
+    """Abstract class for entries in `DeferredCache[KT, VT]`"""
 
-    def __init__(
-        self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
-    ):
-        self.deferred = deferred
-        self.callbacks = set(callbacks)
-        self.invalidated = False
-
-    def invalidate(self) -> None:
-        if not self.invalidated:
-            self.invalidated = True
-            for callback in self.callbacks:
-                callback()
-            self.callbacks.clear()
+    @abc.abstractmethod
+    def deferred(self, key: KT) -> "defer.Deferred[VT]":
+        """Get a deferred that a caller can wait on to get the value at the
+        given key"""
+        ...
+
+    @abc.abstractmethod
+    def add_invalidation_callback(
+        self, key: KT, callback: Optional[Callable[[], None]]
+    ) -> None:
+        """Add an invalidation callback"""
+        ...
+
+    @abc.abstractmethod
+    def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+        """Get all invalidation callbacks"""
+        ...
+
+
+class CacheEntrySingle(CacheEntry[KT, VT]):
+    """An implementation of `CacheEntry` wrapping a deferred that results in a
+    single cache entry.
+    """
+
+    __slots__ = ["_deferred", "_callbacks"]
+
+    def __init__(self, deferred: "defer.Deferred[VT]") -> None:
+        self._deferred = ObservableDeferred(deferred, consumeErrors=True)
+        self._callbacks: Set[Callable[[], None]] = set()
+
+    def deferred(self, key: KT) -> "defer.Deferred[VT]":
+        return self._deferred.observe()
+
+    def add_invalidation_callback(
+        self, key: KT, callback: Optional[Callable[[], None]]
+    ) -> None:
+        if callback is None:
+            return
+
+        self._callbacks.add(callback)
+
+    def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+        return self._callbacks
+
+
+class CacheMultipleEntries(CacheEntry[KT, VT]):
+    """Cache entry that is used for bulk lookups and insertions."""
+
+    __slots__ = ["_deferred", "_callbacks", "_global_callbacks"]
+
+    def __init__(self) -> None:
+        self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None
+        self._callbacks: Dict[KT, Set[Callable[[], None]]] = {}
+        self._global_callbacks: Set[Callable[[], None]] = set()
+
+    def deferred(self, key: KT) -> "defer.Deferred[VT]":
+        if not self._deferred:
+            self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+        return self._deferred.observe().addCallback(lambda res: res.get(key))
+
+    def add_invalidation_callback(
+        self, key: KT, callback: Optional[Callable[[], None]]
+    ) -> None:
+        if callback is None:
+            return
+
+        self._callbacks.setdefault(key, set()).add(callback)
+
+    def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+        return self._callbacks.get(key, set()) | self._global_callbacks
+
+    def add_global_invalidation_callback(
+        self, callback: Optional[Callable[[], None]]
+    ) -> None:
+        """Add a callback for when any keys get invalidated."""
+        if callback is None:
+            return
+
+        self._global_callbacks.add(callback)
+
+    def complete_bulk(
+        self,
+        cache: DeferredCache[KT, VT],
+        result: Dict[KT, VT],
+    ) -> None:
+        """Called when there is a result"""
+        for key, value in result.items():
+            cache._completed_callback(value, self, key)
+
+        if self._deferred:
+            self._deferred.callback(result)
+
+    def error_bulk(
+        self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure
+    ) -> None:
+        """Called when bulk lookup failed."""
+        for key in keys:
+            cache._error_callback(failure, self, key)
+
+        if self._deferred:
+            self._deferred.errback(failure)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 867f315b2a..10aff4d04a 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -25,6 +25,7 @@ from typing import (
     Generic,
     Hashable,
     Iterable,
+    List,
     Mapping,
     Optional,
     Sequence,
@@ -73,8 +74,10 @@ class _CacheDescriptorBase:
         num_args: Optional[int],
         uncached_args: Optional[Collection[str]] = None,
         cache_context: bool = False,
+        name: Optional[str] = None,
     ):
         self.orig = orig
+        self.name = name or orig.__name__
 
         arg_spec = inspect.getfullargspec(orig)
         all_args = arg_spec.args
@@ -211,7 +214,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
         cache: LruCache[CacheKey, Any] = LruCache(
-            cache_name=self.orig.__name__,
+            cache_name=self.name,
             max_size=self.max_entries,
         )
 
@@ -241,7 +244,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
         wrapped = cast(_CachedFunction, _wrapped)
         wrapped.cache = cache
-        obj.__dict__[self.orig.__name__] = wrapped
+        obj.__dict__[self.name] = wrapped
 
         return wrapped
 
@@ -301,12 +304,14 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         cache_context: bool = False,
         iterable: bool = False,
         prune_unread_entries: bool = True,
+        name: Optional[str] = None,
     ):
         super().__init__(
             orig,
             num_args=num_args,
             uncached_args=uncached_args,
             cache_context=cache_context,
+            name=name,
         )
 
         if tree and self.num_args < 2:
@@ -321,7 +326,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
         cache: DeferredCache[CacheKey, Any] = DeferredCache(
-            name=self.orig.__name__,
+            name=self.name,
             max_entries=self.max_entries,
             tree=self.tree,
             iterable=self.iterable,
@@ -372,7 +377,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         wrapped.cache = cache
         wrapped.num_args = self.num_args
 
-        obj.__dict__[self.orig.__name__] = wrapped
+        obj.__dict__[self.name] = wrapped
 
         return wrapped
 
@@ -393,6 +398,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
         cached_method_name: str,
         list_name: str,
         num_args: Optional[int] = None,
+        name: Optional[str] = None,
     ):
         """
         Args:
@@ -403,7 +409,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                 but including list_name) to use as cache keys. Defaults to all
                 named args of the function.
         """
-        super().__init__(orig, num_args=num_args, uncached_args=None)
+        super().__init__(orig, num_args=num_args, uncached_args=None, name=name)
 
         self.list_name = list_name
 
@@ -435,16 +441,6 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
             keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
             list_args = arg_dict[self.list_name]
 
-            results = {}
-
-            def update_results_dict(res: Any, arg: Hashable) -> None:
-                results[arg] = res
-
-            # list of deferreds to wait for
-            cached_defers = []
-
-            missing = set()
-
             # If the cache takes a single arg then that is used as the key,
             # otherwise a tuple is used.
             if num_args == 1:
@@ -452,6 +448,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                 def arg_to_cache_key(arg: Hashable) -> Hashable:
                     return arg
 
+                def cache_key_to_arg(key: tuple) -> Hashable:
+                    return key
+
             else:
                 keylist = list(keyargs)
 
@@ -459,58 +458,53 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                     keylist[self.list_pos] = arg
                     return tuple(keylist)
 
-            for arg in list_args:
-                try:
-                    res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
-                    if not res.called:
-                        res.addCallback(update_results_dict, arg)
-                        cached_defers.append(res)
-                    else:
-                        results[arg] = res.result
-                except KeyError:
-                    missing.add(arg)
+                def cache_key_to_arg(key: tuple) -> Hashable:
+                    return key[self.list_pos]
+
+            cache_keys = [arg_to_cache_key(arg) for arg in list_args]
+            immediate_results, pending_deferred, missing = cache.get_bulk(
+                cache_keys, callback=invalidate_callback
+            )
+
+            results = {cache_key_to_arg(key): v for key, v in immediate_results.items()}
+
+            cached_defers: List["defer.Deferred[Any]"] = []
+            if pending_deferred:
+
+                def update_results(r: Dict) -> None:
+                    for k, v in r.items():
+                        results[cache_key_to_arg(k)] = v
+
+                pending_deferred.addCallback(update_results)
+                cached_defers.append(pending_deferred)
 
             if missing:
-                # we need a deferred for each entry in the list,
-                # which we put in the cache. Each deferred resolves with the
-                # relevant result for that key.
-                deferreds_map = {}
-                for arg in missing:
-                    deferred: "defer.Deferred[Any]" = defer.Deferred()
-                    deferreds_map[arg] = deferred
-                    key = arg_to_cache_key(arg)
-                    cached_defers.append(
-                        cache.set(key, deferred, callback=invalidate_callback)
-                    )
+                cache_entry = cache.start_bulk_input(missing, invalidate_callback)
 
                 def complete_all(res: Dict[Hashable, Any]) -> None:
-                    # the wrapped function has completed. It returns a dict.
-                    # We can now update our own result map, and then resolve the
-                    # observable deferreds in the cache.
-                    for e, d1 in deferreds_map.items():
-                        val = res.get(e, None)
-                        # make sure we update the results map before running the
-                        # deferreds, because as soon as we run the last deferred, the
-                        # gatherResults() below will complete and return the result
-                        # dict to our caller.
-                        results[e] = val
-                        d1.callback(val)
+                    missing_results = {}
+                    for key in missing:
+                        arg = cache_key_to_arg(key)
+                        val = res.get(arg, None)
+
+                        results[arg] = val
+                        missing_results[key] = val
+
+                    cache_entry.complete_bulk(cache, missing_results)
 
                 def errback_all(f: Failure) -> None:
-                    # the wrapped function has failed. Propagate the failure into
-                    # the cache, which will invalidate the entry, and cause the
-                    # relevant cached_deferreds to fail, which will propagate the
-                    # failure to our caller.
-                    for d1 in deferreds_map.values():
-                        d1.errback(f)
+                    cache_entry.error_bulk(cache, missing, f)
 
                 args_to_call = dict(arg_dict)
-                args_to_call[self.list_name] = missing
+                args_to_call[self.list_name] = {
+                    cache_key_to_arg(key) for key in missing
+                }
 
                 # dispatch the call, and attach the two handlers
-                defer.maybeDeferred(
+                missing_d = defer.maybeDeferred(
                     preserve_fn(self.orig), **args_to_call
                 ).addCallbacks(complete_all, errback_all)
+                cached_defers.append(missing_d)
 
             if cached_defers:
                 d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
@@ -525,7 +519,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
             else:
                 return defer.succeed(results)
 
-        obj.__dict__[self.orig.__name__] = wrapped
+        obj.__dict__[self.name] = wrapped
 
         return wrapped
 
@@ -577,6 +571,7 @@ def cached(
     cache_context: bool = False,
     iterable: bool = False,
     prune_unread_entries: bool = True,
+    name: Optional[str] = None,
 ) -> Callable[[F], _CachedFunction[F]]:
     func = lambda orig: DeferredCacheDescriptor(
         orig,
@@ -587,13 +582,18 @@ def cached(
         cache_context=cache_context,
         iterable=iterable,
         prune_unread_entries=prune_unread_entries,
+        name=name,
     )
 
     return cast(Callable[[F], _CachedFunction[F]], func)
 
 
 def cachedList(
-    *, cached_method_name: str, list_name: str, num_args: Optional[int] = None
+    *,
+    cached_method_name: str,
+    list_name: str,
+    num_args: Optional[int] = None,
+    name: Optional[str] = None,
 ) -> Callable[[F], _CachedFunction[F]]:
     """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
 
@@ -628,6 +628,7 @@ def cachedList(
         cached_method_name=cached_method_name,
         list_name=list_name,
         num_args=num_args,
+        name=name,
     )
 
     return cast(Callable[[F], _CachedFunction[F]], func)
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index c1b8ec0c73..fec31da2b6 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -135,6 +135,9 @@ class TreeCache:
     def values(self):
         return iterate_tree_cache_entry(self.root)
 
+    def items(self):
+        return iterate_tree_cache_items((), self.root)
+
     def __len__(self) -> int:
         return self.size