summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py115
1 files changed, 91 insertions, 24 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 43f66ec4be..cd48262420 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -13,22 +13,24 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import functools
 import inspect
 import logging
 import threading
-from collections import namedtuple
+from typing import Any, Tuple, Union, cast
+from weakref import WeakValueDictionary
 
 from six import itervalues
 
 from prometheus_client import Gauge
+from typing_extensions import Protocol
 
 from twisted.internet import defer
 
 from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
 
@@ -36,6 +38,20 @@ from . import register_cache
 
 logger = logging.getLogger(__name__)
 
+CacheKey = Union[Tuple, Any]
+
+
+class _CachedFunction(Protocol):
+    invalidate = None  # type: Any
+    invalidate_all = None  # type: Any
+    invalidate_many = None  # type: Any
+    prefill = None  # type: Any
+    cache = None  # type: Any
+    num_args = None  # type: Any
+
+    def __name__(self):
+        ...
+
 
 cache_pending_metric = Gauge(
     "synapse_util_caches_cache_pending",
@@ -65,7 +81,6 @@ class CacheEntry(object):
 class Cache(object):
     __slots__ = (
         "cache",
-        "max_entries",
         "name",
         "keylen",
         "thread",
@@ -73,7 +88,29 @@ class Cache(object):
         "_pending_deferred_cache",
     )
 
-    def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
+    def __init__(
+        self,
+        name: str,
+        max_entries: int = 1000,
+        keylen: int = 1,
+        tree: bool = False,
+        iterable: bool = False,
+        apply_cache_factor_from_config: bool = True,
+    ):
+        """
+        Args:
+            name: The name of the cache
+            max_entries: Maximum amount of entries that the cache will hold
+            keylen: The length of the tuple used as the cache key
+            tree: Use a TreeCache instead of a dict as the underlying cache type
+            iterable: If True, count each item in the cached object as an entry,
+                rather than each cached object
+            apply_cache_factor_from_config: Whether cache factors specified in the
+                config file affect `max_entries`
+
+        Returns:
+            Cache
+        """
         cache_type = TreeCache if tree else dict
         self._pending_deferred_cache = cache_type()
 
@@ -83,6 +120,7 @@ class Cache(object):
             cache_type=cache_type,
             size_callback=(lambda d: len(d)) if iterable else None,
             evicted_callback=self._on_evicted,
+            apply_cache_factor_from_config=apply_cache_factor_from_config,
         )
 
         self.name = name
@@ -95,6 +133,10 @@ class Cache(object):
             collect_callback=self._metrics_collection_callback,
         )
 
+    @property
+    def max_entries(self):
+        return self.cache.max_size
+
     def _on_evicted(self, evicted_count):
         self.metrics.inc_evictions(evicted_count)
 
@@ -245,7 +287,9 @@ class Cache(object):
 
 
 class _CacheDescriptorBase(object):
-    def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
+    def __init__(
+        self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
+    ):
         self.orig = orig
 
         if inlineCallbacks:
@@ -253,7 +297,7 @@ class _CacheDescriptorBase(object):
         else:
             self.function_to_call = orig
 
-        arg_spec = inspect.getargspec(orig)
+        arg_spec = inspect.getfullargspec(orig)
         all_args = arg_spec.args
 
         if "cache_context" in all_args:
@@ -352,13 +396,11 @@ class CacheDescriptor(_CacheDescriptorBase):
             cache_context=cache_context,
         )
 
-        max_entries = int(max_entries * get_cache_factor_for(orig.__name__))
-
         self.max_entries = max_entries
         self.tree = tree
         self.iterable = iterable
 
-    def __get__(self, obj, objtype=None):
+    def __get__(self, obj, owner):
         cache = Cache(
             name=self.orig.__name__,
             max_entries=self.max_entries,
@@ -404,7 +446,7 @@ class CacheDescriptor(_CacheDescriptorBase):
                 return tuple(get_cache_key_gen(args, kwargs))
 
         @functools.wraps(self.orig)
-        def wrapped(*args, **kwargs):
+        def _wrapped(*args, **kwargs):
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -414,7 +456,7 @@ class CacheDescriptor(_CacheDescriptorBase):
             # Add our own `cache_context` to argument list if the wrapped function
             # has asked for one
             if self.add_cache_context:
-                kwargs["cache_context"] = _CacheContext(cache, cache_key)
+                kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
 
             try:
                 cached_result_d = cache.get(cache_key, callback=invalidate_callback)
@@ -422,7 +464,7 @@ class CacheDescriptor(_CacheDescriptorBase):
                 if isinstance(cached_result_d, ObservableDeferred):
                     observer = cached_result_d.observe()
                 else:
-                    observer = cached_result_d
+                    observer = defer.succeed(cached_result_d)
 
             except KeyError:
                 ret = defer.maybeDeferred(
@@ -440,6 +482,8 @@ class CacheDescriptor(_CacheDescriptorBase):
 
             return make_deferred_yieldable(observer)
 
+        wrapped = cast(_CachedFunction, _wrapped)
+
         if self.num_args == 1:
             wrapped.invalidate = lambda key: cache.invalidate(key[0])
             wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
@@ -464,9 +508,8 @@ class CacheListDescriptor(_CacheDescriptorBase):
     Given a list of keys it looks in the cache to find any hits, then passes
     the list of missing keys to the wrapped function.
 
-    Once wrapped, the function returns either a Deferred which resolves to
-    the list of results, or (if all results were cached), just the list of
-    results.
+    Once wrapped, the function returns a Deferred which resolves to the list
+    of results.
     """
 
     def __init__(
@@ -600,21 +643,45 @@ class CacheListDescriptor(_CacheDescriptorBase):
                 )
                 return make_deferred_yieldable(d)
             else:
-                return results
+                return defer.succeed(results)
 
         obj.__dict__[self.orig.__name__] = wrapped
 
         return wrapped
 
 
-class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
-    # We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
-    # which namedtuple does for us (i.e. two _CacheContext are the same if
-    # their caches and keys match). This is important in particular to
-    # dedupe when we add callbacks to lru cache nodes, otherwise the number
-    # of callbacks would grow.
-    def invalidate(self):
-        self.cache.invalidate(self.key)
+class _CacheContext:
+    """Holds cache information from the cached function higher in the calling order.
+
+    Can be used to invalidate the higher level cache entry if something changes
+    on a lower level.
+    """
+
+    _cache_context_objects = (
+        WeakValueDictionary()
+    )  # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext]
+
+    def __init__(self, cache, cache_key):  # type: (Cache, CacheKey) -> None
+        self._cache = cache
+        self._cache_key = cache_key
+
+    def invalidate(self):  # type: () -> None
+        """Invalidates the cache entry referred to by the context."""
+        self._cache.invalidate(self._cache_key)
+
+    @classmethod
+    def get_instance(cls, cache, cache_key):  # type: (Cache, CacheKey) -> _CacheContext
+        """Returns an instance constructed with the given arguments.
+
+        A new instance is only created if none already exists.
+        """
+
+        # We make sure there are no identical _CacheContext instances. This is
+        # important in particular to dedupe when we add callbacks to lru cache
+        # nodes, otherwise the number of callbacks would grow.
+        return cls._cache_context_objects.setdefault(
+            (cache, cache_key), cls(cache, cache_key)
+        )
 
 
 def cached(