summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py128
1 files changed, 93 insertions, 35 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 5c30ed235d..48dcbafeef 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,8 +18,9 @@ from synapse.util.async import ObservableDeferred
 from synapse.util import unwrapFirstError, logcontext
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+from synapse.util.stringutils import to_ascii
 
-from . import DEBUG_CACHES, register_cache
+from . import register_cache
 
 from twisted.internet import defer
 from collections import namedtuple
@@ -76,7 +77,7 @@ class Cache(object):
 
         self.cache = LruCache(
             max_size=max_entries, keylen=keylen, cache_type=cache_type,
-            size_callback=(lambda d: len(d.result)) if iterable else None,
+            size_callback=(lambda d: len(d)) if iterable else None,
         )
 
         self.name = name
@@ -95,13 +96,26 @@ class Cache(object):
                     "Cache objects can only be accessed from the main thread"
                 )
 
-    def get(self, key, default=_CacheSentinel, callback=None):
+    def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
+        """Looks the key up in the caches.
+
+        Args:
+            key(tuple)
+            default: What is returned if key is not in the caches. If not
+                specified then function throws KeyError instead
+            callback(fn): Gets called when the entry in the cache is invalidated
+            update_metrics (bool): whether to update the cache hit rate metrics
+
+        Returns:
+            Either a Deferred or the raw result
+        """
         callbacks = [callback] if callback else []
         val = self._pending_deferred_cache.get(key, _CacheSentinel)
         if val is not _CacheSentinel:
             if val.sequence == self.sequence:
                 val.callbacks.update(callbacks)
-                self.metrics.inc_hits()
+                if update_metrics:
+                    self.metrics.inc_hits()
                 return val.deferred
 
         val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
@@ -109,7 +123,8 @@ class Cache(object):
             self.metrics.inc_hits()
             return val
 
-        self.metrics.inc_misses()
+        if update_metrics:
+            self.metrics.inc_misses()
 
         if default is _CacheSentinel:
             raise KeyError()
@@ -137,7 +152,7 @@ class Cache(object):
             if self.sequence == entry.sequence:
                 existing_entry = self._pending_deferred_cache.pop(key, None)
                 if existing_entry is entry:
-                    self.cache.set(key, entry.deferred, entry.callbacks)
+                    self.cache.set(key, result, entry.callbacks)
                 else:
                     entry.invalidate()
             else:
@@ -152,10 +167,6 @@ class Cache(object):
 
     def invalidate(self, key):
         self.check_thread()
-        if not isinstance(key, tuple):
-            raise TypeError(
-                "The cache key must be a tuple not %r" % (type(key),)
-            )
 
         # Increment the sequence number so that any SELECT statements that
         # raced with the INSERT don't update the cache (SYN-369)
@@ -224,8 +235,20 @@ class _CacheDescriptorBase(object):
             )
 
         self.num_args = num_args
+
+        # list of the names of the args used as the cache key
         self.arg_names = all_args[1:num_args + 1]
 
+        # self.arg_defaults is a map of arg name to its default value for each
+        # argument that has a default value
+        if arg_spec.defaults:
+            self.arg_defaults = dict(zip(
+                all_args[-len(arg_spec.defaults):],
+                arg_spec.defaults
+            ))
+        else:
+            self.arg_defaults = {}
+
         if "cache_context" in self.arg_names:
             raise Exception(
                 "cache_context arg cannot be included among the cache keys"
@@ -289,18 +312,47 @@ class CacheDescriptor(_CacheDescriptorBase):
             iterable=self.iterable,
         )
 
+        def get_cache_key_gen(args, kwargs):
+            """Given some args/kwargs return a generator that resolves into
+            the cache_key.
+
+            We loop through each arg name, looking up if its in the `kwargs`,
+            otherwise using the next argument in `args`. If there are no more
+            args then we try looking the arg name up in the defaults
+            """
+            pos = 0
+            for nm in self.arg_names:
+                if nm in kwargs:
+                    yield kwargs[nm]
+                elif pos < len(args):
+                    yield args[pos]
+                    pos += 1
+                else:
+                    yield self.arg_defaults[nm]
+
+        # By default our cache key is a tuple, but if there is only one item
+        # then don't bother wrapping in a tuple.  This is to save memory.
+        if self.num_args == 1:
+            nm = self.arg_names[0]
+
+            def get_cache_key(args, kwargs):
+                if nm in kwargs:
+                    return kwargs[nm]
+                elif len(args):
+                    return args[0]
+                else:
+                    return self.arg_defaults[nm]
+        else:
+            def get_cache_key(args, kwargs):
+                return tuple(get_cache_key_gen(args, kwargs))
+
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
 
-            # Add temp cache_context so inspect.getcallargs doesn't explode
-            if self.add_cache_context:
-                kwargs["cache_context"] = None
-
-            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
-            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+            cache_key = get_cache_key(args, kwargs)
 
             # Add our own `cache_context` to argument list if the wrapped function
             # has asked for one
@@ -310,20 +362,10 @@ class CacheDescriptor(_CacheDescriptorBase):
             try:
                 cached_result_d = cache.get(cache_key, callback=invalidate_callback)
 
-                observer = cached_result_d.observe()
-                if DEBUG_CACHES:
-                    @defer.inlineCallbacks
-                    def check_result(cached_result):
-                        actual_result = yield self.function_to_call(obj, *args, **kwargs)
-                        if actual_result != cached_result:
-                            logger.error(
-                                "Stale cache entry %s%r: cached: %r, actual %r",
-                                self.orig.__name__, cache_key,
-                                cached_result, actual_result,
-                            )
-                            raise ValueError("Stale cache entry")
-                        defer.returnValue(cached_result)
-                    observer.addCallback(check_result)
+                if isinstance(cached_result_d, ObservableDeferred):
+                    observer = cached_result_d.observe()
+                else:
+                    observer = cached_result_d
 
             except KeyError:
                 ret = defer.maybeDeferred(
@@ -337,16 +379,30 @@ class CacheDescriptor(_CacheDescriptorBase):
 
                 ret.addErrback(onErr)
 
+                # If our cache_key is a string, try to convert to ascii to save
+                # a bit of space in large caches
+                if isinstance(cache_key, basestring):
+                    cache_key = to_ascii(cache_key)
+
                 result_d = ObservableDeferred(ret, consumeErrors=True)
                 cache.set(cache_key, result_d, callback=invalidate_callback)
                 observer = result_d.observe()
 
-            return logcontext.make_deferred_yieldable(observer)
+            if isinstance(observer, defer.Deferred):
+                return logcontext.make_deferred_yieldable(observer)
+            else:
+                return observer
+
+        if self.num_args == 1:
+            wrapped.invalidate = lambda key: cache.invalidate(key[0])
+            wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
+        else:
+            wrapped.invalidate = cache.invalidate
+            wrapped.invalidate_all = cache.invalidate_all
+            wrapped.invalidate_many = cache.invalidate_many
+            wrapped.prefill = cache.prefill
 
-        wrapped.invalidate = cache.invalidate
         wrapped.invalidate_all = cache.invalidate_all
-        wrapped.invalidate_many = cache.invalidate_many
-        wrapped.prefill = cache.prefill
         wrapped.cache = cache
 
         obj.__dict__[self.orig.__name__] = wrapped
@@ -419,7 +475,9 @@ class CacheListDescriptor(_CacheDescriptorBase):
 
                 try:
                     res = cache.get(tuple(key), callback=invalidate_callback)
-                    if not res.has_succeeded():
+                    if not isinstance(res, ObservableDeferred):
+                        results[arg] = res
+                    elif not res.has_succeeded():
                         res = res.observe()
                         res.addCallback(lambda r, arg: (arg, r), arg)
                         cached_defers[arg] = res