summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/caches/deferred_cache.py346
-rw-r--r--synapse/util/caches/descriptors.py115
-rw-r--r--synapse/util/caches/dictionary_cache.py218
-rw-r--r--synapse/util/caches/lrucache.py107
-rw-r--r--synapse/util/caches/treecache.py41
-rw-r--r--synapse/util/ratelimitutils.py117
6 files changed, 748 insertions, 196 deletions
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 1d6ec22191..6425f851ea 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -14,15 +14,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
 import enum
 import threading
 from typing import (
     Callable,
+    Collection,
+    Dict,
     Generic,
-    Iterable,
     MutableMapping,
     Optional,
+    Set,
     Sized,
+    Tuple,
     TypeVar,
     Union,
     cast,
@@ -31,7 +35,6 @@ from typing import (
 from prometheus_client import Gauge
 
 from twisted.internet import defer
-from twisted.python import failure
 from twisted.python.failure import Failure
 
 from synapse.util.async_helpers import ObservableDeferred
@@ -94,7 +97,7 @@ class DeferredCache(Generic[KT, VT]):
 
         # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
         self._pending_deferred_cache: Union[
-            TreeCache, "MutableMapping[KT, CacheEntry]"
+            TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]"
         ] = cache_type()
 
         def metrics_cb() -> None:
@@ -159,15 +162,16 @@ class DeferredCache(Generic[KT, VT]):
         Raises:
             KeyError if the key is not found in the cache
         """
-        callbacks = [callback] if callback else []
         val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
         if val is not _Sentinel.sentinel:
-            val.callbacks.update(callbacks)
+            val.add_invalidation_callback(key, callback)
             if update_metrics:
                 m = self.cache.metrics
                 assert m  # we always have a name, so should always have metrics
                 m.inc_hits()
-            return val.deferred.observe()
+            return val.deferred(key)
+
+        callbacks = (callback,) if callback else ()
 
         val2 = self.cache.get(
             key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
@@ -177,6 +181,73 @@ class DeferredCache(Generic[KT, VT]):
         else:
             return defer.succeed(val2)
 
+    def get_bulk(
+        self,
+        keys: Collection[KT],
+        callback: Optional[Callable[[], None]] = None,
+    ) -> Tuple[Dict[KT, VT], Optional["defer.Deferred[Dict[KT, VT]]"], Collection[KT]]:
+        """Bulk lookup of items in the cache.
+
+        Returns:
+            A 3-tuple of:
+                1. a dict of key/value of items already cached;
+                2. a deferred that resolves to a dict of key/value of items
+                   we're already fetching; and
+                3. a collection of keys that don't appear in the previous two.
+        """
+
+        # The cached results
+        cached = {}
+
+        # List of pending deferreds
+        pending = []
+
+        # Dict that gets filled out when the pending deferreds complete
+        pending_results = {}
+
+        # List of keys that aren't in either cache
+        missing = []
+
+        callbacks = (callback,) if callback else ()
+
+        for key in keys:
+            # Check if its in the main cache.
+            immediate_value = self.cache.get(
+                key,
+                _Sentinel.sentinel,
+                callbacks=callbacks,
+            )
+            if immediate_value is not _Sentinel.sentinel:
+                cached[key] = immediate_value
+                continue
+
+            # Check if its in the pending cache
+            pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
+            if pending_value is not _Sentinel.sentinel:
+                pending_value.add_invalidation_callback(key, callback)
+
+                def completed_cb(value: VT, key: KT) -> VT:
+                    pending_results[key] = value
+                    return value
+
+                # Add a callback to fill out `pending_results` when that completes
+                d = pending_value.deferred(key).addCallback(completed_cb, key)
+                pending.append(d)
+                continue
+
+            # Not in either cache
+            missing.append(key)
+
+        # If we've got pending deferreds, squash them into a single one that
+        # returns `pending_results`.
+        pending_deferred = None
+        if pending:
+            pending_deferred = defer.gatherResults(
+                pending, consumeErrors=True
+            ).addCallback(lambda _: pending_results)
+
+        return (cached, pending_deferred, missing)
+
     def get_immediate(
         self, key: KT, default: T, update_metrics: bool = True
     ) -> Union[VT, T]:
@@ -218,84 +289,89 @@ class DeferredCache(Generic[KT, VT]):
             value: a deferred which will complete with a result to add to the cache
             callback: An optional callback to be called when the entry is invalidated
         """
-        if not isinstance(value, defer.Deferred):
-            raise TypeError("not a Deferred")
-
-        callbacks = [callback] if callback else []
         self.check_thread()
 
-        existing_entry = self._pending_deferred_cache.pop(key, None)
-        if existing_entry:
-            existing_entry.invalidate()
+        self._pending_deferred_cache.pop(key, None)
 
         # XXX: why don't we invalidate the entry in `self.cache` yet?
 
-        # we can save a whole load of effort if the deferred is ready.
-        if value.called:
-            result = value.result
-            if not isinstance(result, failure.Failure):
-                self.cache.set(key, cast(VT, result), callbacks)
-            return value
-
         # otherwise, we'll add an entry to the _pending_deferred_cache for now,
         # and add callbacks to add it to the cache properly later.
+        entry = CacheEntrySingle[KT, VT](value)
+        entry.add_invalidation_callback(key, callback)
+        self._pending_deferred_cache[key] = entry
+        deferred = entry.deferred(key).addCallbacks(
+            self._completed_callback,
+            self._error_callback,
+            callbackArgs=(entry, key),
+            errbackArgs=(entry, key),
+        )
 
-        observable = ObservableDeferred(value, consumeErrors=True)
-        observer = observable.observe()
-        entry = CacheEntry(deferred=observable, callbacks=callbacks)
+        # we return a new Deferred which will be called before any subsequent observers.
+        return deferred
 
-        self._pending_deferred_cache[key] = entry
+    def start_bulk_input(
+        self,
+        keys: Collection[KT],
+        callback: Optional[Callable[[], None]] = None,
+    ) -> "CacheMultipleEntries[KT, VT]":
+        """Bulk set API for use when fetching multiple keys at once from the DB.
 
-        def compare_and_pop() -> bool:
-            """Check if our entry is still the one in _pending_deferred_cache, and
-            if so, pop it.
-
-            Returns true if the entries matched.
-            """
-            existing_entry = self._pending_deferred_cache.pop(key, None)
-            if existing_entry is entry:
-                return True
-
-            # oops, the _pending_deferred_cache has been updated since
-            # we started our query, so we are out of date.
-            #
-            # Better put back whatever we took out. (We do it this way
-            # round, rather than peeking into the _pending_deferred_cache
-            # and then removing on a match, to make the common case faster)
-            if existing_entry is not None:
-                self._pending_deferred_cache[key] = existing_entry
-
-            return False
-
-        def cb(result: VT) -> None:
-            if compare_and_pop():
-                self.cache.set(key, result, entry.callbacks)
-            else:
-                # we're not going to put this entry into the cache, so need
-                # to make sure that the invalidation callbacks are called.
-                # That was probably done when _pending_deferred_cache was
-                # updated, but it's possible that `set` was called without
-                # `invalidate` being previously called, in which case it may
-                # not have been. Either way, let's double-check now.
-                entry.invalidate()
-
-        def eb(_fail: Failure) -> None:
-            compare_and_pop()
-            entry.invalidate()
-
-        # once the deferred completes, we can move the entry from the
-        # _pending_deferred_cache to the real cache.
-        #
-        observer.addCallbacks(cb, eb)
+        Called *before* starting the fetch from the DB, and the caller *must*
+        call either `complete_bulk(..)` or `error_bulk(..)` on the return value.
+        """
 
-        # we return a new Deferred which will be called before any subsequent observers.
-        return observable.observe()
+        entry = CacheMultipleEntries[KT, VT]()
+        entry.add_global_invalidation_callback(callback)
+
+        for key in keys:
+            self._pending_deferred_cache[key] = entry
+
+        return entry
+
+    def _completed_callback(
+        self, value: VT, entry: "CacheEntry[KT, VT]", key: KT
+    ) -> VT:
+        """Called when a deferred is completed."""
+        # We check if the current entry matches the entry associated with the
+        # deferred. If they don't match then it got invalidated.
+        current_entry = self._pending_deferred_cache.pop(key, None)
+        if current_entry is not entry:
+            if current_entry:
+                self._pending_deferred_cache[key] = current_entry
+            return value
+
+        self.cache.set(key, value, entry.get_invalidation_callbacks(key))
+
+        return value
+
+    def _error_callback(
+        self,
+        failure: Failure,
+        entry: "CacheEntry[KT, VT]",
+        key: KT,
+    ) -> Failure:
+        """Called when a deferred errors."""
+
+        # We check if the current entry matches the entry associated with the
+        # deferred. If they don't match then it got invalidated.
+        current_entry = self._pending_deferred_cache.pop(key, None)
+        if current_entry is not entry:
+            if current_entry:
+                self._pending_deferred_cache[key] = current_entry
+            return failure
+
+        for cb in entry.get_invalidation_callbacks(key):
+            cb()
+
+        return failure
 
     def prefill(
         self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
     ) -> None:
-        callbacks = [callback] if callback else []
+        callbacks = (callback,) if callback else ()
         self.cache.set(key, value, callbacks=callbacks)
+        self._pending_deferred_cache.pop(key, None)
 
     def invalidate(self, key: KT) -> None:
         """Delete a key, or tree of entries
@@ -311,41 +387,129 @@ class DeferredCache(Generic[KT, VT]):
         self.cache.del_multi(key)
 
         # if we have a pending lookup for this key, remove it from the
-        # _pending_deferred_cache, which will (a) stop it being returned
-        # for future queries and (b) stop it being persisted as a proper entry
+        # _pending_deferred_cache, which will (a) stop it being returned for
+        # future queries and (b) stop it being persisted as a proper entry
         # in self.cache.
         entry = self._pending_deferred_cache.pop(key, None)
-
-        # run the invalidation callbacks now, rather than waiting for the
-        # deferred to resolve.
         if entry:
             # _pending_deferred_cache.pop should either return a CacheEntry, or, in the
             # case of a TreeCache, a dict of keys to cache entries. Either way calling
             # iterate_tree_cache_entry on it will do the right thing.
             for entry in iterate_tree_cache_entry(entry):
-                entry.invalidate()
+                for cb in entry.get_invalidation_callbacks(key):
+                    cb()
 
     def invalidate_all(self) -> None:
         self.check_thread()
         self.cache.clear()
-        for entry in self._pending_deferred_cache.values():
-            entry.invalidate()
+        for key, entry in self._pending_deferred_cache.items():
+            for cb in entry.get_invalidation_callbacks(key):
+                cb()
+
         self._pending_deferred_cache.clear()
 
 
-class CacheEntry:
-    __slots__ = ["deferred", "callbacks", "invalidated"]
+class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta):
+    """Abstract class for entries in `DeferredCache[KT, VT]`"""
 
-    def __init__(
-        self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
-    ):
-        self.deferred = deferred
-        self.callbacks = set(callbacks)
-        self.invalidated = False
-
-    def invalidate(self) -> None:
-        if not self.invalidated:
-            self.invalidated = True
-            for callback in self.callbacks:
-                callback()
-            self.callbacks.clear()
+    @abc.abstractmethod
+    def deferred(self, key: KT) -> "defer.Deferred[VT]":
+        """Get a deferred that a caller can wait on to get the value at the
+        given key"""
+        ...
+
+    @abc.abstractmethod
+    def add_invalidation_callback(
+        self, key: KT, callback: Optional[Callable[[], None]]
+    ) -> None:
+        """Add an invalidation callback"""
+        ...
+
+    @abc.abstractmethod
+    def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+        """Get all invalidation callbacks"""
+        ...
+
+
+class CacheEntrySingle(CacheEntry[KT, VT]):
+    """An implementation of `CacheEntry` wrapping a deferred that results in a
+    single cache entry.
+    """
+
+    __slots__ = ["_deferred", "_callbacks"]
+
+    def __init__(self, deferred: "defer.Deferred[VT]") -> None:
+        self._deferred = ObservableDeferred(deferred, consumeErrors=True)
+        self._callbacks: Set[Callable[[], None]] = set()
+
+    def deferred(self, key: KT) -> "defer.Deferred[VT]":
+        return self._deferred.observe()
+
+    def add_invalidation_callback(
+        self, key: KT, callback: Optional[Callable[[], None]]
+    ) -> None:
+        if callback is None:
+            return
+
+        self._callbacks.add(callback)
+
+    def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+        return self._callbacks
+
+
+class CacheMultipleEntries(CacheEntry[KT, VT]):
+    """Cache entry that is used for bulk lookups and insertions."""
+
+    __slots__ = ["_deferred", "_callbacks", "_global_callbacks"]
+
+    def __init__(self) -> None:
+        self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None
+        self._callbacks: Dict[KT, Set[Callable[[], None]]] = {}
+        self._global_callbacks: Set[Callable[[], None]] = set()
+
+    def deferred(self, key: KT) -> "defer.Deferred[VT]":
+        if not self._deferred:
+            self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+        return self._deferred.observe().addCallback(lambda res: res.get(key))
+
+    def add_invalidation_callback(
+        self, key: KT, callback: Optional[Callable[[], None]]
+    ) -> None:
+        if callback is None:
+            return
+
+        self._callbacks.setdefault(key, set()).add(callback)
+
+    def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+        return self._callbacks.get(key, set()) | self._global_callbacks
+
+    def add_global_invalidation_callback(
+        self, callback: Optional[Callable[[], None]]
+    ) -> None:
+        """Add a callback for when any keys get invalidated."""
+        if callback is None:
+            return
+
+        self._global_callbacks.add(callback)
+
+    def complete_bulk(
+        self,
+        cache: DeferredCache[KT, VT],
+        result: Dict[KT, VT],
+    ) -> None:
+        """Called when there is a result"""
+        for key, value in result.items():
+            cache._completed_callback(value, self, key)
+
+        if self._deferred:
+            self._deferred.callback(result)
+
+    def error_bulk(
+        self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure
+    ) -> None:
+        """Called when bulk lookup failed."""
+        for key in keys:
+            cache._error_callback(failure, self, key)
+
+        if self._deferred:
+            self._deferred.errback(failure)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 867f315b2a..10aff4d04a 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -25,6 +25,7 @@ from typing import (
     Generic,
     Hashable,
     Iterable,
+    List,
     Mapping,
     Optional,
     Sequence,
@@ -73,8 +74,10 @@ class _CacheDescriptorBase:
         num_args: Optional[int],
         uncached_args: Optional[Collection[str]] = None,
         cache_context: bool = False,
+        name: Optional[str] = None,
     ):
         self.orig = orig
+        self.name = name or orig.__name__
 
         arg_spec = inspect.getfullargspec(orig)
         all_args = arg_spec.args
@@ -211,7 +214,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
         cache: LruCache[CacheKey, Any] = LruCache(
-            cache_name=self.orig.__name__,
+            cache_name=self.name,
             max_size=self.max_entries,
         )
 
@@ -241,7 +244,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
         wrapped = cast(_CachedFunction, _wrapped)
         wrapped.cache = cache
-        obj.__dict__[self.orig.__name__] = wrapped
+        obj.__dict__[self.name] = wrapped
 
         return wrapped
 
@@ -301,12 +304,14 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         cache_context: bool = False,
         iterable: bool = False,
         prune_unread_entries: bool = True,
+        name: Optional[str] = None,
     ):
         super().__init__(
             orig,
             num_args=num_args,
             uncached_args=uncached_args,
             cache_context=cache_context,
+            name=name,
         )
 
         if tree and self.num_args < 2:
@@ -321,7 +326,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
         cache: DeferredCache[CacheKey, Any] = DeferredCache(
-            name=self.orig.__name__,
+            name=self.name,
             max_entries=self.max_entries,
             tree=self.tree,
             iterable=self.iterable,
@@ -372,7 +377,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         wrapped.cache = cache
         wrapped.num_args = self.num_args
 
-        obj.__dict__[self.orig.__name__] = wrapped
+        obj.__dict__[self.name] = wrapped
 
         return wrapped
 
@@ -393,6 +398,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
         cached_method_name: str,
         list_name: str,
         num_args: Optional[int] = None,
+        name: Optional[str] = None,
     ):
         """
         Args:
@@ -403,7 +409,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                 but including list_name) to use as cache keys. Defaults to all
                 named args of the function.
         """
-        super().__init__(orig, num_args=num_args, uncached_args=None)
+        super().__init__(orig, num_args=num_args, uncached_args=None, name=name)
 
         self.list_name = list_name
 
@@ -435,16 +441,6 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
             keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
             list_args = arg_dict[self.list_name]
 
-            results = {}
-
-            def update_results_dict(res: Any, arg: Hashable) -> None:
-                results[arg] = res
-
-            # list of deferreds to wait for
-            cached_defers = []
-
-            missing = set()
-
             # If the cache takes a single arg then that is used as the key,
             # otherwise a tuple is used.
             if num_args == 1:
@@ -452,6 +448,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                 def arg_to_cache_key(arg: Hashable) -> Hashable:
                     return arg
 
+                def cache_key_to_arg(key: tuple) -> Hashable:
+                    return key
+
             else:
                 keylist = list(keyargs)
 
@@ -459,58 +458,53 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                     keylist[self.list_pos] = arg
                     return tuple(keylist)
 
-            for arg in list_args:
-                try:
-                    res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
-                    if not res.called:
-                        res.addCallback(update_results_dict, arg)
-                        cached_defers.append(res)
-                    else:
-                        results[arg] = res.result
-                except KeyError:
-                    missing.add(arg)
+                def cache_key_to_arg(key: tuple) -> Hashable:
+                    return key[self.list_pos]
+
+            cache_keys = [arg_to_cache_key(arg) for arg in list_args]
+            immediate_results, pending_deferred, missing = cache.get_bulk(
+                cache_keys, callback=invalidate_callback
+            )
+
+            results = {cache_key_to_arg(key): v for key, v in immediate_results.items()}
+
+            cached_defers: List["defer.Deferred[Any]"] = []
+            if pending_deferred:
+
+                def update_results(r: Dict) -> None:
+                    for k, v in r.items():
+                        results[cache_key_to_arg(k)] = v
+
+                pending_deferred.addCallback(update_results)
+                cached_defers.append(pending_deferred)
 
             if missing:
-                # we need a deferred for each entry in the list,
-                # which we put in the cache. Each deferred resolves with the
-                # relevant result for that key.
-                deferreds_map = {}
-                for arg in missing:
-                    deferred: "defer.Deferred[Any]" = defer.Deferred()
-                    deferreds_map[arg] = deferred
-                    key = arg_to_cache_key(arg)
-                    cached_defers.append(
-                        cache.set(key, deferred, callback=invalidate_callback)
-                    )
+                cache_entry = cache.start_bulk_input(missing, invalidate_callback)
 
                 def complete_all(res: Dict[Hashable, Any]) -> None:
-                    # the wrapped function has completed. It returns a dict.
-                    # We can now update our own result map, and then resolve the
-                    # observable deferreds in the cache.
-                    for e, d1 in deferreds_map.items():
-                        val = res.get(e, None)
-                        # make sure we update the results map before running the
-                        # deferreds, because as soon as we run the last deferred, the
-                        # gatherResults() below will complete and return the result
-                        # dict to our caller.
-                        results[e] = val
-                        d1.callback(val)
+                    missing_results = {}
+                    for key in missing:
+                        arg = cache_key_to_arg(key)
+                        val = res.get(arg, None)
+
+                        results[arg] = val
+                        missing_results[key] = val
+
+                    cache_entry.complete_bulk(cache, missing_results)
 
                 def errback_all(f: Failure) -> None:
-                    # the wrapped function has failed. Propagate the failure into
-                    # the cache, which will invalidate the entry, and cause the
-                    # relevant cached_deferreds to fail, which will propagate the
-                    # failure to our caller.
-                    for d1 in deferreds_map.values():
-                        d1.errback(f)
+                    cache_entry.error_bulk(cache, missing, f)
 
                 args_to_call = dict(arg_dict)
-                args_to_call[self.list_name] = missing
+                args_to_call[self.list_name] = {
+                    cache_key_to_arg(key) for key in missing
+                }
 
                 # dispatch the call, and attach the two handlers
-                defer.maybeDeferred(
+                missing_d = defer.maybeDeferred(
                     preserve_fn(self.orig), **args_to_call
                 ).addCallbacks(complete_all, errback_all)
+                cached_defers.append(missing_d)
 
             if cached_defers:
                 d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
@@ -525,7 +519,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
             else:
                 return defer.succeed(results)
 
-        obj.__dict__[self.orig.__name__] = wrapped
+        obj.__dict__[self.name] = wrapped
 
         return wrapped
 
@@ -577,6 +571,7 @@ def cached(
     cache_context: bool = False,
     iterable: bool = False,
     prune_unread_entries: bool = True,
+    name: Optional[str] = None,
 ) -> Callable[[F], _CachedFunction[F]]:
     func = lambda orig: DeferredCacheDescriptor(
         orig,
@@ -587,13 +582,18 @@ def cached(
         cache_context=cache_context,
         iterable=iterable,
         prune_unread_entries=prune_unread_entries,
+        name=name,
     )
 
     return cast(Callable[[F], _CachedFunction[F]], func)
 
 
 def cachedList(
-    *, cached_method_name: str, list_name: str, num_args: Optional[int] = None
+    *,
+    cached_method_name: str,
+    list_name: str,
+    num_args: Optional[int] = None,
+    name: Optional[str] = None,
 ) -> Callable[[F], _CachedFunction[F]]:
     """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
 
@@ -628,6 +628,7 @@ def cachedList(
         cached_method_name=cached_method_name,
         list_name=list_name,
         num_args=num_args,
+        name=name,
     )
 
     return cast(Callable[[F], _CachedFunction[F]], func)
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index d267703df0..fa91479c97 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -14,11 +14,13 @@
 import enum
 import logging
 import threading
-from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar
+from typing import Any, Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union
 
 import attr
+from typing_extensions import Literal
 
 from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.treecache import TreeCache
 
 logger = logging.getLogger(__name__)
 
@@ -33,10 +35,12 @@ DV = TypeVar("DV")
 
 # This class can't be generic because it uses slots with attrs.
 # See: https://github.com/python-attrs/attrs/issues/313
-@attr.s(slots=True, auto_attribs=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class DictionaryEntry:  # should be: Generic[DKT, DV].
     """Returned when getting an entry from the cache
 
+    If `full` is true then `known_absent` will be the empty set.
+
     Attributes:
         full: Whether the cache has the full or dict or just some keys.
             If not full then not all requested keys will necessarily be present
@@ -53,20 +57,90 @@ class DictionaryEntry:  # should be: Generic[DKT, DV].
         return len(self.value)
 
 
+class _FullCacheKey(enum.Enum):
+    """The key we use to cache the full dict."""
+
+    KEY = object()
+
+
 class _Sentinel(enum.Enum):
     # defining a sentinel in this way allows mypy to correctly handle the
     # type of a dictionary lookup.
     sentinel = object()
 
 
+class _PerKeyValue(Generic[DV]):
+    """The cached value of a dictionary key. If `value` is the sentinel,
+    indicates that the requested key is known to *not* be in the full dict.
+    """
+
+    __slots__ = ["value"]
+
+    def __init__(self, value: Union[DV, Literal[_Sentinel.sentinel]]) -> None:
+        self.value = value
+
+    def __len__(self) -> int:
+        # We add a `__len__` implementation as we use this class in a cache
+        # where the values are variable length.
+        return 1
+
+
 class DictionaryCache(Generic[KT, DKT, DV]):
     """Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
     fetching a subset of dictionary keys for a particular key.
+
+    This cache has two levels of key. First there is the "cache key" (of type
+    `KT`), which maps to a dict. The keys to that dict are the "dict key" (of
+    type `DKT`). The overall structure is therefore `KT->DKT->DV`. For
+    example, it might look like:
+
+       {
+           1: { 1: "a", 2: "b" },
+           2: { 1: "c" },
+       }
+
+    It is possible to look up either individual dict keys, or the *complete*
+    dict for a given cache key.
+
+    Each dict item, and the complete dict is treated as a separate LRU
+    entry for the purpose of cache expiry. For example, given:
+        dict_cache.get(1, None)  -> DictionaryEntry({1: "a", 2: "b"})
+        dict_cache.get(1, [1])  -> DictionaryEntry({1: "a"})
+        dict_cache.get(1, [2])  -> DictionaryEntry({2: "b"})
+
+    ... then the cache entry for the complete dict will expire first,
+    followed by the cache entry for the '1' dict key, and finally that
+    for the '2' dict key.
     """
 
     def __init__(self, name: str, max_entries: int = 1000):
-        self.cache: LruCache[KT, DictionaryEntry] = LruCache(
-            max_size=max_entries, cache_name=name, size_callback=len
+        # We use a single LruCache to store two different types of entries:
+        #   1. Map from (key, dict_key) -> dict value (or sentinel, indicating
+        #      the key doesn't exist in the dict); and
+        #   2. Map from (key, _FullCacheKey.KEY) -> full dict.
+        #
+        # The former is used when explicit keys of the dictionary are looked up,
+        # and the latter when the full dictionary is requested.
+        #
+        # If when explicit keys are requested and not in the cache, we then look
+        # to see if we have the full dict and use that if we do. If found in the
+        # full dict each key is added into the cache.
+        #
+        # This set up allows the `LruCache` to prune the full dict entries if
+        # they haven't been used in a while, even when there have been recent
+        # queries for subsets of the dict.
+        #
+        # Typing:
+        #     * A key of `(KT, DKT)` has a value of `_PerKeyValue`
+        #     * A key of `(KT, _FullCacheKey.KEY)` has a value of `Dict[DKT, DV]`
+        self.cache: LruCache[
+            Tuple[KT, Union[DKT, Literal[_FullCacheKey.KEY]]],
+            Union[_PerKeyValue, Dict[DKT, DV]],
+        ] = LruCache(
+            max_size=max_entries,
+            cache_name=name,
+            cache_type=TreeCache,
+            size_callback=len,
         )
 
         self.name = name
@@ -91,23 +165,83 @@ class DictionaryCache(Generic[KT, DKT, DV]):
         Args:
             key
             dict_keys: If given a set of keys then return only those keys
-                that exist in the cache.
+                that exist in the cache. If None then returns the full dict
+                if it is in the cache.
 
         Returns:
-            DictionaryEntry
+            DictionaryEntry: If `dict_keys` is not None then `DictionaryEntry`
+            will contain include the keys that are in the cache. If None then
+            will either return the full dict if in the cache, or the empty
+            dict (with `full` set to False) if it isn't.
         """
-        entry = self.cache.get(key, _Sentinel.sentinel)
-        if entry is not _Sentinel.sentinel:
-            if dict_keys is None:
-                return DictionaryEntry(
-                    entry.full, entry.known_absent, dict(entry.value)
-                )
+        if dict_keys is None:
+            # The caller wants the full set of dictionary keys for this cache key
+            return self._get_full_dict(key)
+
+        # We are being asked for a subset of keys.
+
+        # First go and check for each requested dict key in the cache, tracking
+        # which we couldn't find.
+        values = {}
+        known_absent = set()
+        missing = []
+        for dict_key in dict_keys:
+            entry = self.cache.get((key, dict_key), _Sentinel.sentinel)
+            if entry is _Sentinel.sentinel:
+                missing.append(dict_key)
+                continue
+
+            assert isinstance(entry, _PerKeyValue)
+
+            if entry.value is _Sentinel.sentinel:
+                known_absent.add(dict_key)
             else:
-                return DictionaryEntry(
-                    entry.full,
-                    entry.known_absent,
-                    {k: entry.value[k] for k in dict_keys if k in entry.value},
-                )
+                values[dict_key] = entry.value
+
+        # If we found everything we can return immediately.
+        if not missing:
+            return DictionaryEntry(False, known_absent, values)
+
+        # We are missing some keys, so check if we happen to have the full dict in
+        # the cache.
+        #
+        # We don't update the last access time for this cache fetch, as we
+        # aren't explicitly interested in the full dict and so we don't want
+        # requests for explicit dict keys to keep the full dict in the cache.
+        entry = self.cache.get(
+            (key, _FullCacheKey.KEY),
+            _Sentinel.sentinel,
+            update_last_access=False,
+        )
+        if entry is _Sentinel.sentinel:
+            # Not in the cache, return the subset of keys we found.
+            return DictionaryEntry(False, known_absent, values)
+
+        # We have the full dict!
+        assert isinstance(entry, dict)
+
+        for dict_key in missing:
+            # We explicitly add each dict key to the cache, so that cache hit
+            # rates and LRU times for each key can be tracked separately.
+            value = entry.get(dict_key, _Sentinel.sentinel)  # type: ignore[arg-type]
+            self.cache[(key, dict_key)] = _PerKeyValue(value)
+
+            if value is not _Sentinel.sentinel:
+                values[dict_key] = value
+
+        return DictionaryEntry(True, set(), values)
+
+    def _get_full_dict(
+        self,
+        key: KT,
+    ) -> DictionaryEntry:
+        """Fetch the full dict for the given key."""
+
+        # First we check if we have cached the full dict.
+        entry = self.cache.get((key, _FullCacheKey.KEY), _Sentinel.sentinel)
+        if entry is not _Sentinel.sentinel:
+            assert isinstance(entry, dict)
+            return DictionaryEntry(True, set(), entry)
 
         return DictionaryEntry(False, set(), {})
 
@@ -117,7 +251,13 @@ class DictionaryCache(Generic[KT, DKT, DV]):
         # Increment the sequence number so that any SELECT statements that
         # raced with the INSERT don't update the cache (SYN-369)
         self.sequence += 1
-        self.cache.pop(key, None)
+
+        # We want to drop all information about the dict for the given key, so
+        # we use `del_multi` to delete it all in one go.
+        #
+        # We ignore the type error here: `del_multi` accepts a truncated key
+        # (when the key type is a tuple).
+        self.cache.del_multi((key,))  # type: ignore[arg-type]
 
     def invalidate_all(self) -> None:
         self.check_thread()
@@ -131,7 +271,16 @@ class DictionaryCache(Generic[KT, DKT, DV]):
         value: Dict[DKT, DV],
         fetched_keys: Optional[Iterable[DKT]] = None,
     ) -> None:
-        """Updates the entry in the cache
+        """Updates the entry in the cache.
+
+        Note: This does *not* invalidate any existing entries for the `key`.
+        In particular, if we add an entry for the cached "full dict" with
+        `fetched_keys=None`, existing entries for individual dict keys are
+        not invalidated. Likewise, adding entries for individual keys does
+        not invalidate any cached value for the full dict.
+
+        In other words: if the underlying data is *changed*, the cache must
+        be explicitly invalidated via `.invalidate()`.
 
         Args:
             sequence
@@ -149,20 +298,27 @@ class DictionaryCache(Generic[KT, DKT, DV]):
             # Only update the cache if the caches sequence number matches the
             # number that the cache had before the SELECT was started (SYN-369)
             if fetched_keys is None:
-                self._insert(key, value, set())
+                self.cache[(key, _FullCacheKey.KEY)] = value
             else:
-                self._update_or_insert(key, value, fetched_keys)
+                self._update_subset(key, value, fetched_keys)
 
-    def _update_or_insert(
-        self, key: KT, value: Dict[DKT, DV], known_absent: Iterable[DKT]
+    def _update_subset(
+        self, key: KT, value: Dict[DKT, DV], fetched_keys: Iterable[DKT]
     ) -> None:
-        # We pop and reinsert as we need to tell the cache the size may have
-        # changed
+        """Add the given dictionary values as explicit keys in the cache.
+
+        Args:
+            key: top-level cache key
+            value: The dictionary with all the values that we should cache
+            fetched_keys: The full set of dict keys that were looked up. Any keys
+                here not in `value` should be marked as "known absent".
+        """
+
+        for dict_key, dict_value in value.items():
+            self.cache[(key, dict_key)] = _PerKeyValue(dict_value)
 
-        entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
-        entry.value.update(value)
-        entry.known_absent.update(known_absent)
-        self.cache[key] = entry
+        for dict_key in fetched_keys:
+            if dict_key in value:
+                continue
 
-    def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None:
-        self.cache[key] = DictionaryEntry(True, known_absent, value)
+            self.cache[(key, dict_key)] = _PerKeyValue(_Sentinel.sentinel)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 31f41fec82..aa93109d13 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -25,8 +25,10 @@ from typing import (
     Collection,
     Dict,
     Generic,
+    Iterable,
     List,
     Optional,
+    Tuple,
     Type,
     TypeVar,
     Union,
@@ -44,7 +46,11 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
 from synapse.metrics.jemalloc import get_jemalloc_stats
 from synapse.util import Clock, caches
 from synapse.util.caches import CacheMetric, EvictionReason, register_cache
-from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+from synapse.util.caches.treecache import (
+    TreeCache,
+    iterate_tree_cache_entry,
+    iterate_tree_cache_items,
+)
 from synapse.util.linked_list import ListNode
 
 if TYPE_CHECKING:
@@ -537,6 +543,7 @@ class LruCache(Generic[KT, VT]):
             default: Literal[None] = None,
             callbacks: Collection[Callable[[], None]] = ...,
             update_metrics: bool = ...,
+            update_last_access: bool = ...,
         ) -> Optional[VT]:
             ...
 
@@ -546,6 +553,7 @@ class LruCache(Generic[KT, VT]):
             default: T,
             callbacks: Collection[Callable[[], None]] = ...,
             update_metrics: bool = ...,
+            update_last_access: bool = ...,
         ) -> Union[T, VT]:
             ...
 
@@ -555,10 +563,27 @@ class LruCache(Generic[KT, VT]):
             default: Optional[T] = None,
             callbacks: Collection[Callable[[], None]] = (),
             update_metrics: bool = True,
+            update_last_access: bool = True,
         ) -> Union[None, T, VT]:
+            """Look up a key in the cache
+
+            Args:
+                key
+                default
+                callbacks: A collection of callbacks that will fire when the
+                    node is removed from the cache (either due to invalidation
+                    or expiry).
+                update_metrics: Whether to update the hit rate metrics
+                update_last_access: Whether to update the last access metrics
+                    on a node if successfully fetched. These metrics are used
+                    to determine when to remove the node from the cache. Set
+                    to False if this fetch should *not* prevent a node from
+                    being expired.
+            """
             node = cache.get(key, None)
             if node is not None:
-                move_node_to_front(node)
+                if update_last_access:
+                    move_node_to_front(node)
                 node.add_callbacks(callbacks)
                 if update_metrics and metrics:
                     metrics.inc_hits()
@@ -568,6 +593,65 @@ class LruCache(Generic[KT, VT]):
                     metrics.inc_misses()
                 return default
 
+        @overload
+        def cache_get_multi(
+            key: tuple,
+            default: Literal[None] = None,
+            update_metrics: bool = True,
+        ) -> Union[None, Iterable[Tuple[KT, VT]]]:
+            ...
+
+        @overload
+        def cache_get_multi(
+            key: tuple,
+            default: T,
+            update_metrics: bool = True,
+        ) -> Union[T, Iterable[Tuple[KT, VT]]]:
+            ...
+
+        @synchronized
+        def cache_get_multi(
+            key: tuple,
+            default: Optional[T] = None,
+            update_metrics: bool = True,
+        ) -> Union[None, T, Iterable[Tuple[KT, VT]]]:
+            """Returns a generator yielding all entries under the given key.
+
+            Can only be used if backed by a tree cache.
+
+            Example:
+
+                cache = LruCache(10, cache_type=TreeCache)
+                cache[(1, 1)] = "a"
+                cache[(1, 2)] = "b"
+                cache[(2, 1)] = "c"
+
+                items = cache.get_multi((1,))
+                assert list(items) == [((1, 1), "a"), ((1, 2), "b")]
+
+            Returns:
+                Either default if the key doesn't exist, or a generator of the
+                key/value pairs.
+            """
+
+            assert isinstance(cache, TreeCache)
+
+            node = cache.get(key, None)
+            if node is not None:
+                if update_metrics and metrics:
+                    metrics.inc_hits()
+
+                # We store entries in the `TreeCache` with values of type `_Node`,
+                # which we need to unwrap.
+                return (
+                    (full_key, lru_node.value)
+                    for full_key, lru_node in iterate_tree_cache_items(key, node)
+                )
+            else:
+                if update_metrics and metrics:
+                    metrics.inc_misses()
+                return default
+
         @synchronized
         def cache_set(
             key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
@@ -674,6 +758,8 @@ class LruCache(Generic[KT, VT]):
         self.setdefault = cache_set_default
         self.pop = cache_pop
         self.del_multi = cache_del_multi
+        if cache_type is TreeCache:
+            self.get_multi = cache_get_multi
         # `invalidate` is exposed for consistency with DeferredCache, so that it can be
         # invalidated by the cache invalidation replication stream.
         self.invalidate = cache_del_multi
@@ -748,9 +834,26 @@ class AsyncLruCache(Generic[KT, VT]):
     ) -> Optional[VT]:
         return self._lru_cache.get(key, update_metrics=update_metrics)
 
+    async def get_external(
+        self,
+        key: KT,
+        default: Optional[T] = None,
+        update_metrics: bool = True,
+    ) -> Optional[VT]:
+        # This method should fetch from any configured external cache, in this case noop.
+        return None
+
+    def get_local(
+        self, key: KT, default: Optional[T] = None, update_metrics: bool = True
+    ) -> Optional[VT]:
+        return self._lru_cache.get(key, update_metrics=update_metrics)
+
     async def set(self, key: KT, value: VT) -> None:
         self._lru_cache.set(key, value)
 
+    def set_local(self, key: KT, value: VT) -> None:
+        self._lru_cache.set(key, value)
+
     async def invalidate(self, key: KT) -> None:
         # This method should invalidate any external cache and then invalidate the LruCache.
         return self._lru_cache.invalidate(key)
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index e78305f787..fec31da2b6 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -64,6 +64,15 @@ class TreeCache:
         self.size += 1
 
     def get(self, key, default=None):
+        """When `key` is a full key, fetches the value for the given key (if
+        any).
+
+        If `key` is only a partial key (i.e. a truncated tuple) then returns a
+        `TreeCacheNode`, which can be passed to the `iterate_tree_cache_*`
+        functions to iterate over all entries in the cache with keys that start
+        with the given partial key.
+        """
+
         node = self.root
         for k in key[:-1]:
             node = node.get(k, None)
@@ -126,6 +135,9 @@ class TreeCache:
     def values(self):
         return iterate_tree_cache_entry(self.root)
 
+    def items(self):
+        return iterate_tree_cache_items((), self.root)
+
     def __len__(self) -> int:
         return self.size
 
@@ -139,3 +151,32 @@ def iterate_tree_cache_entry(d):
             yield from iterate_tree_cache_entry(value_d)
     else:
         yield d
+
+
+def iterate_tree_cache_items(key, value):
+    """Helper function to iterate over the leaves of a tree, i.e. a dict of that
+    can contain dicts.
+
+    The provided key is a tuple that will get prepended to the returned keys.
+
+    Example:
+
+        cache = TreeCache()
+        cache[(1, 1)] = "a"
+        cache[(1, 2)] = "b"
+        cache[(2, 1)] = "c"
+
+        tree_node = cache.get((1,))
+
+        items = iterate_tree_cache_items((1,), tree_node)
+        assert list(items) == [((1, 1), "a"), ((1, 2), "b")]
+
+    Returns:
+        A generator yielding key/value pairs.
+    """
+    if isinstance(value, TreeCacheNode):
+        for sub_key, sub_value in value.items():
+            yield from iterate_tree_cache_items((*key, sub_key), sub_value)
+    else:
+        # we've reached a leaf of the tree.
+        yield key, value
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index dfe628c97e..f678b52cb4 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -18,15 +18,19 @@ import logging
 import typing
 from typing import Any, DefaultDict, Iterator, List, Set
 
+from prometheus_client.core import Counter
+
 from twisted.internet import defer
 
 from synapse.api.errors import LimitExceededError
-from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.config.ratelimiting import FederationRatelimitSettings
 from synapse.logging.context import (
     PreserveLoggingContext,
     make_deferred_yieldable,
     run_in_background,
 )
+from synapse.logging.opentracing import start_active_span
+from synapse.metrics import Histogram, LaterGauge
 from synapse.util import Clock
 
 if typing.TYPE_CHECKING:
@@ -35,8 +39,34 @@ if typing.TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+# Track how much the ratelimiter is affecting requests
+rate_limit_sleep_counter = Counter("synapse_rate_limit_sleep", "")
+rate_limit_reject_counter = Counter("synapse_rate_limit_reject", "")
+queue_wait_timer = Histogram(
+    "synapse_rate_limit_queue_wait_time_seconds",
+    "sec",
+    [],
+    buckets=(
+        0.005,
+        0.01,
+        0.025,
+        0.05,
+        0.1,
+        0.25,
+        0.5,
+        0.75,
+        1.0,
+        2.5,
+        5.0,
+        10.0,
+        20.0,
+        "+Inf",
+    ),
+)
+
+
 class FederationRateLimiter:
-    def __init__(self, clock: Clock, config: FederationRateLimitConfig):
+    def __init__(self, clock: Clock, config: FederationRatelimitSettings):
         def new_limiter() -> "_PerHostRatelimiter":
             return _PerHostRatelimiter(clock=clock, config=config)
 
@@ -44,6 +74,27 @@ class FederationRateLimiter:
             str, "_PerHostRatelimiter"
         ] = collections.defaultdict(new_limiter)
 
+        # We track the number of affected hosts per time-period so we can
+        # differentiate one really noisy homeserver from a general
+        # ratelimit tuning problem across the federation.
+        LaterGauge(
+            "synapse_rate_limit_sleep_affected_hosts",
+            "Number of hosts that had requests put to sleep",
+            [],
+            lambda: sum(
+                ratelimiter.should_sleep() for ratelimiter in self.ratelimiters.values()
+            ),
+        )
+        LaterGauge(
+            "synapse_rate_limit_reject_affected_hosts",
+            "Number of hosts that had requests rejected",
+            [],
+            lambda: sum(
+                ratelimiter.should_reject()
+                for ratelimiter in self.ratelimiters.values()
+            ),
+        )
+
     def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]":
         """Used to ratelimit an incoming request from a given host
 
@@ -59,11 +110,11 @@ class FederationRateLimiter:
         Returns:
             context manager which returns a deferred.
         """
-        return self.ratelimiters[host].ratelimit()
+        return self.ratelimiters[host].ratelimit(host)
 
 
 class _PerHostRatelimiter:
-    def __init__(self, clock: Clock, config: FederationRateLimitConfig):
+    def __init__(self, clock: Clock, config: FederationRatelimitSettings):
         """
         Args:
             clock
@@ -94,19 +145,42 @@ class _PerHostRatelimiter:
         self.request_times: List[int] = []
 
     @contextlib.contextmanager
-    def ratelimit(self) -> "Iterator[defer.Deferred[None]]":
+    def ratelimit(self, host: str) -> "Iterator[defer.Deferred[None]]":
         # `contextlib.contextmanager` takes a generator and turns it into a
         # context manager. The generator should only yield once with a value
         # to be returned by manager.
         # Exceptions will be reraised at the yield.
 
+        self.host = host
+
         request_id = object()
-        ret = self._on_enter(request_id)
+        # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
+        # type-checking, but we'd need Twisted >= 21.2.
+        ret = defer.ensureDeferred(self._on_enter_with_tracing(request_id))
         try:
             yield ret
         finally:
             self._on_exit(request_id)
 
+    def should_reject(self) -> bool:
+        """
+        Whether to reject the request if we already have too many queued up
+        (either sleeping or in the ready queue).
+        """
+        queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
+        return queue_size > self.reject_limit
+
+    def should_sleep(self) -> bool:
+        """
+        Whether to sleep the request if we already have too many requests coming
+        through within the window.
+        """
+        return len(self.request_times) > self.sleep_limit
+
+    async def _on_enter_with_tracing(self, request_id: object) -> None:
+        with start_active_span("ratelimit wait"), queue_wait_timer.time():
+            await self._on_enter(request_id)
+
     def _on_enter(self, request_id: object) -> "defer.Deferred[None]":
         time_now = self.clock.time_msec()
 
@@ -117,8 +191,9 @@ class _PerHostRatelimiter:
 
         # reject the request if we already have too many queued up (either
         # sleeping or in the ready queue).
-        queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
-        if queue_size > self.reject_limit:
+        if self.should_reject():
+            logger.debug("Ratelimiter(%s): rejecting request", self.host)
+            rate_limit_reject_counter.inc()
             raise LimitExceededError(
                 retry_after_ms=int(self.window_size / self.sleep_limit)
             )
@@ -130,7 +205,8 @@ class _PerHostRatelimiter:
                 queue_defer: defer.Deferred[None] = defer.Deferred()
                 self.ready_request_queue[request_id] = queue_defer
                 logger.info(
-                    "Ratelimiter: queueing request (queue now %i items)",
+                    "Ratelimiter(%s): queueing request (queue now %i items)",
+                    self.host,
                     len(self.ready_request_queue),
                 )
 
@@ -139,19 +215,28 @@ class _PerHostRatelimiter:
                 return defer.succeed(None)
 
         logger.debug(
-            "Ratelimit [%s]: len(self.request_times)=%d",
+            "Ratelimit(%s) [%s]: len(self.request_times)=%d",
+            self.host,
             id(request_id),
             len(self.request_times),
         )
 
-        if len(self.request_times) > self.sleep_limit:
-            logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec)
+        if self.should_sleep():
+            logger.debug(
+                "Ratelimiter(%s) [%s]: sleeping request for %f sec",
+                self.host,
+                id(request_id),
+                self.sleep_sec,
+            )
+            rate_limit_sleep_counter.inc()
             ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
 
             self.sleeping_requests.add(request_id)
 
             def on_wait_finished(_: Any) -> "defer.Deferred[None]":
-                logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
+                logger.debug(
+                    "Ratelimit(%s) [%s]: Finished sleeping", self.host, id(request_id)
+                )
                 self.sleeping_requests.discard(request_id)
                 queue_defer = queue_request()
                 return queue_defer
@@ -161,7 +246,9 @@ class _PerHostRatelimiter:
             ret_defer = queue_request()
 
         def on_start(r: object) -> object:
-            logger.debug("Ratelimit [%s]: Processing req", id(request_id))
+            logger.debug(
+                "Ratelimit(%s) [%s]: Processing req", self.host, id(request_id)
+            )
             self.current_processing.add(request_id)
             return r
 
@@ -183,7 +270,7 @@ class _PerHostRatelimiter:
         return make_deferred_yieldable(ret_defer)
 
     def _on_exit(self, request_id: object) -> None:
-        logger.debug("Ratelimit [%s]: Processed req", id(request_id))
+        logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id))
         self.current_processing.discard(request_id)
         try:
             # start processing the next item on the queue.