diff options
author | Erik Johnston <erik@matrix.org> | 2022-08-23 15:53:27 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-08-23 14:53:27 +0000 |
commit | f7ddfe17a30a50205a23bf5ca4d7d71e691e1e48 (patch) | |
tree | 3af1cd2f62868f6e1fe0f455e7cf19739161fdfa /synapse/util/caches/deferred_cache.py | |
parent | Fix regression caused by #13573 (#13600) (diff) | |
download | synapse-f7ddfe17a30a50205a23bf5ca4d7d71e691e1e48.tar.xz |
Speed up `@cachedList` (#13591)
This speeds things up by ~2x. The vast majority of the time is now spent in `LruCache` moving things around the linked lists. We do this via two things: 1. Don't create a deferred per-key during bulk set operations in `DeferredCache`. Instead, only create them if a subsequent caller asks for the key. 2. Add a bulk lookup API to `DeferredCache` rather than use a loop.
Diffstat (limited to 'synapse/util/caches/deferred_cache.py')
-rw-r--r-- | synapse/util/caches/deferred_cache.py | 346 |
1 files changed, 255 insertions, 91 deletions
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 1d6ec22191..6425f851ea 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -14,15 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import enum import threading from typing import ( Callable, + Collection, + Dict, Generic, - Iterable, MutableMapping, Optional, + Set, Sized, + Tuple, TypeVar, Union, cast, @@ -31,7 +35,6 @@ from typing import ( from prometheus_client import Gauge from twisted.internet import defer -from twisted.python import failure from twisted.python.failure import Failure from synapse.util.async_helpers import ObservableDeferred @@ -94,7 +97,7 @@ class DeferredCache(Generic[KT, VT]): # _pending_deferred_cache maps from the key value to a `CacheEntry` object. self._pending_deferred_cache: Union[ - TreeCache, "MutableMapping[KT, CacheEntry]" + TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]" ] = cache_type() def metrics_cb() -> None: @@ -159,15 +162,16 @@ class DeferredCache(Generic[KT, VT]): Raises: KeyError if the key is not found in the cache """ - callbacks = [callback] if callback else [] val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) if val is not _Sentinel.sentinel: - val.callbacks.update(callbacks) + val.add_invalidation_callback(key, callback) if update_metrics: m = self.cache.metrics assert m # we always have a name, so should always have metrics m.inc_hits() - return val.deferred.observe() + return val.deferred(key) + + callbacks = (callback,) if callback else () val2 = self.cache.get( key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics @@ -177,6 +181,73 @@ class DeferredCache(Generic[KT, VT]): else: return defer.succeed(val2) + def get_bulk( + self, + keys: Collection[KT], + callback: Optional[Callable[[], None]] = None, + ) -> Tuple[Dict[KT, VT], Optional["defer.Deferred[Dict[KT, VT]]"], Collection[KT]]: + """Bulk lookup of items in the cache. + + Returns: + A 3-tuple of: + 1. a dict of key/value of items already cached; + 2. a deferred that resolves to a dict of key/value of items + we're already fetching; and + 3. a collection of keys that don't appear in the previous two. + """ + + # The cached results + cached = {} + + # List of pending deferreds + pending = [] + + # Dict that gets filled out when the pending deferreds complete + pending_results = {} + + # List of keys that aren't in either cache + missing = [] + + callbacks = (callback,) if callback else () + + for key in keys: + # Check if its in the main cache. + immediate_value = self.cache.get( + key, + _Sentinel.sentinel, + callbacks=callbacks, + ) + if immediate_value is not _Sentinel.sentinel: + cached[key] = immediate_value + continue + + # Check if its in the pending cache + pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel) + if pending_value is not _Sentinel.sentinel: + pending_value.add_invalidation_callback(key, callback) + + def completed_cb(value: VT, key: KT) -> VT: + pending_results[key] = value + return value + + # Add a callback to fill out `pending_results` when that completes + d = pending_value.deferred(key).addCallback(completed_cb, key) + pending.append(d) + continue + + # Not in either cache + missing.append(key) + + # If we've got pending deferreds, squash them into a single one that + # returns `pending_results`. + pending_deferred = None + if pending: + pending_deferred = defer.gatherResults( + pending, consumeErrors=True + ).addCallback(lambda _: pending_results) + + return (cached, pending_deferred, missing) + def get_immediate( self, key: KT, default: T, update_metrics: bool = True ) -> Union[VT, T]: @@ -218,84 +289,89 @@ class DeferredCache(Generic[KT, VT]): value: a deferred which will complete with a result to add to the cache callback: An optional callback to be called when the entry is invalidated """ - if not isinstance(value, defer.Deferred): - raise TypeError("not a Deferred") - - callbacks = [callback] if callback else [] self.check_thread() - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry: - existing_entry.invalidate() + self._pending_deferred_cache.pop(key, None) # XXX: why don't we invalidate the entry in `self.cache` yet? - # we can save a whole load of effort if the deferred is ready. - if value.called: - result = value.result - if not isinstance(result, failure.Failure): - self.cache.set(key, cast(VT, result), callbacks) - return value - # otherwise, we'll add an entry to the _pending_deferred_cache for now, # and add callbacks to add it to the cache properly later. + entry = CacheEntrySingle[KT, VT](value) + entry.add_invalidation_callback(key, callback) + self._pending_deferred_cache[key] = entry + deferred = entry.deferred(key).addCallbacks( + self._completed_callback, + self._error_callback, + callbackArgs=(entry, key), + errbackArgs=(entry, key), + ) - observable = ObservableDeferred(value, consumeErrors=True) - observer = observable.observe() - entry = CacheEntry(deferred=observable, callbacks=callbacks) + # we return a new Deferred which will be called before any subsequent observers. + return deferred - self._pending_deferred_cache[key] = entry + def start_bulk_input( + self, + keys: Collection[KT], + callback: Optional[Callable[[], None]] = None, + ) -> "CacheMultipleEntries[KT, VT]": + """Bulk set API for use when fetching multiple keys at once from the DB. - def compare_and_pop() -> bool: - """Check if our entry is still the one in _pending_deferred_cache, and - if so, pop it. - - Returns true if the entries matched. - """ - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry is entry: - return True - - # oops, the _pending_deferred_cache has been updated since - # we started our query, so we are out of date. - # - # Better put back whatever we took out. (We do it this way - # round, rather than peeking into the _pending_deferred_cache - # and then removing on a match, to make the common case faster) - if existing_entry is not None: - self._pending_deferred_cache[key] = existing_entry - - return False - - def cb(result: VT) -> None: - if compare_and_pop(): - self.cache.set(key, result, entry.callbacks) - else: - # we're not going to put this entry into the cache, so need - # to make sure that the invalidation callbacks are called. - # That was probably done when _pending_deferred_cache was - # updated, but it's possible that `set` was called without - # `invalidate` being previously called, in which case it may - # not have been. Either way, let's double-check now. - entry.invalidate() - - def eb(_fail: Failure) -> None: - compare_and_pop() - entry.invalidate() - - # once the deferred completes, we can move the entry from the - # _pending_deferred_cache to the real cache. - # - observer.addCallbacks(cb, eb) + Called *before* starting the fetch from the DB, and the caller *must* + call either `complete_bulk(..)` or `error_bulk(..)` on the return value. + """ - # we return a new Deferred which will be called before any subsequent observers. - return observable.observe() + entry = CacheMultipleEntries[KT, VT]() + entry.add_global_invalidation_callback(callback) + + for key in keys: + self._pending_deferred_cache[key] = entry + + return entry + + def _completed_callback( + self, value: VT, entry: "CacheEntry[KT, VT]", key: KT + ) -> VT: + """Called when a deferred is completed.""" + # We check if the current entry matches the entry associated with the + # deferred. If they don't match then it got invalidated. + current_entry = self._pending_deferred_cache.pop(key, None) + if current_entry is not entry: + if current_entry: + self._pending_deferred_cache[key] = current_entry + return value + + self.cache.set(key, value, entry.get_invalidation_callbacks(key)) + + return value + + def _error_callback( + self, + failure: Failure, + entry: "CacheEntry[KT, VT]", + key: KT, + ) -> Failure: + """Called when a deferred errors.""" + + # We check if the current entry matches the entry associated with the + # deferred. If they don't match then it got invalidated. + current_entry = self._pending_deferred_cache.pop(key, None) + if current_entry is not entry: + if current_entry: + self._pending_deferred_cache[key] = current_entry + return failure + + for cb in entry.get_invalidation_callbacks(key): + cb() + + return failure def prefill( self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None ) -> None: - callbacks = [callback] if callback else [] + callbacks = (callback,) if callback else () self.cache.set(key, value, callbacks=callbacks) + self._pending_deferred_cache.pop(key, None) def invalidate(self, key: KT) -> None: """Delete a key, or tree of entries @@ -311,41 +387,129 @@ class DeferredCache(Generic[KT, VT]): self.cache.del_multi(key) # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, which will (a) stop it being returned - # for future queries and (b) stop it being persisted as a proper entry + # _pending_deferred_cache, which will (a) stop it being returned for + # future queries and (b) stop it being persisted as a proper entry # in self.cache. entry = self._pending_deferred_cache.pop(key, None) - - # run the invalidation callbacks now, rather than waiting for the - # deferred to resolve. if entry: # _pending_deferred_cache.pop should either return a CacheEntry, or, in the # case of a TreeCache, a dict of keys to cache entries. Either way calling # iterate_tree_cache_entry on it will do the right thing. for entry in iterate_tree_cache_entry(entry): - entry.invalidate() + for cb in entry.get_invalidation_callbacks(key): + cb() def invalidate_all(self) -> None: self.check_thread() self.cache.clear() - for entry in self._pending_deferred_cache.values(): - entry.invalidate() + for key, entry in self._pending_deferred_cache.items(): + for cb in entry.get_invalidation_callbacks(key): + cb() + self._pending_deferred_cache.clear() -class CacheEntry: - __slots__ = ["deferred", "callbacks", "invalidated"] +class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta): + """Abstract class for entries in `DeferredCache[KT, VT]`""" - def __init__( - self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] - ): - self.deferred = deferred - self.callbacks = set(callbacks) - self.invalidated = False - - def invalidate(self) -> None: - if not self.invalidated: - self.invalidated = True - for callback in self.callbacks: - callback() - self.callbacks.clear() + @abc.abstractmethod + def deferred(self, key: KT) -> "defer.Deferred[VT]": + """Get a deferred that a caller can wait on to get the value at the + given key""" + ... + + @abc.abstractmethod + def add_invalidation_callback( + self, key: KT, callback: Optional[Callable[[], None]] + ) -> None: + """Add an invalidation callback""" + ... + + @abc.abstractmethod + def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]: + """Get all invalidation callbacks""" + ... + + +class CacheEntrySingle(CacheEntry[KT, VT]): + """An implementation of `CacheEntry` wrapping a deferred that results in a + single cache entry. + """ + + __slots__ = ["_deferred", "_callbacks"] + + def __init__(self, deferred: "defer.Deferred[VT]") -> None: + self._deferred = ObservableDeferred(deferred, consumeErrors=True) + self._callbacks: Set[Callable[[], None]] = set() + + def deferred(self, key: KT) -> "defer.Deferred[VT]": + return self._deferred.observe() + + def add_invalidation_callback( + self, key: KT, callback: Optional[Callable[[], None]] + ) -> None: + if callback is None: + return + + self._callbacks.add(callback) + + def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]: + return self._callbacks + + +class CacheMultipleEntries(CacheEntry[KT, VT]): + """Cache entry that is used for bulk lookups and insertions.""" + + __slots__ = ["_deferred", "_callbacks", "_global_callbacks"] + + def __init__(self) -> None: + self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None + self._callbacks: Dict[KT, Set[Callable[[], None]]] = {} + self._global_callbacks: Set[Callable[[], None]] = set() + + def deferred(self, key: KT) -> "defer.Deferred[VT]": + if not self._deferred: + self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) + return self._deferred.observe().addCallback(lambda res: res.get(key)) + + def add_invalidation_callback( + self, key: KT, callback: Optional[Callable[[], None]] + ) -> None: + if callback is None: + return + + self._callbacks.setdefault(key, set()).add(callback) + + def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]: + return self._callbacks.get(key, set()) | self._global_callbacks + + def add_global_invalidation_callback( + self, callback: Optional[Callable[[], None]] + ) -> None: + """Add a callback for when any keys get invalidated.""" + if callback is None: + return + + self._global_callbacks.add(callback) + + def complete_bulk( + self, + cache: DeferredCache[KT, VT], + result: Dict[KT, VT], + ) -> None: + """Called when there is a result""" + for key, value in result.items(): + cache._completed_callback(value, self, key) + + if self._deferred: + self._deferred.callback(result) + + def error_bulk( + self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure + ) -> None: + """Called when bulk lookup failed.""" + for key in keys: + cache._error_callback(failure, self, key) + + if self._deferred: + self._deferred.errback(failure) |