summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py254
1 files changed, 8 insertions, 246 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 98b34f2223..1f43886804 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -13,25 +13,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import functools
 import inspect
 import logging
-import threading
 from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
 from weakref import WeakValueDictionary
 
-from prometheus_client import Gauge
-
 from twisted.internet import defer
 
 from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-
-from . import register_cache
+from synapse.util.caches.deferred_cache import DeferredCache
 
 logger = logging.getLogger(__name__)
 
@@ -55,239 +48,6 @@ class _CachedFunction(Generic[F]):
     __call__ = None  # type: F
 
 
-cache_pending_metric = Gauge(
-    "synapse_util_caches_cache_pending",
-    "Number of lookups currently pending for this cache",
-    ["name"],
-)
-
-_CacheSentinel = object()
-
-
-class CacheEntry:
-    __slots__ = ["deferred", "callbacks", "invalidated"]
-
-    def __init__(self, deferred, callbacks):
-        self.deferred = deferred
-        self.callbacks = set(callbacks)
-        self.invalidated = False
-
-    def invalidate(self):
-        if not self.invalidated:
-            self.invalidated = True
-            for callback in self.callbacks:
-                callback()
-            self.callbacks.clear()
-
-
-class Cache:
-    __slots__ = (
-        "cache",
-        "name",
-        "keylen",
-        "thread",
-        "metrics",
-        "_pending_deferred_cache",
-    )
-
-    def __init__(
-        self,
-        name: str,
-        max_entries: int = 1000,
-        keylen: int = 1,
-        tree: bool = False,
-        iterable: bool = False,
-        apply_cache_factor_from_config: bool = True,
-    ):
-        """
-        Args:
-            name: The name of the cache
-            max_entries: Maximum amount of entries that the cache will hold
-            keylen: The length of the tuple used as the cache key
-            tree: Use a TreeCache instead of a dict as the underlying cache type
-            iterable: If True, count each item in the cached object as an entry,
-                rather than each cached object
-            apply_cache_factor_from_config: Whether cache factors specified in the
-                config file affect `max_entries`
-
-        Returns:
-            Cache
-        """
-        cache_type = TreeCache if tree else dict
-        self._pending_deferred_cache = cache_type()
-
-        self.cache = LruCache(
-            max_size=max_entries,
-            keylen=keylen,
-            cache_type=cache_type,
-            size_callback=(lambda d: len(d)) if iterable else None,
-            evicted_callback=self._on_evicted,
-            apply_cache_factor_from_config=apply_cache_factor_from_config,
-        )
-
-        self.name = name
-        self.keylen = keylen
-        self.thread = None  # type: Optional[threading.Thread]
-        self.metrics = register_cache(
-            "cache",
-            name,
-            self.cache,
-            collect_callback=self._metrics_collection_callback,
-        )
-
-    @property
-    def max_entries(self):
-        return self.cache.max_size
-
-    def _on_evicted(self, evicted_count):
-        self.metrics.inc_evictions(evicted_count)
-
-    def _metrics_collection_callback(self):
-        cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
-
-    def check_thread(self):
-        expected_thread = self.thread
-        if expected_thread is None:
-            self.thread = threading.current_thread()
-        else:
-            if expected_thread is not threading.current_thread():
-                raise ValueError(
-                    "Cache objects can only be accessed from the main thread"
-                )
-
-    def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
-        """Looks the key up in the caches.
-
-        Args:
-            key(tuple)
-            default: What is returned if key is not in the caches. If not
-                specified then function throws KeyError instead
-            callback(fn): Gets called when the entry in the cache is invalidated
-            update_metrics (bool): whether to update the cache hit rate metrics
-
-        Returns:
-            Either an ObservableDeferred or the raw result
-        """
-        callbacks = [callback] if callback else []
-        val = self._pending_deferred_cache.get(key, _CacheSentinel)
-        if val is not _CacheSentinel:
-            val.callbacks.update(callbacks)
-            if update_metrics:
-                self.metrics.inc_hits()
-            return val.deferred
-
-        val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
-        if val is not _CacheSentinel:
-            self.metrics.inc_hits()
-            return val
-
-        if update_metrics:
-            self.metrics.inc_misses()
-
-        if default is _CacheSentinel:
-            raise KeyError()
-        else:
-            return default
-
-    def set(self, key, value, callback=None):
-        if not isinstance(value, defer.Deferred):
-            raise TypeError("not a Deferred")
-
-        callbacks = [callback] if callback else []
-        self.check_thread()
-        observable = ObservableDeferred(value, consumeErrors=True)
-        observer = observable.observe()
-        entry = CacheEntry(deferred=observable, callbacks=callbacks)
-
-        existing_entry = self._pending_deferred_cache.pop(key, None)
-        if existing_entry:
-            existing_entry.invalidate()
-
-        self._pending_deferred_cache[key] = entry
-
-        def compare_and_pop():
-            """Check if our entry is still the one in _pending_deferred_cache, and
-            if so, pop it.
-
-            Returns true if the entries matched.
-            """
-            existing_entry = self._pending_deferred_cache.pop(key, None)
-            if existing_entry is entry:
-                return True
-
-            # oops, the _pending_deferred_cache has been updated since
-            # we started our query, so we are out of date.
-            #
-            # Better put back whatever we took out. (We do it this way
-            # round, rather than peeking into the _pending_deferred_cache
-            # and then removing on a match, to make the common case faster)
-            if existing_entry is not None:
-                self._pending_deferred_cache[key] = existing_entry
-
-            return False
-
-        def cb(result):
-            if compare_and_pop():
-                self.cache.set(key, result, entry.callbacks)
-            else:
-                # we're not going to put this entry into the cache, so need
-                # to make sure that the invalidation callbacks are called.
-                # That was probably done when _pending_deferred_cache was
-                # updated, but it's possible that `set` was called without
-                # `invalidate` being previously called, in which case it may
-                # not have been. Either way, let's double-check now.
-                entry.invalidate()
-
-        def eb(_fail):
-            compare_and_pop()
-            entry.invalidate()
-
-        # once the deferred completes, we can move the entry from the
-        # _pending_deferred_cache to the real cache.
-        #
-        observer.addCallbacks(cb, eb)
-        return observable
-
-    def prefill(self, key, value, callback=None):
-        callbacks = [callback] if callback else []
-        self.cache.set(key, value, callbacks=callbacks)
-
-    def invalidate(self, key):
-        self.check_thread()
-        self.cache.pop(key, None)
-
-        # if we have a pending lookup for this key, remove it from the
-        # _pending_deferred_cache, which will (a) stop it being returned
-        # for future queries and (b) stop it being persisted as a proper entry
-        # in self.cache.
-        entry = self._pending_deferred_cache.pop(key, None)
-
-        # run the invalidation callbacks now, rather than waiting for the
-        # deferred to resolve.
-        if entry:
-            entry.invalidate()
-
-    def invalidate_many(self, key):
-        self.check_thread()
-        if not isinstance(key, tuple):
-            raise TypeError("The cache key must be a tuple not %r" % (type(key),))
-        self.cache.del_multi(key)
-
-        # if we have a pending lookup for this key, remove it from the
-        # _pending_deferred_cache, as above
-        entry_dict = self._pending_deferred_cache.pop(key, None)
-        if entry_dict is not None:
-            for entry in iterate_tree_cache_entry(entry_dict):
-                entry.invalidate()
-
-    def invalidate_all(self):
-        self.check_thread()
-        self.cache.clear()
-        for entry in self._pending_deferred_cache.values():
-            entry.invalidate()
-        self._pending_deferred_cache.clear()
-
-
 class _CacheDescriptorBase:
     def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
         self.orig = orig
@@ -390,13 +150,13 @@ class CacheDescriptor(_CacheDescriptorBase):
         self.iterable = iterable
 
     def __get__(self, obj, owner):
-        cache = Cache(
+        cache = DeferredCache(
             name=self.orig.__name__,
             max_entries=self.max_entries,
             keylen=self.num_args,
             tree=self.tree,
             iterable=self.iterable,
-        )
+        )  # type: DeferredCache[Tuple, Any]
 
         def get_cache_key_gen(args, kwargs):
             """Given some args/kwargs return a generator that resolves into
@@ -640,9 +400,9 @@ class _CacheContext:
 
     _cache_context_objects = (
         WeakValueDictionary()
-    )  # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext]
+    )  # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext]
 
-    def __init__(self, cache, cache_key):  # type: (Cache, CacheKey) -> None
+    def __init__(self, cache, cache_key):  # type: (DeferredCache, CacheKey) -> None
         self._cache = cache
         self._cache_key = cache_key
 
@@ -651,7 +411,9 @@ class _CacheContext:
         self._cache.invalidate(self._cache_key)
 
     @classmethod
-    def get_instance(cls, cache, cache_key):  # type: (Cache, CacheKey) -> _CacheContext
+    def get_instance(
+        cls, cache, cache_key
+    ):  # type: (DeferredCache, CacheKey) -> _CacheContext
         """Returns an instance constructed with the given arguments.
 
         A new instance is only created if none already exists.