summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py8
-rw-r--r--synapse/api/room_versions.py2
-rw-r--r--synapse/app/_base.py5
-rw-r--r--synapse/app/generic_worker.py4
-rw-r--r--synapse/config/experimental.py23
-rw-r--r--synapse/config/tls.py22
-rw-r--r--synapse/federation/transport/server.py16
-rw-r--r--synapse/handlers/federation.py12
-rw-r--r--synapse/http/servlet.py196
-rw-r--r--synapse/logging/opentracing.py31
-rw-r--r--synapse/metrics/background_process_metrics.py10
-rw-r--r--synapse/replication/slave/storage/devices.py2
-rw-r--r--synapse/rest/admin/media.py28
-rw-r--r--synapse/rest/client/v1/room.py12
-rw-r--r--synapse/rest/client/v2_alpha/report_event.py13
-rw-r--r--synapse/storage/databases/main/cache.py7
-rw-r--r--synapse/storage/databases/main/devices.py2
-rw-r--r--synapse/storage/databases/main/event_push_actions.py2
-rw-r--r--synapse/storage/databases/main/events.py8
-rw-r--r--synapse/storage/databases/main/events_worker.py61
-rw-r--r--synapse/storage/databases/main/media_repository.py7
-rw-r--r--synapse/storage/databases/main/purge_events.py26
-rw-r--r--synapse/storage/databases/main/receipts.py6
-rw-r--r--synapse/storage/databases/main/room.py2
-rw-r--r--synapse/util/batching_queue.py70
-rw-r--r--synapse/util/caches/deferred_cache.py42
-rw-r--r--synapse/util/caches/descriptors.py8
-rw-r--r--synapse/util/caches/lrucache.py18
-rw-r--r--synapse/util/caches/treecache.py3
29 files changed, 419 insertions, 227 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 458306eba5..26a3b38918 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -206,11 +206,11 @@ class Auth:
                 requester = create_requester(user_id, app_service=app_service)
 
                 request.requester = user_id
+                if user_id in self._force_tracing_for_users:
+                    opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
                 opentracing.set_tag("authenticated_entity", user_id)
                 opentracing.set_tag("user_id", user_id)
                 opentracing.set_tag("appservice_id", app_service.id)
-                if user_id in self._force_tracing_for_users:
-                    opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
 
                 return requester
 
@@ -259,12 +259,12 @@ class Auth:
             )
 
             request.requester = requester
+            if user_info.token_owner in self._force_tracing_for_users:
+                opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
             opentracing.set_tag("authenticated_entity", user_info.token_owner)
             opentracing.set_tag("user_id", user_info.user_id)
             if device_id:
                 opentracing.set_tag("device_id", device_id)
-            if user_info.token_owner in self._force_tracing_for_users:
-                opentracing.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
 
             return requester
         except KeyError:
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index c9f9596ada..373a4669d0 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -181,6 +181,6 @@ KNOWN_ROOM_VERSIONS = {
         RoomVersions.V5,
         RoomVersions.V6,
         RoomVersions.MSC2176,
+        RoomVersions.MSC3083,
     )
-    # Note that we do not include MSC3083 here unless it is enabled in the config.
 }  # type: Dict[str, RoomVersion]
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 59918d789e..1329af2e2b 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -261,13 +261,10 @@ def refresh_certificate(hs):
     Refresh the TLS certificates that Synapse is using by re-reading them from
     disk and updating the TLS context factories to use them.
     """
-
     if not hs.config.has_tls_listener():
-        # attempt to reload the certs for the good of the tls_fingerprints
-        hs.config.read_certificate_from_disk(require_cert_and_key=False)
         return
 
-    hs.config.read_certificate_from_disk(require_cert_and_key=True)
+    hs.config.read_certificate_from_disk()
     hs.tls_server_context_factory = context_factory.ServerContextFactory(hs.config)
 
     if hs._listening_services:
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 91ad326f19..57c2fc2e88 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -109,7 +109,7 @@ from synapse.storage.databases.main.monthly_active_users import (
     MonthlyActiveUsersWorkerStore,
 )
 from synapse.storage.databases.main.presence import PresenceStore
-from synapse.storage.databases.main.search import SearchWorkerStore
+from synapse.storage.databases.main.search import SearchStore
 from synapse.storage.databases.main.stats import StatsStore
 from synapse.storage.databases.main.transactions import TransactionWorkerStore
 from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
@@ -242,7 +242,7 @@ class GenericWorkerSlavedStore(
     MonthlyActiveUsersWorkerStore,
     MediaRepositoryStore,
     ServerMetricsStore,
-    SearchWorkerStore,
+    SearchStore,
     TransactionWorkerStore,
     BaseSlavedStore,
 ):
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index cc67377f0f..6ebce4b2f7 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
 from synapse.config._base import Config
 from synapse.types import JsonDict
 
@@ -28,27 +27,5 @@ class ExperimentalConfig(Config):
         # MSC2858 (multiple SSO identity providers)
         self.msc2858_enabled = experimental.get("msc2858_enabled", False)  # type: bool
 
-        # Spaces (MSC1772, MSC2946, MSC3083, etc)
-        self.spaces_enabled = experimental.get("spaces_enabled", True)  # type: bool
-        if self.spaces_enabled:
-            KNOWN_ROOM_VERSIONS[RoomVersions.MSC3083.identifier] = RoomVersions.MSC3083
-
         # MSC3026 (busy presence state)
         self.msc3026_enabled = experimental.get("msc3026_enabled", False)  # type: bool
-
-    def generate_config_section(self, **kwargs):
-        return """\
-        # Enable experimental features in Synapse.
-        #
-        # Experimental features might break or be removed without a deprecation
-        # period.
-        #
-        experimental_features:
-          # Support for Spaces (MSC1772), it enables the following:
-          #
-          # * The Spaces Summary API (MSC2946).
-          # * Restricting room membership based on space membership (MSC3083).
-          #
-          # Uncomment to disable support for Spaces.
-          #spaces_enabled: false
-        """
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 26f1150ca5..0e9bba53c9 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -215,28 +215,12 @@ class TlsConfig(Config):
         days_remaining = (expires_on - now).days
         return days_remaining
 
-    def read_certificate_from_disk(self, require_cert_and_key: bool):
+    def read_certificate_from_disk(self):
         """
         Read the certificates and private key from disk.
-
-        Args:
-            require_cert_and_key: set to True to throw an error if the certificate
-                and key file are not given
         """
-        if require_cert_and_key:
-            self.tls_private_key = self.read_tls_private_key()
-            self.tls_certificate = self.read_tls_certificate()
-        elif self.tls_certificate_file:
-            # we only need the certificate for the tls_fingerprints. Reload it if we
-            # can, but it's not a fatal error if we can't.
-            try:
-                self.tls_certificate = self.read_tls_certificate()
-            except Exception as e:
-                logger.info(
-                    "Unable to read TLS certificate (%s). Ignoring as no "
-                    "tls listeners enabled.",
-                    e,
-                )
+        self.tls_private_key = self.read_tls_private_key()
+        self.tls_certificate = self.read_tls_certificate()
 
     def generate_config_section(
         self,
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 40eab45549..fdeaa0f37c 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -37,6 +37,7 @@ from synapse.http.servlet import (
 )
 from synapse.logging.context import run_in_background
 from synapse.logging.opentracing import (
+    SynapseTags,
     start_active_span,
     start_active_span_from_request,
     tags,
@@ -314,7 +315,7 @@ class BaseFederationServlet:
                 raise
 
             request_tags = {
-                "request_id": request.get_request_id(),
+                SynapseTags.REQUEST_ID: request.get_request_id(),
                 tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
                 tags.HTTP_METHOD: request.get_method(),
                 tags.HTTP_URL: request.get_redacted_uri(),
@@ -1562,13 +1563,12 @@ def register_servlets(
                 server_name=hs.hostname,
             ).register(resource)
 
-        if hs.config.experimental.spaces_enabled:
-            FederationSpaceSummaryServlet(
-                handler=hs.get_space_summary_handler(),
-                authenticator=authenticator,
-                ratelimiter=ratelimiter,
-                server_name=hs.hostname,
-            ).register(resource)
+        FederationSpaceSummaryServlet(
+            handler=hs.get_space_summary_handler(),
+            authenticator=authenticator,
+            ratelimiter=ratelimiter,
+            server_name=hs.hostname,
+        ).register(resource)
 
     if "openid" in servlet_groups:
         for servletclass in OPENID_SERVLET_CLASSES:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bf11315251..49ed7cabcc 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -577,7 +577,9 @@ class FederationHandler(BaseHandler):
 
         # Fetch the state events from the DB, and check we have the auth events.
         event_map = await self.store.get_events(state_event_ids, allow_rejected=True)
-        auth_events_in_store = await self.store.have_seen_events(auth_event_ids)
+        auth_events_in_store = await self.store.have_seen_events(
+            room_id, auth_event_ids
+        )
 
         # Check for missing events. We handle state and auth event seperately,
         # as we want to pull the state from the DB, but we don't for the auth
@@ -610,7 +612,7 @@ class FederationHandler(BaseHandler):
 
             if missing_auth_events:
                 auth_events_in_store = await self.store.have_seen_events(
-                    missing_auth_events
+                    room_id, missing_auth_events
                 )
                 missing_auth_events.difference_update(auth_events_in_store)
 
@@ -710,7 +712,7 @@ class FederationHandler(BaseHandler):
 
         missing_auth_events = set(auth_event_ids) - fetched_events.keys()
         missing_auth_events.difference_update(
-            await self.store.have_seen_events(missing_auth_events)
+            await self.store.have_seen_events(room_id, missing_auth_events)
         )
         logger.debug("We are also missing %i auth events", len(missing_auth_events))
 
@@ -2475,7 +2477,7 @@ class FederationHandler(BaseHandler):
         #
         # we start by checking if they are in the store, and then try calling /event_auth/.
         if missing_auth:
-            have_events = await self.store.have_seen_events(missing_auth)
+            have_events = await self.store.have_seen_events(event.room_id, missing_auth)
             logger.debug("Events %s are in the store", have_events)
             missing_auth.difference_update(have_events)
 
@@ -2494,7 +2496,7 @@ class FederationHandler(BaseHandler):
                     return context
 
                 seen_remotes = await self.store.have_seen_events(
-                    [e.event_id for e in remote_auth_chain]
+                    event.room_id, [e.event_id for e in remote_auth_chain]
                 )
 
                 for e in remote_auth_chain:
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 31897546a9..3f4f2411fc 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -15,6 +15,9 @@
 """ This module contains base REST classes for constructing REST servlets. """
 
 import logging
+from typing import Iterable, List, Optional, Union, overload
+
+from typing_extensions import Literal
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.util import json_decoder
@@ -107,12 +110,11 @@ def parse_boolean_from_args(args, name, default=None, required=False):
 
 def parse_string(
     request,
-    name,
-    default=None,
-    required=False,
-    allowed_values=None,
-    param_type="string",
-    encoding="ascii",
+    name: Union[bytes, str],
+    default: Optional[str] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: Optional[str] = "ascii",
 ):
     """
     Parse a string parameter from the request query string.
@@ -122,18 +124,17 @@ def parse_string(
 
     Args:
         request: the twisted HTTP request.
-        name (bytes|unicode): the name of the query parameter.
-        default (bytes|unicode|None): value to use if the parameter is absent,
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent,
             defaults to None. Must be bytes if encoding is None.
-        required (bool): whether to raise a 400 SynapseError if the
+        required: whether to raise a 400 SynapseError if the
             parameter is absent, defaults to False.
-        allowed_values (list[bytes|unicode]): List of allowed values for the
+        allowed_values: List of allowed values for the
             string, or None if any value is allowed, defaults to None. Must be
             the same type as name, if given.
-        encoding (str|None): The encoding to decode the string content with.
-
+        encoding : The encoding to decode the string content with.
     Returns:
-        bytes/unicode|None: A string value or the default. Unicode if encoding
+        A string value or the default. Unicode if encoding
         was given, bytes otherwise.
 
     Raises:
@@ -142,45 +143,105 @@ def parse_string(
             is not one of those allowed values.
     """
     return parse_string_from_args(
-        request.args, name, default, required, allowed_values, param_type, encoding
+        request.args, name, default, required, allowed_values, encoding
     )
 
 
-def parse_string_from_args(
-    args,
-    name,
-    default=None,
-    required=False,
-    allowed_values=None,
-    param_type="string",
-    encoding="ascii",
-):
+def _parse_string_value(
+    value: Union[str, bytes],
+    allowed_values: Optional[Iterable[str]],
+    name: str,
+    encoding: Optional[str],
+) -> Union[str, bytes]:
+    if encoding:
+        try:
+            value = value.decode(encoding)
+        except ValueError:
+            raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
+
+    if allowed_values is not None and value not in allowed_values:
+        message = "Query parameter %r must be one of [%s]" % (
+            name,
+            ", ".join(repr(v) for v in allowed_values),
+        )
+        raise SynapseError(400, message)
+    else:
+        return value
+
+
+@overload
+def parse_strings_from_args(
+    args: List[str],
+    name: Union[bytes, str],
+    default: Optional[List[str]] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: Literal[None] = None,
+) -> Optional[List[bytes]]:
+    ...
+
+
+@overload
+def parse_strings_from_args(
+    args: List[str],
+    name: Union[bytes, str],
+    default: Optional[List[str]] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: str = "ascii",
+) -> Optional[List[str]]:
+    ...
+
+
+def parse_strings_from_args(
+    args: List[str],
+    name: Union[bytes, str],
+    default: Optional[List[str]] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: Optional[str] = "ascii",
+) -> Optional[List[Union[bytes, str]]]:
+    """
+    Parse a string parameter from the request query string list.
+
+    If encoding is not None, the content of the query param will be
+    decoded to Unicode using the encoding, otherwise it will be encoded
+
+    Args:
+        args: the twisted HTTP request.args list.
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent,
+            defaults to None. Must be bytes if encoding is None.
+        required : whether to raise a 400 SynapseError if the
+            parameter is absent, defaults to False.
+        allowed_values (list[bytes|unicode]): List of allowed values for the
+            string, or None if any value is allowed, defaults to None. Must be
+            the same type as name, if given.
+        encoding: The encoding to decode the string content with.
+
+    Returns:
+        A string value or the default. Unicode if encoding
+        was given, bytes otherwise.
+
+    Raises:
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present, must be one of a list of allowed values and
+            is not one of those allowed values.
+    """
 
     if not isinstance(name, bytes):
         name = name.encode("ascii")
 
     if name in args:
-        value = args[name][0]
-
-        if encoding:
-            try:
-                value = value.decode(encoding)
-            except ValueError:
-                raise SynapseError(
-                    400, "Query parameter %r must be %s" % (name, encoding)
-                )
-
-        if allowed_values is not None and value not in allowed_values:
-            message = "Query parameter %r must be one of [%s]" % (
-                name,
-                ", ".join(repr(v) for v in allowed_values),
-            )
-            raise SynapseError(400, message)
-        else:
-            return value
+        values = args[name]
+
+        return [
+            _parse_string_value(value, allowed_values, name=name, encoding=encoding)
+            for value in values
+        ]
     else:
         if required:
-            message = "Missing %s query parameter %r" % (param_type, name)
+            message = "Missing string query parameter %r" % (name)
             raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
         else:
 
@@ -190,6 +251,55 @@ def parse_string_from_args(
             return default
 
 
+def parse_string_from_args(
+    args: List[str],
+    name: Union[bytes, str],
+    default: Optional[str] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: Optional[str] = "ascii",
+) -> Optional[Union[bytes, str]]:
+    """
+    Parse the string parameter from the request query string list
+    and return the first result.
+
+    If encoding is not None, the content of the query param will be
+    decoded to Unicode using the encoding, otherwise it will be encoded
+
+    Args:
+        args: the twisted HTTP request.args list.
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent,
+            defaults to None. Must be bytes if encoding is None.
+        required: whether to raise a 400 SynapseError if the
+            parameter is absent, defaults to False.
+        allowed_values: List of allowed values for the
+            string, or None if any value is allowed, defaults to None. Must be
+            the same type as name, if given.
+        encoding: The encoding to decode the string content with.
+
+    Returns:
+        A string value or the default. Unicode if encoding
+        was given, bytes otherwise.
+
+    Raises:
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present, must be one of a list of allowed values and
+            is not one of those allowed values.
+    """
+
+    strings = parse_strings_from_args(
+        args,
+        name,
+        default=[default],
+        required=required,
+        allowed_values=allowed_values,
+        encoding=encoding,
+    )
+
+    return strings[0]
+
+
 def parse_json_value_from_request(request, allow_empty_body=False):
     """Parse a JSON value from the body of a twisted HTTP request.
 
@@ -215,7 +325,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
     try:
         content = json_decoder.decode(content_bytes.decode("utf-8"))
     except Exception as e:
-        logger.warning("Unable to parse JSON: %s", e)
+        logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes)
         raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
 
     return content
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index fba2fa3904..f64845b80c 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -265,6 +265,12 @@ class SynapseTags:
     # Whether the sync response has new data to be returned to the client.
     SYNC_RESULT = "sync.new_data"
 
+    # incoming HTTP request ID  (as written in the logs)
+    REQUEST_ID = "request_id"
+
+    # HTTP request tag (used to distinguish full vs incremental syncs, etc)
+    REQUEST_TAG = "request_tag"
+
 
 # Block everything by default
 # A regex which matches the server_names to expose traces for.
@@ -588,7 +594,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
 
     span = opentracing.tracer.active_span
     carrier = {}  # type: Dict[str, str]
-    opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
+    opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
 
     for key, value in carrier.items():
         headers.addRawHeaders(key, value)
@@ -625,7 +631,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
     span = opentracing.tracer.active_span
 
     carrier = {}  # type: Dict[str, str]
-    opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
+    opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
 
     for key, value in carrier.items():
         headers[key.encode()] = [value.encode()]
@@ -659,7 +665,7 @@ def inject_active_span_text_map(carrier, destination, check_destination=True):
         return
 
     opentracing.tracer.inject(
-        opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+        opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
     )
 
 
@@ -681,7 +687,7 @@ def get_active_span_text_map(destination=None):
 
     carrier = {}  # type: Dict[str, str]
     opentracing.tracer.inject(
-        opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+        opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
     )
 
     return carrier
@@ -696,7 +702,7 @@ def active_span_context_as_string():
     carrier = {}  # type: Dict[str, str]
     if opentracing:
         opentracing.tracer.inject(
-            opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+            opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
         )
     return json_encoder.encode(carrier)
 
@@ -824,7 +830,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
         return
 
     request_tags = {
-        "request_id": request.get_request_id(),
+        SynapseTags.REQUEST_ID: request.get_request_id(),
         tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
         tags.HTTP_METHOD: request.get_method(),
         tags.HTTP_URL: request.get_redacted_uri(),
@@ -833,9 +839,9 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
 
     request_name = request.request_metrics.name
     if extract_context:
-        scope = start_active_span_from_request(request, request_name, tags=request_tags)
+        scope = start_active_span_from_request(request, request_name)
     else:
-        scope = start_active_span(request_name, tags=request_tags)
+        scope = start_active_span(request_name)
 
     with scope:
         try:
@@ -845,4 +851,11 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
             # with JsonResource).
             scope.span.set_operation_name(request.request_metrics.name)
 
-            scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)
+            # set the tags *after* the servlet completes, in case it decided to
+            # prioritise the span (tags will get dropped on unprioritised spans)
+            request_tags[
+                SynapseTags.REQUEST_TAG
+            ] = request.request_metrics.start_context.tag
+
+            for k, v in request_tags.items():
+                scope.span.set_tag(k, v)
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 714caf84c3..0d6d643d35 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -22,7 +22,11 @@ from prometheus_client.core import REGISTRY, Counter, Gauge
 from twisted.internet import defer
 
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
-from synapse.logging.opentracing import noop_context_manager, start_active_span
+from synapse.logging.opentracing import (
+    SynapseTags,
+    noop_context_manager,
+    start_active_span,
+)
 from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
@@ -202,7 +206,9 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
             try:
                 ctx = noop_context_manager()
                 if bg_start_span:
-                    ctx = start_active_span(desc, tags={"request_id": str(context)})
+                    ctx = start_active_span(
+                        desc, tags={SynapseTags.REQUEST_ID: str(context)}
+                    )
                 with ctx:
                     return await maybe_awaitable(func(*args, **kwargs))
             except Exception:
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 70207420a6..26bdead565 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -68,7 +68,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
             if row.entity.startswith("@"):
                 self._device_list_stream_cache.entity_has_changed(row.entity, token)
                 self.get_cached_devices_for_user.invalidate((row.entity,))
-                self._get_cached_user_device.invalidate_many((row.entity,))
+                self._get_cached_user_device.invalidate((row.entity,))
                 self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
 
             else:
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 24dd46113a..2c71af4279 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -137,8 +137,31 @@ class ProtectMediaByID(RestServlet):
 
         logging.info("Protecting local media by ID: %s", media_id)
 
-        # Quarantine this media id
-        await self.store.mark_local_media_as_safe(media_id)
+        # Protect this media id
+        await self.store.mark_local_media_as_safe(media_id, safe=True)
+
+        return 200, {}
+
+
+class UnprotectMediaByID(RestServlet):
+    """Unprotect local media from being quarantined."""
+
+    PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]+)")
+
+    def __init__(self, hs: "HomeServer"):
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+
+    async def on_POST(
+        self, request: SynapseRequest, media_id: str
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        logging.info("Unprotecting local media by ID: %s", media_id)
+
+        # Unprotect this media id
+        await self.store.mark_local_media_as_safe(media_id, safe=False)
 
         return 200, {}
 
@@ -269,6 +292,7 @@ def register_servlets_for_media_repo(hs: "HomeServer", http_server):
     QuarantineMediaByID(hs).register(http_server)
     QuarantineMediaByUser(hs).register(http_server)
     ProtectMediaByID(hs).register(http_server)
+    UnprotectMediaByID(hs).register(http_server)
     ListMediaInRoom(hs).register(http_server)
     DeleteMediaByID(hs).register(http_server)
     DeleteMediaByDateSize(hs).register(http_server)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 51813cccbe..70286b0ff7 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -1060,18 +1060,16 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False):
     RoomRedactEventRestServlet(hs).register(http_server)
     RoomTypingRestServlet(hs).register(http_server)
     RoomEventContextServlet(hs).register(http_server)
-
-    if hs.config.experimental.spaces_enabled:
-        RoomSpaceSummaryRestServlet(hs).register(http_server)
+    RoomSpaceSummaryRestServlet(hs).register(http_server)
+    RoomEventServlet(hs).register(http_server)
+    JoinedRoomsRestServlet(hs).register(http_server)
+    RoomAliasListServlet(hs).register(http_server)
+    SearchRestServlet(hs).register(http_server)
 
     # Some servlets only get registered for the main process.
     if not is_worker:
         RoomCreateRestServlet(hs).register(http_server)
         RoomForgetRestServlet(hs).register(http_server)
-        SearchRestServlet(hs).register(http_server)
-        JoinedRoomsRestServlet(hs).register(http_server)
-        RoomEventServlet(hs).register(http_server)
-        RoomAliasListServlet(hs).register(http_server)
 
 
 def register_deprecated_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index 2c169abbf3..07ea39a8a3 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -16,11 +16,7 @@ import logging
 from http import HTTPStatus
 
 from synapse.api.errors import Codes, SynapseError
-from synapse.http.servlet import (
-    RestServlet,
-    assert_params_in_dict,
-    parse_json_object_from_request,
-)
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
 
 from ._base import client_patterns
 
@@ -42,15 +38,14 @@ class ReportEventRestServlet(RestServlet):
         user_id = requester.user.to_string()
 
         body = parse_json_object_from_request(request)
-        assert_params_in_dict(body, ("reason", "score"))
 
-        if not isinstance(body["reason"], str):
+        if not isinstance(body.get("reason", ""), str):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "Param 'reason' must be a string",
                 Codes.BAD_JSON,
             )
-        if not isinstance(body["score"], int):
+        if not isinstance(body.get("score", 0), int):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "Param 'score' must be an integer",
@@ -61,7 +56,7 @@ class ReportEventRestServlet(RestServlet):
             room_id=room_id,
             event_id=event_id,
             user_id=user_id,
-            reason=body["reason"],
+            reason=body.get("reason"),
             content=body,
             received_ts=self.clock.time_msec(),
         )
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index ecc1f935e2..c57ae5ef15 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -168,10 +168,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         backfilled,
     ):
         self._invalidate_get_event_cache(event_id)
+        self.have_seen_event.invalidate((room_id, event_id))
 
         self.get_latest_event_ids_in_room.invalidate((room_id,))
 
-        self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
+        self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))
 
         if not backfilled:
             self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
@@ -184,8 +185,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self.get_invited_rooms_for_local_user.invalidate((state_key,))
 
         if relates_to:
-            self.get_relations_for_event.invalidate_many((relates_to,))
-            self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
+            self.get_relations_for_event.invalidate((relates_to,))
+            self.get_aggregation_groups_for_event.invalidate((relates_to,))
             self.get_applicable_edit.invalidate((relates_to,))
 
     async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index fd87ba71ab..18f07d96dc 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1282,7 +1282,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
         txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
-        txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,))
+        txn.call_after(self._get_cached_user_device.invalidate, (user_id,))
         txn.call_after(
             self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
         )
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 5845322118..d1237c65cc 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -860,7 +860,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                                   not be deleted.
         """
         txn.call_after(
-            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+            self.get_unread_event_push_actions_by_room_for_user.invalidate,
             (room_id, user_id),
         )
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index fd25c8112d..897fa06639 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1748,9 +1748,9 @@ class PersistEventsStore:
             },
         )
 
-        txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,))
+        txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,))
         txn.call_after(
-            self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
+            self.store.get_aggregation_groups_for_event.invalidate, (parent_id,)
         )
 
         if rel_type == RelationTypes.REPLACE:
@@ -1903,7 +1903,7 @@ class PersistEventsStore:
 
                 for user_id in user_ids:
                     txn.call_after(
-                        self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+                        self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
                         (room_id, user_id),
                     )
 
@@ -1917,7 +1917,7 @@ class PersistEventsStore:
     def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
         # Sad that we have to blow away the cache for the whole room here
         txn.call_after(
-            self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+            self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
             (room_id,),
         )
         txn.execute(
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 6963bbf7f4..403a5ddaba 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -22,6 +22,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Set,
     Tuple,
     overload,
 )
@@ -55,7 +56,7 @@ from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
@@ -1045,32 +1046,74 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {r["event_id"] for r in rows}
 
-    async def have_seen_events(self, event_ids):
+    async def have_seen_events(
+        self, room_id: str, event_ids: Iterable[str]
+    ) -> Set[str]:
         """Given a list of event ids, check if we have already processed them.
 
+        The room_id is only used to structure the cache (so that it can later be
+        invalidated by room_id) - there is no guarantee that the events are actually
+        in the room in question.
+
         Args:
-            event_ids (iterable[str]):
+            room_id: Room we are polling
+            event_ids: events we are looking for
 
         Returns:
             set[str]: The events we have already seen.
         """
+        res = await self._have_seen_events_dict(
+            (room_id, event_id) for event_id in event_ids
+        )
+        return {eid for ((_rid, eid), have_event) in res.items() if have_event}
+
+    @cachedList("have_seen_event", "keys")
+    async def _have_seen_events_dict(
+        self, keys: Iterable[Tuple[str, str]]
+    ) -> Dict[Tuple[str, str], bool]:
+        """Helper for have_seen_events
+
+        Returns:
+             a dict {(room_id, event_id)-> bool}
+        """
         # if the event cache contains the event, obviously we've seen it.
-        results = {x for x in event_ids if self._get_event_cache.contains(x)}
 
-        def have_seen_events_txn(txn, chunk):
-            sql = "SELECT event_id FROM events as e WHERE "
+        cache_results = {
+            (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,))
+        }
+        results = {x: True for x in cache_results}
+
+        def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
+            # we deliberately do *not* query the database for room_id, to make the
+            # query an index-only lookup on `events_event_id_key`.
+            #
+            # We therefore pull the events from the database into a set...
+
+            sql = "SELECT event_id FROM events AS e WHERE "
             clause, args = make_in_list_sql_clause(
-                txn.database_engine, "e.event_id", chunk
+                txn.database_engine, "e.event_id", [eid for (_rid, eid) in chunk]
             )
             txn.execute(sql + clause, args)
-            results.update(row[0] for row in txn)
+            found_events = {eid for eid, in txn}
 
-        for chunk in batch_iter((x for x in event_ids if x not in results), 100):
+            # ... and then we can update the results for each row in the batch
+            results.update({(rid, eid): (eid in found_events) for (rid, eid) in chunk})
+
+        # each batch requires its own index scan, so we make the batches as big as
+        # possible.
+        for chunk in batch_iter((k for k in keys if k not in cache_results), 500):
             await self.db_pool.runInteraction(
                 "have_seen_events", have_seen_events_txn, chunk
             )
+
         return results
 
+    @cached(max_entries=100000, tree=True)
+    async def have_seen_event(self, room_id: str, event_id: str):
+        # this only exists for the benefit of the @cachedList descriptor on
+        # _have_seen_events_dict
+        raise NotImplementedError()
+
     def _get_current_state_event_counts_txn(self, txn, room_id):
         """
         See get_current_state_event_counts.
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index c584868188..2fa945d171 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -143,6 +143,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "created_ts",
                 "quarantined_by",
                 "url_cache",
+                "safe_from_quarantine",
             ),
             allow_none=True,
             desc="get_local_media",
@@ -296,12 +297,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_local_media",
         )
 
-    async def mark_local_media_as_safe(self, media_id: str) -> None:
-        """Mark a local media as safe from quarantining."""
+    async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None:
+        """Mark a local media as safe or unsafe from quarantining."""
         await self.db_pool.simple_update_one(
             table="local_media_repository",
             keyvalues={"media_id": media_id},
-            updatevalues={"safe_from_quarantine": True},
+            updatevalues={"safe_from_quarantine": safe},
             desc="mark_local_media_as_safe",
         )
 
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 8f83748b5e..7fb7780d0f 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -16,14 +16,14 @@ import logging
 from typing import Any, List, Set, Tuple
 
 from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore
+from synapse.storage.databases.main import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.state import StateGroupWorkerStore
 from synapse.types import RoomStreamToken
 
 logger = logging.getLogger(__name__)
 
 
-class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
+class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
     async def purge_history(
         self, room_id: str, token: str, delete_local_events: bool
     ) -> Set[int]:
@@ -203,8 +203,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
             "DELETE FROM event_to_state_groups "
             "WHERE event_id IN (SELECT event_id from events_to_purge)"
         )
-        for event_id, _ in event_rows:
-            txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
 
         # Delete all remote non-state events
         for table in (
@@ -283,6 +281,20 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         # so make sure to keep this actually last.
         txn.execute("DROP TABLE events_to_purge")
 
+        for event_id, should_delete in event_rows:
+            self._invalidate_cache_and_stream(
+                txn, self._get_state_group_for_event, (event_id,)
+            )
+
+            # XXX: This is racy, since have_seen_events could be called between the
+            #    transaction completing and the invalidation running. On the other hand,
+            #    that's no different to calling `have_seen_events` just before the
+            #    event is deleted from the database.
+            if should_delete:
+                self._invalidate_cache_and_stream(
+                    txn, self.have_seen_event, (room_id, event_id)
+                )
+
         logger.info("[purge] done")
 
         return referenced_state_groups
@@ -422,7 +434,11 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         #       index on them. In any case we should be clearing out 'stream' tables
         #       periodically anyway (#5888)
 
-        # TODO: we could probably usefully do a bunch of cache invalidation here
+        # TODO: we could probably usefully do a bunch more cache invalidation here
+
+        # XXX: as with purge_history, this is racy, but no worse than other races
+        #   that already exist.
+        self._invalidate_cache_and_stream(txn, self.have_seen_event, (room_id,))
 
         logger.info("[purge] done")
 
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 3647276acb..edeaacd7a6 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -460,7 +460,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
         self.get_receipts_for_user.invalidate((user_id, receipt_type))
-        self._get_linearized_receipts_for_room.invalidate_many((room_id,))
+        self._get_linearized_receipts_for_room.invalidate((room_id,))
         self.get_last_receipt_event_id_for_user.invalidate(
             (user_id, room_id, receipt_type)
         )
@@ -659,9 +659,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
         txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
         # FIXME: This shouldn't invalidate the whole cache
-        txn.call_after(
-            self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
-        )
+        txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))
 
         self.db_pool.simple_delete_txn(
             txn,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 5f38634f48..0cf450f81d 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1498,7 +1498,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         room_id: str,
         event_id: str,
         user_id: str,
-        reason: str,
+        reason: Optional[str],
         content: JsonDict,
         received_ts: int,
     ) -> None:
diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py
index 44bbb7b1a8..8fd5bfb69b 100644
--- a/synapse/util/batching_queue.py
+++ b/synapse/util/batching_queue.py
@@ -25,10 +25,11 @@ from typing import (
     TypeVar,
 )
 
+from prometheus_client import Gauge
+
 from twisted.internet import defer
 
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
-from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.util import Clock
 
@@ -38,6 +39,24 @@ logger = logging.getLogger(__name__)
 V = TypeVar("V")
 R = TypeVar("R")
 
+number_queued = Gauge(
+    "synapse_util_batching_queue_number_queued",
+    "The number of items waiting in the queue across all keys",
+    labelnames=("name",),
+)
+
+number_in_flight = Gauge(
+    "synapse_util_batching_queue_number_pending",
+    "The number of items across all keys either being processed or waiting in a queue",
+    labelnames=("name",),
+)
+
+number_of_keys = Gauge(
+    "synapse_util_batching_queue_number_of_keys",
+    "The number of distinct keys that have items queued",
+    labelnames=("name",),
+)
+
 
 class BatchingQueue(Generic[V, R]):
     """A queue that batches up work, calling the provided processing function
@@ -48,10 +67,20 @@ class BatchingQueue(Generic[V, R]):
     called, and will keep being called until the queue has been drained (for the
     given key).
 
+    If the processing function raises an exception then the exception is proxied
+    through to the callers waiting on that batch of work.
+
     Note that the return value of `add_to_queue` will be the return value of the
     processing function that processed the given item. This means that the
     returned value will likely include data for other items that were in the
     batch.
+
+    Args:
+        name: A name for the queue, used for logging contexts and metrics.
+            This must be unique, otherwise the metrics will be wrong.
+        clock: The clock to use to schedule work.
+        process_batch_callback: The callback to to be run to process a batch of
+            work.
     """
 
     def __init__(
@@ -73,19 +102,15 @@ class BatchingQueue(Generic[V, R]):
         # The function to call with batches of values.
         self._process_batch_callback = process_batch_callback
 
-        LaterGauge(
-            "synapse_util_batching_queue_number_queued",
-            "The number of items waiting in the queue across all keys",
-            labels=("name",),
-            caller=lambda: sum(len(v) for v in self._next_values.values()),
+        number_queued.labels(self._name).set_function(
+            lambda: sum(len(q) for q in self._next_values.values())
         )
 
-        LaterGauge(
-            "synapse_util_batching_queue_number_of_keys",
-            "The number of distinct keys that have items queued",
-            labels=("name",),
-            caller=lambda: len(self._next_values),
-        )
+        number_of_keys.labels(self._name).set_function(lambda: len(self._next_values))
+
+        self._number_in_flight_metric = number_in_flight.labels(
+            self._name
+        )  # type: Gauge
 
     async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
         """Adds the value to the queue with the given key, returning the result
@@ -107,17 +132,18 @@ class BatchingQueue(Generic[V, R]):
         if key not in self._processing_keys:
             run_as_background_process(self._name, self._process_queue, key)
 
-        return await make_deferred_yieldable(d)
+        with self._number_in_flight_metric.track_inprogress():
+            return await make_deferred_yieldable(d)
 
     async def _process_queue(self, key: Hashable) -> None:
         """A background task to repeatedly pull things off the queue for the
         given key and call the `self._process_batch_callback` with the values.
         """
 
-        try:
-            if key in self._processing_keys:
-                return
+        if key in self._processing_keys:
+            return
 
+        try:
             self._processing_keys.add(key)
 
             while True:
@@ -137,16 +163,16 @@ class BatchingQueue(Generic[V, R]):
                     values = [value for value, _ in next_values]
                     results = await self._process_batch_callback(values)
 
-                    for _, deferred in next_values:
-                        with PreserveLoggingContext():
+                    with PreserveLoggingContext():
+                        for _, deferred in next_values:
                             deferred.callback(results)
 
                 except Exception as e:
-                    for _, deferred in next_values:
-                        if deferred.called:
-                            continue
+                    with PreserveLoggingContext():
+                        for _, deferred in next_values:
+                            if deferred.called:
+                                continue
 
-                        with PreserveLoggingContext():
                             deferred.errback(e)
 
         finally:
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 371e7e4dd0..1044139119 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -16,16 +16,7 @@
 
 import enum
 import threading
-from typing import (
-    Callable,
-    Generic,
-    Iterable,
-    MutableMapping,
-    Optional,
-    TypeVar,
-    Union,
-    cast,
-)
+from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, Union
 
 from prometheus_client import Gauge
 
@@ -91,7 +82,7 @@ class DeferredCache(Generic[KT, VT]):
         # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
         self._pending_deferred_cache = (
             cache_type()
-        )  # type: MutableMapping[KT, CacheEntry]
+        )  # type: Union[TreeCache, MutableMapping[KT, CacheEntry]]
 
         def metrics_cb():
             cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
@@ -287,8 +278,17 @@ class DeferredCache(Generic[KT, VT]):
         self.cache.set(key, value, callbacks=callbacks)
 
     def invalidate(self, key):
+        """Delete a key, or tree of entries
+
+        If the cache is backed by a regular dict, then "key" must be of
+        the right type for this cache
+
+        If the cache is backed by a TreeCache, then "key" must be a tuple, but
+        may be of lower cardinality than the TreeCache - in which case the whole
+        subtree is deleted.
+        """
         self.check_thread()
-        self.cache.pop(key, None)
+        self.cache.del_multi(key)
 
         # if we have a pending lookup for this key, remove it from the
         # _pending_deferred_cache, which will (a) stop it being returned
@@ -299,20 +299,10 @@ class DeferredCache(Generic[KT, VT]):
         # run the invalidation callbacks now, rather than waiting for the
         # deferred to resolve.
         if entry:
-            entry.invalidate()
-
-    def invalidate_many(self, key: KT):
-        self.check_thread()
-        if not isinstance(key, tuple):
-            raise TypeError("The cache key must be a tuple not %r" % (type(key),))
-        key = cast(KT, key)
-        self.cache.del_multi(key)
-
-        # if we have a pending lookup for this key, remove it from the
-        # _pending_deferred_cache, as above
-        entry_dict = self._pending_deferred_cache.pop(key, None)
-        if entry_dict is not None:
-            for entry in iterate_tree_cache_entry(entry_dict):
+            # _pending_deferred_cache.pop should either return a CacheEntry, or, in the
+            # case of a TreeCache, a dict of keys to cache entries. Either way calling
+            # iterate_tree_cache_entry on it will do the right thing.
+            for entry in iterate_tree_cache_entry(entry):
                 entry.invalidate()
 
     def invalidate_all(self):
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 2ac24a2f25..d77e8edeea 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -48,7 +48,6 @@ F = TypeVar("F", bound=Callable[..., Any])
 class _CachedFunction(Generic[F]):
     invalidate = None  # type: Any
     invalidate_all = None  # type: Any
-    invalidate_many = None  # type: Any
     prefill = None  # type: Any
     cache = None  # type: Any
     num_args = None  # type: Any
@@ -262,6 +261,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
     ):
         super().__init__(orig, num_args=num_args, cache_context=cache_context)
 
+        if tree and self.num_args < 2:
+            raise RuntimeError(
+                "tree=True is nonsensical for cached functions with a single parameter"
+            )
+
         self.max_entries = max_entries
         self.tree = tree
         self.iterable = iterable
@@ -302,11 +306,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         wrapped = cast(_CachedFunction, _wrapped)
 
         if self.num_args == 1:
+            assert not self.tree
             wrapped.invalidate = lambda key: cache.invalidate(key[0])
             wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
         else:
             wrapped.invalidate = cache.invalidate
-            wrapped.invalidate_many = cache.invalidate_many
             wrapped.prefill = cache.prefill
 
         wrapped.invalidate_all = cache.invalidate_all
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 54df407ff7..d89e9d9b1d 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -152,7 +152,6 @@ class LruCache(Generic[KT, VT]):
     """
     Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
 
-    Supports del_multi only if cache_type=TreeCache
     If cache_type=TreeCache, all keys must be tuples.
     """
 
@@ -393,10 +392,16 @@ class LruCache(Generic[KT, VT]):
 
         @synchronized
         def cache_del_multi(key: KT) -> None:
+            """Delete an entry, or tree of entries
+
+            If the LruCache is backed by a regular dict, then "key" must be of
+            the right type for this cache
+
+            If the LruCache is backed by a TreeCache, then "key" must be a tuple, but
+            may be of lower cardinality than the TreeCache - in which case the whole
+            subtree is deleted.
             """
-            This will only work if constructed with cache_type=TreeCache
-            """
-            popped = cache.pop(key)
+            popped = cache.pop(key, None)
             if popped is None:
                 return
             # for each deleted node, we now need to remove it from the linked list
@@ -430,11 +435,10 @@ class LruCache(Generic[KT, VT]):
         self.set = cache_set
         self.setdefault = cache_set_default
         self.pop = cache_pop
+        self.del_multi = cache_del_multi
         # `invalidate` is exposed for consistency with DeferredCache, so that it can be
         # invalidated by the cache invalidation replication stream.
-        self.invalidate = cache_pop
-        if cache_type is TreeCache:
-            self.del_multi = cache_del_multi
+        self.invalidate = cache_del_multi
         self.len = synchronized(cache_len)
         self.contains = cache_contains
         self.clear = cache_clear
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 73502a8b06..a6df81ebff 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -89,6 +89,9 @@ class TreeCache:
             value. If the key is partial, the TreeCacheNode corresponding to the part
             of the tree that was removed.
         """
+        if not isinstance(key, tuple):
+            raise TypeError("The cache key must be a tuple not %r" % (type(key),))
+
         # a list of the nodes we have touched on the way down the tree
         nodes = []