summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/room_versions.py2
-rw-r--r--synapse/app/_base.py5
-rw-r--r--synapse/config/experimental.py23
-rw-r--r--synapse/config/tls.py22
-rw-r--r--synapse/federation/transport/server.py13
-rw-r--r--synapse/replication/slave/storage/devices.py2
-rw-r--r--synapse/rest/admin/media.py28
-rw-r--r--synapse/rest/client/v1/room.py4
-rw-r--r--synapse/storage/databases/main/cache.py6
-rw-r--r--synapse/storage/databases/main/devices.py2
-rw-r--r--synapse/storage/databases/main/event_push_actions.py2
-rw-r--r--synapse/storage/databases/main/events.py8
-rw-r--r--synapse/storage/databases/main/media_repository.py7
-rw-r--r--synapse/storage/databases/main/receipts.py6
-rw-r--r--synapse/util/batching_queue.py70
-rw-r--r--synapse/util/caches/deferred_cache.py42
-rw-r--r--synapse/util/caches/descriptors.py8
-rw-r--r--synapse/util/caches/lrucache.py18
-rw-r--r--synapse/util/caches/treecache.py3
19 files changed, 138 insertions, 133 deletions
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index c9f9596ada..373a4669d0 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -181,6 +181,6 @@ KNOWN_ROOM_VERSIONS = {
         RoomVersions.V5,
         RoomVersions.V6,
         RoomVersions.MSC2176,
+        RoomVersions.MSC3083,
     )
-    # Note that we do not include MSC3083 here unless it is enabled in the config.
 }  # type: Dict[str, RoomVersion]
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 59918d789e..1329af2e2b 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -261,13 +261,10 @@ def refresh_certificate(hs):
     Refresh the TLS certificates that Synapse is using by re-reading them from
     disk and updating the TLS context factories to use them.
     """
-
     if not hs.config.has_tls_listener():
-        # attempt to reload the certs for the good of the tls_fingerprints
-        hs.config.read_certificate_from_disk(require_cert_and_key=False)
         return
 
-    hs.config.read_certificate_from_disk(require_cert_and_key=True)
+    hs.config.read_certificate_from_disk()
     hs.tls_server_context_factory = context_factory.ServerContextFactory(hs.config)
 
     if hs._listening_services:
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index cc67377f0f..6ebce4b2f7 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
 from synapse.config._base import Config
 from synapse.types import JsonDict
 
@@ -28,27 +27,5 @@ class ExperimentalConfig(Config):
         # MSC2858 (multiple SSO identity providers)
         self.msc2858_enabled = experimental.get("msc2858_enabled", False)  # type: bool
 
-        # Spaces (MSC1772, MSC2946, MSC3083, etc)
-        self.spaces_enabled = experimental.get("spaces_enabled", True)  # type: bool
-        if self.spaces_enabled:
-            KNOWN_ROOM_VERSIONS[RoomVersions.MSC3083.identifier] = RoomVersions.MSC3083
-
         # MSC3026 (busy presence state)
         self.msc3026_enabled = experimental.get("msc3026_enabled", False)  # type: bool
-
-    def generate_config_section(self, **kwargs):
-        return """\
-        # Enable experimental features in Synapse.
-        #
-        # Experimental features might break or be removed without a deprecation
-        # period.
-        #
-        experimental_features:
-          # Support for Spaces (MSC1772), it enables the following:
-          #
-          # * The Spaces Summary API (MSC2946).
-          # * Restricting room membership based on space membership (MSC3083).
-          #
-          # Uncomment to disable support for Spaces.
-          #spaces_enabled: false
-        """
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 26f1150ca5..0e9bba53c9 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -215,28 +215,12 @@ class TlsConfig(Config):
         days_remaining = (expires_on - now).days
         return days_remaining
 
-    def read_certificate_from_disk(self, require_cert_and_key: bool):
+    def read_certificate_from_disk(self):
         """
         Read the certificates and private key from disk.
-
-        Args:
-            require_cert_and_key: set to True to throw an error if the certificate
-                and key file are not given
         """
-        if require_cert_and_key:
-            self.tls_private_key = self.read_tls_private_key()
-            self.tls_certificate = self.read_tls_certificate()
-        elif self.tls_certificate_file:
-            # we only need the certificate for the tls_fingerprints. Reload it if we
-            # can, but it's not a fatal error if we can't.
-            try:
-                self.tls_certificate = self.read_tls_certificate()
-            except Exception as e:
-                logger.info(
-                    "Unable to read TLS certificate (%s). Ignoring as no "
-                    "tls listeners enabled.",
-                    e,
-                )
+        self.tls_private_key = self.read_tls_private_key()
+        self.tls_certificate = self.read_tls_certificate()
 
     def generate_config_section(
         self,
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 40eab45549..59e0a434dc 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1562,13 +1562,12 @@ def register_servlets(
                 server_name=hs.hostname,
             ).register(resource)
 
-        if hs.config.experimental.spaces_enabled:
-            FederationSpaceSummaryServlet(
-                handler=hs.get_space_summary_handler(),
-                authenticator=authenticator,
-                ratelimiter=ratelimiter,
-                server_name=hs.hostname,
-            ).register(resource)
+        FederationSpaceSummaryServlet(
+            handler=hs.get_space_summary_handler(),
+            authenticator=authenticator,
+            ratelimiter=ratelimiter,
+            server_name=hs.hostname,
+        ).register(resource)
 
     if "openid" in servlet_groups:
         for servletclass in OPENID_SERVLET_CLASSES:
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 70207420a6..26bdead565 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -68,7 +68,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
             if row.entity.startswith("@"):
                 self._device_list_stream_cache.entity_has_changed(row.entity, token)
                 self.get_cached_devices_for_user.invalidate((row.entity,))
-                self._get_cached_user_device.invalidate_many((row.entity,))
+                self._get_cached_user_device.invalidate((row.entity,))
                 self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
 
             else:
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 24dd46113a..2c71af4279 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -137,8 +137,31 @@ class ProtectMediaByID(RestServlet):
 
         logging.info("Protecting local media by ID: %s", media_id)
 
-        # Quarantine this media id
-        await self.store.mark_local_media_as_safe(media_id)
+        # Protect this media id
+        await self.store.mark_local_media_as_safe(media_id, safe=True)
+
+        return 200, {}
+
+
+class UnprotectMediaByID(RestServlet):
+    """Unprotect local media from being quarantined."""
+
+    PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]+)")
+
+    def __init__(self, hs: "HomeServer"):
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+
+    async def on_POST(
+        self, request: SynapseRequest, media_id: str
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        logging.info("Unprotecting local media by ID: %s", media_id)
+
+        # Unprotect this media id
+        await self.store.mark_local_media_as_safe(media_id, safe=False)
 
         return 200, {}
 
@@ -269,6 +292,7 @@ def register_servlets_for_media_repo(hs: "HomeServer", http_server):
     QuarantineMediaByID(hs).register(http_server)
     QuarantineMediaByUser(hs).register(http_server)
     ProtectMediaByID(hs).register(http_server)
+    UnprotectMediaByID(hs).register(http_server)
     ListMediaInRoom(hs).register(http_server)
     DeleteMediaByID(hs).register(http_server)
     DeleteMediaByDateSize(hs).register(http_server)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 51813cccbe..d6d55893af 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -1060,9 +1060,7 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False):
     RoomRedactEventRestServlet(hs).register(http_server)
     RoomTypingRestServlet(hs).register(http_server)
     RoomEventContextServlet(hs).register(http_server)
-
-    if hs.config.experimental.spaces_enabled:
-        RoomSpaceSummaryRestServlet(hs).register(http_server)
+    RoomSpaceSummaryRestServlet(hs).register(http_server)
 
     # Some servlets only get registered for the main process.
     if not is_worker:
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index ecc1f935e2..f7872501a0 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -171,7 +171,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         self.get_latest_event_ids_in_room.invalidate((room_id,))
 
-        self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
+        self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))
 
         if not backfilled:
             self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
@@ -184,8 +184,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self.get_invited_rooms_for_local_user.invalidate((state_key,))
 
         if relates_to:
-            self.get_relations_for_event.invalidate_many((relates_to,))
-            self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
+            self.get_relations_for_event.invalidate((relates_to,))
+            self.get_aggregation_groups_for_event.invalidate((relates_to,))
             self.get_applicable_edit.invalidate((relates_to,))
 
     async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index fd87ba71ab..18f07d96dc 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1282,7 +1282,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
         txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
-        txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,))
+        txn.call_after(self._get_cached_user_device.invalidate, (user_id,))
         txn.call_after(
             self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
         )
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 5845322118..d1237c65cc 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -860,7 +860,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                                   not be deleted.
         """
         txn.call_after(
-            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+            self.get_unread_event_push_actions_by_room_for_user.invalidate,
             (room_id, user_id),
         )
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index fd25c8112d..897fa06639 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1748,9 +1748,9 @@ class PersistEventsStore:
             },
         )
 
-        txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,))
+        txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,))
         txn.call_after(
-            self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
+            self.store.get_aggregation_groups_for_event.invalidate, (parent_id,)
         )
 
         if rel_type == RelationTypes.REPLACE:
@@ -1903,7 +1903,7 @@ class PersistEventsStore:
 
                 for user_id in user_ids:
                     txn.call_after(
-                        self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+                        self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
                         (room_id, user_id),
                     )
 
@@ -1917,7 +1917,7 @@ class PersistEventsStore:
     def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
         # Sad that we have to blow away the cache for the whole room here
         txn.call_after(
-            self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+            self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
             (room_id,),
         )
         txn.execute(
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index c584868188..2fa945d171 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -143,6 +143,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "created_ts",
                 "quarantined_by",
                 "url_cache",
+                "safe_from_quarantine",
             ),
             allow_none=True,
             desc="get_local_media",
@@ -296,12 +297,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_local_media",
         )
 
-    async def mark_local_media_as_safe(self, media_id: str) -> None:
-        """Mark a local media as safe from quarantining."""
+    async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None:
+        """Mark a local media as safe or unsafe from quarantining."""
         await self.db_pool.simple_update_one(
             table="local_media_repository",
             keyvalues={"media_id": media_id},
-            updatevalues={"safe_from_quarantine": True},
+            updatevalues={"safe_from_quarantine": safe},
             desc="mark_local_media_as_safe",
         )
 
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 3647276acb..edeaacd7a6 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -460,7 +460,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
         self.get_receipts_for_user.invalidate((user_id, receipt_type))
-        self._get_linearized_receipts_for_room.invalidate_many((room_id,))
+        self._get_linearized_receipts_for_room.invalidate((room_id,))
         self.get_last_receipt_event_id_for_user.invalidate(
             (user_id, room_id, receipt_type)
         )
@@ -659,9 +659,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
         txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
         # FIXME: This shouldn't invalidate the whole cache
-        txn.call_after(
-            self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
-        )
+        txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))
 
         self.db_pool.simple_delete_txn(
             txn,
diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py
index 44bbb7b1a8..8fd5bfb69b 100644
--- a/synapse/util/batching_queue.py
+++ b/synapse/util/batching_queue.py
@@ -25,10 +25,11 @@ from typing import (
     TypeVar,
 )
 
+from prometheus_client import Gauge
+
 from twisted.internet import defer
 
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
-from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.util import Clock
 
@@ -38,6 +39,24 @@ logger = logging.getLogger(__name__)
 V = TypeVar("V")
 R = TypeVar("R")
 
+number_queued = Gauge(
+    "synapse_util_batching_queue_number_queued",
+    "The number of items waiting in the queue across all keys",
+    labelnames=("name",),
+)
+
+number_in_flight = Gauge(
+    "synapse_util_batching_queue_number_pending",
+    "The number of items across all keys either being processed or waiting in a queue",
+    labelnames=("name",),
+)
+
+number_of_keys = Gauge(
+    "synapse_util_batching_queue_number_of_keys",
+    "The number of distinct keys that have items queued",
+    labelnames=("name",),
+)
+
 
 class BatchingQueue(Generic[V, R]):
     """A queue that batches up work, calling the provided processing function
@@ -48,10 +67,20 @@ class BatchingQueue(Generic[V, R]):
     called, and will keep being called until the queue has been drained (for the
     given key).
 
+    If the processing function raises an exception then the exception is proxied
+    through to the callers waiting on that batch of work.
+
     Note that the return value of `add_to_queue` will be the return value of the
     processing function that processed the given item. This means that the
     returned value will likely include data for other items that were in the
     batch.
+
+    Args:
+        name: A name for the queue, used for logging contexts and metrics.
+            This must be unique, otherwise the metrics will be wrong.
+        clock: The clock to use to schedule work.
+        process_batch_callback: The callback to to be run to process a batch of
+            work.
     """
 
     def __init__(
@@ -73,19 +102,15 @@ class BatchingQueue(Generic[V, R]):
         # The function to call with batches of values.
         self._process_batch_callback = process_batch_callback
 
-        LaterGauge(
-            "synapse_util_batching_queue_number_queued",
-            "The number of items waiting in the queue across all keys",
-            labels=("name",),
-            caller=lambda: sum(len(v) for v in self._next_values.values()),
+        number_queued.labels(self._name).set_function(
+            lambda: sum(len(q) for q in self._next_values.values())
         )
 
-        LaterGauge(
-            "synapse_util_batching_queue_number_of_keys",
-            "The number of distinct keys that have items queued",
-            labels=("name",),
-            caller=lambda: len(self._next_values),
-        )
+        number_of_keys.labels(self._name).set_function(lambda: len(self._next_values))
+
+        self._number_in_flight_metric = number_in_flight.labels(
+            self._name
+        )  # type: Gauge
 
     async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
         """Adds the value to the queue with the given key, returning the result
@@ -107,17 +132,18 @@ class BatchingQueue(Generic[V, R]):
         if key not in self._processing_keys:
             run_as_background_process(self._name, self._process_queue, key)
 
-        return await make_deferred_yieldable(d)
+        with self._number_in_flight_metric.track_inprogress():
+            return await make_deferred_yieldable(d)
 
     async def _process_queue(self, key: Hashable) -> None:
         """A background task to repeatedly pull things off the queue for the
         given key and call the `self._process_batch_callback` with the values.
         """
 
-        try:
-            if key in self._processing_keys:
-                return
+        if key in self._processing_keys:
+            return
 
+        try:
             self._processing_keys.add(key)
 
             while True:
@@ -137,16 +163,16 @@ class BatchingQueue(Generic[V, R]):
                     values = [value for value, _ in next_values]
                     results = await self._process_batch_callback(values)
 
-                    for _, deferred in next_values:
-                        with PreserveLoggingContext():
+                    with PreserveLoggingContext():
+                        for _, deferred in next_values:
                             deferred.callback(results)
 
                 except Exception as e:
-                    for _, deferred in next_values:
-                        if deferred.called:
-                            continue
+                    with PreserveLoggingContext():
+                        for _, deferred in next_values:
+                            if deferred.called:
+                                continue
 
-                        with PreserveLoggingContext():
                             deferred.errback(e)
 
         finally:
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 371e7e4dd0..1044139119 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -16,16 +16,7 @@
 
 import enum
 import threading
-from typing import (
-    Callable,
-    Generic,
-    Iterable,
-    MutableMapping,
-    Optional,
-    TypeVar,
-    Union,
-    cast,
-)
+from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, Union
 
 from prometheus_client import Gauge
 
@@ -91,7 +82,7 @@ class DeferredCache(Generic[KT, VT]):
         # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
         self._pending_deferred_cache = (
             cache_type()
-        )  # type: MutableMapping[KT, CacheEntry]
+        )  # type: Union[TreeCache, MutableMapping[KT, CacheEntry]]
 
         def metrics_cb():
             cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
@@ -287,8 +278,17 @@ class DeferredCache(Generic[KT, VT]):
         self.cache.set(key, value, callbacks=callbacks)
 
     def invalidate(self, key):
+        """Delete a key, or tree of entries
+
+        If the cache is backed by a regular dict, then "key" must be of
+        the right type for this cache
+
+        If the cache is backed by a TreeCache, then "key" must be a tuple, but
+        may be of lower cardinality than the TreeCache - in which case the whole
+        subtree is deleted.
+        """
         self.check_thread()
-        self.cache.pop(key, None)
+        self.cache.del_multi(key)
 
         # if we have a pending lookup for this key, remove it from the
         # _pending_deferred_cache, which will (a) stop it being returned
@@ -299,20 +299,10 @@ class DeferredCache(Generic[KT, VT]):
         # run the invalidation callbacks now, rather than waiting for the
         # deferred to resolve.
         if entry:
-            entry.invalidate()
-
-    def invalidate_many(self, key: KT):
-        self.check_thread()
-        if not isinstance(key, tuple):
-            raise TypeError("The cache key must be a tuple not %r" % (type(key),))
-        key = cast(KT, key)
-        self.cache.del_multi(key)
-
-        # if we have a pending lookup for this key, remove it from the
-        # _pending_deferred_cache, as above
-        entry_dict = self._pending_deferred_cache.pop(key, None)
-        if entry_dict is not None:
-            for entry in iterate_tree_cache_entry(entry_dict):
+            # _pending_deferred_cache.pop should either return a CacheEntry, or, in the
+            # case of a TreeCache, a dict of keys to cache entries. Either way calling
+            # iterate_tree_cache_entry on it will do the right thing.
+            for entry in iterate_tree_cache_entry(entry):
                 entry.invalidate()
 
     def invalidate_all(self):
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 2ac24a2f25..d77e8edeea 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -48,7 +48,6 @@ F = TypeVar("F", bound=Callable[..., Any])
 class _CachedFunction(Generic[F]):
     invalidate = None  # type: Any
     invalidate_all = None  # type: Any
-    invalidate_many = None  # type: Any
     prefill = None  # type: Any
     cache = None  # type: Any
     num_args = None  # type: Any
@@ -262,6 +261,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
     ):
         super().__init__(orig, num_args=num_args, cache_context=cache_context)
 
+        if tree and self.num_args < 2:
+            raise RuntimeError(
+                "tree=True is nonsensical for cached functions with a single parameter"
+            )
+
         self.max_entries = max_entries
         self.tree = tree
         self.iterable = iterable
@@ -302,11 +306,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         wrapped = cast(_CachedFunction, _wrapped)
 
         if self.num_args == 1:
+            assert not self.tree
             wrapped.invalidate = lambda key: cache.invalidate(key[0])
             wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
         else:
             wrapped.invalidate = cache.invalidate
-            wrapped.invalidate_many = cache.invalidate_many
             wrapped.prefill = cache.prefill
 
         wrapped.invalidate_all = cache.invalidate_all
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 54df407ff7..d89e9d9b1d 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -152,7 +152,6 @@ class LruCache(Generic[KT, VT]):
     """
     Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
 
-    Supports del_multi only if cache_type=TreeCache
     If cache_type=TreeCache, all keys must be tuples.
     """
 
@@ -393,10 +392,16 @@ class LruCache(Generic[KT, VT]):
 
         @synchronized
         def cache_del_multi(key: KT) -> None:
+            """Delete an entry, or tree of entries
+
+            If the LruCache is backed by a regular dict, then "key" must be of
+            the right type for this cache
+
+            If the LruCache is backed by a TreeCache, then "key" must be a tuple, but
+            may be of lower cardinality than the TreeCache - in which case the whole
+            subtree is deleted.
             """
-            This will only work if constructed with cache_type=TreeCache
-            """
-            popped = cache.pop(key)
+            popped = cache.pop(key, None)
             if popped is None:
                 return
             # for each deleted node, we now need to remove it from the linked list
@@ -430,11 +435,10 @@ class LruCache(Generic[KT, VT]):
         self.set = cache_set
         self.setdefault = cache_set_default
         self.pop = cache_pop
+        self.del_multi = cache_del_multi
         # `invalidate` is exposed for consistency with DeferredCache, so that it can be
         # invalidated by the cache invalidation replication stream.
-        self.invalidate = cache_pop
-        if cache_type is TreeCache:
-            self.del_multi = cache_del_multi
+        self.invalidate = cache_del_multi
         self.len = synchronized(cache_len)
         self.contains = cache_contains
         self.clear = cache_clear
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 73502a8b06..a6df81ebff 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -89,6 +89,9 @@ class TreeCache:
             value. If the key is partial, the TreeCacheNode corresponding to the part
             of the tree that was removed.
         """
+        if not isinstance(key, tuple):
+            raise TypeError("The cache key must be a tuple not %r" % (type(key),))
+
         # a list of the nodes we have touched on the way down the tree
         nodes = []