diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 42f6abb5e1..9387632d0d 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -20,9 +20,11 @@ from sys import intern
from typing import Any, Callable, Dict, List, Optional, Sized, TypeVar
import attr
+from prometheus_client import REGISTRY
from prometheus_client.core import Gauge
from synapse.config.cache import add_resizable_cache
+from synapse.util.metrics import DynamicCollectorRegistry
logger = logging.getLogger(__name__)
@@ -30,27 +32,62 @@ logger = logging.getLogger(__name__)
# Whether to track estimated memory usage of the LruCaches.
TRACK_MEMORY_USAGE = False
+# We track cache metrics in a special registry that lets us update the metrics
+# just before they are returned from the scrape endpoint.
+CACHE_METRIC_REGISTRY = DynamicCollectorRegistry()
caches_by_name: Dict[str, Sized] = {}
-collectors_by_name: Dict[str, "CacheMetric"] = {}
-cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
-cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
-cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name", "reason"])
-cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"])
-cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"])
+cache_size = Gauge(
+ "synapse_util_caches_cache_size", "", ["name"], registry=CACHE_METRIC_REGISTRY
+)
+cache_hits = Gauge(
+ "synapse_util_caches_cache_hits", "", ["name"], registry=CACHE_METRIC_REGISTRY
+)
+cache_evicted = Gauge(
+ "synapse_util_caches_cache_evicted_size",
+ "",
+ ["name", "reason"],
+ registry=CACHE_METRIC_REGISTRY,
+)
+cache_total = Gauge(
+ "synapse_util_caches_cache", "", ["name"], registry=CACHE_METRIC_REGISTRY
+)
+cache_max_size = Gauge(
+ "synapse_util_caches_cache_max_size", "", ["name"], registry=CACHE_METRIC_REGISTRY
+)
cache_memory_usage = Gauge(
"synapse_util_caches_cache_size_bytes",
"Estimated memory usage of the caches",
["name"],
+ registry=CACHE_METRIC_REGISTRY,
)
-response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"])
-response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"])
+response_cache_size = Gauge(
+ "synapse_util_caches_response_cache_size",
+ "",
+ ["name"],
+ registry=CACHE_METRIC_REGISTRY,
+)
+response_cache_hits = Gauge(
+ "synapse_util_caches_response_cache_hits",
+ "",
+ ["name"],
+ registry=CACHE_METRIC_REGISTRY,
+)
response_cache_evicted = Gauge(
- "synapse_util_caches_response_cache:evicted_size", "", ["name", "reason"]
+ "synapse_util_caches_response_cache_evicted_size",
+ "",
+ ["name", "reason"],
+ registry=CACHE_METRIC_REGISTRY,
)
-response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
+response_cache_total = Gauge(
+ "synapse_util_caches_response_cache", "", ["name"], registry=CACHE_METRIC_REGISTRY
+)
+
+
+# Register our custom cache metrics registry with the global registry
+REGISTRY.register(CACHE_METRIC_REGISTRY)
class EvictionReason(Enum):
@@ -160,7 +197,7 @@ def register_cache(
resize_callback: A function which can be called to resize the cache.
Returns:
- CacheMetric: an object which provides inc_{hits,misses,evictions} methods
+ an object which provides inc_{hits,misses,evictions} methods
"""
if resizable:
if not resize_callback:
@@ -170,7 +207,7 @@ def register_cache(
metric = CacheMetric(cache, cache_type, cache_name, collect_callback)
metric_name = "cache_%s_%s" % (cache_type, cache_name)
caches_by_name[cache_name] = cache
- collectors_by_name[metric_name] = metric
+ CACHE_METRIC_REGISTRY.register_hook(metric_name, metric.collect)
return metric
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 1d6ec22191..bf7bd351e0 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -14,15 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import enum
import threading
from typing import (
Callable,
+ Collection,
+ Dict,
Generic,
- Iterable,
MutableMapping,
Optional,
+ Set,
Sized,
+ Tuple,
TypeVar,
Union,
cast,
@@ -31,7 +35,6 @@ from typing import (
from prometheus_client import Gauge
from twisted.internet import defer
-from twisted.python import failure
from twisted.python.failure import Failure
from synapse.util.async_helpers import ObservableDeferred
@@ -94,7 +97,7 @@ class DeferredCache(Generic[KT, VT]):
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
self._pending_deferred_cache: Union[
- TreeCache, "MutableMapping[KT, CacheEntry]"
+ TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]"
] = cache_type()
def metrics_cb() -> None:
@@ -150,7 +153,7 @@ class DeferredCache(Generic[KT, VT]):
Args:
key:
callback: Gets called when the entry in the cache is invalidated
- update_metrics (bool): whether to update the cache hit rate metrics
+ update_metrics: whether to update the cache hit rate metrics
Returns:
A Deferred which completes with the result. Note that this may later fail
@@ -159,15 +162,16 @@ class DeferredCache(Generic[KT, VT]):
Raises:
KeyError if the key is not found in the cache
"""
- callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if val is not _Sentinel.sentinel:
- val.callbacks.update(callbacks)
+ val.add_invalidation_callback(key, callback)
if update_metrics:
m = self.cache.metrics
assert m # we always have a name, so should always have metrics
m.inc_hits()
- return val.deferred.observe()
+ return val.deferred(key)
+
+ callbacks = (callback,) if callback else ()
val2 = self.cache.get(
key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
@@ -177,6 +181,73 @@ class DeferredCache(Generic[KT, VT]):
else:
return defer.succeed(val2)
+ def get_bulk(
+ self,
+ keys: Collection[KT],
+ callback: Optional[Callable[[], None]] = None,
+ ) -> Tuple[Dict[KT, VT], Optional["defer.Deferred[Dict[KT, VT]]"], Collection[KT]]:
+ """Bulk lookup of items in the cache.
+
+ Returns:
+ A 3-tuple of:
+ 1. a dict of key/value of items already cached;
+ 2. a deferred that resolves to a dict of key/value of items
+ we're already fetching; and
+ 3. a collection of keys that don't appear in the previous two.
+ """
+
+ # The cached results
+ cached = {}
+
+ # List of pending deferreds
+ pending = []
+
+ # Dict that gets filled out when the pending deferreds complete
+ pending_results = {}
+
+ # List of keys that aren't in either cache
+ missing = []
+
+ callbacks = (callback,) if callback else ()
+
+ for key in keys:
+ # Check if its in the main cache.
+ immediate_value = self.cache.get(
+ key,
+ _Sentinel.sentinel,
+ callbacks=callbacks,
+ )
+ if immediate_value is not _Sentinel.sentinel:
+ cached[key] = immediate_value
+ continue
+
+ # Check if its in the pending cache
+ pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
+ if pending_value is not _Sentinel.sentinel:
+ pending_value.add_invalidation_callback(key, callback)
+
+ def completed_cb(value: VT, key: KT) -> VT:
+ pending_results[key] = value
+ return value
+
+ # Add a callback to fill out `pending_results` when that completes
+ d = pending_value.deferred(key).addCallback(completed_cb, key)
+ pending.append(d)
+ continue
+
+ # Not in either cache
+ missing.append(key)
+
+ # If we've got pending deferreds, squash them into a single one that
+ # returns `pending_results`.
+ pending_deferred = None
+ if pending:
+ pending_deferred = defer.gatherResults(
+ pending, consumeErrors=True
+ ).addCallback(lambda _: pending_results)
+
+ return (cached, pending_deferred, missing)
+
def get_immediate(
self, key: KT, default: T, update_metrics: bool = True
) -> Union[VT, T]:
@@ -218,84 +289,89 @@ class DeferredCache(Generic[KT, VT]):
value: a deferred which will complete with a result to add to the cache
callback: An optional callback to be called when the entry is invalidated
"""
- if not isinstance(value, defer.Deferred):
- raise TypeError("not a Deferred")
-
- callbacks = [callback] if callback else []
self.check_thread()
- existing_entry = self._pending_deferred_cache.pop(key, None)
- if existing_entry:
- existing_entry.invalidate()
+ self._pending_deferred_cache.pop(key, None)
# XXX: why don't we invalidate the entry in `self.cache` yet?
- # we can save a whole load of effort if the deferred is ready.
- if value.called:
- result = value.result
- if not isinstance(result, failure.Failure):
- self.cache.set(key, cast(VT, result), callbacks)
- return value
-
# otherwise, we'll add an entry to the _pending_deferred_cache for now,
# and add callbacks to add it to the cache properly later.
+ entry = CacheEntrySingle[KT, VT](value)
+ entry.add_invalidation_callback(key, callback)
+ self._pending_deferred_cache[key] = entry
+ deferred = entry.deferred(key).addCallbacks(
+ self._completed_callback,
+ self._error_callback,
+ callbackArgs=(entry, key),
+ errbackArgs=(entry, key),
+ )
- observable = ObservableDeferred(value, consumeErrors=True)
- observer = observable.observe()
- entry = CacheEntry(deferred=observable, callbacks=callbacks)
+ # we return a new Deferred which will be called before any subsequent observers.
+ return deferred
- self._pending_deferred_cache[key] = entry
+ def start_bulk_input(
+ self,
+ keys: Collection[KT],
+ callback: Optional[Callable[[], None]] = None,
+ ) -> "CacheMultipleEntries[KT, VT]":
+ """Bulk set API for use when fetching multiple keys at once from the DB.
- def compare_and_pop() -> bool:
- """Check if our entry is still the one in _pending_deferred_cache, and
- if so, pop it.
-
- Returns true if the entries matched.
- """
- existing_entry = self._pending_deferred_cache.pop(key, None)
- if existing_entry is entry:
- return True
-
- # oops, the _pending_deferred_cache has been updated since
- # we started our query, so we are out of date.
- #
- # Better put back whatever we took out. (We do it this way
- # round, rather than peeking into the _pending_deferred_cache
- # and then removing on a match, to make the common case faster)
- if existing_entry is not None:
- self._pending_deferred_cache[key] = existing_entry
-
- return False
-
- def cb(result: VT) -> None:
- if compare_and_pop():
- self.cache.set(key, result, entry.callbacks)
- else:
- # we're not going to put this entry into the cache, so need
- # to make sure that the invalidation callbacks are called.
- # That was probably done when _pending_deferred_cache was
- # updated, but it's possible that `set` was called without
- # `invalidate` being previously called, in which case it may
- # not have been. Either way, let's double-check now.
- entry.invalidate()
-
- def eb(_fail: Failure) -> None:
- compare_and_pop()
- entry.invalidate()
-
- # once the deferred completes, we can move the entry from the
- # _pending_deferred_cache to the real cache.
- #
- observer.addCallbacks(cb, eb)
+ Called *before* starting the fetch from the DB, and the caller *must*
+ call either `complete_bulk(..)` or `error_bulk(..)` on the return value.
+ """
- # we return a new Deferred which will be called before any subsequent observers.
- return observable.observe()
+ entry = CacheMultipleEntries[KT, VT]()
+ entry.add_global_invalidation_callback(callback)
+
+ for key in keys:
+ self._pending_deferred_cache[key] = entry
+
+ return entry
+
+ def _completed_callback(
+ self, value: VT, entry: "CacheEntry[KT, VT]", key: KT
+ ) -> VT:
+ """Called when a deferred is completed."""
+ # We check if the current entry matches the entry associated with the
+ # deferred. If they don't match then it got invalidated.
+ current_entry = self._pending_deferred_cache.pop(key, None)
+ if current_entry is not entry:
+ if current_entry:
+ self._pending_deferred_cache[key] = current_entry
+ return value
+
+ self.cache.set(key, value, entry.get_invalidation_callbacks(key))
+
+ return value
+
+ def _error_callback(
+ self,
+ failure: Failure,
+ entry: "CacheEntry[KT, VT]",
+ key: KT,
+ ) -> Failure:
+ """Called when a deferred errors."""
+
+ # We check if the current entry matches the entry associated with the
+ # deferred. If they don't match then it got invalidated.
+ current_entry = self._pending_deferred_cache.pop(key, None)
+ if current_entry is not entry:
+ if current_entry:
+ self._pending_deferred_cache[key] = current_entry
+ return failure
+
+ for cb in entry.get_invalidation_callbacks(key):
+ cb()
+
+ return failure
def prefill(
self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
) -> None:
- callbacks = [callback] if callback else []
+ callbacks = (callback,) if callback else ()
self.cache.set(key, value, callbacks=callbacks)
+ self._pending_deferred_cache.pop(key, None)
def invalidate(self, key: KT) -> None:
"""Delete a key, or tree of entries
@@ -311,41 +387,129 @@ class DeferredCache(Generic[KT, VT]):
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
- # _pending_deferred_cache, which will (a) stop it being returned
- # for future queries and (b) stop it being persisted as a proper entry
+ # _pending_deferred_cache, which will (a) stop it being returned for
+ # future queries and (b) stop it being persisted as a proper entry
# in self.cache.
entry = self._pending_deferred_cache.pop(key, None)
-
- # run the invalidation callbacks now, rather than waiting for the
- # deferred to resolve.
if entry:
# _pending_deferred_cache.pop should either return a CacheEntry, or, in the
# case of a TreeCache, a dict of keys to cache entries. Either way calling
# iterate_tree_cache_entry on it will do the right thing.
- for entry in iterate_tree_cache_entry(entry):
- entry.invalidate()
+ for iter_entry in iterate_tree_cache_entry(entry):
+ for cb in iter_entry.get_invalidation_callbacks(key):
+ cb()
def invalidate_all(self) -> None:
self.check_thread()
self.cache.clear()
- for entry in self._pending_deferred_cache.values():
- entry.invalidate()
+ for key, entry in self._pending_deferred_cache.items():
+ for cb in entry.get_invalidation_callbacks(key):
+ cb()
+
self._pending_deferred_cache.clear()
-class CacheEntry:
- __slots__ = ["deferred", "callbacks", "invalidated"]
+class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta):
+ """Abstract class for entries in `DeferredCache[KT, VT]`"""
- def __init__(
- self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
- ):
- self.deferred = deferred
- self.callbacks = set(callbacks)
- self.invalidated = False
-
- def invalidate(self) -> None:
- if not self.invalidated:
- self.invalidated = True
- for callback in self.callbacks:
- callback()
- self.callbacks.clear()
+ @abc.abstractmethod
+ def deferred(self, key: KT) -> "defer.Deferred[VT]":
+ """Get a deferred that a caller can wait on to get the value at the
+ given key"""
+ ...
+
+ @abc.abstractmethod
+ def add_invalidation_callback(
+ self, key: KT, callback: Optional[Callable[[], None]]
+ ) -> None:
+ """Add an invalidation callback"""
+ ...
+
+ @abc.abstractmethod
+ def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+ """Get all invalidation callbacks"""
+ ...
+
+
+class CacheEntrySingle(CacheEntry[KT, VT]):
+ """An implementation of `CacheEntry` wrapping a deferred that results in a
+ single cache entry.
+ """
+
+ __slots__ = ["_deferred", "_callbacks"]
+
+ def __init__(self, deferred: "defer.Deferred[VT]") -> None:
+ self._deferred = ObservableDeferred(deferred, consumeErrors=True)
+ self._callbacks: Set[Callable[[], None]] = set()
+
+ def deferred(self, key: KT) -> "defer.Deferred[VT]":
+ return self._deferred.observe()
+
+ def add_invalidation_callback(
+ self, key: KT, callback: Optional[Callable[[], None]]
+ ) -> None:
+ if callback is None:
+ return
+
+ self._callbacks.add(callback)
+
+ def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+ return self._callbacks
+
+
+class CacheMultipleEntries(CacheEntry[KT, VT]):
+ """Cache entry that is used for bulk lookups and insertions."""
+
+ __slots__ = ["_deferred", "_callbacks", "_global_callbacks"]
+
+ def __init__(self) -> None:
+ self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None
+ self._callbacks: Dict[KT, Set[Callable[[], None]]] = {}
+ self._global_callbacks: Set[Callable[[], None]] = set()
+
+ def deferred(self, key: KT) -> "defer.Deferred[VT]":
+ if not self._deferred:
+ self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+ return self._deferred.observe().addCallback(lambda res: res.get(key))
+
+ def add_invalidation_callback(
+ self, key: KT, callback: Optional[Callable[[], None]]
+ ) -> None:
+ if callback is None:
+ return
+
+ self._callbacks.setdefault(key, set()).add(callback)
+
+ def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+ return self._callbacks.get(key, set()) | self._global_callbacks
+
+ def add_global_invalidation_callback(
+ self, callback: Optional[Callable[[], None]]
+ ) -> None:
+ """Add a callback for when any keys get invalidated."""
+ if callback is None:
+ return
+
+ self._global_callbacks.add(callback)
+
+ def complete_bulk(
+ self,
+ cache: DeferredCache[KT, VT],
+ result: Dict[KT, VT],
+ ) -> None:
+ """Called when there is a result"""
+ for key, value in result.items():
+ cache._completed_callback(value, self, key)
+
+ if self._deferred:
+ self._deferred.callback(result)
+
+ def error_bulk(
+ self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure
+ ) -> None:
+ """Called when bulk lookup failed."""
+ for key in keys:
+ cache._error_callback(failure, self, key)
+
+ if self._deferred:
+ self._deferred.errback(failure)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 867f315b2a..72227359b9 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import enum
import functools
import inspect
import logging
@@ -25,6 +24,7 @@ from typing import (
Generic,
Hashable,
Iterable,
+ List,
Mapping,
Optional,
Sequence,
@@ -52,7 +52,7 @@ CacheKey = Union[Tuple, Any]
F = TypeVar("F", bound=Callable[..., Any])
-class _CachedFunction(Generic[F]):
+class CachedFunction(Generic[F]):
invalidate: Any = None
invalidate_all: Any = None
prefill: Any = None
@@ -73,8 +73,10 @@ class _CacheDescriptorBase:
num_args: Optional[int],
uncached_args: Optional[Collection[str]] = None,
cache_context: bool = False,
+ name: Optional[str] = None,
):
self.orig = orig
+ self.name = name or orig.__name__
arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args
@@ -143,109 +145,6 @@ class _CacheDescriptorBase:
)
-class _LruCachedFunction(Generic[F]):
- cache: LruCache[CacheKey, Any]
- __call__: F
-
-
-def lru_cache(
- *, max_entries: int = 1000, cache_context: bool = False
-) -> Callable[[F], _LruCachedFunction[F]]:
- """A method decorator that applies a memoizing cache around the function.
-
- This is more-or-less a drop-in equivalent to functools.lru_cache, although note
- that the signature is slightly different.
-
- The main differences with functools.lru_cache are:
- (a) the size of the cache can be controlled via the cache_factor mechanism
- (b) the wrapped function can request a "cache_context" which provides a
- callback mechanism to indicate that the result is no longer valid
- (c) prometheus metrics are exposed automatically.
-
- The function should take zero or more arguments, which are used as the key for the
- cache. Single-argument functions use that argument as the cache key; otherwise the
- arguments are built into a tuple.
-
- Cached functions can be "chained" (i.e. a cached function can call other cached
- functions and get appropriately invalidated when they called caches are
- invalidated) by adding a special "cache_context" argument to the function
- and passing that as a kwarg to all caches called. For example:
-
- @lru_cache(cache_context=True)
- def foo(self, key, cache_context):
- r1 = self.bar1(key, on_invalidate=cache_context.invalidate)
- r2 = self.bar2(key, on_invalidate=cache_context.invalidate)
- return r1 + r2
-
- The wrapped function also has a 'cache' property which offers direct access to the
- underlying LruCache.
- """
-
- def func(orig: F) -> _LruCachedFunction[F]:
- desc = LruCacheDescriptor(
- orig,
- max_entries=max_entries,
- cache_context=cache_context,
- )
- return cast(_LruCachedFunction[F], desc)
-
- return func
-
-
-class LruCacheDescriptor(_CacheDescriptorBase):
- """Helper for @lru_cache"""
-
- class _Sentinel(enum.Enum):
- sentinel = object()
-
- def __init__(
- self,
- orig: Callable[..., Any],
- max_entries: int = 1000,
- cache_context: bool = False,
- ):
- super().__init__(
- orig, num_args=None, uncached_args=None, cache_context=cache_context
- )
- self.max_entries = max_entries
-
- def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
- cache: LruCache[CacheKey, Any] = LruCache(
- cache_name=self.orig.__name__,
- max_size=self.max_entries,
- )
-
- get_cache_key = self.cache_key_builder
- sentinel = LruCacheDescriptor._Sentinel.sentinel
-
- @functools.wraps(self.orig)
- def _wrapped(*args: Any, **kwargs: Any) -> Any:
- invalidate_callback = kwargs.pop("on_invalidate", None)
- callbacks = (invalidate_callback,) if invalidate_callback else ()
-
- cache_key = get_cache_key(args, kwargs)
-
- ret = cache.get(cache_key, default=sentinel, callbacks=callbacks)
- if ret != sentinel:
- return ret
-
- # Add our own `cache_context` to argument list if the wrapped function
- # has asked for one
- if self.add_cache_context:
- kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
-
- ret2 = self.orig(obj, *args, **kwargs)
- cache.set(cache_key, ret2, callbacks=callbacks)
-
- return ret2
-
- wrapped = cast(_CachedFunction, _wrapped)
- wrapped.cache = cache
- obj.__dict__[self.orig.__name__] = wrapped
-
- return wrapped
-
-
class DeferredCacheDescriptor(_CacheDescriptorBase):
"""A method decorator that applies a memoizing cache around the function.
@@ -301,12 +200,14 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
+ name: Optional[str] = None,
):
super().__init__(
orig,
num_args=num_args,
uncached_args=uncached_args,
cache_context=cache_context,
+ name=name,
)
if tree and self.num_args < 2:
@@ -321,7 +222,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: DeferredCache[CacheKey, Any] = DeferredCache(
- name=self.orig.__name__,
+ name=self.name,
max_entries=self.max_entries,
tree=self.tree,
iterable=self.iterable,
@@ -358,7 +259,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
return make_deferred_yieldable(ret)
- wrapped = cast(_CachedFunction, _wrapped)
+ wrapped = cast(CachedFunction, _wrapped)
if self.num_args == 1:
assert not self.tree
@@ -372,7 +273,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
wrapped.cache = cache
wrapped.num_args = self.num_args
- obj.__dict__[self.orig.__name__] = wrapped
+ obj.__dict__[self.name] = wrapped
return wrapped
@@ -393,6 +294,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
cached_method_name: str,
list_name: str,
num_args: Optional[int] = None,
+ name: Optional[str] = None,
):
"""
Args:
@@ -403,7 +305,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
but including list_name) to use as cache keys. Defaults to all
named args of the function.
"""
- super().__init__(orig, num_args=num_args, uncached_args=None)
+ super().__init__(orig, num_args=num_args, uncached_args=None, name=name)
self.list_name = list_name
@@ -425,6 +327,12 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
cache: DeferredCache[CacheKey, Any] = cached_method.cache
num_args = cached_method.num_args
+ if num_args != self.num_args:
+ raise TypeError(
+ "Number of args (%s) does not match underlying cache_method_name=%s (%s)."
+ % (self.num_args, self.cached_method_name, num_args)
+ )
+
@functools.wraps(self.orig)
def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]":
# If we're passed a cache_context then we'll want to call its
@@ -435,16 +343,6 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
- results = {}
-
- def update_results_dict(res: Any, arg: Hashable) -> None:
- results[arg] = res
-
- # list of deferreds to wait for
- cached_defers = []
-
- missing = set()
-
# If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used.
if num_args == 1:
@@ -452,6 +350,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
def arg_to_cache_key(arg: Hashable) -> Hashable:
return arg
+ def cache_key_to_arg(key: tuple) -> Hashable:
+ return key
+
else:
keylist = list(keyargs)
@@ -459,58 +360,53 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
keylist[self.list_pos] = arg
return tuple(keylist)
- for arg in list_args:
- try:
- res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
- if not res.called:
- res.addCallback(update_results_dict, arg)
- cached_defers.append(res)
- else:
- results[arg] = res.result
- except KeyError:
- missing.add(arg)
+ def cache_key_to_arg(key: tuple) -> Hashable:
+ return key[self.list_pos]
+
+ cache_keys = [arg_to_cache_key(arg) for arg in list_args]
+ immediate_results, pending_deferred, missing = cache.get_bulk(
+ cache_keys, callback=invalidate_callback
+ )
+
+ results = {cache_key_to_arg(key): v for key, v in immediate_results.items()}
+
+ cached_defers: List["defer.Deferred[Any]"] = []
+ if pending_deferred:
+
+ def update_results(r: Dict) -> None:
+ for k, v in r.items():
+ results[cache_key_to_arg(k)] = v
+
+ pending_deferred.addCallback(update_results)
+ cached_defers.append(pending_deferred)
if missing:
- # we need a deferred for each entry in the list,
- # which we put in the cache. Each deferred resolves with the
- # relevant result for that key.
- deferreds_map = {}
- for arg in missing:
- deferred: "defer.Deferred[Any]" = defer.Deferred()
- deferreds_map[arg] = deferred
- key = arg_to_cache_key(arg)
- cached_defers.append(
- cache.set(key, deferred, callback=invalidate_callback)
- )
+ cache_entry = cache.start_bulk_input(missing, invalidate_callback)
def complete_all(res: Dict[Hashable, Any]) -> None:
- # the wrapped function has completed. It returns a dict.
- # We can now update our own result map, and then resolve the
- # observable deferreds in the cache.
- for e, d1 in deferreds_map.items():
- val = res.get(e, None)
- # make sure we update the results map before running the
- # deferreds, because as soon as we run the last deferred, the
- # gatherResults() below will complete and return the result
- # dict to our caller.
- results[e] = val
- d1.callback(val)
+ missing_results = {}
+ for key in missing:
+ arg = cache_key_to_arg(key)
+ val = res.get(arg, None)
+
+ results[arg] = val
+ missing_results[key] = val
+
+ cache_entry.complete_bulk(cache, missing_results)
def errback_all(f: Failure) -> None:
- # the wrapped function has failed. Propagate the failure into
- # the cache, which will invalidate the entry, and cause the
- # relevant cached_deferreds to fail, which will propagate the
- # failure to our caller.
- for d1 in deferreds_map.values():
- d1.errback(f)
+ cache_entry.error_bulk(cache, missing, f)
args_to_call = dict(arg_dict)
- args_to_call[self.list_name] = missing
+ args_to_call[self.list_name] = {
+ cache_key_to_arg(key) for key in missing
+ }
# dispatch the call, and attach the two handlers
- defer.maybeDeferred(
+ missing_d = defer.maybeDeferred(
preserve_fn(self.orig), **args_to_call
).addCallbacks(complete_all, errback_all)
+ cached_defers.append(missing_d)
if cached_defers:
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
@@ -525,7 +421,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
else:
return defer.succeed(results)
- obj.__dict__[self.orig.__name__] = wrapped
+ obj.__dict__[self.name] = wrapped
return wrapped
@@ -577,7 +473,8 @@ def cached(
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
-) -> Callable[[F], _CachedFunction[F]]:
+ name: Optional[str] = None,
+) -> Callable[[F], CachedFunction[F]]:
func = lambda orig: DeferredCacheDescriptor(
orig,
max_entries=max_entries,
@@ -587,21 +484,26 @@ def cached(
cache_context=cache_context,
iterable=iterable,
prune_unread_entries=prune_unread_entries,
+ name=name,
)
- return cast(Callable[[F], _CachedFunction[F]], func)
+ return cast(Callable[[F], CachedFunction[F]], func)
def cachedList(
- *, cached_method_name: str, list_name: str, num_args: Optional[int] = None
-) -> Callable[[F], _CachedFunction[F]]:
+ *,
+ cached_method_name: str,
+ list_name: str,
+ num_args: Optional[int] = None,
+ name: Optional[str] = None,
+) -> Callable[[F], CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
Used to do batch lookups for an already created cache. One of the arguments
is specified as a list that is iterated through to lookup keys in the
original cache. A new tuple consisting of the (deduplicated) keys that weren't in
the cache gets passed to the original function, which is expected to results
- in a map of key to value for each passed value. THe new results are stored in the
+ in a map of key to value for each passed value. The new results are stored in the
original cache. Note that any missing values are cached as None.
Args:
@@ -628,9 +530,10 @@ def cachedList(
cached_method_name=cached_method_name,
list_name=list_name,
num_args=num_args,
+ name=name,
)
- return cast(Callable[[F], _CachedFunction[F]], func)
+ return cast(Callable[[F], CachedFunction[F]], func)
def _get_cache_key_builder(
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index fa91479c97..5eaf70c7ab 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -169,10 +169,11 @@ class DictionaryCache(Generic[KT, DKT, DV]):
if it is in the cache.
Returns:
- DictionaryEntry: If `dict_keys` is not None then `DictionaryEntry`
- will contain include the keys that are in the cache. If None then
- will either return the full dict if in the cache, or the empty
- dict (with `full` set to False) if it isn't.
+ If `dict_keys` is not None then `DictionaryEntry` will contain include
+ the keys that are in the cache.
+
+ If None then will either return the full dict if in the cache, or the
+ empty dict (with `full` set to False) if it isn't.
"""
if dict_keys is None:
# The caller wants the full set of dictionary keys for this cache key
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index c6a5d0dfc0..01ad02af67 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -207,7 +207,7 @@ class ExpiringCache(Generic[KT, VT]):
items from the cache.
Returns:
- bool: Whether the cache changed size or not.
+ Whether the cache changed size or not.
"""
new_size = int(self._original_max_size * factor)
if new_size != self._max_size:
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index b3bdedb04c..dcf0eac3bf 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -389,11 +389,11 @@ class LruCache(Generic[KT, VT]):
cache_name: The name of this cache, for the prometheus metrics. If unset,
no metrics will be reported on this cache.
- cache_type (type):
+ cache_type:
type of underlying cache to be used. Typically one of dict
or TreeCache.
- size_callback (func(V) -> int | None):
+ size_callback:
metrics_collection_callback:
metrics collection callback. This is called early in the metrics
@@ -403,7 +403,7 @@ class LruCache(Generic[KT, VT]):
Ignored if cache_name is None.
- apply_cache_factor_from_config (bool): If true, `max_size` will be
+ apply_cache_factor_from_config: If true, `max_size` will be
multiplied by a cache factor derived from the homeserver config
clock:
@@ -796,7 +796,7 @@ class LruCache(Generic[KT, VT]):
items from the cache.
Returns:
- bool: Whether the cache changed size or not.
+ Whether the cache changed size or not.
"""
if not self.apply_cache_factor_from_config:
return False
@@ -834,9 +834,26 @@ class AsyncLruCache(Generic[KT, VT]):
) -> Optional[VT]:
return self._lru_cache.get(key, update_metrics=update_metrics)
+ async def get_external(
+ self,
+ key: KT,
+ default: Optional[T] = None,
+ update_metrics: bool = True,
+ ) -> Optional[VT]:
+ # This method should fetch from any configured external cache, in this case noop.
+ return None
+
+ def get_local(
+ self, key: KT, default: Optional[T] = None, update_metrics: bool = True
+ ) -> Optional[VT]:
+ return self._lru_cache.get(key, update_metrics=update_metrics)
+
async def set(self, key: KT, value: VT) -> None:
self._lru_cache.set(key, value)
+ def set_local(self, key: KT, value: VT) -> None:
+ self._lru_cache.set(key, value)
+
async def invalidate(self, key: KT) -> None:
# This method should invalidate any external cache and then invalidate the LruCache.
return self._lru_cache.invalidate(key)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 330709b8b7..666f4b6895 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -72,7 +72,7 @@ class StreamChangeCache:
items from the cache.
Returns:
- bool: Whether the cache changed size or not.
+ Whether the cache changed size or not.
"""
new_size = math.floor(self._original_max_size * factor)
if new_size != self._max_size:
@@ -188,14 +188,8 @@ class StreamChangeCache:
self._entity_to_key[entity] = stream_pos
self._evict()
- # if the cache is too big, remove entries
- while len(self._cache) > self._max_size:
- k, r = self._cache.popitem(0)
- self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
- for entity in r:
- del self._entity_to_key[entity]
-
def _evict(self) -> None:
+ # if the cache is too big, remove entries
while len(self._cache) > self._max_size:
k, r = self._cache.popitem(0)
self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
@@ -203,7 +197,6 @@ class StreamChangeCache:
self._entity_to_key.pop(entity, None)
def get_max_pos_of_last_change(self, entity: EntityType) -> int:
-
"""Returns an upper bound of the stream id of the last change to an
entity.
"""
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index c1b8ec0c73..fec31da2b6 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -135,6 +135,9 @@ class TreeCache:
def values(self):
return iterate_tree_cache_entry(self.root)
+ def items(self):
+ return iterate_tree_cache_items((), self.root)
+
def __len__(self) -> int:
return self.size
|