diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 8fc05be278..89f0b38535 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -16,7 +16,7 @@
import logging
from sys import intern
-from typing import Callable, Dict, Optional
+from typing import Callable, Dict, Optional, Sized
import attr
from prometheus_client.core import Gauge
@@ -92,7 +92,7 @@ class CacheMetric:
def register_cache(
cache_type: str,
cache_name: str,
- cache,
+ cache: Sized,
collect_callback: Optional[Callable] = None,
resizable: bool = True,
resize_callback: Optional[Callable] = None,
@@ -100,12 +100,15 @@ def register_cache(
"""Register a cache object for metric collection and resizing.
Args:
- cache_type
+ cache_type: a string indicating the "type" of the cache. This is used
+ only for deduplication so isn't too important provided it's constant.
cache_name: name of the cache
- cache: cache itself
+ cache: cache itself, which must implement __len__(), and may optionally implement
+ a max_size property
collect_callback: If given, a function which is called during metric
collection to update additional metrics.
- resizable: Whether this cache supports being resized.
+ resizable: Whether this cache supports being resized, in which case either
+ resize_callback must be provided, or the cache must support set_max_size().
resize_callback: A function which can be called to resize the cache.
Returns:
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
new file mode 100644
index 0000000000..601305487c
--- /dev/null
+++ b/synapse/util/caches/deferred_cache.py
@@ -0,0 +1,342 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import enum
+import threading
+from typing import (
+ Callable,
+ Generic,
+ Iterable,
+ MutableMapping,
+ Optional,
+ TypeVar,
+ Union,
+ cast,
+)
+
+from prometheus_client import Gauge
+
+from twisted.internet import defer
+from twisted.python import failure
+
+from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+
+cache_pending_metric = Gauge(
+ "synapse_util_caches_cache_pending",
+ "Number of lookups currently pending for this cache",
+ ["name"],
+)
+
+T = TypeVar("T")
+KT = TypeVar("KT")
+VT = TypeVar("VT")
+
+
+class _Sentinel(enum.Enum):
+ # defining a sentinel in this way allows mypy to correctly handle the
+ # type of a dictionary lookup.
+ sentinel = object()
+
+
+class DeferredCache(Generic[KT, VT]):
+ """Wraps an LruCache, adding support for Deferred results.
+
+ It expects that each entry added with set() will be a Deferred; likewise get()
+ will return a Deferred.
+ """
+
+ __slots__ = (
+ "cache",
+ "thread",
+ "_pending_deferred_cache",
+ )
+
+ def __init__(
+ self,
+ name: str,
+ max_entries: int = 1000,
+ keylen: int = 1,
+ tree: bool = False,
+ iterable: bool = False,
+ apply_cache_factor_from_config: bool = True,
+ ):
+ """
+ Args:
+ name: The name of the cache
+ max_entries: Maximum amount of entries that the cache will hold
+ keylen: The length of the tuple used as the cache key. Ignored unless
+ `tree` is True.
+ tree: Use a TreeCache instead of a dict as the underlying cache type
+ iterable: If True, count each item in the cached object as an entry,
+ rather than each cached object
+ apply_cache_factor_from_config: Whether cache factors specified in the
+ config file affect `max_entries`
+ """
+ cache_type = TreeCache if tree else dict
+
+ # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
+ self._pending_deferred_cache = (
+ cache_type()
+ ) # type: MutableMapping[KT, CacheEntry]
+
+ def metrics_cb():
+ cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
+
+ # cache is used for completed results and maps to the result itself, rather than
+ # a Deferred.
+ self.cache = LruCache(
+ max_size=max_entries,
+ keylen=keylen,
+ cache_name=name,
+ cache_type=cache_type,
+ size_callback=(lambda d: len(d)) if iterable else None,
+ metrics_collection_callback=metrics_cb,
+ apply_cache_factor_from_config=apply_cache_factor_from_config,
+ ) # type: LruCache[KT, VT]
+
+ self.thread = None # type: Optional[threading.Thread]
+
+ @property
+ def max_entries(self):
+ return self.cache.max_size
+
+ def check_thread(self):
+ expected_thread = self.thread
+ if expected_thread is None:
+ self.thread = threading.current_thread()
+ else:
+ if expected_thread is not threading.current_thread():
+ raise ValueError(
+ "Cache objects can only be accessed from the main thread"
+ )
+
+ def get(
+ self,
+ key: KT,
+ callback: Optional[Callable[[], None]] = None,
+ update_metrics: bool = True,
+ ) -> defer.Deferred:
+ """Looks the key up in the caches.
+
+ For symmetry with set(), this method does *not* follow the synapse logcontext
+ rules: the logcontext will not be cleared on return, and the Deferred will run
+ its callbacks in the sentinel context. In other words: wrap the result with
+ make_deferred_yieldable() before `await`ing it.
+
+ Args:
+ key:
+ callback: Gets called when the entry in the cache is invalidated
+ update_metrics (bool): whether to update the cache hit rate metrics
+
+ Returns:
+ A Deferred which completes with the result. Note that this may later fail
+ if there is an ongoing set() operation which later completes with a failure.
+
+ Raises:
+ KeyError if the key is not found in the cache
+ """
+ callbacks = [callback] if callback else []
+ val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
+ if val is not _Sentinel.sentinel:
+ val.callbacks.update(callbacks)
+ if update_metrics:
+ m = self.cache.metrics
+ assert m # we always have a name, so should always have metrics
+ m.inc_hits()
+ return val.deferred.observe()
+
+ val2 = self.cache.get(
+ key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
+ )
+ if val2 is _Sentinel.sentinel:
+ raise KeyError()
+ else:
+ return defer.succeed(val2)
+
+ def get_immediate(
+ self, key: KT, default: T, update_metrics: bool = True
+ ) -> Union[VT, T]:
+ """If we have a *completed* cached value, return it."""
+ return self.cache.get(key, default, update_metrics=update_metrics)
+
+ def set(
+ self,
+ key: KT,
+ value: defer.Deferred,
+ callback: Optional[Callable[[], None]] = None,
+ ) -> defer.Deferred:
+ """Adds a new entry to the cache (or updates an existing one).
+
+ The given `value` *must* be a Deferred.
+
+ First any existing entry for the same key is invalidated. Then a new entry
+ is added to the cache for the given key.
+
+ Until the `value` completes, calls to `get()` for the key will also result in an
+ incomplete Deferred, which will ultimately complete with the same result as
+ `value`.
+
+ If `value` completes successfully, subsequent calls to `get()` will then return
+ a completed deferred with the same result. If it *fails*, the cache is
+ invalidated and subequent calls to `get()` will raise a KeyError.
+
+ If another call to `set()` happens before `value` completes, then (a) any
+ invalidation callbacks registered in the interim will be called, (b) any
+ `get()`s in the interim will continue to complete with the result from the
+ *original* `value`, (c) any future calls to `get()` will complete with the
+ result from the *new* `value`.
+
+ It is expected that `value` does *not* follow the synapse logcontext rules - ie,
+ if it is incomplete, it runs its callbacks in the sentinel context.
+
+ Args:
+ key: Key to be set
+ value: a deferred which will complete with a result to add to the cache
+ callback: An optional callback to be called when the entry is invalidated
+ """
+ if not isinstance(value, defer.Deferred):
+ raise TypeError("not a Deferred")
+
+ callbacks = [callback] if callback else []
+ self.check_thread()
+
+ existing_entry = self._pending_deferred_cache.pop(key, None)
+ if existing_entry:
+ existing_entry.invalidate()
+
+ # XXX: why don't we invalidate the entry in `self.cache` yet?
+
+ # we can save a whole load of effort if the deferred is ready.
+ if value.called:
+ result = value.result
+ if not isinstance(result, failure.Failure):
+ self.cache.set(key, result, callbacks)
+ return value
+
+ # otherwise, we'll add an entry to the _pending_deferred_cache for now,
+ # and add callbacks to add it to the cache properly later.
+
+ observable = ObservableDeferred(value, consumeErrors=True)
+ observer = observable.observe()
+ entry = CacheEntry(deferred=observable, callbacks=callbacks)
+
+ self._pending_deferred_cache[key] = entry
+
+ def compare_and_pop():
+ """Check if our entry is still the one in _pending_deferred_cache, and
+ if so, pop it.
+
+ Returns true if the entries matched.
+ """
+ existing_entry = self._pending_deferred_cache.pop(key, None)
+ if existing_entry is entry:
+ return True
+
+ # oops, the _pending_deferred_cache has been updated since
+ # we started our query, so we are out of date.
+ #
+ # Better put back whatever we took out. (We do it this way
+ # round, rather than peeking into the _pending_deferred_cache
+ # and then removing on a match, to make the common case faster)
+ if existing_entry is not None:
+ self._pending_deferred_cache[key] = existing_entry
+
+ return False
+
+ def cb(result):
+ if compare_and_pop():
+ self.cache.set(key, result, entry.callbacks)
+ else:
+ # we're not going to put this entry into the cache, so need
+ # to make sure that the invalidation callbacks are called.
+ # That was probably done when _pending_deferred_cache was
+ # updated, but it's possible that `set` was called without
+ # `invalidate` being previously called, in which case it may
+ # not have been. Either way, let's double-check now.
+ entry.invalidate()
+
+ def eb(_fail):
+ compare_and_pop()
+ entry.invalidate()
+
+ # once the deferred completes, we can move the entry from the
+ # _pending_deferred_cache to the real cache.
+ #
+ observer.addCallbacks(cb, eb)
+
+ # we return a new Deferred which will be called before any subsequent observers.
+ return observable.observe()
+
+ def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
+ callbacks = [callback] if callback else []
+ self.cache.set(key, value, callbacks=callbacks)
+
+ def invalidate(self, key):
+ self.check_thread()
+ self.cache.pop(key, None)
+
+ # if we have a pending lookup for this key, remove it from the
+ # _pending_deferred_cache, which will (a) stop it being returned
+ # for future queries and (b) stop it being persisted as a proper entry
+ # in self.cache.
+ entry = self._pending_deferred_cache.pop(key, None)
+
+ # run the invalidation callbacks now, rather than waiting for the
+ # deferred to resolve.
+ if entry:
+ entry.invalidate()
+
+ def invalidate_many(self, key: KT):
+ self.check_thread()
+ if not isinstance(key, tuple):
+ raise TypeError("The cache key must be a tuple not %r" % (type(key),))
+ key = cast(KT, key)
+ self.cache.del_multi(key)
+
+ # if we have a pending lookup for this key, remove it from the
+ # _pending_deferred_cache, as above
+ entry_dict = self._pending_deferred_cache.pop(key, None)
+ if entry_dict is not None:
+ for entry in iterate_tree_cache_entry(entry_dict):
+ entry.invalidate()
+
+ def invalidate_all(self):
+ self.check_thread()
+ self.cache.clear()
+ for entry in self._pending_deferred_cache.values():
+ entry.invalidate()
+ self._pending_deferred_cache.clear()
+
+
+class CacheEntry:
+ __slots__ = ["deferred", "callbacks", "invalidated"]
+
+ def __init__(
+ self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
+ ):
+ self.deferred = deferred
+ self.callbacks = set(callbacks)
+ self.invalidated = False
+
+ def invalidate(self):
+ if not self.invalidated:
+ self.invalidated = True
+ for callback in self.callbacks:
+ callback()
+ self.callbacks.clear()
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 98b34f2223..a924140cdf 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -13,25 +13,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import enum
import functools
import inspect
import logging
-import threading
-from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
+from typing import (
+ Any,
+ Callable,
+ Generic,
+ Iterable,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ TypeVar,
+ Union,
+ cast,
+)
from weakref import WeakValueDictionary
-from prometheus_client import Gauge
-
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-
-from . import register_cache
logger = logging.getLogger(__name__)
@@ -55,241 +61,8 @@ class _CachedFunction(Generic[F]):
__call__ = None # type: F
-cache_pending_metric = Gauge(
- "synapse_util_caches_cache_pending",
- "Number of lookups currently pending for this cache",
- ["name"],
-)
-
-_CacheSentinel = object()
-
-
-class CacheEntry:
- __slots__ = ["deferred", "callbacks", "invalidated"]
-
- def __init__(self, deferred, callbacks):
- self.deferred = deferred
- self.callbacks = set(callbacks)
- self.invalidated = False
-
- def invalidate(self):
- if not self.invalidated:
- self.invalidated = True
- for callback in self.callbacks:
- callback()
- self.callbacks.clear()
-
-
-class Cache:
- __slots__ = (
- "cache",
- "name",
- "keylen",
- "thread",
- "metrics",
- "_pending_deferred_cache",
- )
-
- def __init__(
- self,
- name: str,
- max_entries: int = 1000,
- keylen: int = 1,
- tree: bool = False,
- iterable: bool = False,
- apply_cache_factor_from_config: bool = True,
- ):
- """
- Args:
- name: The name of the cache
- max_entries: Maximum amount of entries that the cache will hold
- keylen: The length of the tuple used as the cache key
- tree: Use a TreeCache instead of a dict as the underlying cache type
- iterable: If True, count each item in the cached object as an entry,
- rather than each cached object
- apply_cache_factor_from_config: Whether cache factors specified in the
- config file affect `max_entries`
-
- Returns:
- Cache
- """
- cache_type = TreeCache if tree else dict
- self._pending_deferred_cache = cache_type()
-
- self.cache = LruCache(
- max_size=max_entries,
- keylen=keylen,
- cache_type=cache_type,
- size_callback=(lambda d: len(d)) if iterable else None,
- evicted_callback=self._on_evicted,
- apply_cache_factor_from_config=apply_cache_factor_from_config,
- )
-
- self.name = name
- self.keylen = keylen
- self.thread = None # type: Optional[threading.Thread]
- self.metrics = register_cache(
- "cache",
- name,
- self.cache,
- collect_callback=self._metrics_collection_callback,
- )
-
- @property
- def max_entries(self):
- return self.cache.max_size
-
- def _on_evicted(self, evicted_count):
- self.metrics.inc_evictions(evicted_count)
-
- def _metrics_collection_callback(self):
- cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
-
- def check_thread(self):
- expected_thread = self.thread
- if expected_thread is None:
- self.thread = threading.current_thread()
- else:
- if expected_thread is not threading.current_thread():
- raise ValueError(
- "Cache objects can only be accessed from the main thread"
- )
-
- def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
- """Looks the key up in the caches.
-
- Args:
- key(tuple)
- default: What is returned if key is not in the caches. If not
- specified then function throws KeyError instead
- callback(fn): Gets called when the entry in the cache is invalidated
- update_metrics (bool): whether to update the cache hit rate metrics
-
- Returns:
- Either an ObservableDeferred or the raw result
- """
- callbacks = [callback] if callback else []
- val = self._pending_deferred_cache.get(key, _CacheSentinel)
- if val is not _CacheSentinel:
- val.callbacks.update(callbacks)
- if update_metrics:
- self.metrics.inc_hits()
- return val.deferred
-
- val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
- if val is not _CacheSentinel:
- self.metrics.inc_hits()
- return val
-
- if update_metrics:
- self.metrics.inc_misses()
-
- if default is _CacheSentinel:
- raise KeyError()
- else:
- return default
-
- def set(self, key, value, callback=None):
- if not isinstance(value, defer.Deferred):
- raise TypeError("not a Deferred")
-
- callbacks = [callback] if callback else []
- self.check_thread()
- observable = ObservableDeferred(value, consumeErrors=True)
- observer = observable.observe()
- entry = CacheEntry(deferred=observable, callbacks=callbacks)
-
- existing_entry = self._pending_deferred_cache.pop(key, None)
- if existing_entry:
- existing_entry.invalidate()
-
- self._pending_deferred_cache[key] = entry
-
- def compare_and_pop():
- """Check if our entry is still the one in _pending_deferred_cache, and
- if so, pop it.
-
- Returns true if the entries matched.
- """
- existing_entry = self._pending_deferred_cache.pop(key, None)
- if existing_entry is entry:
- return True
-
- # oops, the _pending_deferred_cache has been updated since
- # we started our query, so we are out of date.
- #
- # Better put back whatever we took out. (We do it this way
- # round, rather than peeking into the _pending_deferred_cache
- # and then removing on a match, to make the common case faster)
- if existing_entry is not None:
- self._pending_deferred_cache[key] = existing_entry
-
- return False
-
- def cb(result):
- if compare_and_pop():
- self.cache.set(key, result, entry.callbacks)
- else:
- # we're not going to put this entry into the cache, so need
- # to make sure that the invalidation callbacks are called.
- # That was probably done when _pending_deferred_cache was
- # updated, but it's possible that `set` was called without
- # `invalidate` being previously called, in which case it may
- # not have been. Either way, let's double-check now.
- entry.invalidate()
-
- def eb(_fail):
- compare_and_pop()
- entry.invalidate()
-
- # once the deferred completes, we can move the entry from the
- # _pending_deferred_cache to the real cache.
- #
- observer.addCallbacks(cb, eb)
- return observable
-
- def prefill(self, key, value, callback=None):
- callbacks = [callback] if callback else []
- self.cache.set(key, value, callbacks=callbacks)
-
- def invalidate(self, key):
- self.check_thread()
- self.cache.pop(key, None)
-
- # if we have a pending lookup for this key, remove it from the
- # _pending_deferred_cache, which will (a) stop it being returned
- # for future queries and (b) stop it being persisted as a proper entry
- # in self.cache.
- entry = self._pending_deferred_cache.pop(key, None)
-
- # run the invalidation callbacks now, rather than waiting for the
- # deferred to resolve.
- if entry:
- entry.invalidate()
-
- def invalidate_many(self, key):
- self.check_thread()
- if not isinstance(key, tuple):
- raise TypeError("The cache key must be a tuple not %r" % (type(key),))
- self.cache.del_multi(key)
-
- # if we have a pending lookup for this key, remove it from the
- # _pending_deferred_cache, as above
- entry_dict = self._pending_deferred_cache.pop(key, None)
- if entry_dict is not None:
- for entry in iterate_tree_cache_entry(entry_dict):
- entry.invalidate()
-
- def invalidate_all(self):
- self.check_thread()
- self.cache.clear()
- for entry in self._pending_deferred_cache.values():
- entry.invalidate()
- self._pending_deferred_cache.clear()
-
-
class _CacheDescriptorBase:
- def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
+ def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
self.orig = orig
arg_spec = inspect.getfullargspec(orig)
@@ -338,8 +111,107 @@ class _CacheDescriptorBase:
self.add_cache_context = cache_context
+ self.cache_key_builder = get_cache_key_builder(
+ self.arg_names, self.arg_defaults
+ )
+
+
+class _LruCachedFunction(Generic[F]):
+ cache = None # type: LruCache[CacheKey, Any]
+ __call__ = None # type: F
+
+
+def lru_cache(
+ max_entries: int = 1000, cache_context: bool = False,
+) -> Callable[[F], _LruCachedFunction[F]]:
+ """A method decorator that applies a memoizing cache around the function.
+
+ This is more-or-less a drop-in equivalent to functools.lru_cache, although note
+ that the signature is slightly different.
+
+ The main differences with functools.lru_cache are:
+ (a) the size of the cache can be controlled via the cache_factor mechanism
+ (b) the wrapped function can request a "cache_context" which provides a
+ callback mechanism to indicate that the result is no longer valid
+ (c) prometheus metrics are exposed automatically.
+
+ The function should take zero or more arguments, which are used as the key for the
+ cache. Single-argument functions use that argument as the cache key; otherwise the
+ arguments are built into a tuple.
+
+ Cached functions can be "chained" (i.e. a cached function can call other cached
+ functions and get appropriately invalidated when they called caches are
+ invalidated) by adding a special "cache_context" argument to the function
+ and passing that as a kwarg to all caches called. For example:
+
+ @lru_cache(cache_context=True)
+ def foo(self, key, cache_context):
+ r1 = self.bar1(key, on_invalidate=cache_context.invalidate)
+ r2 = self.bar2(key, on_invalidate=cache_context.invalidate)
+ return r1 + r2
+
+ The wrapped function also has a 'cache' property which offers direct access to the
+ underlying LruCache.
+ """
+
+ def func(orig: F) -> _LruCachedFunction[F]:
+ desc = LruCacheDescriptor(
+ orig, max_entries=max_entries, cache_context=cache_context,
+ )
+ return cast(_LruCachedFunction[F], desc)
+
+ return func
+
+
+class LruCacheDescriptor(_CacheDescriptorBase):
+ """Helper for @lru_cache"""
-class CacheDescriptor(_CacheDescriptorBase):
+ class _Sentinel(enum.Enum):
+ sentinel = object()
+
+ def __init__(
+ self, orig, max_entries: int = 1000, cache_context: bool = False,
+ ):
+ super().__init__(orig, num_args=None, cache_context=cache_context)
+ self.max_entries = max_entries
+
+ def __get__(self, obj, owner):
+ cache = LruCache(
+ cache_name=self.orig.__name__, max_size=self.max_entries,
+ ) # type: LruCache[CacheKey, Any]
+
+ get_cache_key = self.cache_key_builder
+ sentinel = LruCacheDescriptor._Sentinel.sentinel
+
+ @functools.wraps(self.orig)
+ def _wrapped(*args, **kwargs):
+ invalidate_callback = kwargs.pop("on_invalidate", None)
+ callbacks = (invalidate_callback,) if invalidate_callback else ()
+
+ cache_key = get_cache_key(args, kwargs)
+
+ ret = cache.get(cache_key, default=sentinel, callbacks=callbacks)
+ if ret != sentinel:
+ return ret
+
+ # Add our own `cache_context` to argument list if the wrapped function
+ # has asked for one
+ if self.add_cache_context:
+ kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
+
+ ret2 = self.orig(obj, *args, **kwargs)
+ cache.set(cache_key, ret2, callbacks=callbacks)
+
+ return ret2
+
+ wrapped = cast(_CachedFunction, _wrapped)
+ wrapped.cache = cache
+ obj.__dict__[self.orig.__name__] = wrapped
+
+ return wrapped
+
+
+class DeferredCacheDescriptor(_CacheDescriptorBase):
""" A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
@@ -382,7 +254,6 @@ class CacheDescriptor(_CacheDescriptorBase):
cache_context=False,
iterable=False,
):
-
super().__init__(orig, num_args=num_args, cache_context=cache_context)
self.max_entries = max_entries
@@ -390,49 +261,15 @@ class CacheDescriptor(_CacheDescriptorBase):
self.iterable = iterable
def __get__(self, obj, owner):
- cache = Cache(
+ cache = DeferredCache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
tree=self.tree,
iterable=self.iterable,
- )
-
- def get_cache_key_gen(args, kwargs):
- """Given some args/kwargs return a generator that resolves into
- the cache_key.
-
- We loop through each arg name, looking up if its in the `kwargs`,
- otherwise using the next argument in `args`. If there are no more
- args then we try looking the arg name up in the defaults
- """
- pos = 0
- for nm in self.arg_names:
- if nm in kwargs:
- yield kwargs[nm]
- elif pos < len(args):
- yield args[pos]
- pos += 1
- else:
- yield self.arg_defaults[nm]
-
- # By default our cache key is a tuple, but if there is only one item
- # then don't bother wrapping in a tuple. This is to save memory.
- if self.num_args == 1:
- nm = self.arg_names[0]
-
- def get_cache_key(args, kwargs):
- if nm in kwargs:
- return kwargs[nm]
- elif len(args):
- return args[0]
- else:
- return self.arg_defaults[nm]
-
- else:
+ ) # type: DeferredCache[CacheKey, Any]
- def get_cache_key(args, kwargs):
- return tuple(get_cache_key_gen(args, kwargs))
+ get_cache_key = self.cache_key_builder
@functools.wraps(self.orig)
def _wrapped(*args, **kwargs):
@@ -442,32 +279,20 @@ class CacheDescriptor(_CacheDescriptorBase):
cache_key = get_cache_key(args, kwargs)
- # Add our own `cache_context` to argument list if the wrapped function
- # has asked for one
- if self.add_cache_context:
- kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
-
try:
- cached_result_d = cache.get(cache_key, callback=invalidate_callback)
-
- if isinstance(cached_result_d, ObservableDeferred):
- observer = cached_result_d.observe()
- else:
- observer = defer.succeed(cached_result_d)
-
+ ret = cache.get(cache_key, callback=invalidate_callback)
except KeyError:
- ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
+ # Add our own `cache_context` to argument list if the wrapped function
+ # has asked for one
+ if self.add_cache_context:
+ kwargs["cache_context"] = _CacheContext.get_instance(
+ cache, cache_key
+ )
- def onErr(f):
- cache.invalidate(cache_key)
- return f
-
- ret.addErrback(onErr)
-
- result_d = cache.set(cache_key, ret, callback=invalidate_callback)
- observer = result_d.observe()
+ ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
+ ret = cache.set(cache_key, ret, callback=invalidate_callback)
- return make_deferred_yieldable(observer)
+ return make_deferred_yieldable(ret)
wrapped = cast(_CachedFunction, _wrapped)
@@ -476,7 +301,6 @@ class CacheDescriptor(_CacheDescriptorBase):
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
else:
wrapped.invalidate = cache.invalidate
- wrapped.invalidate_all = cache.invalidate_all
wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = cache.prefill
@@ -489,7 +313,7 @@ class CacheDescriptor(_CacheDescriptorBase):
return wrapped
-class CacheListDescriptor(_CacheDescriptorBase):
+class DeferredCacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes
@@ -526,7 +350,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
def __get__(self, obj, objtype=None):
cached_method = getattr(obj, self.cached_method_name)
- cache = cached_method.cache
+ cache = cached_method.cache # type: DeferredCache[CacheKey, Any]
num_args = cached_method.num_args
@functools.wraps(self.orig)
@@ -566,14 +390,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
for arg in list_args:
try:
res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
- if not isinstance(res, ObservableDeferred):
- results[arg] = res
- elif not res.has_succeeded():
- res = res.observe()
+ if not res.called:
res.addCallback(update_results_dict, arg)
cached_defers.append(res)
else:
- results[arg] = res.get_result()
+ results[arg] = res.result
except KeyError:
missing.add(arg)
@@ -638,11 +459,13 @@ class _CacheContext:
on a lower level.
"""
+ Cache = Union[DeferredCache, LruCache]
+
_cache_context_objects = (
WeakValueDictionary()
- ) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext]
+ ) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext]
- def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None
+ def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
self._cache = cache
self._cache_key = cache_key
@@ -651,7 +474,9 @@ class _CacheContext:
self._cache.invalidate(self._cache_key)
@classmethod
- def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext
+ def get_instance(
+ cls, cache: "_CacheContext.Cache", cache_key: CacheKey
+ ) -> "_CacheContext":
"""Returns an instance constructed with the given arguments.
A new instance is only created if none already exists.
@@ -672,7 +497,7 @@ def cached(
cache_context: bool = False,
iterable: bool = False,
) -> Callable[[F], _CachedFunction[F]]:
- func = lambda orig: CacheDescriptor(
+ func = lambda orig: DeferredCacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
@@ -714,7 +539,7 @@ def cachedList(
def batch_do_something(self, first_arg, second_args):
...
"""
- func = lambda orig: CacheListDescriptor(
+ func = lambda orig: DeferredCacheListDescriptor(
orig,
cached_method_name=cached_method_name,
list_name=list_name,
@@ -722,3 +547,65 @@ def cachedList(
)
return cast(Callable[[F], _CachedFunction[F]], func)
+
+
+def get_cache_key_builder(
+ param_names: Sequence[str], param_defaults: Mapping[str, Any]
+) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
+ """Construct a function which will build cache keys suitable for a cached function
+
+ Args:
+ param_names: list of formal parameter names for the cached function
+ param_defaults: a mapping from parameter name to default value for that param
+
+ Returns:
+ A function which will take an (args, kwargs) pair and return a cache key
+ """
+
+ # By default our cache key is a tuple, but if there is only one item
+ # then don't bother wrapping in a tuple. This is to save memory.
+
+ if len(param_names) == 1:
+ nm = param_names[0]
+
+ def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
+ if nm in kwargs:
+ return kwargs[nm]
+ elif len(args):
+ return args[0]
+ else:
+ return param_defaults[nm]
+
+ else:
+
+ def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
+ return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
+
+ return get_cache_key
+
+
+def _get_cache_key_gen(
+ param_names: Iterable[str],
+ param_defaults: Mapping[str, Any],
+ args: Sequence[Any],
+ kwargs: Mapping[str, Any],
+) -> Iterable[Any]:
+ """Given some args/kwargs return a generator that resolves into
+ the cache_key.
+
+ This is essentially the same operation as `inspect.getcallargs`, but optimised so
+ that we don't need to inspect the target function for each call.
+ """
+
+ # We loop through each arg name, looking up if its in the `kwargs`,
+ # otherwise using the next argument in `args`. If there are no more
+ # args then we try looking the arg name up in the defaults.
+ pos = 0
+ for nm in param_names:
+ if nm in kwargs:
+ yield kwargs[nm]
+ elif pos < len(args):
+ yield args[pos]
+ pos += 1
+ else:
+ yield param_defaults[nm]
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 8592b93689..588d2d49f2 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -12,15 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import enum
import logging
import threading
from collections import namedtuple
+from typing import Any
from synapse.util.caches.lrucache import LruCache
-from . import register_cache
-
logger = logging.getLogger(__name__)
@@ -40,24 +39,25 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
return len(self.value)
+class _Sentinel(enum.Enum):
+ # defining a sentinel in this way allows mypy to correctly handle the
+ # type of a dictionary lookup.
+ sentinel = object()
+
+
class DictionaryCache:
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
"""
def __init__(self, name, max_entries=1000):
- self.cache = LruCache(max_size=max_entries, size_callback=len)
+ self.cache = LruCache(
+ max_size=max_entries, cache_name=name, size_callback=len
+ ) # type: LruCache[Any, DictionaryEntry]
self.name = name
self.sequence = 0
self.thread = None
- # caches_by_name[name] = self.cache
-
- class Sentinel:
- __slots__ = []
-
- self.sentinel = Sentinel()
- self.metrics = register_cache("dictionary", name, self.cache)
def check_thread(self):
expected_thread = self.thread
@@ -80,10 +80,8 @@ class DictionaryCache:
Returns:
DictionaryEntry
"""
- entry = self.cache.get(key, self.sentinel)
- if entry is not self.sentinel:
- self.metrics.inc_hits()
-
+ entry = self.cache.get(key, _Sentinel.sentinel)
+ if entry is not _Sentinel.sentinel:
if dict_keys is None:
return DictionaryEntry(
entry.full, entry.known_absent, dict(entry.value)
@@ -95,7 +93,6 @@ class DictionaryCache:
{k: entry.value[k] for k in dict_keys if k in entry.value},
)
- self.metrics.inc_misses()
return DictionaryEntry(False, set(), {})
def invalidate(self, key):
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 4bc1a67b58..60bb6ff642 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -15,11 +15,35 @@
import threading
from functools import wraps
-from typing import Callable, Optional, Type, Union
+from typing import (
+ Any,
+ Callable,
+ Generic,
+ Iterable,
+ Optional,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+ overload,
+)
+
+from typing_extensions import Literal
from synapse.config import cache as cache_config
+from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.treecache import TreeCache
+# Function type: the type used for invalidation callbacks
+FT = TypeVar("FT", bound=Callable[..., Any])
+
+# Key and Value type for the cache
+KT = TypeVar("KT")
+VT = TypeVar("VT")
+
+# a general type var, distinct from either KT or VT
+T = TypeVar("T")
+
def enumerate_leaves(node, depth):
if depth == 0:
@@ -41,30 +65,33 @@ class _Node:
self.callbacks = callbacks
-class LruCache:
+class LruCache(Generic[KT, VT]):
"""
- Least-recently-used cache.
+ Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
+
Supports del_multi only if cache_type=TreeCache
If cache_type=TreeCache, all keys must be tuples.
-
- Can also set callbacks on objects when getting/setting which are fired
- when that key gets invalidated/evicted.
"""
def __init__(
self,
max_size: int,
+ cache_name: Optional[str] = None,
keylen: int = 1,
cache_type: Type[Union[dict, TreeCache]] = dict,
size_callback: Optional[Callable] = None,
- evicted_callback: Optional[Callable] = None,
+ metrics_collection_callback: Optional[Callable[[], None]] = None,
apply_cache_factor_from_config: bool = True,
):
"""
Args:
max_size: The maximum amount of entries the cache can hold
- keylen: The length of the tuple used as the cache key
+ cache_name: The name of this cache, for the prometheus metrics. If unset,
+ no metrics will be reported on this cache.
+
+ keylen: The length of the tuple used as the cache key. Ignored unless
+ cache_type is `TreeCache`.
cache_type (type):
type of underlying cache to be used. Typically one of dict
@@ -72,9 +99,13 @@ class LruCache:
size_callback (func(V) -> int | None):
- evicted_callback (func(int)|None):
- if not None, called on eviction with the size of the evicted
- entry
+ metrics_collection_callback:
+ metrics collection callback. This is called early in the metrics
+ collection process, before any of the metrics registered with the
+ prometheus Registry are collected, so can be used to update any dynamic
+ metrics.
+
+ Ignored if cache_name is None.
apply_cache_factor_from_config (bool): If true, `max_size` will be
multiplied by a cache factor derived from the homeserver config
@@ -93,6 +124,23 @@ class LruCache:
else:
self.max_size = int(max_size)
+ # register_cache might call our "set_cache_factor" callback; there's nothing to
+ # do yet when we get resized.
+ self._on_resize = None # type: Optional[Callable[[],None]]
+
+ if cache_name is not None:
+ metrics = register_cache(
+ "lru_cache",
+ cache_name,
+ self,
+ collect_callback=metrics_collection_callback,
+ ) # type: Optional[CacheMetric]
+ else:
+ metrics = None
+
+ # this is exposed for access from outside this class
+ self.metrics = metrics
+
list_root = _Node(None, None, None, None)
list_root.next_node = list_root
list_root.prev_node = list_root
@@ -104,16 +152,16 @@ class LruCache:
todelete = list_root.prev_node
evicted_len = delete_node(todelete)
cache.pop(todelete.key, None)
- if evicted_callback:
- evicted_callback(evicted_len)
+ if metrics:
+ metrics.inc_evictions(evicted_len)
- def synchronized(f):
+ def synchronized(f: FT) -> FT:
@wraps(f)
def inner(*args, **kwargs):
with lock:
return f(*args, **kwargs)
- return inner
+ return cast(FT, inner)
cached_cache_len = [0]
if size_callback is not None:
@@ -167,18 +215,45 @@ class LruCache:
node.callbacks.clear()
return deleted_len
+ @overload
+ def cache_get(
+ key: KT,
+ default: Literal[None] = None,
+ callbacks: Iterable[Callable[[], None]] = ...,
+ update_metrics: bool = ...,
+ ) -> Optional[VT]:
+ ...
+
+ @overload
+ def cache_get(
+ key: KT,
+ default: T,
+ callbacks: Iterable[Callable[[], None]] = ...,
+ update_metrics: bool = ...,
+ ) -> Union[T, VT]:
+ ...
+
@synchronized
- def cache_get(key, default=None, callbacks=[]):
+ def cache_get(
+ key: KT,
+ default: Optional[T] = None,
+ callbacks: Iterable[Callable[[], None]] = [],
+ update_metrics: bool = True,
+ ):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
node.callbacks.update(callbacks)
+ if update_metrics and metrics:
+ metrics.inc_hits()
return node.value
else:
+ if update_metrics and metrics:
+ metrics.inc_misses()
return default
@synchronized
- def cache_set(key, value, callbacks=[]):
+ def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
node = cache.get(key, None)
if node is not None:
# We sometimes store large objects, e.g. dicts, which cause
@@ -207,7 +282,7 @@ class LruCache:
evict()
@synchronized
- def cache_set_default(key, value):
+ def cache_set_default(key: KT, value: VT) -> VT:
node = cache.get(key, None)
if node is not None:
return node.value
@@ -216,8 +291,16 @@ class LruCache:
evict()
return value
+ @overload
+ def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]:
+ ...
+
+ @overload
+ def cache_pop(key: KT, default: T) -> Union[T, VT]:
+ ...
+
@synchronized
- def cache_pop(key, default=None):
+ def cache_pop(key: KT, default: Optional[T] = None):
node = cache.get(key, None)
if node:
delete_node(node)
@@ -227,18 +310,18 @@ class LruCache:
return default
@synchronized
- def cache_del_multi(key):
+ def cache_del_multi(key: KT) -> None:
"""
This will only work if constructed with cache_type=TreeCache
"""
popped = cache.pop(key)
if popped is None:
return
- for leaf in enumerate_leaves(popped, keylen - len(key)):
+ for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
delete_node(leaf)
@synchronized
- def cache_clear():
+ def cache_clear() -> None:
list_root.next_node = list_root
list_root.prev_node = list_root
for node in cache.values():
@@ -249,15 +332,21 @@ class LruCache:
cached_cache_len[0] = 0
@synchronized
- def cache_contains(key):
+ def cache_contains(key: KT) -> bool:
return key in cache
self.sentinel = object()
+
+ # make sure that we clear out any excess entries after we get resized.
self._on_resize = evict
+
self.get = cache_get
self.set = cache_set
self.setdefault = cache_set_default
self.pop = cache_pop
+ # `invalidate` is exposed for consistency with DeferredCache, so that it can be
+ # invalidated by the cache invalidation replication stream.
+ self.invalidate = cache_pop
if cache_type is TreeCache:
self.del_multi = cache_del_multi
self.len = synchronized(cache_len)
@@ -301,6 +390,7 @@ class LruCache:
new_size = int(self._original_max_size * factor)
if new_size != self.max_size:
self.max_size = new_size
- self._on_resize()
+ if self._on_resize:
+ self._on_resize()
return True
return False
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index df1a721add..32228f42ee 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
from twisted.internet import defer
@@ -20,10 +21,15 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
+T = TypeVar("T")
+
-class ResponseCache:
+class ResponseCache(Generic[T]):
"""
This caches a deferred response. Until the deferred completes it will be
returned from the cache. This means that if the client retries the request
@@ -31,8 +37,9 @@ class ResponseCache:
used rather than trying to compute a new response.
"""
- def __init__(self, hs, name, timeout_ms=0):
- self.pending_result_cache = {} # Requests that haven't finished yet.
+ def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
+ # Requests that haven't finished yet.
+ self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
self.clock = hs.get_clock()
self.timeout_sec = timeout_ms / 1000.0
@@ -40,13 +47,13 @@ class ResponseCache:
self._name = name
self._metrics = register_cache("response_cache", name, self, resizable=False)
- def size(self):
+ def size(self) -> int:
return len(self.pending_result_cache)
- def __len__(self):
+ def __len__(self) -> int:
return self.size()
- def get(self, key):
+ def get(self, key: T) -> Optional[defer.Deferred]:
"""Look up the given key.
Can return either a new Deferred (which also doesn't follow the synapse
@@ -58,12 +65,11 @@ class ResponseCache:
from an absent cache entry.
Args:
- key (hashable):
+ key: key to get/set in the cache
Returns:
- twisted.internet.defer.Deferred|None|E: None if there is no entry
- for this key; otherwise either a deferred result or the result
- itself.
+ None if there is no entry for this key; otherwise a deferred which
+ resolves to the result.
"""
result = self.pending_result_cache.get(key)
if result is not None:
@@ -73,7 +79,7 @@ class ResponseCache:
self._metrics.inc_misses()
return None
- def set(self, key, deferred):
+ def set(self, key: T, deferred: defer.Deferred) -> defer.Deferred:
"""Set the entry for the given key to the given deferred.
*deferred* should run its callbacks in the sentinel logcontext (ie,
@@ -85,12 +91,11 @@ class ResponseCache:
result. You will probably want to make_deferred_yieldable the result.
Args:
- key (hashable):
- deferred (twisted.internet.defer.Deferred[T):
+ key: key to get/set in the cache
+ deferred: The deferred which resolves to the result.
Returns:
- twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual
- result.
+ A new deferred which resolves to the actual result.
"""
result = ObservableDeferred(deferred, consumeErrors=True)
self.pending_result_cache[key] = result
@@ -107,7 +112,9 @@ class ResponseCache:
result.addBoth(remove)
return result.observe()
- def wrap(self, key, callback, *args, **kwargs):
+ def wrap(
+ self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any
+ ) -> defer.Deferred:
"""Wrap together a *get* and *set* call, taking care of logcontexts
First looks up the key in the cache, and if it is present makes it
@@ -118,21 +125,20 @@ class ResponseCache:
Example usage:
- @defer.inlineCallbacks
- def handle_request(request):
+ async def handle_request(request):
# etc
return result
- result = yield response_cache.wrap(
+ result = await response_cache.wrap(
key,
handle_request,
request,
)
Args:
- key (hashable): key to get/set in the cache
+ key: key to get/set in the cache
- callback (callable): function to call if the key is not found in
+ callback: function to call if the key is not found in
the cache
*args: positional parameters to pass to the callback, if it is used
@@ -140,7 +146,7 @@ class ResponseCache:
**kwargs: named parameters to pass to the callback, if it is used
Returns:
- twisted.internet.defer.Deferred: yieldable result
+ Deferred which resolves to the result
"""
result = self.get(key)
if not result:
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 3e180cafd3..6ce2a3d12b 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -34,7 +34,7 @@ class TTLCache:
self._data = {}
# the _CacheEntries, sorted by expiry time
- self._expiry_list = SortedList()
+ self._expiry_list = SortedList() # type: SortedList[_CacheEntry]
self._timer = timer
|