summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py100
1 files changed, 53 insertions, 47 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 187510576a..d2f25063aa 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -40,9 +40,7 @@ _CacheSentinel = object()
 
 
 class CacheEntry(object):
-    __slots__ = [
-        "deferred", "callbacks", "invalidated"
-    ]
+    __slots__ = ["deferred", "callbacks", "invalidated"]
 
     def __init__(self, deferred, callbacks):
         self.deferred = deferred
@@ -73,7 +71,9 @@ class Cache(object):
         self._pending_deferred_cache = cache_type()
 
         self.cache = LruCache(
-            max_size=max_entries, keylen=keylen, cache_type=cache_type,
+            max_size=max_entries,
+            keylen=keylen,
+            cache_type=cache_type,
             size_callback=(lambda d: len(d)) if iterable else None,
             evicted_callback=self._on_evicted,
         )
@@ -133,10 +133,7 @@ class Cache(object):
     def set(self, key, value, callback=None):
         callbacks = [callback] if callback else []
         self.check_thread()
-        entry = CacheEntry(
-            deferred=value,
-            callbacks=callbacks,
-        )
+        entry = CacheEntry(deferred=value, callbacks=callbacks)
 
         existing_entry = self._pending_deferred_cache.pop(key, None)
         if existing_entry:
@@ -191,9 +188,7 @@ class Cache(object):
     def invalidate_many(self, key):
         self.check_thread()
         if not isinstance(key, tuple):
-            raise TypeError(
-                "The cache key must be a tuple not %r" % (type(key),)
-            )
+            raise TypeError("The cache key must be a tuple not %r" % (type(key),))
         self.cache.del_multi(key)
 
         # if we have a pending lookup for this key, remove it from the
@@ -244,29 +239,25 @@ class _CacheDescriptorBase(object):
             raise Exception(
                 "Not enough explicit positional arguments to key off for %r: "
                 "got %i args, but wanted %i. (@cached cannot key off *args or "
-                "**kwargs)"
-                % (orig.__name__, len(all_args), num_args)
+                "**kwargs)" % (orig.__name__, len(all_args), num_args)
             )
 
         self.num_args = num_args
 
         # list of the names of the args used as the cache key
-        self.arg_names = all_args[1:num_args + 1]
+        self.arg_names = all_args[1 : num_args + 1]
 
         # self.arg_defaults is a map of arg name to its default value for each
         # argument that has a default value
         if arg_spec.defaults:
-            self.arg_defaults = dict(zip(
-                all_args[-len(arg_spec.defaults):],
-                arg_spec.defaults
-            ))
+            self.arg_defaults = dict(
+                zip(all_args[-len(arg_spec.defaults) :], arg_spec.defaults)
+            )
         else:
             self.arg_defaults = {}
 
         if "cache_context" in self.arg_names:
-            raise Exception(
-                "cache_context arg cannot be included among the cache keys"
-            )
+            raise Exception("cache_context arg cannot be included among the cache keys")
 
         self.add_cache_context = cache_context
 
@@ -304,12 +295,24 @@ class CacheDescriptor(_CacheDescriptorBase):
             ``cache_context``) to use as cache keys. Defaults to all named
             args of the function.
     """
-    def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
-                 inlineCallbacks=False, cache_context=False, iterable=False):
+
+    def __init__(
+        self,
+        orig,
+        max_entries=1000,
+        num_args=None,
+        tree=False,
+        inlineCallbacks=False,
+        cache_context=False,
+        iterable=False,
+    ):
 
         super(CacheDescriptor, self).__init__(
-            orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
-            cache_context=cache_context)
+            orig,
+            num_args=num_args,
+            inlineCallbacks=inlineCallbacks,
+            cache_context=cache_context,
+        )
 
         max_entries = int(max_entries * get_cache_factor_for(orig.__name__))
 
@@ -356,7 +359,9 @@ class CacheDescriptor(_CacheDescriptorBase):
                     return args[0]
                 else:
                     return self.arg_defaults[nm]
+
         else:
+
             def get_cache_key(args, kwargs):
                 return tuple(get_cache_key_gen(args, kwargs))
 
@@ -383,8 +388,7 @@ class CacheDescriptor(_CacheDescriptorBase):
 
             except KeyError:
                 ret = defer.maybeDeferred(
-                    logcontext.preserve_fn(self.function_to_call),
-                    obj, *args, **kwargs
+                    logcontext.preserve_fn(self.function_to_call), obj, *args, **kwargs
                 )
 
                 def onErr(f):
@@ -437,8 +441,9 @@ class CacheListDescriptor(_CacheDescriptorBase):
     results.
     """
 
-    def __init__(self, orig, cached_method_name, list_name, num_args=None,
-                 inlineCallbacks=False):
+    def __init__(
+        self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False
+    ):
         """
         Args:
             orig (function)
@@ -451,7 +456,8 @@ class CacheListDescriptor(_CacheDescriptorBase):
                 be wrapped by defer.inlineCallbacks
         """
         super(CacheListDescriptor, self).__init__(
-            orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
+            orig, num_args=num_args, inlineCallbacks=inlineCallbacks
+        )
 
         self.list_name = list_name
 
@@ -463,7 +469,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
         if self.list_name not in self.arg_names:
             raise Exception(
                 "Couldn't see arguments %r for %r."
-                % (self.list_name, cached_method_name,)
+                % (self.list_name, cached_method_name)
             )
 
     def __get__(self, obj, objtype=None):
@@ -494,8 +500,10 @@ class CacheListDescriptor(_CacheDescriptorBase):
             # If the cache takes a single arg then that is used as the key,
             # otherwise a tuple is used.
             if num_args == 1:
+
                 def arg_to_cache_key(arg):
                     return arg
+
             else:
                 keylist = list(keyargs)
 
@@ -505,8 +513,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
 
             for arg in list_args:
                 try:
-                    res = cache.get(arg_to_cache_key(arg),
-                                    callback=invalidate_callback)
+                    res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
                     if not isinstance(res, ObservableDeferred):
                         results[arg] = res
                     elif not res.has_succeeded():
@@ -554,18 +561,15 @@ class CacheListDescriptor(_CacheDescriptorBase):
                 args_to_call = dict(arg_dict)
                 args_to_call[self.list_name] = list(missing)
 
-                cached_defers.append(defer.maybeDeferred(
-                    logcontext.preserve_fn(self.function_to_call),
-                    **args_to_call
-                ).addCallbacks(complete_all, errback))
+                cached_defers.append(
+                    defer.maybeDeferred(
+                        logcontext.preserve_fn(self.function_to_call), **args_to_call
+                    ).addCallbacks(complete_all, errback)
+                )
 
             if cached_defers:
-                d = defer.gatherResults(
-                    cached_defers,
-                    consumeErrors=True,
-                ).addCallbacks(
-                    lambda _: results,
-                    unwrapFirstError
+                d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
+                    lambda _: results, unwrapFirstError
                 )
                 return logcontext.make_deferred_yieldable(d)
             else:
@@ -586,8 +590,9 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
         self.cache.invalidate(self.key)
 
 
-def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
-           iterable=False):
+def cached(
+    max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
+):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
@@ -598,8 +603,9 @@ def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
-                          cache_context=False, iterable=False):
+def cachedInlineCallbacks(
+    max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
+):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,