diff options
Diffstat (limited to 'synapse/util')
-rw-r--r-- | synapse/util/caches/deferred_cache.py | 346 | ||||
-rw-r--r-- | synapse/util/caches/descriptors.py | 115 | ||||
-rw-r--r-- | synapse/util/caches/dictionary_cache.py | 218 | ||||
-rw-r--r-- | synapse/util/caches/lrucache.py | 107 | ||||
-rw-r--r-- | synapse/util/caches/treecache.py | 41 | ||||
-rw-r--r-- | synapse/util/ratelimitutils.py | 117 |
6 files changed, 748 insertions, 196 deletions
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 1d6ec22191..6425f851ea 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -14,15 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import enum import threading from typing import ( Callable, + Collection, + Dict, Generic, - Iterable, MutableMapping, Optional, + Set, Sized, + Tuple, TypeVar, Union, cast, @@ -31,7 +35,6 @@ from typing import ( from prometheus_client import Gauge from twisted.internet import defer -from twisted.python import failure from twisted.python.failure import Failure from synapse.util.async_helpers import ObservableDeferred @@ -94,7 +97,7 @@ class DeferredCache(Generic[KT, VT]): # _pending_deferred_cache maps from the key value to a `CacheEntry` object. self._pending_deferred_cache: Union[ - TreeCache, "MutableMapping[KT, CacheEntry]" + TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]" ] = cache_type() def metrics_cb() -> None: @@ -159,15 +162,16 @@ class DeferredCache(Generic[KT, VT]): Raises: KeyError if the key is not found in the cache """ - callbacks = [callback] if callback else [] val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) if val is not _Sentinel.sentinel: - val.callbacks.update(callbacks) + val.add_invalidation_callback(key, callback) if update_metrics: m = self.cache.metrics assert m # we always have a name, so should always have metrics m.inc_hits() - return val.deferred.observe() + return val.deferred(key) + + callbacks = (callback,) if callback else () val2 = self.cache.get( key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics @@ -177,6 +181,73 @@ class DeferredCache(Generic[KT, VT]): else: return defer.succeed(val2) + def get_bulk( + self, + keys: Collection[KT], + callback: Optional[Callable[[], None]] = None, + ) -> Tuple[Dict[KT, VT], Optional["defer.Deferred[Dict[KT, VT]]"], Collection[KT]]: + """Bulk lookup of items in the cache. + + Returns: + A 3-tuple of: + 1. a dict of key/value of items already cached; + 2. a deferred that resolves to a dict of key/value of items + we're already fetching; and + 3. a collection of keys that don't appear in the previous two. + """ + + # The cached results + cached = {} + + # List of pending deferreds + pending = [] + + # Dict that gets filled out when the pending deferreds complete + pending_results = {} + + # List of keys that aren't in either cache + missing = [] + + callbacks = (callback,) if callback else () + + for key in keys: + # Check if its in the main cache. + immediate_value = self.cache.get( + key, + _Sentinel.sentinel, + callbacks=callbacks, + ) + if immediate_value is not _Sentinel.sentinel: + cached[key] = immediate_value + continue + + # Check if its in the pending cache + pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel) + if pending_value is not _Sentinel.sentinel: + pending_value.add_invalidation_callback(key, callback) + + def completed_cb(value: VT, key: KT) -> VT: + pending_results[key] = value + return value + + # Add a callback to fill out `pending_results` when that completes + d = pending_value.deferred(key).addCallback(completed_cb, key) + pending.append(d) + continue + + # Not in either cache + missing.append(key) + + # If we've got pending deferreds, squash them into a single one that + # returns `pending_results`. + pending_deferred = None + if pending: + pending_deferred = defer.gatherResults( + pending, consumeErrors=True + ).addCallback(lambda _: pending_results) + + return (cached, pending_deferred, missing) + def get_immediate( self, key: KT, default: T, update_metrics: bool = True ) -> Union[VT, T]: @@ -218,84 +289,89 @@ class DeferredCache(Generic[KT, VT]): value: a deferred which will complete with a result to add to the cache callback: An optional callback to be called when the entry is invalidated """ - if not isinstance(value, defer.Deferred): - raise TypeError("not a Deferred") - - callbacks = [callback] if callback else [] self.check_thread() - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry: - existing_entry.invalidate() + self._pending_deferred_cache.pop(key, None) # XXX: why don't we invalidate the entry in `self.cache` yet? - # we can save a whole load of effort if the deferred is ready. - if value.called: - result = value.result - if not isinstance(result, failure.Failure): - self.cache.set(key, cast(VT, result), callbacks) - return value - # otherwise, we'll add an entry to the _pending_deferred_cache for now, # and add callbacks to add it to the cache properly later. + entry = CacheEntrySingle[KT, VT](value) + entry.add_invalidation_callback(key, callback) + self._pending_deferred_cache[key] = entry + deferred = entry.deferred(key).addCallbacks( + self._completed_callback, + self._error_callback, + callbackArgs=(entry, key), + errbackArgs=(entry, key), + ) - observable = ObservableDeferred(value, consumeErrors=True) - observer = observable.observe() - entry = CacheEntry(deferred=observable, callbacks=callbacks) + # we return a new Deferred which will be called before any subsequent observers. + return deferred - self._pending_deferred_cache[key] = entry + def start_bulk_input( + self, + keys: Collection[KT], + callback: Optional[Callable[[], None]] = None, + ) -> "CacheMultipleEntries[KT, VT]": + """Bulk set API for use when fetching multiple keys at once from the DB. - def compare_and_pop() -> bool: - """Check if our entry is still the one in _pending_deferred_cache, and - if so, pop it. - - Returns true if the entries matched. - """ - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry is entry: - return True - - # oops, the _pending_deferred_cache has been updated since - # we started our query, so we are out of date. - # - # Better put back whatever we took out. (We do it this way - # round, rather than peeking into the _pending_deferred_cache - # and then removing on a match, to make the common case faster) - if existing_entry is not None: - self._pending_deferred_cache[key] = existing_entry - - return False - - def cb(result: VT) -> None: - if compare_and_pop(): - self.cache.set(key, result, entry.callbacks) - else: - # we're not going to put this entry into the cache, so need - # to make sure that the invalidation callbacks are called. - # That was probably done when _pending_deferred_cache was - # updated, but it's possible that `set` was called without - # `invalidate` being previously called, in which case it may - # not have been. Either way, let's double-check now. - entry.invalidate() - - def eb(_fail: Failure) -> None: - compare_and_pop() - entry.invalidate() - - # once the deferred completes, we can move the entry from the - # _pending_deferred_cache to the real cache. - # - observer.addCallbacks(cb, eb) + Called *before* starting the fetch from the DB, and the caller *must* + call either `complete_bulk(..)` or `error_bulk(..)` on the return value. + """ - # we return a new Deferred which will be called before any subsequent observers. - return observable.observe() + entry = CacheMultipleEntries[KT, VT]() + entry.add_global_invalidation_callback(callback) + + for key in keys: + self._pending_deferred_cache[key] = entry + + return entry + + def _completed_callback( + self, value: VT, entry: "CacheEntry[KT, VT]", key: KT + ) -> VT: + """Called when a deferred is completed.""" + # We check if the current entry matches the entry associated with the + # deferred. If they don't match then it got invalidated. + current_entry = self._pending_deferred_cache.pop(key, None) + if current_entry is not entry: + if current_entry: + self._pending_deferred_cache[key] = current_entry + return value + + self.cache.set(key, value, entry.get_invalidation_callbacks(key)) + + return value + + def _error_callback( + self, + failure: Failure, + entry: "CacheEntry[KT, VT]", + key: KT, + ) -> Failure: + """Called when a deferred errors.""" + + # We check if the current entry matches the entry associated with the + # deferred. If they don't match then it got invalidated. + current_entry = self._pending_deferred_cache.pop(key, None) + if current_entry is not entry: + if current_entry: + self._pending_deferred_cache[key] = current_entry + return failure + + for cb in entry.get_invalidation_callbacks(key): + cb() + + return failure def prefill( self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None ) -> None: - callbacks = [callback] if callback else [] + callbacks = (callback,) if callback else () self.cache.set(key, value, callbacks=callbacks) + self._pending_deferred_cache.pop(key, None) def invalidate(self, key: KT) -> None: """Delete a key, or tree of entries @@ -311,41 +387,129 @@ class DeferredCache(Generic[KT, VT]): self.cache.del_multi(key) # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, which will (a) stop it being returned - # for future queries and (b) stop it being persisted as a proper entry + # _pending_deferred_cache, which will (a) stop it being returned for + # future queries and (b) stop it being persisted as a proper entry # in self.cache. entry = self._pending_deferred_cache.pop(key, None) - - # run the invalidation callbacks now, rather than waiting for the - # deferred to resolve. if entry: # _pending_deferred_cache.pop should either return a CacheEntry, or, in the # case of a TreeCache, a dict of keys to cache entries. Either way calling # iterate_tree_cache_entry on it will do the right thing. for entry in iterate_tree_cache_entry(entry): - entry.invalidate() + for cb in entry.get_invalidation_callbacks(key): + cb() def invalidate_all(self) -> None: self.check_thread() self.cache.clear() - for entry in self._pending_deferred_cache.values(): - entry.invalidate() + for key, entry in self._pending_deferred_cache.items(): + for cb in entry.get_invalidation_callbacks(key): + cb() + self._pending_deferred_cache.clear() -class CacheEntry: - __slots__ = ["deferred", "callbacks", "invalidated"] +class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta): + """Abstract class for entries in `DeferredCache[KT, VT]`""" - def __init__( - self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] - ): - self.deferred = deferred - self.callbacks = set(callbacks) - self.invalidated = False - - def invalidate(self) -> None: - if not self.invalidated: - self.invalidated = True - for callback in self.callbacks: - callback() - self.callbacks.clear() + @abc.abstractmethod + def deferred(self, key: KT) -> "defer.Deferred[VT]": + """Get a deferred that a caller can wait on to get the value at the + given key""" + ... + + @abc.abstractmethod + def add_invalidation_callback( + self, key: KT, callback: Optional[Callable[[], None]] + ) -> None: + """Add an invalidation callback""" + ... + + @abc.abstractmethod + def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]: + """Get all invalidation callbacks""" + ... + + +class CacheEntrySingle(CacheEntry[KT, VT]): + """An implementation of `CacheEntry` wrapping a deferred that results in a + single cache entry. + """ + + __slots__ = ["_deferred", "_callbacks"] + + def __init__(self, deferred: "defer.Deferred[VT]") -> None: + self._deferred = ObservableDeferred(deferred, consumeErrors=True) + self._callbacks: Set[Callable[[], None]] = set() + + def deferred(self, key: KT) -> "defer.Deferred[VT]": + return self._deferred.observe() + + def add_invalidation_callback( + self, key: KT, callback: Optional[Callable[[], None]] + ) -> None: + if callback is None: + return + + self._callbacks.add(callback) + + def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]: + return self._callbacks + + +class CacheMultipleEntries(CacheEntry[KT, VT]): + """Cache entry that is used for bulk lookups and insertions.""" + + __slots__ = ["_deferred", "_callbacks", "_global_callbacks"] + + def __init__(self) -> None: + self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None + self._callbacks: Dict[KT, Set[Callable[[], None]]] = {} + self._global_callbacks: Set[Callable[[], None]] = set() + + def deferred(self, key: KT) -> "defer.Deferred[VT]": + if not self._deferred: + self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) + return self._deferred.observe().addCallback(lambda res: res.get(key)) + + def add_invalidation_callback( + self, key: KT, callback: Optional[Callable[[], None]] + ) -> None: + if callback is None: + return + + self._callbacks.setdefault(key, set()).add(callback) + + def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]: + return self._callbacks.get(key, set()) | self._global_callbacks + + def add_global_invalidation_callback( + self, callback: Optional[Callable[[], None]] + ) -> None: + """Add a callback for when any keys get invalidated.""" + if callback is None: + return + + self._global_callbacks.add(callback) + + def complete_bulk( + self, + cache: DeferredCache[KT, VT], + result: Dict[KT, VT], + ) -> None: + """Called when there is a result""" + for key, value in result.items(): + cache._completed_callback(value, self, key) + + if self._deferred: + self._deferred.callback(result) + + def error_bulk( + self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure + ) -> None: + """Called when bulk lookup failed.""" + for key in keys: + cache._error_callback(failure, self, key) + + if self._deferred: + self._deferred.errback(failure) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 867f315b2a..10aff4d04a 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -25,6 +25,7 @@ from typing import ( Generic, Hashable, Iterable, + List, Mapping, Optional, Sequence, @@ -73,8 +74,10 @@ class _CacheDescriptorBase: num_args: Optional[int], uncached_args: Optional[Collection[str]] = None, cache_context: bool = False, + name: Optional[str] = None, ): self.orig = orig + self.name = name or orig.__name__ arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args @@ -211,7 +214,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: cache: LruCache[CacheKey, Any] = LruCache( - cache_name=self.orig.__name__, + cache_name=self.name, max_size=self.max_entries, ) @@ -241,7 +244,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): wrapped = cast(_CachedFunction, _wrapped) wrapped.cache = cache - obj.__dict__[self.orig.__name__] = wrapped + obj.__dict__[self.name] = wrapped return wrapped @@ -301,12 +304,14 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): cache_context: bool = False, iterable: bool = False, prune_unread_entries: bool = True, + name: Optional[str] = None, ): super().__init__( orig, num_args=num_args, uncached_args=uncached_args, cache_context=cache_context, + name=name, ) if tree and self.num_args < 2: @@ -321,7 +326,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: cache: DeferredCache[CacheKey, Any] = DeferredCache( - name=self.orig.__name__, + name=self.name, max_entries=self.max_entries, tree=self.tree, iterable=self.iterable, @@ -372,7 +377,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): wrapped.cache = cache wrapped.num_args = self.num_args - obj.__dict__[self.orig.__name__] = wrapped + obj.__dict__[self.name] = wrapped return wrapped @@ -393,6 +398,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): cached_method_name: str, list_name: str, num_args: Optional[int] = None, + name: Optional[str] = None, ): """ Args: @@ -403,7 +409,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): but including list_name) to use as cache keys. Defaults to all named args of the function. """ - super().__init__(orig, num_args=num_args, uncached_args=None) + super().__init__(orig, num_args=num_args, uncached_args=None, name=name) self.list_name = list_name @@ -435,16 +441,6 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] list_args = arg_dict[self.list_name] - results = {} - - def update_results_dict(res: Any, arg: Hashable) -> None: - results[arg] = res - - # list of deferreds to wait for - cached_defers = [] - - missing = set() - # If the cache takes a single arg then that is used as the key, # otherwise a tuple is used. if num_args == 1: @@ -452,6 +448,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): def arg_to_cache_key(arg: Hashable) -> Hashable: return arg + def cache_key_to_arg(key: tuple) -> Hashable: + return key + else: keylist = list(keyargs) @@ -459,58 +458,53 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): keylist[self.list_pos] = arg return tuple(keylist) - for arg in list_args: - try: - res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback) - if not res.called: - res.addCallback(update_results_dict, arg) - cached_defers.append(res) - else: - results[arg] = res.result - except KeyError: - missing.add(arg) + def cache_key_to_arg(key: tuple) -> Hashable: + return key[self.list_pos] + + cache_keys = [arg_to_cache_key(arg) for arg in list_args] + immediate_results, pending_deferred, missing = cache.get_bulk( + cache_keys, callback=invalidate_callback + ) + + results = {cache_key_to_arg(key): v for key, v in immediate_results.items()} + + cached_defers: List["defer.Deferred[Any]"] = [] + if pending_deferred: + + def update_results(r: Dict) -> None: + for k, v in r.items(): + results[cache_key_to_arg(k)] = v + + pending_deferred.addCallback(update_results) + cached_defers.append(pending_deferred) if missing: - # we need a deferred for each entry in the list, - # which we put in the cache. Each deferred resolves with the - # relevant result for that key. - deferreds_map = {} - for arg in missing: - deferred: "defer.Deferred[Any]" = defer.Deferred() - deferreds_map[arg] = deferred - key = arg_to_cache_key(arg) - cached_defers.append( - cache.set(key, deferred, callback=invalidate_callback) - ) + cache_entry = cache.start_bulk_input(missing, invalidate_callback) def complete_all(res: Dict[Hashable, Any]) -> None: - # the wrapped function has completed. It returns a dict. - # We can now update our own result map, and then resolve the - # observable deferreds in the cache. - for e, d1 in deferreds_map.items(): - val = res.get(e, None) - # make sure we update the results map before running the - # deferreds, because as soon as we run the last deferred, the - # gatherResults() below will complete and return the result - # dict to our caller. - results[e] = val - d1.callback(val) + missing_results = {} + for key in missing: + arg = cache_key_to_arg(key) + val = res.get(arg, None) + + results[arg] = val + missing_results[key] = val + + cache_entry.complete_bulk(cache, missing_results) def errback_all(f: Failure) -> None: - # the wrapped function has failed. Propagate the failure into - # the cache, which will invalidate the entry, and cause the - # relevant cached_deferreds to fail, which will propagate the - # failure to our caller. - for d1 in deferreds_map.values(): - d1.errback(f) + cache_entry.error_bulk(cache, missing, f) args_to_call = dict(arg_dict) - args_to_call[self.list_name] = missing + args_to_call[self.list_name] = { + cache_key_to_arg(key) for key in missing + } # dispatch the call, and attach the two handlers - defer.maybeDeferred( + missing_d = defer.maybeDeferred( preserve_fn(self.orig), **args_to_call ).addCallbacks(complete_all, errback_all) + cached_defers.append(missing_d) if cached_defers: d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( @@ -525,7 +519,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): else: return defer.succeed(results) - obj.__dict__[self.orig.__name__] = wrapped + obj.__dict__[self.name] = wrapped return wrapped @@ -577,6 +571,7 @@ def cached( cache_context: bool = False, iterable: bool = False, prune_unread_entries: bool = True, + name: Optional[str] = None, ) -> Callable[[F], _CachedFunction[F]]: func = lambda orig: DeferredCacheDescriptor( orig, @@ -587,13 +582,18 @@ def cached( cache_context=cache_context, iterable=iterable, prune_unread_entries=prune_unread_entries, + name=name, ) return cast(Callable[[F], _CachedFunction[F]], func) def cachedList( - *, cached_method_name: str, list_name: str, num_args: Optional[int] = None + *, + cached_method_name: str, + list_name: str, + num_args: Optional[int] = None, + name: Optional[str] = None, ) -> Callable[[F], _CachedFunction[F]]: """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. @@ -628,6 +628,7 @@ def cachedList( cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, + name=name, ) return cast(Callable[[F], _CachedFunction[F]], func) diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index d267703df0..fa91479c97 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -14,11 +14,13 @@ import enum import logging import threading -from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar +from typing import Any, Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union import attr +from typing_extensions import Literal from synapse.util.caches.lrucache import LruCache +from synapse.util.caches.treecache import TreeCache logger = logging.getLogger(__name__) @@ -33,10 +35,12 @@ DV = TypeVar("DV") # This class can't be generic because it uses slots with attrs. # See: https://github.com/python-attrs/attrs/issues/313 -@attr.s(slots=True, auto_attribs=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class DictionaryEntry: # should be: Generic[DKT, DV]. """Returned when getting an entry from the cache + If `full` is true then `known_absent` will be the empty set. + Attributes: full: Whether the cache has the full or dict or just some keys. If not full then not all requested keys will necessarily be present @@ -53,20 +57,90 @@ class DictionaryEntry: # should be: Generic[DKT, DV]. return len(self.value) +class _FullCacheKey(enum.Enum): + """The key we use to cache the full dict.""" + + KEY = object() + + class _Sentinel(enum.Enum): # defining a sentinel in this way allows mypy to correctly handle the # type of a dictionary lookup. sentinel = object() +class _PerKeyValue(Generic[DV]): + """The cached value of a dictionary key. If `value` is the sentinel, + indicates that the requested key is known to *not* be in the full dict. + """ + + __slots__ = ["value"] + + def __init__(self, value: Union[DV, Literal[_Sentinel.sentinel]]) -> None: + self.value = value + + def __len__(self) -> int: + # We add a `__len__` implementation as we use this class in a cache + # where the values are variable length. + return 1 + + class DictionaryCache(Generic[KT, DKT, DV]): """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. fetching a subset of dictionary keys for a particular key. + + This cache has two levels of key. First there is the "cache key" (of type + `KT`), which maps to a dict. The keys to that dict are the "dict key" (of + type `DKT`). The overall structure is therefore `KT->DKT->DV`. For + example, it might look like: + + { + 1: { 1: "a", 2: "b" }, + 2: { 1: "c" }, + } + + It is possible to look up either individual dict keys, or the *complete* + dict for a given cache key. + + Each dict item, and the complete dict is treated as a separate LRU + entry for the purpose of cache expiry. For example, given: + dict_cache.get(1, None) -> DictionaryEntry({1: "a", 2: "b"}) + dict_cache.get(1, [1]) -> DictionaryEntry({1: "a"}) + dict_cache.get(1, [2]) -> DictionaryEntry({2: "b"}) + + ... then the cache entry for the complete dict will expire first, + followed by the cache entry for the '1' dict key, and finally that + for the '2' dict key. """ def __init__(self, name: str, max_entries: int = 1000): - self.cache: LruCache[KT, DictionaryEntry] = LruCache( - max_size=max_entries, cache_name=name, size_callback=len + # We use a single LruCache to store two different types of entries: + # 1. Map from (key, dict_key) -> dict value (or sentinel, indicating + # the key doesn't exist in the dict); and + # 2. Map from (key, _FullCacheKey.KEY) -> full dict. + # + # The former is used when explicit keys of the dictionary are looked up, + # and the latter when the full dictionary is requested. + # + # If when explicit keys are requested and not in the cache, we then look + # to see if we have the full dict and use that if we do. If found in the + # full dict each key is added into the cache. + # + # This set up allows the `LruCache` to prune the full dict entries if + # they haven't been used in a while, even when there have been recent + # queries for subsets of the dict. + # + # Typing: + # * A key of `(KT, DKT)` has a value of `_PerKeyValue` + # * A key of `(KT, _FullCacheKey.KEY)` has a value of `Dict[DKT, DV]` + self.cache: LruCache[ + Tuple[KT, Union[DKT, Literal[_FullCacheKey.KEY]]], + Union[_PerKeyValue, Dict[DKT, DV]], + ] = LruCache( + max_size=max_entries, + cache_name=name, + cache_type=TreeCache, + size_callback=len, ) self.name = name @@ -91,23 +165,83 @@ class DictionaryCache(Generic[KT, DKT, DV]): Args: key dict_keys: If given a set of keys then return only those keys - that exist in the cache. + that exist in the cache. If None then returns the full dict + if it is in the cache. Returns: - DictionaryEntry + DictionaryEntry: If `dict_keys` is not None then `DictionaryEntry` + will contain include the keys that are in the cache. If None then + will either return the full dict if in the cache, or the empty + dict (with `full` set to False) if it isn't. """ - entry = self.cache.get(key, _Sentinel.sentinel) - if entry is not _Sentinel.sentinel: - if dict_keys is None: - return DictionaryEntry( - entry.full, entry.known_absent, dict(entry.value) - ) + if dict_keys is None: + # The caller wants the full set of dictionary keys for this cache key + return self._get_full_dict(key) + + # We are being asked for a subset of keys. + + # First go and check for each requested dict key in the cache, tracking + # which we couldn't find. + values = {} + known_absent = set() + missing = [] + for dict_key in dict_keys: + entry = self.cache.get((key, dict_key), _Sentinel.sentinel) + if entry is _Sentinel.sentinel: + missing.append(dict_key) + continue + + assert isinstance(entry, _PerKeyValue) + + if entry.value is _Sentinel.sentinel: + known_absent.add(dict_key) else: - return DictionaryEntry( - entry.full, - entry.known_absent, - {k: entry.value[k] for k in dict_keys if k in entry.value}, - ) + values[dict_key] = entry.value + + # If we found everything we can return immediately. + if not missing: + return DictionaryEntry(False, known_absent, values) + + # We are missing some keys, so check if we happen to have the full dict in + # the cache. + # + # We don't update the last access time for this cache fetch, as we + # aren't explicitly interested in the full dict and so we don't want + # requests for explicit dict keys to keep the full dict in the cache. + entry = self.cache.get( + (key, _FullCacheKey.KEY), + _Sentinel.sentinel, + update_last_access=False, + ) + if entry is _Sentinel.sentinel: + # Not in the cache, return the subset of keys we found. + return DictionaryEntry(False, known_absent, values) + + # We have the full dict! + assert isinstance(entry, dict) + + for dict_key in missing: + # We explicitly add each dict key to the cache, so that cache hit + # rates and LRU times for each key can be tracked separately. + value = entry.get(dict_key, _Sentinel.sentinel) # type: ignore[arg-type] + self.cache[(key, dict_key)] = _PerKeyValue(value) + + if value is not _Sentinel.sentinel: + values[dict_key] = value + + return DictionaryEntry(True, set(), values) + + def _get_full_dict( + self, + key: KT, + ) -> DictionaryEntry: + """Fetch the full dict for the given key.""" + + # First we check if we have cached the full dict. + entry = self.cache.get((key, _FullCacheKey.KEY), _Sentinel.sentinel) + if entry is not _Sentinel.sentinel: + assert isinstance(entry, dict) + return DictionaryEntry(True, set(), entry) return DictionaryEntry(False, set(), {}) @@ -117,7 +251,13 @@ class DictionaryCache(Generic[KT, DKT, DV]): # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 - self.cache.pop(key, None) + + # We want to drop all information about the dict for the given key, so + # we use `del_multi` to delete it all in one go. + # + # We ignore the type error here: `del_multi` accepts a truncated key + # (when the key type is a tuple). + self.cache.del_multi((key,)) # type: ignore[arg-type] def invalidate_all(self) -> None: self.check_thread() @@ -131,7 +271,16 @@ class DictionaryCache(Generic[KT, DKT, DV]): value: Dict[DKT, DV], fetched_keys: Optional[Iterable[DKT]] = None, ) -> None: - """Updates the entry in the cache + """Updates the entry in the cache. + + Note: This does *not* invalidate any existing entries for the `key`. + In particular, if we add an entry for the cached "full dict" with + `fetched_keys=None`, existing entries for individual dict keys are + not invalidated. Likewise, adding entries for individual keys does + not invalidate any cached value for the full dict. + + In other words: if the underlying data is *changed*, the cache must + be explicitly invalidated via `.invalidate()`. Args: sequence @@ -149,20 +298,27 @@ class DictionaryCache(Generic[KT, DKT, DV]): # Only update the cache if the caches sequence number matches the # number that the cache had before the SELECT was started (SYN-369) if fetched_keys is None: - self._insert(key, value, set()) + self.cache[(key, _FullCacheKey.KEY)] = value else: - self._update_or_insert(key, value, fetched_keys) + self._update_subset(key, value, fetched_keys) - def _update_or_insert( - self, key: KT, value: Dict[DKT, DV], known_absent: Iterable[DKT] + def _update_subset( + self, key: KT, value: Dict[DKT, DV], fetched_keys: Iterable[DKT] ) -> None: - # We pop and reinsert as we need to tell the cache the size may have - # changed + """Add the given dictionary values as explicit keys in the cache. + + Args: + key: top-level cache key + value: The dictionary with all the values that we should cache + fetched_keys: The full set of dict keys that were looked up. Any keys + here not in `value` should be marked as "known absent". + """ + + for dict_key, dict_value in value.items(): + self.cache[(key, dict_key)] = _PerKeyValue(dict_value) - entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {})) - entry.value.update(value) - entry.known_absent.update(known_absent) - self.cache[key] = entry + for dict_key in fetched_keys: + if dict_key in value: + continue - def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None: - self.cache[key] = DictionaryEntry(True, known_absent, value) + self.cache[(key, dict_key)] = _PerKeyValue(_Sentinel.sentinel) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 31f41fec82..aa93109d13 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -25,8 +25,10 @@ from typing import ( Collection, Dict, Generic, + Iterable, List, Optional, + Tuple, Type, TypeVar, Union, @@ -44,7 +46,11 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces from synapse.metrics.jemalloc import get_jemalloc_stats from synapse.util import Clock, caches from synapse.util.caches import CacheMetric, EvictionReason, register_cache -from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry +from synapse.util.caches.treecache import ( + TreeCache, + iterate_tree_cache_entry, + iterate_tree_cache_items, +) from synapse.util.linked_list import ListNode if TYPE_CHECKING: @@ -537,6 +543,7 @@ class LruCache(Generic[KT, VT]): default: Literal[None] = None, callbacks: Collection[Callable[[], None]] = ..., update_metrics: bool = ..., + update_last_access: bool = ..., ) -> Optional[VT]: ... @@ -546,6 +553,7 @@ class LruCache(Generic[KT, VT]): default: T, callbacks: Collection[Callable[[], None]] = ..., update_metrics: bool = ..., + update_last_access: bool = ..., ) -> Union[T, VT]: ... @@ -555,10 +563,27 @@ class LruCache(Generic[KT, VT]): default: Optional[T] = None, callbacks: Collection[Callable[[], None]] = (), update_metrics: bool = True, + update_last_access: bool = True, ) -> Union[None, T, VT]: + """Look up a key in the cache + + Args: + key + default + callbacks: A collection of callbacks that will fire when the + node is removed from the cache (either due to invalidation + or expiry). + update_metrics: Whether to update the hit rate metrics + update_last_access: Whether to update the last access metrics + on a node if successfully fetched. These metrics are used + to determine when to remove the node from the cache. Set + to False if this fetch should *not* prevent a node from + being expired. + """ node = cache.get(key, None) if node is not None: - move_node_to_front(node) + if update_last_access: + move_node_to_front(node) node.add_callbacks(callbacks) if update_metrics and metrics: metrics.inc_hits() @@ -568,6 +593,65 @@ class LruCache(Generic[KT, VT]): metrics.inc_misses() return default + @overload + def cache_get_multi( + key: tuple, + default: Literal[None] = None, + update_metrics: bool = True, + ) -> Union[None, Iterable[Tuple[KT, VT]]]: + ... + + @overload + def cache_get_multi( + key: tuple, + default: T, + update_metrics: bool = True, + ) -> Union[T, Iterable[Tuple[KT, VT]]]: + ... + + @synchronized + def cache_get_multi( + key: tuple, + default: Optional[T] = None, + update_metrics: bool = True, + ) -> Union[None, T, Iterable[Tuple[KT, VT]]]: + """Returns a generator yielding all entries under the given key. + + Can only be used if backed by a tree cache. + + Example: + + cache = LruCache(10, cache_type=TreeCache) + cache[(1, 1)] = "a" + cache[(1, 2)] = "b" + cache[(2, 1)] = "c" + + items = cache.get_multi((1,)) + assert list(items) == [((1, 1), "a"), ((1, 2), "b")] + + Returns: + Either default if the key doesn't exist, or a generator of the + key/value pairs. + """ + + assert isinstance(cache, TreeCache) + + node = cache.get(key, None) + if node is not None: + if update_metrics and metrics: + metrics.inc_hits() + + # We store entries in the `TreeCache` with values of type `_Node`, + # which we need to unwrap. + return ( + (full_key, lru_node.value) + for full_key, lru_node in iterate_tree_cache_items(key, node) + ) + else: + if update_metrics and metrics: + metrics.inc_misses() + return default + @synchronized def cache_set( key: KT, value: VT, callbacks: Collection[Callable[[], None]] = () @@ -674,6 +758,8 @@ class LruCache(Generic[KT, VT]): self.setdefault = cache_set_default self.pop = cache_pop self.del_multi = cache_del_multi + if cache_type is TreeCache: + self.get_multi = cache_get_multi # `invalidate` is exposed for consistency with DeferredCache, so that it can be # invalidated by the cache invalidation replication stream. self.invalidate = cache_del_multi @@ -748,9 +834,26 @@ class AsyncLruCache(Generic[KT, VT]): ) -> Optional[VT]: return self._lru_cache.get(key, update_metrics=update_metrics) + async def get_external( + self, + key: KT, + default: Optional[T] = None, + update_metrics: bool = True, + ) -> Optional[VT]: + # This method should fetch from any configured external cache, in this case noop. + return None + + def get_local( + self, key: KT, default: Optional[T] = None, update_metrics: bool = True + ) -> Optional[VT]: + return self._lru_cache.get(key, update_metrics=update_metrics) + async def set(self, key: KT, value: VT) -> None: self._lru_cache.set(key, value) + def set_local(self, key: KT, value: VT) -> None: + self._lru_cache.set(key, value) + async def invalidate(self, key: KT) -> None: # This method should invalidate any external cache and then invalidate the LruCache. return self._lru_cache.invalidate(key) diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index e78305f787..fec31da2b6 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -64,6 +64,15 @@ class TreeCache: self.size += 1 def get(self, key, default=None): + """When `key` is a full key, fetches the value for the given key (if + any). + + If `key` is only a partial key (i.e. a truncated tuple) then returns a + `TreeCacheNode`, which can be passed to the `iterate_tree_cache_*` + functions to iterate over all entries in the cache with keys that start + with the given partial key. + """ + node = self.root for k in key[:-1]: node = node.get(k, None) @@ -126,6 +135,9 @@ class TreeCache: def values(self): return iterate_tree_cache_entry(self.root) + def items(self): + return iterate_tree_cache_items((), self.root) + def __len__(self) -> int: return self.size @@ -139,3 +151,32 @@ def iterate_tree_cache_entry(d): yield from iterate_tree_cache_entry(value_d) else: yield d + + +def iterate_tree_cache_items(key, value): + """Helper function to iterate over the leaves of a tree, i.e. a dict of that + can contain dicts. + + The provided key is a tuple that will get prepended to the returned keys. + + Example: + + cache = TreeCache() + cache[(1, 1)] = "a" + cache[(1, 2)] = "b" + cache[(2, 1)] = "c" + + tree_node = cache.get((1,)) + + items = iterate_tree_cache_items((1,), tree_node) + assert list(items) == [((1, 1), "a"), ((1, 2), "b")] + + Returns: + A generator yielding key/value pairs. + """ + if isinstance(value, TreeCacheNode): + for sub_key, sub_value in value.items(): + yield from iterate_tree_cache_items((*key, sub_key), sub_value) + else: + # we've reached a leaf of the tree. + yield key, value diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index dfe628c97e..f678b52cb4 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -18,15 +18,19 @@ import logging import typing from typing import Any, DefaultDict, Iterator, List, Set +from prometheus_client.core import Counter + from twisted.internet import defer from synapse.api.errors import LimitExceededError -from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.config.ratelimiting import FederationRatelimitSettings from synapse.logging.context import ( PreserveLoggingContext, make_deferred_yieldable, run_in_background, ) +from synapse.logging.opentracing import start_active_span +from synapse.metrics import Histogram, LaterGauge from synapse.util import Clock if typing.TYPE_CHECKING: @@ -35,8 +39,34 @@ if typing.TYPE_CHECKING: logger = logging.getLogger(__name__) +# Track how much the ratelimiter is affecting requests +rate_limit_sleep_counter = Counter("synapse_rate_limit_sleep", "") +rate_limit_reject_counter = Counter("synapse_rate_limit_reject", "") +queue_wait_timer = Histogram( + "synapse_rate_limit_queue_wait_time_seconds", + "sec", + [], + buckets=( + 0.005, + 0.01, + 0.025, + 0.05, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 10.0, + 20.0, + "+Inf", + ), +) + + class FederationRateLimiter: - def __init__(self, clock: Clock, config: FederationRateLimitConfig): + def __init__(self, clock: Clock, config: FederationRatelimitSettings): def new_limiter() -> "_PerHostRatelimiter": return _PerHostRatelimiter(clock=clock, config=config) @@ -44,6 +74,27 @@ class FederationRateLimiter: str, "_PerHostRatelimiter" ] = collections.defaultdict(new_limiter) + # We track the number of affected hosts per time-period so we can + # differentiate one really noisy homeserver from a general + # ratelimit tuning problem across the federation. + LaterGauge( + "synapse_rate_limit_sleep_affected_hosts", + "Number of hosts that had requests put to sleep", + [], + lambda: sum( + ratelimiter.should_sleep() for ratelimiter in self.ratelimiters.values() + ), + ) + LaterGauge( + "synapse_rate_limit_reject_affected_hosts", + "Number of hosts that had requests rejected", + [], + lambda: sum( + ratelimiter.should_reject() + for ratelimiter in self.ratelimiters.values() + ), + ) + def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]": """Used to ratelimit an incoming request from a given host @@ -59,11 +110,11 @@ class FederationRateLimiter: Returns: context manager which returns a deferred. """ - return self.ratelimiters[host].ratelimit() + return self.ratelimiters[host].ratelimit(host) class _PerHostRatelimiter: - def __init__(self, clock: Clock, config: FederationRateLimitConfig): + def __init__(self, clock: Clock, config: FederationRatelimitSettings): """ Args: clock @@ -94,19 +145,42 @@ class _PerHostRatelimiter: self.request_times: List[int] = [] @contextlib.contextmanager - def ratelimit(self) -> "Iterator[defer.Deferred[None]]": + def ratelimit(self, host: str) -> "Iterator[defer.Deferred[None]]": # `contextlib.contextmanager` takes a generator and turns it into a # context manager. The generator should only yield once with a value # to be returned by manager. # Exceptions will be reraised at the yield. + self.host = host + request_id = object() - ret = self._on_enter(request_id) + # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant + # type-checking, but we'd need Twisted >= 21.2. + ret = defer.ensureDeferred(self._on_enter_with_tracing(request_id)) try: yield ret finally: self._on_exit(request_id) + def should_reject(self) -> bool: + """ + Whether to reject the request if we already have too many queued up + (either sleeping or in the ready queue). + """ + queue_size = len(self.ready_request_queue) + len(self.sleeping_requests) + return queue_size > self.reject_limit + + def should_sleep(self) -> bool: + """ + Whether to sleep the request if we already have too many requests coming + through within the window. + """ + return len(self.request_times) > self.sleep_limit + + async def _on_enter_with_tracing(self, request_id: object) -> None: + with start_active_span("ratelimit wait"), queue_wait_timer.time(): + await self._on_enter(request_id) + def _on_enter(self, request_id: object) -> "defer.Deferred[None]": time_now = self.clock.time_msec() @@ -117,8 +191,9 @@ class _PerHostRatelimiter: # reject the request if we already have too many queued up (either # sleeping or in the ready queue). - queue_size = len(self.ready_request_queue) + len(self.sleeping_requests) - if queue_size > self.reject_limit: + if self.should_reject(): + logger.debug("Ratelimiter(%s): rejecting request", self.host) + rate_limit_reject_counter.inc() raise LimitExceededError( retry_after_ms=int(self.window_size / self.sleep_limit) ) @@ -130,7 +205,8 @@ class _PerHostRatelimiter: queue_defer: defer.Deferred[None] = defer.Deferred() self.ready_request_queue[request_id] = queue_defer logger.info( - "Ratelimiter: queueing request (queue now %i items)", + "Ratelimiter(%s): queueing request (queue now %i items)", + self.host, len(self.ready_request_queue), ) @@ -139,19 +215,28 @@ class _PerHostRatelimiter: return defer.succeed(None) logger.debug( - "Ratelimit [%s]: len(self.request_times)=%d", + "Ratelimit(%s) [%s]: len(self.request_times)=%d", + self.host, id(request_id), len(self.request_times), ) - if len(self.request_times) > self.sleep_limit: - logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec) + if self.should_sleep(): + logger.debug( + "Ratelimiter(%s) [%s]: sleeping request for %f sec", + self.host, + id(request_id), + self.sleep_sec, + ) + rate_limit_sleep_counter.inc() ret_defer = run_in_background(self.clock.sleep, self.sleep_sec) self.sleeping_requests.add(request_id) def on_wait_finished(_: Any) -> "defer.Deferred[None]": - logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id)) + logger.debug( + "Ratelimit(%s) [%s]: Finished sleeping", self.host, id(request_id) + ) self.sleeping_requests.discard(request_id) queue_defer = queue_request() return queue_defer @@ -161,7 +246,9 @@ class _PerHostRatelimiter: ret_defer = queue_request() def on_start(r: object) -> object: - logger.debug("Ratelimit [%s]: Processing req", id(request_id)) + logger.debug( + "Ratelimit(%s) [%s]: Processing req", self.host, id(request_id) + ) self.current_processing.add(request_id) return r @@ -183,7 +270,7 @@ class _PerHostRatelimiter: return make_deferred_yieldable(ret_defer) def _on_exit(self, request_id: object) -> None: - logger.debug("Ratelimit [%s]: Processed req", id(request_id)) + logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id)) self.current_processing.discard(request_id) try: # start processing the next item on the queue. |