diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 1d6ec22191..6425f851ea 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -14,15 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import enum
import threading
from typing import (
Callable,
+ Collection,
+ Dict,
Generic,
- Iterable,
MutableMapping,
Optional,
+ Set,
Sized,
+ Tuple,
TypeVar,
Union,
cast,
@@ -31,7 +35,6 @@ from typing import (
from prometheus_client import Gauge
from twisted.internet import defer
-from twisted.python import failure
from twisted.python.failure import Failure
from synapse.util.async_helpers import ObservableDeferred
@@ -94,7 +97,7 @@ class DeferredCache(Generic[KT, VT]):
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
self._pending_deferred_cache: Union[
- TreeCache, "MutableMapping[KT, CacheEntry]"
+ TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]"
] = cache_type()
def metrics_cb() -> None:
@@ -159,15 +162,16 @@ class DeferredCache(Generic[KT, VT]):
Raises:
KeyError if the key is not found in the cache
"""
- callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if val is not _Sentinel.sentinel:
- val.callbacks.update(callbacks)
+ val.add_invalidation_callback(key, callback)
if update_metrics:
m = self.cache.metrics
assert m # we always have a name, so should always have metrics
m.inc_hits()
- return val.deferred.observe()
+ return val.deferred(key)
+
+ callbacks = (callback,) if callback else ()
val2 = self.cache.get(
key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
@@ -177,6 +181,73 @@ class DeferredCache(Generic[KT, VT]):
else:
return defer.succeed(val2)
+ def get_bulk(
+ self,
+ keys: Collection[KT],
+ callback: Optional[Callable[[], None]] = None,
+ ) -> Tuple[Dict[KT, VT], Optional["defer.Deferred[Dict[KT, VT]]"], Collection[KT]]:
+ """Bulk lookup of items in the cache.
+
+ Returns:
+ A 3-tuple of:
+ 1. a dict of key/value of items already cached;
+ 2. a deferred that resolves to a dict of key/value of items
+ we're already fetching; and
+ 3. a collection of keys that don't appear in the previous two.
+ """
+
+ # The cached results
+ cached = {}
+
+ # List of pending deferreds
+ pending = []
+
+ # Dict that gets filled out when the pending deferreds complete
+ pending_results = {}
+
+ # List of keys that aren't in either cache
+ missing = []
+
+ callbacks = (callback,) if callback else ()
+
+ for key in keys:
+ # Check if its in the main cache.
+ immediate_value = self.cache.get(
+ key,
+ _Sentinel.sentinel,
+ callbacks=callbacks,
+ )
+ if immediate_value is not _Sentinel.sentinel:
+ cached[key] = immediate_value
+ continue
+
+ # Check if its in the pending cache
+ pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
+ if pending_value is not _Sentinel.sentinel:
+ pending_value.add_invalidation_callback(key, callback)
+
+ def completed_cb(value: VT, key: KT) -> VT:
+ pending_results[key] = value
+ return value
+
+ # Add a callback to fill out `pending_results` when that completes
+ d = pending_value.deferred(key).addCallback(completed_cb, key)
+ pending.append(d)
+ continue
+
+ # Not in either cache
+ missing.append(key)
+
+ # If we've got pending deferreds, squash them into a single one that
+ # returns `pending_results`.
+ pending_deferred = None
+ if pending:
+ pending_deferred = defer.gatherResults(
+ pending, consumeErrors=True
+ ).addCallback(lambda _: pending_results)
+
+ return (cached, pending_deferred, missing)
+
def get_immediate(
self, key: KT, default: T, update_metrics: bool = True
) -> Union[VT, T]:
@@ -218,84 +289,89 @@ class DeferredCache(Generic[KT, VT]):
value: a deferred which will complete with a result to add to the cache
callback: An optional callback to be called when the entry is invalidated
"""
- if not isinstance(value, defer.Deferred):
- raise TypeError("not a Deferred")
-
- callbacks = [callback] if callback else []
self.check_thread()
- existing_entry = self._pending_deferred_cache.pop(key, None)
- if existing_entry:
- existing_entry.invalidate()
+ self._pending_deferred_cache.pop(key, None)
# XXX: why don't we invalidate the entry in `self.cache` yet?
- # we can save a whole load of effort if the deferred is ready.
- if value.called:
- result = value.result
- if not isinstance(result, failure.Failure):
- self.cache.set(key, cast(VT, result), callbacks)
- return value
-
# otherwise, we'll add an entry to the _pending_deferred_cache for now,
# and add callbacks to add it to the cache properly later.
+ entry = CacheEntrySingle[KT, VT](value)
+ entry.add_invalidation_callback(key, callback)
+ self._pending_deferred_cache[key] = entry
+ deferred = entry.deferred(key).addCallbacks(
+ self._completed_callback,
+ self._error_callback,
+ callbackArgs=(entry, key),
+ errbackArgs=(entry, key),
+ )
- observable = ObservableDeferred(value, consumeErrors=True)
- observer = observable.observe()
- entry = CacheEntry(deferred=observable, callbacks=callbacks)
+ # we return a new Deferred which will be called before any subsequent observers.
+ return deferred
- self._pending_deferred_cache[key] = entry
+ def start_bulk_input(
+ self,
+ keys: Collection[KT],
+ callback: Optional[Callable[[], None]] = None,
+ ) -> "CacheMultipleEntries[KT, VT]":
+ """Bulk set API for use when fetching multiple keys at once from the DB.
- def compare_and_pop() -> bool:
- """Check if our entry is still the one in _pending_deferred_cache, and
- if so, pop it.
-
- Returns true if the entries matched.
- """
- existing_entry = self._pending_deferred_cache.pop(key, None)
- if existing_entry is entry:
- return True
-
- # oops, the _pending_deferred_cache has been updated since
- # we started our query, so we are out of date.
- #
- # Better put back whatever we took out. (We do it this way
- # round, rather than peeking into the _pending_deferred_cache
- # and then removing on a match, to make the common case faster)
- if existing_entry is not None:
- self._pending_deferred_cache[key] = existing_entry
-
- return False
-
- def cb(result: VT) -> None:
- if compare_and_pop():
- self.cache.set(key, result, entry.callbacks)
- else:
- # we're not going to put this entry into the cache, so need
- # to make sure that the invalidation callbacks are called.
- # That was probably done when _pending_deferred_cache was
- # updated, but it's possible that `set` was called without
- # `invalidate` being previously called, in which case it may
- # not have been. Either way, let's double-check now.
- entry.invalidate()
-
- def eb(_fail: Failure) -> None:
- compare_and_pop()
- entry.invalidate()
-
- # once the deferred completes, we can move the entry from the
- # _pending_deferred_cache to the real cache.
- #
- observer.addCallbacks(cb, eb)
+ Called *before* starting the fetch from the DB, and the caller *must*
+ call either `complete_bulk(..)` or `error_bulk(..)` on the return value.
+ """
- # we return a new Deferred which will be called before any subsequent observers.
- return observable.observe()
+ entry = CacheMultipleEntries[KT, VT]()
+ entry.add_global_invalidation_callback(callback)
+
+ for key in keys:
+ self._pending_deferred_cache[key] = entry
+
+ return entry
+
+ def _completed_callback(
+ self, value: VT, entry: "CacheEntry[KT, VT]", key: KT
+ ) -> VT:
+ """Called when a deferred is completed."""
+ # We check if the current entry matches the entry associated with the
+ # deferred. If they don't match then it got invalidated.
+ current_entry = self._pending_deferred_cache.pop(key, None)
+ if current_entry is not entry:
+ if current_entry:
+ self._pending_deferred_cache[key] = current_entry
+ return value
+
+ self.cache.set(key, value, entry.get_invalidation_callbacks(key))
+
+ return value
+
+ def _error_callback(
+ self,
+ failure: Failure,
+ entry: "CacheEntry[KT, VT]",
+ key: KT,
+ ) -> Failure:
+ """Called when a deferred errors."""
+
+ # We check if the current entry matches the entry associated with the
+ # deferred. If they don't match then it got invalidated.
+ current_entry = self._pending_deferred_cache.pop(key, None)
+ if current_entry is not entry:
+ if current_entry:
+ self._pending_deferred_cache[key] = current_entry
+ return failure
+
+ for cb in entry.get_invalidation_callbacks(key):
+ cb()
+
+ return failure
def prefill(
self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
) -> None:
- callbacks = [callback] if callback else []
+ callbacks = (callback,) if callback else ()
self.cache.set(key, value, callbacks=callbacks)
+ self._pending_deferred_cache.pop(key, None)
def invalidate(self, key: KT) -> None:
"""Delete a key, or tree of entries
@@ -311,41 +387,129 @@ class DeferredCache(Generic[KT, VT]):
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
- # _pending_deferred_cache, which will (a) stop it being returned
- # for future queries and (b) stop it being persisted as a proper entry
+ # _pending_deferred_cache, which will (a) stop it being returned for
+ # future queries and (b) stop it being persisted as a proper entry
# in self.cache.
entry = self._pending_deferred_cache.pop(key, None)
-
- # run the invalidation callbacks now, rather than waiting for the
- # deferred to resolve.
if entry:
# _pending_deferred_cache.pop should either return a CacheEntry, or, in the
# case of a TreeCache, a dict of keys to cache entries. Either way calling
# iterate_tree_cache_entry on it will do the right thing.
for entry in iterate_tree_cache_entry(entry):
- entry.invalidate()
+ for cb in entry.get_invalidation_callbacks(key):
+ cb()
def invalidate_all(self) -> None:
self.check_thread()
self.cache.clear()
- for entry in self._pending_deferred_cache.values():
- entry.invalidate()
+ for key, entry in self._pending_deferred_cache.items():
+ for cb in entry.get_invalidation_callbacks(key):
+ cb()
+
self._pending_deferred_cache.clear()
-class CacheEntry:
- __slots__ = ["deferred", "callbacks", "invalidated"]
+class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta):
+ """Abstract class for entries in `DeferredCache[KT, VT]`"""
- def __init__(
- self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
- ):
- self.deferred = deferred
- self.callbacks = set(callbacks)
- self.invalidated = False
-
- def invalidate(self) -> None:
- if not self.invalidated:
- self.invalidated = True
- for callback in self.callbacks:
- callback()
- self.callbacks.clear()
+ @abc.abstractmethod
+ def deferred(self, key: KT) -> "defer.Deferred[VT]":
+ """Get a deferred that a caller can wait on to get the value at the
+ given key"""
+ ...
+
+ @abc.abstractmethod
+ def add_invalidation_callback(
+ self, key: KT, callback: Optional[Callable[[], None]]
+ ) -> None:
+ """Add an invalidation callback"""
+ ...
+
+ @abc.abstractmethod
+ def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+ """Get all invalidation callbacks"""
+ ...
+
+
+class CacheEntrySingle(CacheEntry[KT, VT]):
+ """An implementation of `CacheEntry` wrapping a deferred that results in a
+ single cache entry.
+ """
+
+ __slots__ = ["_deferred", "_callbacks"]
+
+ def __init__(self, deferred: "defer.Deferred[VT]") -> None:
+ self._deferred = ObservableDeferred(deferred, consumeErrors=True)
+ self._callbacks: Set[Callable[[], None]] = set()
+
+ def deferred(self, key: KT) -> "defer.Deferred[VT]":
+ return self._deferred.observe()
+
+ def add_invalidation_callback(
+ self, key: KT, callback: Optional[Callable[[], None]]
+ ) -> None:
+ if callback is None:
+ return
+
+ self._callbacks.add(callback)
+
+ def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+ return self._callbacks
+
+
+class CacheMultipleEntries(CacheEntry[KT, VT]):
+ """Cache entry that is used for bulk lookups and insertions."""
+
+ __slots__ = ["_deferred", "_callbacks", "_global_callbacks"]
+
+ def __init__(self) -> None:
+ self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None
+ self._callbacks: Dict[KT, Set[Callable[[], None]]] = {}
+ self._global_callbacks: Set[Callable[[], None]] = set()
+
+ def deferred(self, key: KT) -> "defer.Deferred[VT]":
+ if not self._deferred:
+ self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+ return self._deferred.observe().addCallback(lambda res: res.get(key))
+
+ def add_invalidation_callback(
+ self, key: KT, callback: Optional[Callable[[], None]]
+ ) -> None:
+ if callback is None:
+ return
+
+ self._callbacks.setdefault(key, set()).add(callback)
+
+ def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+ return self._callbacks.get(key, set()) | self._global_callbacks
+
+ def add_global_invalidation_callback(
+ self, callback: Optional[Callable[[], None]]
+ ) -> None:
+ """Add a callback for when any keys get invalidated."""
+ if callback is None:
+ return
+
+ self._global_callbacks.add(callback)
+
+ def complete_bulk(
+ self,
+ cache: DeferredCache[KT, VT],
+ result: Dict[KT, VT],
+ ) -> None:
+ """Called when there is a result"""
+ for key, value in result.items():
+ cache._completed_callback(value, self, key)
+
+ if self._deferred:
+ self._deferred.callback(result)
+
+ def error_bulk(
+ self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure
+ ) -> None:
+ """Called when bulk lookup failed."""
+ for key in keys:
+ cache._error_callback(failure, self, key)
+
+ if self._deferred:
+ self._deferred.errback(failure)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 867f315b2a..10aff4d04a 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -25,6 +25,7 @@ from typing import (
Generic,
Hashable,
Iterable,
+ List,
Mapping,
Optional,
Sequence,
@@ -73,8 +74,10 @@ class _CacheDescriptorBase:
num_args: Optional[int],
uncached_args: Optional[Collection[str]] = None,
cache_context: bool = False,
+ name: Optional[str] = None,
):
self.orig = orig
+ self.name = name or orig.__name__
arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args
@@ -211,7 +214,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: LruCache[CacheKey, Any] = LruCache(
- cache_name=self.orig.__name__,
+ cache_name=self.name,
max_size=self.max_entries,
)
@@ -241,7 +244,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
wrapped = cast(_CachedFunction, _wrapped)
wrapped.cache = cache
- obj.__dict__[self.orig.__name__] = wrapped
+ obj.__dict__[self.name] = wrapped
return wrapped
@@ -301,12 +304,14 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
+ name: Optional[str] = None,
):
super().__init__(
orig,
num_args=num_args,
uncached_args=uncached_args,
cache_context=cache_context,
+ name=name,
)
if tree and self.num_args < 2:
@@ -321,7 +326,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: DeferredCache[CacheKey, Any] = DeferredCache(
- name=self.orig.__name__,
+ name=self.name,
max_entries=self.max_entries,
tree=self.tree,
iterable=self.iterable,
@@ -372,7 +377,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
wrapped.cache = cache
wrapped.num_args = self.num_args
- obj.__dict__[self.orig.__name__] = wrapped
+ obj.__dict__[self.name] = wrapped
return wrapped
@@ -393,6 +398,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
cached_method_name: str,
list_name: str,
num_args: Optional[int] = None,
+ name: Optional[str] = None,
):
"""
Args:
@@ -403,7 +409,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
but including list_name) to use as cache keys. Defaults to all
named args of the function.
"""
- super().__init__(orig, num_args=num_args, uncached_args=None)
+ super().__init__(orig, num_args=num_args, uncached_args=None, name=name)
self.list_name = list_name
@@ -435,16 +441,6 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
- results = {}
-
- def update_results_dict(res: Any, arg: Hashable) -> None:
- results[arg] = res
-
- # list of deferreds to wait for
- cached_defers = []
-
- missing = set()
-
# If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used.
if num_args == 1:
@@ -452,6 +448,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
def arg_to_cache_key(arg: Hashable) -> Hashable:
return arg
+ def cache_key_to_arg(key: tuple) -> Hashable:
+ return key
+
else:
keylist = list(keyargs)
@@ -459,58 +458,53 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
keylist[self.list_pos] = arg
return tuple(keylist)
- for arg in list_args:
- try:
- res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
- if not res.called:
- res.addCallback(update_results_dict, arg)
- cached_defers.append(res)
- else:
- results[arg] = res.result
- except KeyError:
- missing.add(arg)
+ def cache_key_to_arg(key: tuple) -> Hashable:
+ return key[self.list_pos]
+
+ cache_keys = [arg_to_cache_key(arg) for arg in list_args]
+ immediate_results, pending_deferred, missing = cache.get_bulk(
+ cache_keys, callback=invalidate_callback
+ )
+
+ results = {cache_key_to_arg(key): v for key, v in immediate_results.items()}
+
+ cached_defers: List["defer.Deferred[Any]"] = []
+ if pending_deferred:
+
+ def update_results(r: Dict) -> None:
+ for k, v in r.items():
+ results[cache_key_to_arg(k)] = v
+
+ pending_deferred.addCallback(update_results)
+ cached_defers.append(pending_deferred)
if missing:
- # we need a deferred for each entry in the list,
- # which we put in the cache. Each deferred resolves with the
- # relevant result for that key.
- deferreds_map = {}
- for arg in missing:
- deferred: "defer.Deferred[Any]" = defer.Deferred()
- deferreds_map[arg] = deferred
- key = arg_to_cache_key(arg)
- cached_defers.append(
- cache.set(key, deferred, callback=invalidate_callback)
- )
+ cache_entry = cache.start_bulk_input(missing, invalidate_callback)
def complete_all(res: Dict[Hashable, Any]) -> None:
- # the wrapped function has completed. It returns a dict.
- # We can now update our own result map, and then resolve the
- # observable deferreds in the cache.
- for e, d1 in deferreds_map.items():
- val = res.get(e, None)
- # make sure we update the results map before running the
- # deferreds, because as soon as we run the last deferred, the
- # gatherResults() below will complete and return the result
- # dict to our caller.
- results[e] = val
- d1.callback(val)
+ missing_results = {}
+ for key in missing:
+ arg = cache_key_to_arg(key)
+ val = res.get(arg, None)
+
+ results[arg] = val
+ missing_results[key] = val
+
+ cache_entry.complete_bulk(cache, missing_results)
def errback_all(f: Failure) -> None:
- # the wrapped function has failed. Propagate the failure into
- # the cache, which will invalidate the entry, and cause the
- # relevant cached_deferreds to fail, which will propagate the
- # failure to our caller.
- for d1 in deferreds_map.values():
- d1.errback(f)
+ cache_entry.error_bulk(cache, missing, f)
args_to_call = dict(arg_dict)
- args_to_call[self.list_name] = missing
+ args_to_call[self.list_name] = {
+ cache_key_to_arg(key) for key in missing
+ }
# dispatch the call, and attach the two handlers
- defer.maybeDeferred(
+ missing_d = defer.maybeDeferred(
preserve_fn(self.orig), **args_to_call
).addCallbacks(complete_all, errback_all)
+ cached_defers.append(missing_d)
if cached_defers:
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
@@ -525,7 +519,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
else:
return defer.succeed(results)
- obj.__dict__[self.orig.__name__] = wrapped
+ obj.__dict__[self.name] = wrapped
return wrapped
@@ -577,6 +571,7 @@ def cached(
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
+ name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]:
func = lambda orig: DeferredCacheDescriptor(
orig,
@@ -587,13 +582,18 @@ def cached(
cache_context=cache_context,
iterable=iterable,
prune_unread_entries=prune_unread_entries,
+ name=name,
)
return cast(Callable[[F], _CachedFunction[F]], func)
def cachedList(
- *, cached_method_name: str, list_name: str, num_args: Optional[int] = None
+ *,
+ cached_method_name: str,
+ list_name: str,
+ num_args: Optional[int] = None,
+ name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
@@ -628,6 +628,7 @@ def cachedList(
cached_method_name=cached_method_name,
list_name=list_name,
num_args=num_args,
+ name=name,
)
return cast(Callable[[F], _CachedFunction[F]], func)
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index d267703df0..fa91479c97 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -14,11 +14,13 @@
import enum
import logging
import threading
-from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar
+from typing import Any, Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union
import attr
+from typing_extensions import Literal
from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.treecache import TreeCache
logger = logging.getLogger(__name__)
@@ -33,10 +35,12 @@ DV = TypeVar("DV")
# This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313
-@attr.s(slots=True, auto_attribs=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class DictionaryEntry: # should be: Generic[DKT, DV].
"""Returned when getting an entry from the cache
+ If `full` is true then `known_absent` will be the empty set.
+
Attributes:
full: Whether the cache has the full or dict or just some keys.
If not full then not all requested keys will necessarily be present
@@ -53,20 +57,90 @@ class DictionaryEntry: # should be: Generic[DKT, DV].
return len(self.value)
+class _FullCacheKey(enum.Enum):
+ """The key we use to cache the full dict."""
+
+ KEY = object()
+
+
class _Sentinel(enum.Enum):
# defining a sentinel in this way allows mypy to correctly handle the
# type of a dictionary lookup.
sentinel = object()
+class _PerKeyValue(Generic[DV]):
+ """The cached value of a dictionary key. If `value` is the sentinel,
+ indicates that the requested key is known to *not* be in the full dict.
+ """
+
+ __slots__ = ["value"]
+
+ def __init__(self, value: Union[DV, Literal[_Sentinel.sentinel]]) -> None:
+ self.value = value
+
+ def __len__(self) -> int:
+ # We add a `__len__` implementation as we use this class in a cache
+ # where the values are variable length.
+ return 1
+
+
class DictionaryCache(Generic[KT, DKT, DV]):
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
+
+ This cache has two levels of key. First there is the "cache key" (of type
+ `KT`), which maps to a dict. The keys to that dict are the "dict key" (of
+ type `DKT`). The overall structure is therefore `KT->DKT->DV`. For
+ example, it might look like:
+
+ {
+ 1: { 1: "a", 2: "b" },
+ 2: { 1: "c" },
+ }
+
+ It is possible to look up either individual dict keys, or the *complete*
+ dict for a given cache key.
+
+ Each dict item, and the complete dict is treated as a separate LRU
+ entry for the purpose of cache expiry. For example, given:
+ dict_cache.get(1, None) -> DictionaryEntry({1: "a", 2: "b"})
+ dict_cache.get(1, [1]) -> DictionaryEntry({1: "a"})
+ dict_cache.get(1, [2]) -> DictionaryEntry({2: "b"})
+
+ ... then the cache entry for the complete dict will expire first,
+ followed by the cache entry for the '1' dict key, and finally that
+ for the '2' dict key.
"""
def __init__(self, name: str, max_entries: int = 1000):
- self.cache: LruCache[KT, DictionaryEntry] = LruCache(
- max_size=max_entries, cache_name=name, size_callback=len
+ # We use a single LruCache to store two different types of entries:
+ # 1. Map from (key, dict_key) -> dict value (or sentinel, indicating
+ # the key doesn't exist in the dict); and
+ # 2. Map from (key, _FullCacheKey.KEY) -> full dict.
+ #
+ # The former is used when explicit keys of the dictionary are looked up,
+ # and the latter when the full dictionary is requested.
+ #
+ # If when explicit keys are requested and not in the cache, we then look
+ # to see if we have the full dict and use that if we do. If found in the
+ # full dict each key is added into the cache.
+ #
+ # This set up allows the `LruCache` to prune the full dict entries if
+ # they haven't been used in a while, even when there have been recent
+ # queries for subsets of the dict.
+ #
+ # Typing:
+ # * A key of `(KT, DKT)` has a value of `_PerKeyValue`
+ # * A key of `(KT, _FullCacheKey.KEY)` has a value of `Dict[DKT, DV]`
+ self.cache: LruCache[
+ Tuple[KT, Union[DKT, Literal[_FullCacheKey.KEY]]],
+ Union[_PerKeyValue, Dict[DKT, DV]],
+ ] = LruCache(
+ max_size=max_entries,
+ cache_name=name,
+ cache_type=TreeCache,
+ size_callback=len,
)
self.name = name
@@ -91,23 +165,83 @@ class DictionaryCache(Generic[KT, DKT, DV]):
Args:
key
dict_keys: If given a set of keys then return only those keys
- that exist in the cache.
+ that exist in the cache. If None then returns the full dict
+ if it is in the cache.
Returns:
- DictionaryEntry
+ DictionaryEntry: If `dict_keys` is not None then `DictionaryEntry`
+ will contain include the keys that are in the cache. If None then
+ will either return the full dict if in the cache, or the empty
+ dict (with `full` set to False) if it isn't.
"""
- entry = self.cache.get(key, _Sentinel.sentinel)
- if entry is not _Sentinel.sentinel:
- if dict_keys is None:
- return DictionaryEntry(
- entry.full, entry.known_absent, dict(entry.value)
- )
+ if dict_keys is None:
+ # The caller wants the full set of dictionary keys for this cache key
+ return self._get_full_dict(key)
+
+ # We are being asked for a subset of keys.
+
+ # First go and check for each requested dict key in the cache, tracking
+ # which we couldn't find.
+ values = {}
+ known_absent = set()
+ missing = []
+ for dict_key in dict_keys:
+ entry = self.cache.get((key, dict_key), _Sentinel.sentinel)
+ if entry is _Sentinel.sentinel:
+ missing.append(dict_key)
+ continue
+
+ assert isinstance(entry, _PerKeyValue)
+
+ if entry.value is _Sentinel.sentinel:
+ known_absent.add(dict_key)
else:
- return DictionaryEntry(
- entry.full,
- entry.known_absent,
- {k: entry.value[k] for k in dict_keys if k in entry.value},
- )
+ values[dict_key] = entry.value
+
+ # If we found everything we can return immediately.
+ if not missing:
+ return DictionaryEntry(False, known_absent, values)
+
+ # We are missing some keys, so check if we happen to have the full dict in
+ # the cache.
+ #
+ # We don't update the last access time for this cache fetch, as we
+ # aren't explicitly interested in the full dict and so we don't want
+ # requests for explicit dict keys to keep the full dict in the cache.
+ entry = self.cache.get(
+ (key, _FullCacheKey.KEY),
+ _Sentinel.sentinel,
+ update_last_access=False,
+ )
+ if entry is _Sentinel.sentinel:
+ # Not in the cache, return the subset of keys we found.
+ return DictionaryEntry(False, known_absent, values)
+
+ # We have the full dict!
+ assert isinstance(entry, dict)
+
+ for dict_key in missing:
+ # We explicitly add each dict key to the cache, so that cache hit
+ # rates and LRU times for each key can be tracked separately.
+ value = entry.get(dict_key, _Sentinel.sentinel) # type: ignore[arg-type]
+ self.cache[(key, dict_key)] = _PerKeyValue(value)
+
+ if value is not _Sentinel.sentinel:
+ values[dict_key] = value
+
+ return DictionaryEntry(True, set(), values)
+
+ def _get_full_dict(
+ self,
+ key: KT,
+ ) -> DictionaryEntry:
+ """Fetch the full dict for the given key."""
+
+ # First we check if we have cached the full dict.
+ entry = self.cache.get((key, _FullCacheKey.KEY), _Sentinel.sentinel)
+ if entry is not _Sentinel.sentinel:
+ assert isinstance(entry, dict)
+ return DictionaryEntry(True, set(), entry)
return DictionaryEntry(False, set(), {})
@@ -117,7 +251,13 @@ class DictionaryCache(Generic[KT, DKT, DV]):
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
- self.cache.pop(key, None)
+
+ # We want to drop all information about the dict for the given key, so
+ # we use `del_multi` to delete it all in one go.
+ #
+ # We ignore the type error here: `del_multi` accepts a truncated key
+ # (when the key type is a tuple).
+ self.cache.del_multi((key,)) # type: ignore[arg-type]
def invalidate_all(self) -> None:
self.check_thread()
@@ -131,7 +271,16 @@ class DictionaryCache(Generic[KT, DKT, DV]):
value: Dict[DKT, DV],
fetched_keys: Optional[Iterable[DKT]] = None,
) -> None:
- """Updates the entry in the cache
+ """Updates the entry in the cache.
+
+ Note: This does *not* invalidate any existing entries for the `key`.
+ In particular, if we add an entry for the cached "full dict" with
+ `fetched_keys=None`, existing entries for individual dict keys are
+ not invalidated. Likewise, adding entries for individual keys does
+ not invalidate any cached value for the full dict.
+
+ In other words: if the underlying data is *changed*, the cache must
+ be explicitly invalidated via `.invalidate()`.
Args:
sequence
@@ -149,20 +298,27 @@ class DictionaryCache(Generic[KT, DKT, DV]):
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
if fetched_keys is None:
- self._insert(key, value, set())
+ self.cache[(key, _FullCacheKey.KEY)] = value
else:
- self._update_or_insert(key, value, fetched_keys)
+ self._update_subset(key, value, fetched_keys)
- def _update_or_insert(
- self, key: KT, value: Dict[DKT, DV], known_absent: Iterable[DKT]
+ def _update_subset(
+ self, key: KT, value: Dict[DKT, DV], fetched_keys: Iterable[DKT]
) -> None:
- # We pop and reinsert as we need to tell the cache the size may have
- # changed
+ """Add the given dictionary values as explicit keys in the cache.
+
+ Args:
+ key: top-level cache key
+ value: The dictionary with all the values that we should cache
+ fetched_keys: The full set of dict keys that were looked up. Any keys
+ here not in `value` should be marked as "known absent".
+ """
+
+ for dict_key, dict_value in value.items():
+ self.cache[(key, dict_key)] = _PerKeyValue(dict_value)
- entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
- entry.value.update(value)
- entry.known_absent.update(known_absent)
- self.cache[key] = entry
+ for dict_key in fetched_keys:
+ if dict_key in value:
+ continue
- def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None:
- self.cache[key] = DictionaryEntry(True, known_absent, value)
+ self.cache[(key, dict_key)] = _PerKeyValue(_Sentinel.sentinel)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 31f41fec82..aa93109d13 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -25,8 +25,10 @@ from typing import (
Collection,
Dict,
Generic,
+ Iterable,
List,
Optional,
+ Tuple,
Type,
TypeVar,
Union,
@@ -44,7 +46,11 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.metrics.jemalloc import get_jemalloc_stats
from synapse.util import Clock, caches
from synapse.util.caches import CacheMetric, EvictionReason, register_cache
-from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+from synapse.util.caches.treecache import (
+ TreeCache,
+ iterate_tree_cache_entry,
+ iterate_tree_cache_items,
+)
from synapse.util.linked_list import ListNode
if TYPE_CHECKING:
@@ -537,6 +543,7 @@ class LruCache(Generic[KT, VT]):
default: Literal[None] = None,
callbacks: Collection[Callable[[], None]] = ...,
update_metrics: bool = ...,
+ update_last_access: bool = ...,
) -> Optional[VT]:
...
@@ -546,6 +553,7 @@ class LruCache(Generic[KT, VT]):
default: T,
callbacks: Collection[Callable[[], None]] = ...,
update_metrics: bool = ...,
+ update_last_access: bool = ...,
) -> Union[T, VT]:
...
@@ -555,10 +563,27 @@ class LruCache(Generic[KT, VT]):
default: Optional[T] = None,
callbacks: Collection[Callable[[], None]] = (),
update_metrics: bool = True,
+ update_last_access: bool = True,
) -> Union[None, T, VT]:
+ """Look up a key in the cache
+
+ Args:
+ key
+ default
+ callbacks: A collection of callbacks that will fire when the
+ node is removed from the cache (either due to invalidation
+ or expiry).
+ update_metrics: Whether to update the hit rate metrics
+ update_last_access: Whether to update the last access metrics
+ on a node if successfully fetched. These metrics are used
+ to determine when to remove the node from the cache. Set
+ to False if this fetch should *not* prevent a node from
+ being expired.
+ """
node = cache.get(key, None)
if node is not None:
- move_node_to_front(node)
+ if update_last_access:
+ move_node_to_front(node)
node.add_callbacks(callbacks)
if update_metrics and metrics:
metrics.inc_hits()
@@ -568,6 +593,65 @@ class LruCache(Generic[KT, VT]):
metrics.inc_misses()
return default
+ @overload
+ def cache_get_multi(
+ key: tuple,
+ default: Literal[None] = None,
+ update_metrics: bool = True,
+ ) -> Union[None, Iterable[Tuple[KT, VT]]]:
+ ...
+
+ @overload
+ def cache_get_multi(
+ key: tuple,
+ default: T,
+ update_metrics: bool = True,
+ ) -> Union[T, Iterable[Tuple[KT, VT]]]:
+ ...
+
+ @synchronized
+ def cache_get_multi(
+ key: tuple,
+ default: Optional[T] = None,
+ update_metrics: bool = True,
+ ) -> Union[None, T, Iterable[Tuple[KT, VT]]]:
+ """Returns a generator yielding all entries under the given key.
+
+ Can only be used if backed by a tree cache.
+
+ Example:
+
+ cache = LruCache(10, cache_type=TreeCache)
+ cache[(1, 1)] = "a"
+ cache[(1, 2)] = "b"
+ cache[(2, 1)] = "c"
+
+ items = cache.get_multi((1,))
+ assert list(items) == [((1, 1), "a"), ((1, 2), "b")]
+
+ Returns:
+ Either default if the key doesn't exist, or a generator of the
+ key/value pairs.
+ """
+
+ assert isinstance(cache, TreeCache)
+
+ node = cache.get(key, None)
+ if node is not None:
+ if update_metrics and metrics:
+ metrics.inc_hits()
+
+ # We store entries in the `TreeCache` with values of type `_Node`,
+ # which we need to unwrap.
+ return (
+ (full_key, lru_node.value)
+ for full_key, lru_node in iterate_tree_cache_items(key, node)
+ )
+ else:
+ if update_metrics and metrics:
+ metrics.inc_misses()
+ return default
+
@synchronized
def cache_set(
key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
@@ -674,6 +758,8 @@ class LruCache(Generic[KT, VT]):
self.setdefault = cache_set_default
self.pop = cache_pop
self.del_multi = cache_del_multi
+ if cache_type is TreeCache:
+ self.get_multi = cache_get_multi
# `invalidate` is exposed for consistency with DeferredCache, so that it can be
# invalidated by the cache invalidation replication stream.
self.invalidate = cache_del_multi
@@ -748,9 +834,26 @@ class AsyncLruCache(Generic[KT, VT]):
) -> Optional[VT]:
return self._lru_cache.get(key, update_metrics=update_metrics)
+ async def get_external(
+ self,
+ key: KT,
+ default: Optional[T] = None,
+ update_metrics: bool = True,
+ ) -> Optional[VT]:
+ # This method should fetch from any configured external cache, in this case noop.
+ return None
+
+ def get_local(
+ self, key: KT, default: Optional[T] = None, update_metrics: bool = True
+ ) -> Optional[VT]:
+ return self._lru_cache.get(key, update_metrics=update_metrics)
+
async def set(self, key: KT, value: VT) -> None:
self._lru_cache.set(key, value)
+ def set_local(self, key: KT, value: VT) -> None:
+ self._lru_cache.set(key, value)
+
async def invalidate(self, key: KT) -> None:
# This method should invalidate any external cache and then invalidate the LruCache.
return self._lru_cache.invalidate(key)
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index e78305f787..fec31da2b6 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -64,6 +64,15 @@ class TreeCache:
self.size += 1
def get(self, key, default=None):
+ """When `key` is a full key, fetches the value for the given key (if
+ any).
+
+ If `key` is only a partial key (i.e. a truncated tuple) then returns a
+ `TreeCacheNode`, which can be passed to the `iterate_tree_cache_*`
+ functions to iterate over all entries in the cache with keys that start
+ with the given partial key.
+ """
+
node = self.root
for k in key[:-1]:
node = node.get(k, None)
@@ -126,6 +135,9 @@ class TreeCache:
def values(self):
return iterate_tree_cache_entry(self.root)
+ def items(self):
+ return iterate_tree_cache_items((), self.root)
+
def __len__(self) -> int:
return self.size
@@ -139,3 +151,32 @@ def iterate_tree_cache_entry(d):
yield from iterate_tree_cache_entry(value_d)
else:
yield d
+
+
+def iterate_tree_cache_items(key, value):
+ """Helper function to iterate over the leaves of a tree, i.e. a dict of that
+ can contain dicts.
+
+ The provided key is a tuple that will get prepended to the returned keys.
+
+ Example:
+
+ cache = TreeCache()
+ cache[(1, 1)] = "a"
+ cache[(1, 2)] = "b"
+ cache[(2, 1)] = "c"
+
+ tree_node = cache.get((1,))
+
+ items = iterate_tree_cache_items((1,), tree_node)
+ assert list(items) == [((1, 1), "a"), ((1, 2), "b")]
+
+ Returns:
+ A generator yielding key/value pairs.
+ """
+ if isinstance(value, TreeCacheNode):
+ for sub_key, sub_value in value.items():
+ yield from iterate_tree_cache_items((*key, sub_key), sub_value)
+ else:
+ # we've reached a leaf of the tree.
+ yield key, value
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index dfe628c97e..f678b52cb4 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -18,15 +18,19 @@ import logging
import typing
from typing import Any, DefaultDict, Iterator, List, Set
+from prometheus_client.core import Counter
+
from twisted.internet import defer
from synapse.api.errors import LimitExceededError
-from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.config.ratelimiting import FederationRatelimitSettings
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
)
+from synapse.logging.opentracing import start_active_span
+from synapse.metrics import Histogram, LaterGauge
from synapse.util import Clock
if typing.TYPE_CHECKING:
@@ -35,8 +39,34 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
+# Track how much the ratelimiter is affecting requests
+rate_limit_sleep_counter = Counter("synapse_rate_limit_sleep", "")
+rate_limit_reject_counter = Counter("synapse_rate_limit_reject", "")
+queue_wait_timer = Histogram(
+ "synapse_rate_limit_queue_wait_time_seconds",
+ "sec",
+ [],
+ buckets=(
+ 0.005,
+ 0.01,
+ 0.025,
+ 0.05,
+ 0.1,
+ 0.25,
+ 0.5,
+ 0.75,
+ 1.0,
+ 2.5,
+ 5.0,
+ 10.0,
+ 20.0,
+ "+Inf",
+ ),
+)
+
+
class FederationRateLimiter:
- def __init__(self, clock: Clock, config: FederationRateLimitConfig):
+ def __init__(self, clock: Clock, config: FederationRatelimitSettings):
def new_limiter() -> "_PerHostRatelimiter":
return _PerHostRatelimiter(clock=clock, config=config)
@@ -44,6 +74,27 @@ class FederationRateLimiter:
str, "_PerHostRatelimiter"
] = collections.defaultdict(new_limiter)
+ # We track the number of affected hosts per time-period so we can
+ # differentiate one really noisy homeserver from a general
+ # ratelimit tuning problem across the federation.
+ LaterGauge(
+ "synapse_rate_limit_sleep_affected_hosts",
+ "Number of hosts that had requests put to sleep",
+ [],
+ lambda: sum(
+ ratelimiter.should_sleep() for ratelimiter in self.ratelimiters.values()
+ ),
+ )
+ LaterGauge(
+ "synapse_rate_limit_reject_affected_hosts",
+ "Number of hosts that had requests rejected",
+ [],
+ lambda: sum(
+ ratelimiter.should_reject()
+ for ratelimiter in self.ratelimiters.values()
+ ),
+ )
+
def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]":
"""Used to ratelimit an incoming request from a given host
@@ -59,11 +110,11 @@ class FederationRateLimiter:
Returns:
context manager which returns a deferred.
"""
- return self.ratelimiters[host].ratelimit()
+ return self.ratelimiters[host].ratelimit(host)
class _PerHostRatelimiter:
- def __init__(self, clock: Clock, config: FederationRateLimitConfig):
+ def __init__(self, clock: Clock, config: FederationRatelimitSettings):
"""
Args:
clock
@@ -94,19 +145,42 @@ class _PerHostRatelimiter:
self.request_times: List[int] = []
@contextlib.contextmanager
- def ratelimit(self) -> "Iterator[defer.Deferred[None]]":
+ def ratelimit(self, host: str) -> "Iterator[defer.Deferred[None]]":
# `contextlib.contextmanager` takes a generator and turns it into a
# context manager. The generator should only yield once with a value
# to be returned by manager.
# Exceptions will be reraised at the yield.
+ self.host = host
+
request_id = object()
- ret = self._on_enter(request_id)
+ # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
+ # type-checking, but we'd need Twisted >= 21.2.
+ ret = defer.ensureDeferred(self._on_enter_with_tracing(request_id))
try:
yield ret
finally:
self._on_exit(request_id)
+ def should_reject(self) -> bool:
+ """
+ Whether to reject the request if we already have too many queued up
+ (either sleeping or in the ready queue).
+ """
+ queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
+ return queue_size > self.reject_limit
+
+ def should_sleep(self) -> bool:
+ """
+ Whether to sleep the request if we already have too many requests coming
+ through within the window.
+ """
+ return len(self.request_times) > self.sleep_limit
+
+ async def _on_enter_with_tracing(self, request_id: object) -> None:
+ with start_active_span("ratelimit wait"), queue_wait_timer.time():
+ await self._on_enter(request_id)
+
def _on_enter(self, request_id: object) -> "defer.Deferred[None]":
time_now = self.clock.time_msec()
@@ -117,8 +191,9 @@ class _PerHostRatelimiter:
# reject the request if we already have too many queued up (either
# sleeping or in the ready queue).
- queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
- if queue_size > self.reject_limit:
+ if self.should_reject():
+ logger.debug("Ratelimiter(%s): rejecting request", self.host)
+ rate_limit_reject_counter.inc()
raise LimitExceededError(
retry_after_ms=int(self.window_size / self.sleep_limit)
)
@@ -130,7 +205,8 @@ class _PerHostRatelimiter:
queue_defer: defer.Deferred[None] = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
logger.info(
- "Ratelimiter: queueing request (queue now %i items)",
+ "Ratelimiter(%s): queueing request (queue now %i items)",
+ self.host,
len(self.ready_request_queue),
)
@@ -139,19 +215,28 @@ class _PerHostRatelimiter:
return defer.succeed(None)
logger.debug(
- "Ratelimit [%s]: len(self.request_times)=%d",
+ "Ratelimit(%s) [%s]: len(self.request_times)=%d",
+ self.host,
id(request_id),
len(self.request_times),
)
- if len(self.request_times) > self.sleep_limit:
- logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec)
+ if self.should_sleep():
+ logger.debug(
+ "Ratelimiter(%s) [%s]: sleeping request for %f sec",
+ self.host,
+ id(request_id),
+ self.sleep_sec,
+ )
+ rate_limit_sleep_counter.inc()
ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
self.sleeping_requests.add(request_id)
def on_wait_finished(_: Any) -> "defer.Deferred[None]":
- logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
+ logger.debug(
+ "Ratelimit(%s) [%s]: Finished sleeping", self.host, id(request_id)
+ )
self.sleeping_requests.discard(request_id)
queue_defer = queue_request()
return queue_defer
@@ -161,7 +246,9 @@ class _PerHostRatelimiter:
ret_defer = queue_request()
def on_start(r: object) -> object:
- logger.debug("Ratelimit [%s]: Processing req", id(request_id))
+ logger.debug(
+ "Ratelimit(%s) [%s]: Processing req", self.host, id(request_id)
+ )
self.current_processing.add(request_id)
return r
@@ -183,7 +270,7 @@ class _PerHostRatelimiter:
return make_deferred_yieldable(ret_defer)
def _on_exit(self, request_id: object) -> None:
- logger.debug("Ratelimit [%s]: Processed req", id(request_id))
+ logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id))
self.current_processing.discard(request_id)
try:
# start processing the next item on the queue.
|