summary refs log tree commit diff
path: root/synapse/util/caches
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches')
-rw-r--r--synapse/util/caches/__init__.py13
-rw-r--r--synapse/util/caches/deferred_cache.py342
-rw-r--r--synapse/util/caches/descriptors.py521
-rw-r--r--synapse/util/caches/dictionary_cache.py29
-rw-r--r--synapse/util/caches/lrucache.py138
-rw-r--r--synapse/util/caches/response_cache.py50
-rw-r--r--synapse/util/caches/ttlcache.py2
7 files changed, 710 insertions, 385 deletions
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 8fc05be278..89f0b38535 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -16,7 +16,7 @@
 
 import logging
 from sys import intern
-from typing import Callable, Dict, Optional
+from typing import Callable, Dict, Optional, Sized
 
 import attr
 from prometheus_client.core import Gauge
@@ -92,7 +92,7 @@ class CacheMetric:
 def register_cache(
     cache_type: str,
     cache_name: str,
-    cache,
+    cache: Sized,
     collect_callback: Optional[Callable] = None,
     resizable: bool = True,
     resize_callback: Optional[Callable] = None,
@@ -100,12 +100,15 @@ def register_cache(
     """Register a cache object for metric collection and resizing.
 
     Args:
-        cache_type
+        cache_type: a string indicating the "type" of the cache. This is used
+            only for deduplication so isn't too important provided it's constant.
         cache_name: name of the cache
-        cache: cache itself
+        cache: cache itself, which must implement __len__(), and may optionally implement
+             a max_size property
         collect_callback: If given, a function which is called during metric
             collection to update additional metrics.
-        resizable: Whether this cache supports being resized.
+        resizable: Whether this cache supports being resized, in which case either
+            resize_callback must be provided, or the cache must support set_max_size().
         resize_callback: A function which can be called to resize the cache.
 
     Returns:
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
new file mode 100644
index 0000000000..601305487c
--- /dev/null
+++ b/synapse/util/caches/deferred_cache.py
@@ -0,0 +1,342 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import enum
+import threading
+from typing import (
+    Callable,
+    Generic,
+    Iterable,
+    MutableMapping,
+    Optional,
+    TypeVar,
+    Union,
+    cast,
+)
+
+from prometheus_client import Gauge
+
+from twisted.internet import defer
+from twisted.python import failure
+
+from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+
+cache_pending_metric = Gauge(
+    "synapse_util_caches_cache_pending",
+    "Number of lookups currently pending for this cache",
+    ["name"],
+)
+
+T = TypeVar("T")
+KT = TypeVar("KT")
+VT = TypeVar("VT")
+
+
+class _Sentinel(enum.Enum):
+    # defining a sentinel in this way allows mypy to correctly handle the
+    # type of a dictionary lookup.
+    sentinel = object()
+
+
+class DeferredCache(Generic[KT, VT]):
+    """Wraps an LruCache, adding support for Deferred results.
+
+    It expects that each entry added with set() will be a Deferred; likewise get()
+    will return a Deferred.
+    """
+
+    __slots__ = (
+        "cache",
+        "thread",
+        "_pending_deferred_cache",
+    )
+
+    def __init__(
+        self,
+        name: str,
+        max_entries: int = 1000,
+        keylen: int = 1,
+        tree: bool = False,
+        iterable: bool = False,
+        apply_cache_factor_from_config: bool = True,
+    ):
+        """
+        Args:
+            name: The name of the cache
+            max_entries: Maximum amount of entries that the cache will hold
+            keylen: The length of the tuple used as the cache key. Ignored unless
+               `tree` is True.
+            tree: Use a TreeCache instead of a dict as the underlying cache type
+            iterable: If True, count each item in the cached object as an entry,
+                rather than each cached object
+            apply_cache_factor_from_config: Whether cache factors specified in the
+                config file affect `max_entries`
+        """
+        cache_type = TreeCache if tree else dict
+
+        # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
+        self._pending_deferred_cache = (
+            cache_type()
+        )  # type: MutableMapping[KT, CacheEntry]
+
+        def metrics_cb():
+            cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
+
+        # cache is used for completed results and maps to the result itself, rather than
+        # a Deferred.
+        self.cache = LruCache(
+            max_size=max_entries,
+            keylen=keylen,
+            cache_name=name,
+            cache_type=cache_type,
+            size_callback=(lambda d: len(d)) if iterable else None,
+            metrics_collection_callback=metrics_cb,
+            apply_cache_factor_from_config=apply_cache_factor_from_config,
+        )  # type: LruCache[KT, VT]
+
+        self.thread = None  # type: Optional[threading.Thread]
+
+    @property
+    def max_entries(self):
+        return self.cache.max_size
+
+    def check_thread(self):
+        expected_thread = self.thread
+        if expected_thread is None:
+            self.thread = threading.current_thread()
+        else:
+            if expected_thread is not threading.current_thread():
+                raise ValueError(
+                    "Cache objects can only be accessed from the main thread"
+                )
+
+    def get(
+        self,
+        key: KT,
+        callback: Optional[Callable[[], None]] = None,
+        update_metrics: bool = True,
+    ) -> defer.Deferred:
+        """Looks the key up in the caches.
+
+        For symmetry with set(), this method does *not* follow the synapse logcontext
+        rules: the logcontext will not be cleared on return, and the Deferred will run
+        its callbacks in the sentinel context. In other words: wrap the result with
+        make_deferred_yieldable() before `await`ing it.
+
+        Args:
+            key:
+            callback: Gets called when the entry in the cache is invalidated
+            update_metrics (bool): whether to update the cache hit rate metrics
+
+        Returns:
+            A Deferred which completes with the result. Note that this may later fail
+            if there is an ongoing set() operation which later completes with a failure.
+
+        Raises:
+            KeyError if the key is not found in the cache
+        """
+        callbacks = [callback] if callback else []
+        val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
+        if val is not _Sentinel.sentinel:
+            val.callbacks.update(callbacks)
+            if update_metrics:
+                m = self.cache.metrics
+                assert m  # we always have a name, so should always have metrics
+                m.inc_hits()
+            return val.deferred.observe()
+
+        val2 = self.cache.get(
+            key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
+        )
+        if val2 is _Sentinel.sentinel:
+            raise KeyError()
+        else:
+            return defer.succeed(val2)
+
+    def get_immediate(
+        self, key: KT, default: T, update_metrics: bool = True
+    ) -> Union[VT, T]:
+        """If we have a *completed* cached value, return it."""
+        return self.cache.get(key, default, update_metrics=update_metrics)
+
+    def set(
+        self,
+        key: KT,
+        value: defer.Deferred,
+        callback: Optional[Callable[[], None]] = None,
+    ) -> defer.Deferred:
+        """Adds a new entry to the cache (or updates an existing one).
+
+        The given `value` *must* be a Deferred.
+
+        First any existing entry for the same key is invalidated. Then a new entry
+        is added to the cache for the given key.
+
+        Until the `value` completes, calls to `get()` for the key will also result in an
+        incomplete Deferred, which will ultimately complete with the same result as
+        `value`.
+
+        If `value` completes successfully, subsequent calls to `get()` will then return
+        a completed deferred with the same result. If it *fails*, the cache is
+        invalidated and subequent calls to `get()` will raise a KeyError.
+
+        If another call to `set()` happens before `value` completes, then (a) any
+        invalidation callbacks registered in the interim will be called, (b) any
+        `get()`s in the interim will continue to complete with the result from the
+        *original* `value`, (c) any future calls to `get()` will complete with the
+        result from the *new* `value`.
+
+        It is expected that `value` does *not* follow the synapse logcontext rules - ie,
+        if it is incomplete, it runs its callbacks in the sentinel context.
+
+        Args:
+            key: Key to be set
+            value: a deferred which will complete with a result to add to the cache
+            callback: An optional callback to be called when the entry is invalidated
+        """
+        if not isinstance(value, defer.Deferred):
+            raise TypeError("not a Deferred")
+
+        callbacks = [callback] if callback else []
+        self.check_thread()
+
+        existing_entry = self._pending_deferred_cache.pop(key, None)
+        if existing_entry:
+            existing_entry.invalidate()
+
+        # XXX: why don't we invalidate the entry in `self.cache` yet?
+
+        # we can save a whole load of effort if the deferred is ready.
+        if value.called:
+            result = value.result
+            if not isinstance(result, failure.Failure):
+                self.cache.set(key, result, callbacks)
+            return value
+
+        # otherwise, we'll add an entry to the _pending_deferred_cache for now,
+        # and add callbacks to add it to the cache properly later.
+
+        observable = ObservableDeferred(value, consumeErrors=True)
+        observer = observable.observe()
+        entry = CacheEntry(deferred=observable, callbacks=callbacks)
+
+        self._pending_deferred_cache[key] = entry
+
+        def compare_and_pop():
+            """Check if our entry is still the one in _pending_deferred_cache, and
+            if so, pop it.
+
+            Returns true if the entries matched.
+            """
+            existing_entry = self._pending_deferred_cache.pop(key, None)
+            if existing_entry is entry:
+                return True
+
+            # oops, the _pending_deferred_cache has been updated since
+            # we started our query, so we are out of date.
+            #
+            # Better put back whatever we took out. (We do it this way
+            # round, rather than peeking into the _pending_deferred_cache
+            # and then removing on a match, to make the common case faster)
+            if existing_entry is not None:
+                self._pending_deferred_cache[key] = existing_entry
+
+            return False
+
+        def cb(result):
+            if compare_and_pop():
+                self.cache.set(key, result, entry.callbacks)
+            else:
+                # we're not going to put this entry into the cache, so need
+                # to make sure that the invalidation callbacks are called.
+                # That was probably done when _pending_deferred_cache was
+                # updated, but it's possible that `set` was called without
+                # `invalidate` being previously called, in which case it may
+                # not have been. Either way, let's double-check now.
+                entry.invalidate()
+
+        def eb(_fail):
+            compare_and_pop()
+            entry.invalidate()
+
+        # once the deferred completes, we can move the entry from the
+        # _pending_deferred_cache to the real cache.
+        #
+        observer.addCallbacks(cb, eb)
+
+        # we return a new Deferred which will be called before any subsequent observers.
+        return observable.observe()
+
+    def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
+        callbacks = [callback] if callback else []
+        self.cache.set(key, value, callbacks=callbacks)
+
+    def invalidate(self, key):
+        self.check_thread()
+        self.cache.pop(key, None)
+
+        # if we have a pending lookup for this key, remove it from the
+        # _pending_deferred_cache, which will (a) stop it being returned
+        # for future queries and (b) stop it being persisted as a proper entry
+        # in self.cache.
+        entry = self._pending_deferred_cache.pop(key, None)
+
+        # run the invalidation callbacks now, rather than waiting for the
+        # deferred to resolve.
+        if entry:
+            entry.invalidate()
+
+    def invalidate_many(self, key: KT):
+        self.check_thread()
+        if not isinstance(key, tuple):
+            raise TypeError("The cache key must be a tuple not %r" % (type(key),))
+        key = cast(KT, key)
+        self.cache.del_multi(key)
+
+        # if we have a pending lookup for this key, remove it from the
+        # _pending_deferred_cache, as above
+        entry_dict = self._pending_deferred_cache.pop(key, None)
+        if entry_dict is not None:
+            for entry in iterate_tree_cache_entry(entry_dict):
+                entry.invalidate()
+
+    def invalidate_all(self):
+        self.check_thread()
+        self.cache.clear()
+        for entry in self._pending_deferred_cache.values():
+            entry.invalidate()
+        self._pending_deferred_cache.clear()
+
+
+class CacheEntry:
+    __slots__ = ["deferred", "callbacks", "invalidated"]
+
+    def __init__(
+        self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
+    ):
+        self.deferred = deferred
+        self.callbacks = set(callbacks)
+        self.invalidated = False
+
+    def invalidate(self):
+        if not self.invalidated:
+            self.invalidated = True
+            for callback in self.callbacks:
+                callback()
+            self.callbacks.clear()
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 98b34f2223..a924140cdf 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -13,25 +13,31 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import enum
 import functools
 import inspect
 import logging
-import threading
-from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
+from typing import (
+    Any,
+    Callable,
+    Generic,
+    Iterable,
+    Mapping,
+    Optional,
+    Sequence,
+    Tuple,
+    TypeVar,
+    Union,
+    cast,
+)
 from weakref import WeakValueDictionary
 
-from prometheus_client import Gauge
-
 from twisted.internet import defer
 
 from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.caches.deferred_cache import DeferredCache
 from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-
-from . import register_cache
 
 logger = logging.getLogger(__name__)
 
@@ -55,241 +61,8 @@ class _CachedFunction(Generic[F]):
     __call__ = None  # type: F
 
 
-cache_pending_metric = Gauge(
-    "synapse_util_caches_cache_pending",
-    "Number of lookups currently pending for this cache",
-    ["name"],
-)
-
-_CacheSentinel = object()
-
-
-class CacheEntry:
-    __slots__ = ["deferred", "callbacks", "invalidated"]
-
-    def __init__(self, deferred, callbacks):
-        self.deferred = deferred
-        self.callbacks = set(callbacks)
-        self.invalidated = False
-
-    def invalidate(self):
-        if not self.invalidated:
-            self.invalidated = True
-            for callback in self.callbacks:
-                callback()
-            self.callbacks.clear()
-
-
-class Cache:
-    __slots__ = (
-        "cache",
-        "name",
-        "keylen",
-        "thread",
-        "metrics",
-        "_pending_deferred_cache",
-    )
-
-    def __init__(
-        self,
-        name: str,
-        max_entries: int = 1000,
-        keylen: int = 1,
-        tree: bool = False,
-        iterable: bool = False,
-        apply_cache_factor_from_config: bool = True,
-    ):
-        """
-        Args:
-            name: The name of the cache
-            max_entries: Maximum amount of entries that the cache will hold
-            keylen: The length of the tuple used as the cache key
-            tree: Use a TreeCache instead of a dict as the underlying cache type
-            iterable: If True, count each item in the cached object as an entry,
-                rather than each cached object
-            apply_cache_factor_from_config: Whether cache factors specified in the
-                config file affect `max_entries`
-
-        Returns:
-            Cache
-        """
-        cache_type = TreeCache if tree else dict
-        self._pending_deferred_cache = cache_type()
-
-        self.cache = LruCache(
-            max_size=max_entries,
-            keylen=keylen,
-            cache_type=cache_type,
-            size_callback=(lambda d: len(d)) if iterable else None,
-            evicted_callback=self._on_evicted,
-            apply_cache_factor_from_config=apply_cache_factor_from_config,
-        )
-
-        self.name = name
-        self.keylen = keylen
-        self.thread = None  # type: Optional[threading.Thread]
-        self.metrics = register_cache(
-            "cache",
-            name,
-            self.cache,
-            collect_callback=self._metrics_collection_callback,
-        )
-
-    @property
-    def max_entries(self):
-        return self.cache.max_size
-
-    def _on_evicted(self, evicted_count):
-        self.metrics.inc_evictions(evicted_count)
-
-    def _metrics_collection_callback(self):
-        cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
-
-    def check_thread(self):
-        expected_thread = self.thread
-        if expected_thread is None:
-            self.thread = threading.current_thread()
-        else:
-            if expected_thread is not threading.current_thread():
-                raise ValueError(
-                    "Cache objects can only be accessed from the main thread"
-                )
-
-    def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
-        """Looks the key up in the caches.
-
-        Args:
-            key(tuple)
-            default: What is returned if key is not in the caches. If not
-                specified then function throws KeyError instead
-            callback(fn): Gets called when the entry in the cache is invalidated
-            update_metrics (bool): whether to update the cache hit rate metrics
-
-        Returns:
-            Either an ObservableDeferred or the raw result
-        """
-        callbacks = [callback] if callback else []
-        val = self._pending_deferred_cache.get(key, _CacheSentinel)
-        if val is not _CacheSentinel:
-            val.callbacks.update(callbacks)
-            if update_metrics:
-                self.metrics.inc_hits()
-            return val.deferred
-
-        val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
-        if val is not _CacheSentinel:
-            self.metrics.inc_hits()
-            return val
-
-        if update_metrics:
-            self.metrics.inc_misses()
-
-        if default is _CacheSentinel:
-            raise KeyError()
-        else:
-            return default
-
-    def set(self, key, value, callback=None):
-        if not isinstance(value, defer.Deferred):
-            raise TypeError("not a Deferred")
-
-        callbacks = [callback] if callback else []
-        self.check_thread()
-        observable = ObservableDeferred(value, consumeErrors=True)
-        observer = observable.observe()
-        entry = CacheEntry(deferred=observable, callbacks=callbacks)
-
-        existing_entry = self._pending_deferred_cache.pop(key, None)
-        if existing_entry:
-            existing_entry.invalidate()
-
-        self._pending_deferred_cache[key] = entry
-
-        def compare_and_pop():
-            """Check if our entry is still the one in _pending_deferred_cache, and
-            if so, pop it.
-
-            Returns true if the entries matched.
-            """
-            existing_entry = self._pending_deferred_cache.pop(key, None)
-            if existing_entry is entry:
-                return True
-
-            # oops, the _pending_deferred_cache has been updated since
-            # we started our query, so we are out of date.
-            #
-            # Better put back whatever we took out. (We do it this way
-            # round, rather than peeking into the _pending_deferred_cache
-            # and then removing on a match, to make the common case faster)
-            if existing_entry is not None:
-                self._pending_deferred_cache[key] = existing_entry
-
-            return False
-
-        def cb(result):
-            if compare_and_pop():
-                self.cache.set(key, result, entry.callbacks)
-            else:
-                # we're not going to put this entry into the cache, so need
-                # to make sure that the invalidation callbacks are called.
-                # That was probably done when _pending_deferred_cache was
-                # updated, but it's possible that `set` was called without
-                # `invalidate` being previously called, in which case it may
-                # not have been. Either way, let's double-check now.
-                entry.invalidate()
-
-        def eb(_fail):
-            compare_and_pop()
-            entry.invalidate()
-
-        # once the deferred completes, we can move the entry from the
-        # _pending_deferred_cache to the real cache.
-        #
-        observer.addCallbacks(cb, eb)
-        return observable
-
-    def prefill(self, key, value, callback=None):
-        callbacks = [callback] if callback else []
-        self.cache.set(key, value, callbacks=callbacks)
-
-    def invalidate(self, key):
-        self.check_thread()
-        self.cache.pop(key, None)
-
-        # if we have a pending lookup for this key, remove it from the
-        # _pending_deferred_cache, which will (a) stop it being returned
-        # for future queries and (b) stop it being persisted as a proper entry
-        # in self.cache.
-        entry = self._pending_deferred_cache.pop(key, None)
-
-        # run the invalidation callbacks now, rather than waiting for the
-        # deferred to resolve.
-        if entry:
-            entry.invalidate()
-
-    def invalidate_many(self, key):
-        self.check_thread()
-        if not isinstance(key, tuple):
-            raise TypeError("The cache key must be a tuple not %r" % (type(key),))
-        self.cache.del_multi(key)
-
-        # if we have a pending lookup for this key, remove it from the
-        # _pending_deferred_cache, as above
-        entry_dict = self._pending_deferred_cache.pop(key, None)
-        if entry_dict is not None:
-            for entry in iterate_tree_cache_entry(entry_dict):
-                entry.invalidate()
-
-    def invalidate_all(self):
-        self.check_thread()
-        self.cache.clear()
-        for entry in self._pending_deferred_cache.values():
-            entry.invalidate()
-        self._pending_deferred_cache.clear()
-
-
 class _CacheDescriptorBase:
-    def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
+    def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
         self.orig = orig
 
         arg_spec = inspect.getfullargspec(orig)
@@ -338,8 +111,107 @@ class _CacheDescriptorBase:
 
         self.add_cache_context = cache_context
 
+        self.cache_key_builder = get_cache_key_builder(
+            self.arg_names, self.arg_defaults
+        )
+
+
+class _LruCachedFunction(Generic[F]):
+    cache = None  # type: LruCache[CacheKey, Any]
+    __call__ = None  # type: F
+
+
+def lru_cache(
+    max_entries: int = 1000, cache_context: bool = False,
+) -> Callable[[F], _LruCachedFunction[F]]:
+    """A method decorator that applies a memoizing cache around the function.
+
+    This is more-or-less a drop-in equivalent to functools.lru_cache, although note
+    that the signature is slightly different.
+
+    The main differences with functools.lru_cache are:
+        (a) the size of the cache can be controlled via the cache_factor mechanism
+        (b) the wrapped function can request a "cache_context" which provides a
+            callback mechanism to indicate that the result is no longer valid
+        (c) prometheus metrics are exposed automatically.
+
+    The function should take zero or more arguments, which are used as the key for the
+    cache. Single-argument functions use that argument as the cache key; otherwise the
+    arguments are built into a tuple.
+
+    Cached functions can be "chained" (i.e. a cached function can call other cached
+    functions and get appropriately invalidated when they called caches are
+    invalidated) by adding a special "cache_context" argument to the function
+    and passing that as a kwarg to all caches called. For example:
+
+        @lru_cache(cache_context=True)
+        def foo(self, key, cache_context):
+            r1 = self.bar1(key, on_invalidate=cache_context.invalidate)
+            r2 = self.bar2(key, on_invalidate=cache_context.invalidate)
+            return r1 + r2
+
+    The wrapped function also has a 'cache' property which offers direct access to the
+    underlying LruCache.
+    """
+
+    def func(orig: F) -> _LruCachedFunction[F]:
+        desc = LruCacheDescriptor(
+            orig, max_entries=max_entries, cache_context=cache_context,
+        )
+        return cast(_LruCachedFunction[F], desc)
+
+    return func
+
+
+class LruCacheDescriptor(_CacheDescriptorBase):
+    """Helper for @lru_cache"""
 
-class CacheDescriptor(_CacheDescriptorBase):
+    class _Sentinel(enum.Enum):
+        sentinel = object()
+
+    def __init__(
+        self, orig, max_entries: int = 1000, cache_context: bool = False,
+    ):
+        super().__init__(orig, num_args=None, cache_context=cache_context)
+        self.max_entries = max_entries
+
+    def __get__(self, obj, owner):
+        cache = LruCache(
+            cache_name=self.orig.__name__, max_size=self.max_entries,
+        )  # type: LruCache[CacheKey, Any]
+
+        get_cache_key = self.cache_key_builder
+        sentinel = LruCacheDescriptor._Sentinel.sentinel
+
+        @functools.wraps(self.orig)
+        def _wrapped(*args, **kwargs):
+            invalidate_callback = kwargs.pop("on_invalidate", None)
+            callbacks = (invalidate_callback,) if invalidate_callback else ()
+
+            cache_key = get_cache_key(args, kwargs)
+
+            ret = cache.get(cache_key, default=sentinel, callbacks=callbacks)
+            if ret != sentinel:
+                return ret
+
+            # Add our own `cache_context` to argument list if the wrapped function
+            # has asked for one
+            if self.add_cache_context:
+                kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
+
+            ret2 = self.orig(obj, *args, **kwargs)
+            cache.set(cache_key, ret2, callbacks=callbacks)
+
+            return ret2
+
+        wrapped = cast(_CachedFunction, _wrapped)
+        wrapped.cache = cache
+        obj.__dict__[self.orig.__name__] = wrapped
+
+        return wrapped
+
+
+class DeferredCacheDescriptor(_CacheDescriptorBase):
     """ A method decorator that applies a memoizing cache around the function.
 
     This caches deferreds, rather than the results themselves. Deferreds that
@@ -382,7 +254,6 @@ class CacheDescriptor(_CacheDescriptorBase):
         cache_context=False,
         iterable=False,
     ):
-
         super().__init__(orig, num_args=num_args, cache_context=cache_context)
 
         self.max_entries = max_entries
@@ -390,49 +261,15 @@ class CacheDescriptor(_CacheDescriptorBase):
         self.iterable = iterable
 
     def __get__(self, obj, owner):
-        cache = Cache(
+        cache = DeferredCache(
             name=self.orig.__name__,
             max_entries=self.max_entries,
             keylen=self.num_args,
             tree=self.tree,
             iterable=self.iterable,
-        )
-
-        def get_cache_key_gen(args, kwargs):
-            """Given some args/kwargs return a generator that resolves into
-            the cache_key.
-
-            We loop through each arg name, looking up if its in the `kwargs`,
-            otherwise using the next argument in `args`. If there are no more
-            args then we try looking the arg name up in the defaults
-            """
-            pos = 0
-            for nm in self.arg_names:
-                if nm in kwargs:
-                    yield kwargs[nm]
-                elif pos < len(args):
-                    yield args[pos]
-                    pos += 1
-                else:
-                    yield self.arg_defaults[nm]
-
-        # By default our cache key is a tuple, but if there is only one item
-        # then don't bother wrapping in a tuple.  This is to save memory.
-        if self.num_args == 1:
-            nm = self.arg_names[0]
-
-            def get_cache_key(args, kwargs):
-                if nm in kwargs:
-                    return kwargs[nm]
-                elif len(args):
-                    return args[0]
-                else:
-                    return self.arg_defaults[nm]
-
-        else:
+        )  # type: DeferredCache[CacheKey, Any]
 
-            def get_cache_key(args, kwargs):
-                return tuple(get_cache_key_gen(args, kwargs))
+        get_cache_key = self.cache_key_builder
 
         @functools.wraps(self.orig)
         def _wrapped(*args, **kwargs):
@@ -442,32 +279,20 @@ class CacheDescriptor(_CacheDescriptorBase):
 
             cache_key = get_cache_key(args, kwargs)
 
-            # Add our own `cache_context` to argument list if the wrapped function
-            # has asked for one
-            if self.add_cache_context:
-                kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
-
             try:
-                cached_result_d = cache.get(cache_key, callback=invalidate_callback)
-
-                if isinstance(cached_result_d, ObservableDeferred):
-                    observer = cached_result_d.observe()
-                else:
-                    observer = defer.succeed(cached_result_d)
-
+                ret = cache.get(cache_key, callback=invalidate_callback)
             except KeyError:
-                ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
+                # Add our own `cache_context` to argument list if the wrapped function
+                # has asked for one
+                if self.add_cache_context:
+                    kwargs["cache_context"] = _CacheContext.get_instance(
+                        cache, cache_key
+                    )
 
-                def onErr(f):
-                    cache.invalidate(cache_key)
-                    return f
-
-                ret.addErrback(onErr)
-
-                result_d = cache.set(cache_key, ret, callback=invalidate_callback)
-                observer = result_d.observe()
+                ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
+                ret = cache.set(cache_key, ret, callback=invalidate_callback)
 
-            return make_deferred_yieldable(observer)
+            return make_deferred_yieldable(ret)
 
         wrapped = cast(_CachedFunction, _wrapped)
 
@@ -476,7 +301,6 @@ class CacheDescriptor(_CacheDescriptorBase):
             wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
         else:
             wrapped.invalidate = cache.invalidate
-            wrapped.invalidate_all = cache.invalidate_all
             wrapped.invalidate_many = cache.invalidate_many
             wrapped.prefill = cache.prefill
 
@@ -489,7 +313,7 @@ class CacheDescriptor(_CacheDescriptorBase):
         return wrapped
 
 
-class CacheListDescriptor(_CacheDescriptorBase):
+class DeferredCacheListDescriptor(_CacheDescriptorBase):
     """Wraps an existing cache to support bulk fetching of keys.
 
     Given a list of keys it looks in the cache to find any hits, then passes
@@ -526,7 +350,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj, objtype=None):
         cached_method = getattr(obj, self.cached_method_name)
-        cache = cached_method.cache
+        cache = cached_method.cache  # type: DeferredCache[CacheKey, Any]
         num_args = cached_method.num_args
 
         @functools.wraps(self.orig)
@@ -566,14 +390,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
             for arg in list_args:
                 try:
                     res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
-                    if not isinstance(res, ObservableDeferred):
-                        results[arg] = res
-                    elif not res.has_succeeded():
-                        res = res.observe()
+                    if not res.called:
                         res.addCallback(update_results_dict, arg)
                         cached_defers.append(res)
                     else:
-                        results[arg] = res.get_result()
+                        results[arg] = res.result
                 except KeyError:
                     missing.add(arg)
 
@@ -638,11 +459,13 @@ class _CacheContext:
     on a lower level.
     """
 
+    Cache = Union[DeferredCache, LruCache]
+
     _cache_context_objects = (
         WeakValueDictionary()
-    )  # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext]
+    )  # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext]
 
-    def __init__(self, cache, cache_key):  # type: (Cache, CacheKey) -> None
+    def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
         self._cache = cache
         self._cache_key = cache_key
 
@@ -651,7 +474,9 @@ class _CacheContext:
         self._cache.invalidate(self._cache_key)
 
     @classmethod
-    def get_instance(cls, cache, cache_key):  # type: (Cache, CacheKey) -> _CacheContext
+    def get_instance(
+        cls, cache: "_CacheContext.Cache", cache_key: CacheKey
+    ) -> "_CacheContext":
         """Returns an instance constructed with the given arguments.
 
         A new instance is only created if none already exists.
@@ -672,7 +497,7 @@ def cached(
     cache_context: bool = False,
     iterable: bool = False,
 ) -> Callable[[F], _CachedFunction[F]]:
-    func = lambda orig: CacheDescriptor(
+    func = lambda orig: DeferredCacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
@@ -714,7 +539,7 @@ def cachedList(
             def batch_do_something(self, first_arg, second_args):
                 ...
     """
-    func = lambda orig: CacheListDescriptor(
+    func = lambda orig: DeferredCacheListDescriptor(
         orig,
         cached_method_name=cached_method_name,
         list_name=list_name,
@@ -722,3 +547,65 @@ def cachedList(
     )
 
     return cast(Callable[[F], _CachedFunction[F]], func)
+
+
+def get_cache_key_builder(
+    param_names: Sequence[str], param_defaults: Mapping[str, Any]
+) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
+    """Construct a function which will build cache keys suitable for a cached function
+
+    Args:
+        param_names: list of formal parameter names for the cached function
+        param_defaults: a mapping from parameter name to default value for that param
+
+    Returns:
+        A function which will take an (args, kwargs) pair and return a cache key
+    """
+
+    # By default our cache key is a tuple, but if there is only one item
+    # then don't bother wrapping in a tuple.  This is to save memory.
+
+    if len(param_names) == 1:
+        nm = param_names[0]
+
+        def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
+            if nm in kwargs:
+                return kwargs[nm]
+            elif len(args):
+                return args[0]
+            else:
+                return param_defaults[nm]
+
+    else:
+
+        def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
+            return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
+
+    return get_cache_key
+
+
+def _get_cache_key_gen(
+    param_names: Iterable[str],
+    param_defaults: Mapping[str, Any],
+    args: Sequence[Any],
+    kwargs: Mapping[str, Any],
+) -> Iterable[Any]:
+    """Given some args/kwargs return a generator that resolves into
+    the cache_key.
+
+    This is essentially the same operation as `inspect.getcallargs`, but optimised so
+    that we don't need to inspect the target function for each call.
+    """
+
+    # We loop through each arg name, looking up if its in the `kwargs`,
+    # otherwise using the next argument in `args`. If there are no more
+    # args then we try looking the arg name up in the defaults.
+    pos = 0
+    for nm in param_names:
+        if nm in kwargs:
+            yield kwargs[nm]
+        elif pos < len(args):
+            yield args[pos]
+            pos += 1
+        else:
+            yield param_defaults[nm]
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 8592b93689..588d2d49f2 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -12,15 +12,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import enum
 import logging
 import threading
 from collections import namedtuple
+from typing import Any
 
 from synapse.util.caches.lrucache import LruCache
 
-from . import register_cache
-
 logger = logging.getLogger(__name__)
 
 
@@ -40,24 +39,25 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
         return len(self.value)
 
 
+class _Sentinel(enum.Enum):
+    # defining a sentinel in this way allows mypy to correctly handle the
+    # type of a dictionary lookup.
+    sentinel = object()
+
+
 class DictionaryCache:
     """Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
     fetching a subset of dictionary keys for a particular key.
     """
 
     def __init__(self, name, max_entries=1000):
-        self.cache = LruCache(max_size=max_entries, size_callback=len)
+        self.cache = LruCache(
+            max_size=max_entries, cache_name=name, size_callback=len
+        )  # type: LruCache[Any, DictionaryEntry]
 
         self.name = name
         self.sequence = 0
         self.thread = None
-        # caches_by_name[name] = self.cache
-
-        class Sentinel:
-            __slots__ = []
-
-        self.sentinel = Sentinel()
-        self.metrics = register_cache("dictionary", name, self.cache)
 
     def check_thread(self):
         expected_thread = self.thread
@@ -80,10 +80,8 @@ class DictionaryCache:
         Returns:
             DictionaryEntry
         """
-        entry = self.cache.get(key, self.sentinel)
-        if entry is not self.sentinel:
-            self.metrics.inc_hits()
-
+        entry = self.cache.get(key, _Sentinel.sentinel)
+        if entry is not _Sentinel.sentinel:
             if dict_keys is None:
                 return DictionaryEntry(
                     entry.full, entry.known_absent, dict(entry.value)
@@ -95,7 +93,6 @@ class DictionaryCache:
                     {k: entry.value[k] for k in dict_keys if k in entry.value},
                 )
 
-        self.metrics.inc_misses()
         return DictionaryEntry(False, set(), {})
 
     def invalidate(self, key):
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 4bc1a67b58..60bb6ff642 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -15,11 +15,35 @@
 
 import threading
 from functools import wraps
-from typing import Callable, Optional, Type, Union
+from typing import (
+    Any,
+    Callable,
+    Generic,
+    Iterable,
+    Optional,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+    overload,
+)
+
+from typing_extensions import Literal
 
 from synapse.config import cache as cache_config
+from synapse.util.caches import CacheMetric, register_cache
 from synapse.util.caches.treecache import TreeCache
 
+# Function type: the type used for invalidation callbacks
+FT = TypeVar("FT", bound=Callable[..., Any])
+
+# Key and Value type for the cache
+KT = TypeVar("KT")
+VT = TypeVar("VT")
+
+# a general type var, distinct from either KT or VT
+T = TypeVar("T")
+
 
 def enumerate_leaves(node, depth):
     if depth == 0:
@@ -41,30 +65,33 @@ class _Node:
         self.callbacks = callbacks
 
 
-class LruCache:
+class LruCache(Generic[KT, VT]):
     """
-    Least-recently-used cache.
+    Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
+
     Supports del_multi only if cache_type=TreeCache
     If cache_type=TreeCache, all keys must be tuples.
-
-    Can also set callbacks on objects when getting/setting which are fired
-    when that key gets invalidated/evicted.
     """
 
     def __init__(
         self,
         max_size: int,
+        cache_name: Optional[str] = None,
         keylen: int = 1,
         cache_type: Type[Union[dict, TreeCache]] = dict,
         size_callback: Optional[Callable] = None,
-        evicted_callback: Optional[Callable] = None,
+        metrics_collection_callback: Optional[Callable[[], None]] = None,
         apply_cache_factor_from_config: bool = True,
     ):
         """
         Args:
             max_size: The maximum amount of entries the cache can hold
 
-            keylen: The length of the tuple used as the cache key
+            cache_name: The name of this cache, for the prometheus metrics. If unset,
+                no metrics will be reported on this cache.
+
+            keylen: The length of the tuple used as the cache key. Ignored unless
+                cache_type is `TreeCache`.
 
             cache_type (type):
                 type of underlying cache to be used. Typically one of dict
@@ -72,9 +99,13 @@ class LruCache:
 
             size_callback (func(V) -> int | None):
 
-            evicted_callback (func(int)|None):
-                if not None, called on eviction with the size of the evicted
-                entry
+            metrics_collection_callback:
+                metrics collection callback. This is called early in the metrics
+                collection process, before any of the metrics registered with the
+                prometheus Registry are collected, so can be used to update any dynamic
+                metrics.
+
+                Ignored if cache_name is None.
 
             apply_cache_factor_from_config (bool): If true, `max_size` will be
                 multiplied by a cache factor derived from the homeserver config
@@ -93,6 +124,23 @@ class LruCache:
         else:
             self.max_size = int(max_size)
 
+        # register_cache might call our "set_cache_factor" callback; there's nothing to
+        # do yet when we get resized.
+        self._on_resize = None  # type: Optional[Callable[[],None]]
+
+        if cache_name is not None:
+            metrics = register_cache(
+                "lru_cache",
+                cache_name,
+                self,
+                collect_callback=metrics_collection_callback,
+            )  # type: Optional[CacheMetric]
+        else:
+            metrics = None
+
+        # this is exposed for access from outside this class
+        self.metrics = metrics
+
         list_root = _Node(None, None, None, None)
         list_root.next_node = list_root
         list_root.prev_node = list_root
@@ -104,16 +152,16 @@ class LruCache:
                 todelete = list_root.prev_node
                 evicted_len = delete_node(todelete)
                 cache.pop(todelete.key, None)
-                if evicted_callback:
-                    evicted_callback(evicted_len)
+                if metrics:
+                    metrics.inc_evictions(evicted_len)
 
-        def synchronized(f):
+        def synchronized(f: FT) -> FT:
             @wraps(f)
             def inner(*args, **kwargs):
                 with lock:
                     return f(*args, **kwargs)
 
-            return inner
+            return cast(FT, inner)
 
         cached_cache_len = [0]
         if size_callback is not None:
@@ -167,18 +215,45 @@ class LruCache:
             node.callbacks.clear()
             return deleted_len
 
+        @overload
+        def cache_get(
+            key: KT,
+            default: Literal[None] = None,
+            callbacks: Iterable[Callable[[], None]] = ...,
+            update_metrics: bool = ...,
+        ) -> Optional[VT]:
+            ...
+
+        @overload
+        def cache_get(
+            key: KT,
+            default: T,
+            callbacks: Iterable[Callable[[], None]] = ...,
+            update_metrics: bool = ...,
+        ) -> Union[T, VT]:
+            ...
+
         @synchronized
-        def cache_get(key, default=None, callbacks=[]):
+        def cache_get(
+            key: KT,
+            default: Optional[T] = None,
+            callbacks: Iterable[Callable[[], None]] = [],
+            update_metrics: bool = True,
+        ):
             node = cache.get(key, None)
             if node is not None:
                 move_node_to_front(node)
                 node.callbacks.update(callbacks)
+                if update_metrics and metrics:
+                    metrics.inc_hits()
                 return node.value
             else:
+                if update_metrics and metrics:
+                    metrics.inc_misses()
                 return default
 
         @synchronized
-        def cache_set(key, value, callbacks=[]):
+        def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
             node = cache.get(key, None)
             if node is not None:
                 # We sometimes store large objects, e.g. dicts, which cause
@@ -207,7 +282,7 @@ class LruCache:
             evict()
 
         @synchronized
-        def cache_set_default(key, value):
+        def cache_set_default(key: KT, value: VT) -> VT:
             node = cache.get(key, None)
             if node is not None:
                 return node.value
@@ -216,8 +291,16 @@ class LruCache:
                 evict()
                 return value
 
+        @overload
+        def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]:
+            ...
+
+        @overload
+        def cache_pop(key: KT, default: T) -> Union[T, VT]:
+            ...
+
         @synchronized
-        def cache_pop(key, default=None):
+        def cache_pop(key: KT, default: Optional[T] = None):
             node = cache.get(key, None)
             if node:
                 delete_node(node)
@@ -227,18 +310,18 @@ class LruCache:
                 return default
 
         @synchronized
-        def cache_del_multi(key):
+        def cache_del_multi(key: KT) -> None:
             """
             This will only work if constructed with cache_type=TreeCache
             """
             popped = cache.pop(key)
             if popped is None:
                 return
-            for leaf in enumerate_leaves(popped, keylen - len(key)):
+            for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
                 delete_node(leaf)
 
         @synchronized
-        def cache_clear():
+        def cache_clear() -> None:
             list_root.next_node = list_root
             list_root.prev_node = list_root
             for node in cache.values():
@@ -249,15 +332,21 @@ class LruCache:
                 cached_cache_len[0] = 0
 
         @synchronized
-        def cache_contains(key):
+        def cache_contains(key: KT) -> bool:
             return key in cache
 
         self.sentinel = object()
+
+        # make sure that we clear out any excess entries after we get resized.
         self._on_resize = evict
+
         self.get = cache_get
         self.set = cache_set
         self.setdefault = cache_set_default
         self.pop = cache_pop
+        # `invalidate` is exposed for consistency with DeferredCache, so that it can be
+        # invalidated by the cache invalidation replication stream.
+        self.invalidate = cache_pop
         if cache_type is TreeCache:
             self.del_multi = cache_del_multi
         self.len = synchronized(cache_len)
@@ -301,6 +390,7 @@ class LruCache:
         new_size = int(self._original_max_size * factor)
         if new_size != self.max_size:
             self.max_size = new_size
-            self._on_resize()
+            if self._on_resize:
+                self._on_resize()
             return True
         return False
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index df1a721add..32228f42ee 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
 
 from twisted.internet import defer
 
@@ -20,10 +21,15 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches import register_cache
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
+T = TypeVar("T")
+
 
-class ResponseCache:
+class ResponseCache(Generic[T]):
     """
     This caches a deferred response. Until the deferred completes it will be
     returned from the cache. This means that if the client retries the request
@@ -31,8 +37,9 @@ class ResponseCache:
     used rather than trying to compute a new response.
     """
 
-    def __init__(self, hs, name, timeout_ms=0):
-        self.pending_result_cache = {}  # Requests that haven't finished yet.
+    def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
+        # Requests that haven't finished yet.
+        self.pending_result_cache = {}  # type: Dict[T, ObservableDeferred]
 
         self.clock = hs.get_clock()
         self.timeout_sec = timeout_ms / 1000.0
@@ -40,13 +47,13 @@ class ResponseCache:
         self._name = name
         self._metrics = register_cache("response_cache", name, self, resizable=False)
 
-    def size(self):
+    def size(self) -> int:
         return len(self.pending_result_cache)
 
-    def __len__(self):
+    def __len__(self) -> int:
         return self.size()
 
-    def get(self, key):
+    def get(self, key: T) -> Optional[defer.Deferred]:
         """Look up the given key.
 
         Can return either a new Deferred (which also doesn't follow the synapse
@@ -58,12 +65,11 @@ class ResponseCache:
         from an absent cache entry.
 
         Args:
-            key (hashable):
+            key: key to get/set in the cache
 
         Returns:
-            twisted.internet.defer.Deferred|None|E: None if there is no entry
-            for this key; otherwise either a deferred result or the result
-            itself.
+            None if there is no entry for this key; otherwise a deferred which
+            resolves to the result.
         """
         result = self.pending_result_cache.get(key)
         if result is not None:
@@ -73,7 +79,7 @@ class ResponseCache:
             self._metrics.inc_misses()
             return None
 
-    def set(self, key, deferred):
+    def set(self, key: T, deferred: defer.Deferred) -> defer.Deferred:
         """Set the entry for the given key to the given deferred.
 
         *deferred* should run its callbacks in the sentinel logcontext (ie,
@@ -85,12 +91,11 @@ class ResponseCache:
         result. You will probably want to make_deferred_yieldable the result.
 
         Args:
-            key (hashable):
-            deferred (twisted.internet.defer.Deferred[T):
+            key: key to get/set in the cache
+            deferred: The deferred which resolves to the result.
 
         Returns:
-            twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual
-                result.
+            A new deferred which resolves to the actual result.
         """
         result = ObservableDeferred(deferred, consumeErrors=True)
         self.pending_result_cache[key] = result
@@ -107,7 +112,9 @@ class ResponseCache:
         result.addBoth(remove)
         return result.observe()
 
-    def wrap(self, key, callback, *args, **kwargs):
+    def wrap(
+        self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any
+    ) -> defer.Deferred:
         """Wrap together a *get* and *set* call, taking care of logcontexts
 
         First looks up the key in the cache, and if it is present makes it
@@ -118,21 +125,20 @@ class ResponseCache:
 
         Example usage:
 
-            @defer.inlineCallbacks
-            def handle_request(request):
+            async def handle_request(request):
                 # etc
                 return result
 
-            result = yield response_cache.wrap(
+            result = await response_cache.wrap(
                 key,
                 handle_request,
                 request,
             )
 
         Args:
-            key (hashable): key to get/set in the cache
+            key: key to get/set in the cache
 
-            callback (callable): function to call if the key is not found in
+            callback: function to call if the key is not found in
                 the cache
 
             *args: positional parameters to pass to the callback, if it is used
@@ -140,7 +146,7 @@ class ResponseCache:
             **kwargs: named parameters to pass to the callback, if it is used
 
         Returns:
-            twisted.internet.defer.Deferred: yieldable result
+            Deferred which resolves to the result
         """
         result = self.get(key)
         if not result:
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 3e180cafd3..6ce2a3d12b 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -34,7 +34,7 @@ class TTLCache:
         self._data = {}
 
         # the _CacheEntries, sorted by expiry time
-        self._expiry_list = SortedList()
+        self._expiry_list = SortedList()  # type: SortedList[_CacheEntry]
 
         self._timer = timer