summary refs log tree commit diff
path: root/synapse/util/caches
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches')
-rw-r--r--synapse/util/caches/deferred_cache.py346
-rw-r--r--synapse/util/caches/descriptors.py89
-rw-r--r--synapse/util/caches/treecache.py3
3 files changed, 297 insertions, 141 deletions
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 1d6ec22191..6425f851ea 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -14,15 +14,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
 import enum
 import threading
 from typing import (
     Callable,
+    Collection,
+    Dict,
     Generic,
-    Iterable,
     MutableMapping,
     Optional,
+    Set,
     Sized,
+    Tuple,
     TypeVar,
     Union,
     cast,
@@ -31,7 +35,6 @@ from typing import (
 from prometheus_client import Gauge
 
 from twisted.internet import defer
-from twisted.python import failure
 from twisted.python.failure import Failure
 
 from synapse.util.async_helpers import ObservableDeferred
@@ -94,7 +97,7 @@ class DeferredCache(Generic[KT, VT]):
 
         # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
         self._pending_deferred_cache: Union[
-            TreeCache, "MutableMapping[KT, CacheEntry]"
+            TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]"
         ] = cache_type()
 
         def metrics_cb() -> None:
@@ -159,15 +162,16 @@ class DeferredCache(Generic[KT, VT]):
         Raises:
             KeyError if the key is not found in the cache
         """
-        callbacks = [callback] if callback else []
         val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
         if val is not _Sentinel.sentinel:
-            val.callbacks.update(callbacks)
+            val.add_invalidation_callback(key, callback)
             if update_metrics:
                 m = self.cache.metrics
                 assert m  # we always have a name, so should always have metrics
                 m.inc_hits()
-            return val.deferred.observe()
+            return val.deferred(key)
+
+        callbacks = (callback,) if callback else ()
 
         val2 = self.cache.get(
             key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
@@ -177,6 +181,73 @@ class DeferredCache(Generic[KT, VT]):
         else:
             return defer.succeed(val2)
 
+    def get_bulk(
+        self,
+        keys: Collection[KT],
+        callback: Optional[Callable[[], None]] = None,
+    ) -> Tuple[Dict[KT, VT], Optional["defer.Deferred[Dict[KT, VT]]"], Collection[KT]]:
+        """Bulk lookup of items in the cache.
+
+        Returns:
+            A 3-tuple of:
+                1. a dict of key/value of items already cached;
+                2. a deferred that resolves to a dict of key/value of items
+                   we're already fetching; and
+                3. a collection of keys that don't appear in the previous two.
+        """
+
+        # The cached results
+        cached = {}
+
+        # List of pending deferreds
+        pending = []
+
+        # Dict that gets filled out when the pending deferreds complete
+        pending_results = {}
+
+        # List of keys that aren't in either cache
+        missing = []
+
+        callbacks = (callback,) if callback else ()
+
+        for key in keys:
+            # Check if its in the main cache.
+            immediate_value = self.cache.get(
+                key,
+                _Sentinel.sentinel,
+                callbacks=callbacks,
+            )
+            if immediate_value is not _Sentinel.sentinel:
+                cached[key] = immediate_value
+                continue
+
+            # Check if its in the pending cache
+            pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
+            if pending_value is not _Sentinel.sentinel:
+                pending_value.add_invalidation_callback(key, callback)
+
+                def completed_cb(value: VT, key: KT) -> VT:
+                    pending_results[key] = value
+                    return value
+
+                # Add a callback to fill out `pending_results` when that completes
+                d = pending_value.deferred(key).addCallback(completed_cb, key)
+                pending.append(d)
+                continue
+
+            # Not in either cache
+            missing.append(key)
+
+        # If we've got pending deferreds, squash them into a single one that
+        # returns `pending_results`.
+        pending_deferred = None
+        if pending:
+            pending_deferred = defer.gatherResults(
+                pending, consumeErrors=True
+            ).addCallback(lambda _: pending_results)
+
+        return (cached, pending_deferred, missing)
+
     def get_immediate(
         self, key: KT, default: T, update_metrics: bool = True
     ) -> Union[VT, T]:
@@ -218,84 +289,89 @@ class DeferredCache(Generic[KT, VT]):
             value: a deferred which will complete with a result to add to the cache
             callback: An optional callback to be called when the entry is invalidated
         """
-        if not isinstance(value, defer.Deferred):
-            raise TypeError("not a Deferred")
-
-        callbacks = [callback] if callback else []
         self.check_thread()
 
-        existing_entry = self._pending_deferred_cache.pop(key, None)
-        if existing_entry:
-            existing_entry.invalidate()
+        self._pending_deferred_cache.pop(key, None)
 
         # XXX: why don't we invalidate the entry in `self.cache` yet?
 
-        # we can save a whole load of effort if the deferred is ready.
-        if value.called:
-            result = value.result
-            if not isinstance(result, failure.Failure):
-                self.cache.set(key, cast(VT, result), callbacks)
-            return value
-
         # otherwise, we'll add an entry to the _pending_deferred_cache for now,
         # and add callbacks to add it to the cache properly later.
+        entry = CacheEntrySingle[KT, VT](value)
+        entry.add_invalidation_callback(key, callback)
+        self._pending_deferred_cache[key] = entry
+        deferred = entry.deferred(key).addCallbacks(
+            self._completed_callback,
+            self._error_callback,
+            callbackArgs=(entry, key),
+            errbackArgs=(entry, key),
+        )
 
-        observable = ObservableDeferred(value, consumeErrors=True)
-        observer = observable.observe()
-        entry = CacheEntry(deferred=observable, callbacks=callbacks)
+        # we return a new Deferred which will be called before any subsequent observers.
+        return deferred
 
-        self._pending_deferred_cache[key] = entry
+    def start_bulk_input(
+        self,
+        keys: Collection[KT],
+        callback: Optional[Callable[[], None]] = None,
+    ) -> "CacheMultipleEntries[KT, VT]":
+        """Bulk set API for use when fetching multiple keys at once from the DB.
 
-        def compare_and_pop() -> bool:
-            """Check if our entry is still the one in _pending_deferred_cache, and
-            if so, pop it.
-
-            Returns true if the entries matched.
-            """
-            existing_entry = self._pending_deferred_cache.pop(key, None)
-            if existing_entry is entry:
-                return True
-
-            # oops, the _pending_deferred_cache has been updated since
-            # we started our query, so we are out of date.
-            #
-            # Better put back whatever we took out. (We do it this way
-            # round, rather than peeking into the _pending_deferred_cache
-            # and then removing on a match, to make the common case faster)
-            if existing_entry is not None:
-                self._pending_deferred_cache[key] = existing_entry
-
-            return False
-
-        def cb(result: VT) -> None:
-            if compare_and_pop():
-                self.cache.set(key, result, entry.callbacks)
-            else:
-                # we're not going to put this entry into the cache, so need
-                # to make sure that the invalidation callbacks are called.
-                # That was probably done when _pending_deferred_cache was
-                # updated, but it's possible that `set` was called without
-                # `invalidate` being previously called, in which case it may
-                # not have been. Either way, let's double-check now.
-                entry.invalidate()
-
-        def eb(_fail: Failure) -> None:
-            compare_and_pop()
-            entry.invalidate()
-
-        # once the deferred completes, we can move the entry from the
-        # _pending_deferred_cache to the real cache.
-        #
-        observer.addCallbacks(cb, eb)
+        Called *before* starting the fetch from the DB, and the caller *must*
+        call either `complete_bulk(..)` or `error_bulk(..)` on the return value.
+        """
 
-        # we return a new Deferred which will be called before any subsequent observers.
-        return observable.observe()
+        entry = CacheMultipleEntries[KT, VT]()
+        entry.add_global_invalidation_callback(callback)
+
+        for key in keys:
+            self._pending_deferred_cache[key] = entry
+
+        return entry
+
+    def _completed_callback(
+        self, value: VT, entry: "CacheEntry[KT, VT]", key: KT
+    ) -> VT:
+        """Called when a deferred is completed."""
+        # We check if the current entry matches the entry associated with the
+        # deferred. If they don't match then it got invalidated.
+        current_entry = self._pending_deferred_cache.pop(key, None)
+        if current_entry is not entry:
+            if current_entry:
+                self._pending_deferred_cache[key] = current_entry
+            return value
+
+        self.cache.set(key, value, entry.get_invalidation_callbacks(key))
+
+        return value
+
+    def _error_callback(
+        self,
+        failure: Failure,
+        entry: "CacheEntry[KT, VT]",
+        key: KT,
+    ) -> Failure:
+        """Called when a deferred errors."""
+
+        # We check if the current entry matches the entry associated with the
+        # deferred. If they don't match then it got invalidated.
+        current_entry = self._pending_deferred_cache.pop(key, None)
+        if current_entry is not entry:
+            if current_entry:
+                self._pending_deferred_cache[key] = current_entry
+            return failure
+
+        for cb in entry.get_invalidation_callbacks(key):
+            cb()
+
+        return failure
 
     def prefill(
         self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
     ) -> None:
-        callbacks = [callback] if callback else []
+        callbacks = (callback,) if callback else ()
         self.cache.set(key, value, callbacks=callbacks)
+        self._pending_deferred_cache.pop(key, None)
 
     def invalidate(self, key: KT) -> None:
         """Delete a key, or tree of entries
@@ -311,41 +387,129 @@ class DeferredCache(Generic[KT, VT]):
         self.cache.del_multi(key)
 
         # if we have a pending lookup for this key, remove it from the
-        # _pending_deferred_cache, which will (a) stop it being returned
-        # for future queries and (b) stop it being persisted as a proper entry
+        # _pending_deferred_cache, which will (a) stop it being returned for
+        # future queries and (b) stop it being persisted as a proper entry
         # in self.cache.
         entry = self._pending_deferred_cache.pop(key, None)
-
-        # run the invalidation callbacks now, rather than waiting for the
-        # deferred to resolve.
         if entry:
             # _pending_deferred_cache.pop should either return a CacheEntry, or, in the
             # case of a TreeCache, a dict of keys to cache entries. Either way calling
             # iterate_tree_cache_entry on it will do the right thing.
             for entry in iterate_tree_cache_entry(entry):
-                entry.invalidate()
+                for cb in entry.get_invalidation_callbacks(key):
+                    cb()
 
     def invalidate_all(self) -> None:
         self.check_thread()
         self.cache.clear()
-        for entry in self._pending_deferred_cache.values():
-            entry.invalidate()
+        for key, entry in self._pending_deferred_cache.items():
+            for cb in entry.get_invalidation_callbacks(key):
+                cb()
+
         self._pending_deferred_cache.clear()
 
 
-class CacheEntry:
-    __slots__ = ["deferred", "callbacks", "invalidated"]
+class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta):
+    """Abstract class for entries in `DeferredCache[KT, VT]`"""
 
-    def __init__(
-        self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
-    ):
-        self.deferred = deferred
-        self.callbacks = set(callbacks)
-        self.invalidated = False
-
-    def invalidate(self) -> None:
-        if not self.invalidated:
-            self.invalidated = True
-            for callback in self.callbacks:
-                callback()
-            self.callbacks.clear()
+    @abc.abstractmethod
+    def deferred(self, key: KT) -> "defer.Deferred[VT]":
+        """Get a deferred that a caller can wait on to get the value at the
+        given key"""
+        ...
+
+    @abc.abstractmethod
+    def add_invalidation_callback(
+        self, key: KT, callback: Optional[Callable[[], None]]
+    ) -> None:
+        """Add an invalidation callback"""
+        ...
+
+    @abc.abstractmethod
+    def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+        """Get all invalidation callbacks"""
+        ...
+
+
+class CacheEntrySingle(CacheEntry[KT, VT]):
+    """An implementation of `CacheEntry` wrapping a deferred that results in a
+    single cache entry.
+    """
+
+    __slots__ = ["_deferred", "_callbacks"]
+
+    def __init__(self, deferred: "defer.Deferred[VT]") -> None:
+        self._deferred = ObservableDeferred(deferred, consumeErrors=True)
+        self._callbacks: Set[Callable[[], None]] = set()
+
+    def deferred(self, key: KT) -> "defer.Deferred[VT]":
+        return self._deferred.observe()
+
+    def add_invalidation_callback(
+        self, key: KT, callback: Optional[Callable[[], None]]
+    ) -> None:
+        if callback is None:
+            return
+
+        self._callbacks.add(callback)
+
+    def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+        return self._callbacks
+
+
+class CacheMultipleEntries(CacheEntry[KT, VT]):
+    """Cache entry that is used for bulk lookups and insertions."""
+
+    __slots__ = ["_deferred", "_callbacks", "_global_callbacks"]
+
+    def __init__(self) -> None:
+        self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None
+        self._callbacks: Dict[KT, Set[Callable[[], None]]] = {}
+        self._global_callbacks: Set[Callable[[], None]] = set()
+
+    def deferred(self, key: KT) -> "defer.Deferred[VT]":
+        if not self._deferred:
+            self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+        return self._deferred.observe().addCallback(lambda res: res.get(key))
+
+    def add_invalidation_callback(
+        self, key: KT, callback: Optional[Callable[[], None]]
+    ) -> None:
+        if callback is None:
+            return
+
+        self._callbacks.setdefault(key, set()).add(callback)
+
+    def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
+        return self._callbacks.get(key, set()) | self._global_callbacks
+
+    def add_global_invalidation_callback(
+        self, callback: Optional[Callable[[], None]]
+    ) -> None:
+        """Add a callback for when any keys get invalidated."""
+        if callback is None:
+            return
+
+        self._global_callbacks.add(callback)
+
+    def complete_bulk(
+        self,
+        cache: DeferredCache[KT, VT],
+        result: Dict[KT, VT],
+    ) -> None:
+        """Called when there is a result"""
+        for key, value in result.items():
+            cache._completed_callback(value, self, key)
+
+        if self._deferred:
+            self._deferred.callback(result)
+
+    def error_bulk(
+        self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure
+    ) -> None:
+        """Called when bulk lookup failed."""
+        for key in keys:
+            cache._error_callback(failure, self, key)
+
+        if self._deferred:
+            self._deferred.errback(failure)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 9d4bc89edb..10aff4d04a 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -25,6 +25,7 @@ from typing import (
     Generic,
     Hashable,
     Iterable,
+    List,
     Mapping,
     Optional,
     Sequence,
@@ -440,16 +441,6 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
             keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
             list_args = arg_dict[self.list_name]
 
-            results = {}
-
-            def update_results_dict(res: Any, arg: Hashable) -> None:
-                results[arg] = res
-
-            # list of deferreds to wait for
-            cached_defers = []
-
-            missing = set()
-
             # If the cache takes a single arg then that is used as the key,
             # otherwise a tuple is used.
             if num_args == 1:
@@ -457,6 +448,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                 def arg_to_cache_key(arg: Hashable) -> Hashable:
                     return arg
 
+                def cache_key_to_arg(key: tuple) -> Hashable:
+                    return key
+
             else:
                 keylist = list(keyargs)
 
@@ -464,58 +458,53 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                     keylist[self.list_pos] = arg
                     return tuple(keylist)
 
-            for arg in list_args:
-                try:
-                    res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
-                    if not res.called:
-                        res.addCallback(update_results_dict, arg)
-                        cached_defers.append(res)
-                    else:
-                        results[arg] = res.result
-                except KeyError:
-                    missing.add(arg)
+                def cache_key_to_arg(key: tuple) -> Hashable:
+                    return key[self.list_pos]
+
+            cache_keys = [arg_to_cache_key(arg) for arg in list_args]
+            immediate_results, pending_deferred, missing = cache.get_bulk(
+                cache_keys, callback=invalidate_callback
+            )
+
+            results = {cache_key_to_arg(key): v for key, v in immediate_results.items()}
+
+            cached_defers: List["defer.Deferred[Any]"] = []
+            if pending_deferred:
+
+                def update_results(r: Dict) -> None:
+                    for k, v in r.items():
+                        results[cache_key_to_arg(k)] = v
+
+                pending_deferred.addCallback(update_results)
+                cached_defers.append(pending_deferred)
 
             if missing:
-                # we need a deferred for each entry in the list,
-                # which we put in the cache. Each deferred resolves with the
-                # relevant result for that key.
-                deferreds_map = {}
-                for arg in missing:
-                    deferred: "defer.Deferred[Any]" = defer.Deferred()
-                    deferreds_map[arg] = deferred
-                    key = arg_to_cache_key(arg)
-                    cached_defers.append(
-                        cache.set(key, deferred, callback=invalidate_callback)
-                    )
+                cache_entry = cache.start_bulk_input(missing, invalidate_callback)
 
                 def complete_all(res: Dict[Hashable, Any]) -> None:
-                    # the wrapped function has completed. It returns a dict.
-                    # We can now update our own result map, and then resolve the
-                    # observable deferreds in the cache.
-                    for e, d1 in deferreds_map.items():
-                        val = res.get(e, None)
-                        # make sure we update the results map before running the
-                        # deferreds, because as soon as we run the last deferred, the
-                        # gatherResults() below will complete and return the result
-                        # dict to our caller.
-                        results[e] = val
-                        d1.callback(val)
+                    missing_results = {}
+                    for key in missing:
+                        arg = cache_key_to_arg(key)
+                        val = res.get(arg, None)
+
+                        results[arg] = val
+                        missing_results[key] = val
+
+                    cache_entry.complete_bulk(cache, missing_results)
 
                 def errback_all(f: Failure) -> None:
-                    # the wrapped function has failed. Propagate the failure into
-                    # the cache, which will invalidate the entry, and cause the
-                    # relevant cached_deferreds to fail, which will propagate the
-                    # failure to our caller.
-                    for d1 in deferreds_map.values():
-                        d1.errback(f)
+                    cache_entry.error_bulk(cache, missing, f)
 
                 args_to_call = dict(arg_dict)
-                args_to_call[self.list_name] = missing
+                args_to_call[self.list_name] = {
+                    cache_key_to_arg(key) for key in missing
+                }
 
                 # dispatch the call, and attach the two handlers
-                defer.maybeDeferred(
+                missing_d = defer.maybeDeferred(
                     preserve_fn(self.orig), **args_to_call
                 ).addCallbacks(complete_all, errback_all)
+                cached_defers.append(missing_d)
 
             if cached_defers:
                 d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index c1b8ec0c73..fec31da2b6 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -135,6 +135,9 @@ class TreeCache:
     def values(self):
         return iterate_tree_cache_entry(self.root)
 
+    def items(self):
+        return iterate_tree_cache_items((), self.root)
+
     def __len__(self) -> int:
         return self.size