summary refs log tree commit diff
path: root/synapse/util/caches
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches')
-rw-r--r--synapse/util/caches/__init__.py6
-rw-r--r--synapse/util/caches/descriptors.py100
-rw-r--r--synapse/util/caches/dictionary_cache.py15
-rw-r--r--synapse/util/caches/expiringcache.py18
-rw-r--r--synapse/util/caches/lrucache.py14
-rw-r--r--synapse/util/caches/response_cache.py24
-rw-r--r--synapse/util/caches/stream_change_cache.py13
-rw-r--r--synapse/util/caches/treecache.py1
-rw-r--r--synapse/util/caches/ttlcache.py1
9 files changed, 110 insertions, 82 deletions
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index f37d5bec08..8271229015 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -104,8 +104,8 @@ def register_cache(cache_type, cache_name, cache):
 
 
 KNOWN_KEYS = {
-    key: key for key in
-    (
+    key: key
+    for key in (
         "auth_events",
         "content",
         "depth",
@@ -150,7 +150,7 @@ def intern_dict(dictionary):
 
 
 def _intern_known_values(key, value):
-    intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key",)
+    intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key")
 
     if key in intern_keys:
         return intern_string(value)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 187510576a..d2f25063aa 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -40,9 +40,7 @@ _CacheSentinel = object()
 
 
 class CacheEntry(object):
-    __slots__ = [
-        "deferred", "callbacks", "invalidated"
-    ]
+    __slots__ = ["deferred", "callbacks", "invalidated"]
 
     def __init__(self, deferred, callbacks):
         self.deferred = deferred
@@ -73,7 +71,9 @@ class Cache(object):
         self._pending_deferred_cache = cache_type()
 
         self.cache = LruCache(
-            max_size=max_entries, keylen=keylen, cache_type=cache_type,
+            max_size=max_entries,
+            keylen=keylen,
+            cache_type=cache_type,
             size_callback=(lambda d: len(d)) if iterable else None,
             evicted_callback=self._on_evicted,
         )
@@ -133,10 +133,7 @@ class Cache(object):
     def set(self, key, value, callback=None):
         callbacks = [callback] if callback else []
         self.check_thread()
-        entry = CacheEntry(
-            deferred=value,
-            callbacks=callbacks,
-        )
+        entry = CacheEntry(deferred=value, callbacks=callbacks)
 
         existing_entry = self._pending_deferred_cache.pop(key, None)
         if existing_entry:
@@ -191,9 +188,7 @@ class Cache(object):
     def invalidate_many(self, key):
         self.check_thread()
         if not isinstance(key, tuple):
-            raise TypeError(
-                "The cache key must be a tuple not %r" % (type(key),)
-            )
+            raise TypeError("The cache key must be a tuple not %r" % (type(key),))
         self.cache.del_multi(key)
 
         # if we have a pending lookup for this key, remove it from the
@@ -244,29 +239,25 @@ class _CacheDescriptorBase(object):
             raise Exception(
                 "Not enough explicit positional arguments to key off for %r: "
                 "got %i args, but wanted %i. (@cached cannot key off *args or "
-                "**kwargs)"
-                % (orig.__name__, len(all_args), num_args)
+                "**kwargs)" % (orig.__name__, len(all_args), num_args)
             )
 
         self.num_args = num_args
 
         # list of the names of the args used as the cache key
-        self.arg_names = all_args[1:num_args + 1]
+        self.arg_names = all_args[1 : num_args + 1]
 
         # self.arg_defaults is a map of arg name to its default value for each
         # argument that has a default value
         if arg_spec.defaults:
-            self.arg_defaults = dict(zip(
-                all_args[-len(arg_spec.defaults):],
-                arg_spec.defaults
-            ))
+            self.arg_defaults = dict(
+                zip(all_args[-len(arg_spec.defaults) :], arg_spec.defaults)
+            )
         else:
             self.arg_defaults = {}
 
         if "cache_context" in self.arg_names:
-            raise Exception(
-                "cache_context arg cannot be included among the cache keys"
-            )
+            raise Exception("cache_context arg cannot be included among the cache keys")
 
         self.add_cache_context = cache_context
 
@@ -304,12 +295,24 @@ class CacheDescriptor(_CacheDescriptorBase):
             ``cache_context``) to use as cache keys. Defaults to all named
             args of the function.
     """
-    def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
-                 inlineCallbacks=False, cache_context=False, iterable=False):
+
+    def __init__(
+        self,
+        orig,
+        max_entries=1000,
+        num_args=None,
+        tree=False,
+        inlineCallbacks=False,
+        cache_context=False,
+        iterable=False,
+    ):
 
         super(CacheDescriptor, self).__init__(
-            orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
-            cache_context=cache_context)
+            orig,
+            num_args=num_args,
+            inlineCallbacks=inlineCallbacks,
+            cache_context=cache_context,
+        )
 
         max_entries = int(max_entries * get_cache_factor_for(orig.__name__))
 
@@ -356,7 +359,9 @@ class CacheDescriptor(_CacheDescriptorBase):
                     return args[0]
                 else:
                     return self.arg_defaults[nm]
+
         else:
+
             def get_cache_key(args, kwargs):
                 return tuple(get_cache_key_gen(args, kwargs))
 
@@ -383,8 +388,7 @@ class CacheDescriptor(_CacheDescriptorBase):
 
             except KeyError:
                 ret = defer.maybeDeferred(
-                    logcontext.preserve_fn(self.function_to_call),
-                    obj, *args, **kwargs
+                    logcontext.preserve_fn(self.function_to_call), obj, *args, **kwargs
                 )
 
                 def onErr(f):
@@ -437,8 +441,9 @@ class CacheListDescriptor(_CacheDescriptorBase):
     results.
     """
 
-    def __init__(self, orig, cached_method_name, list_name, num_args=None,
-                 inlineCallbacks=False):
+    def __init__(
+        self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False
+    ):
         """
         Args:
             orig (function)
@@ -451,7 +456,8 @@ class CacheListDescriptor(_CacheDescriptorBase):
                 be wrapped by defer.inlineCallbacks
         """
         super(CacheListDescriptor, self).__init__(
-            orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
+            orig, num_args=num_args, inlineCallbacks=inlineCallbacks
+        )
 
         self.list_name = list_name
 
@@ -463,7 +469,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
         if self.list_name not in self.arg_names:
             raise Exception(
                 "Couldn't see arguments %r for %r."
-                % (self.list_name, cached_method_name,)
+                % (self.list_name, cached_method_name)
             )
 
     def __get__(self, obj, objtype=None):
@@ -494,8 +500,10 @@ class CacheListDescriptor(_CacheDescriptorBase):
             # If the cache takes a single arg then that is used as the key,
             # otherwise a tuple is used.
             if num_args == 1:
+
                 def arg_to_cache_key(arg):
                     return arg
+
             else:
                 keylist = list(keyargs)
 
@@ -505,8 +513,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
 
             for arg in list_args:
                 try:
-                    res = cache.get(arg_to_cache_key(arg),
-                                    callback=invalidate_callback)
+                    res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
                     if not isinstance(res, ObservableDeferred):
                         results[arg] = res
                     elif not res.has_succeeded():
@@ -554,18 +561,15 @@ class CacheListDescriptor(_CacheDescriptorBase):
                 args_to_call = dict(arg_dict)
                 args_to_call[self.list_name] = list(missing)
 
-                cached_defers.append(defer.maybeDeferred(
-                    logcontext.preserve_fn(self.function_to_call),
-                    **args_to_call
-                ).addCallbacks(complete_all, errback))
+                cached_defers.append(
+                    defer.maybeDeferred(
+                        logcontext.preserve_fn(self.function_to_call), **args_to_call
+                    ).addCallbacks(complete_all, errback)
+                )
 
             if cached_defers:
-                d = defer.gatherResults(
-                    cached_defers,
-                    consumeErrors=True,
-                ).addCallbacks(
-                    lambda _: results,
-                    unwrapFirstError
+                d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
+                    lambda _: results, unwrapFirstError
                 )
                 return logcontext.make_deferred_yieldable(d)
             else:
@@ -586,8 +590,9 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
         self.cache.invalidate(self.key)
 
 
-def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
-           iterable=False):
+def cached(
+    max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
+):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
@@ -598,8 +603,9 @@ def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
-                          cache_context=False, iterable=False):
+def cachedInlineCallbacks(
+    max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
+):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 6c0b5a4094..6834e6f3ae 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -35,6 +35,7 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
             there.
         value (dict): The full or partial dict value
     """
+
     def __len__(self):
         return len(self.value)
 
@@ -84,13 +85,15 @@ class DictionaryCache(object):
             self.metrics.inc_hits()
 
             if dict_keys is None:
-                return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value))
+                return DictionaryEntry(
+                    entry.full, entry.known_absent, dict(entry.value)
+                )
             else:
-                return DictionaryEntry(entry.full, entry.known_absent, {
-                    k: entry.value[k]
-                    for k in dict_keys
-                    if k in entry.value
-                })
+                return DictionaryEntry(
+                    entry.full,
+                    entry.known_absent,
+                    {k: entry.value[k] for k in dict_keys if k in entry.value},
+                )
 
         self.metrics.inc_misses()
         return DictionaryEntry(False, set(), {})
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index f369780277..cddf1ed515 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -28,8 +28,15 @@ SENTINEL = object()
 
 
 class ExpiringCache(object):
-    def __init__(self, cache_name, clock, max_len=0, expiry_ms=0,
-                 reset_expiry_on_get=False, iterable=False):
+    def __init__(
+        self,
+        cache_name,
+        clock,
+        max_len=0,
+        expiry_ms=0,
+        reset_expiry_on_get=False,
+        iterable=False,
+    ):
         """
         Args:
             cache_name (str): Name of this cache, used for logging.
@@ -67,8 +74,7 @@ class ExpiringCache(object):
 
         def f():
             return run_as_background_process(
-                "prune_cache_%s" % self._cache_name,
-                self._prune_cache,
+                "prune_cache_%s" % self._cache_name, self._prune_cache
             )
 
         self._clock.looping_call(f, self._expiry_ms / 2)
@@ -153,7 +159,9 @@ class ExpiringCache(object):
 
         logger.debug(
             "[%s] _prune_cache before: %d, after len: %d",
-            self._cache_name, begin_length, len(self)
+            self._cache_name,
+            begin_length,
+            len(self),
         )
 
     def __len__(self):
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index b684f24e7b..1536cb64f3 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -49,8 +49,15 @@ class LruCache(object):
     Can also set callbacks on objects when getting/setting which are fired
     when that key gets invalidated/evicted.
     """
-    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
-                 evicted_callback=None):
+
+    def __init__(
+        self,
+        max_size,
+        keylen=1,
+        cache_type=dict,
+        size_callback=None,
+        evicted_callback=None,
+    ):
         """
         Args:
             max_size (int):
@@ -93,9 +100,12 @@ class LruCache(object):
 
         cached_cache_len = [0]
         if size_callback is not None:
+
             def cache_len():
                 return cached_cache_len[0]
+
         else:
+
             def cache_len():
                 return len(cache)
 
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index afb03b2e1b..cbe54d45dd 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -35,12 +35,10 @@ class ResponseCache(object):
         self.pending_result_cache = {}  # Requests that haven't finished yet.
 
         self.clock = hs.get_clock()
-        self.timeout_sec = timeout_ms / 1000.
+        self.timeout_sec = timeout_ms / 1000.0
 
         self._name = name
-        self._metrics = register_cache(
-            "response_cache", name, self
-        )
+        self._metrics = register_cache("response_cache", name, self)
 
     def size(self):
         return len(self.pending_result_cache)
@@ -100,8 +98,7 @@ class ResponseCache(object):
         def remove(r):
             if self.timeout_sec:
                 self.clock.call_later(
-                    self.timeout_sec,
-                    self.pending_result_cache.pop, key, None,
+                    self.timeout_sec, self.pending_result_cache.pop, key, None
                 )
             else:
                 self.pending_result_cache.pop(key, None)
@@ -140,21 +137,22 @@ class ResponseCache(object):
 
             *args: positional parameters to pass to the callback, if it is used
 
-            **kwargs: named paramters to pass to the callback, if it is used
+            **kwargs: named parameters to pass to the callback, if it is used
 
         Returns:
             twisted.internet.defer.Deferred: yieldable result
         """
         result = self.get(key)
         if not result:
-            logger.info("[%s]: no cached result for [%s], calculating new one",
-                        self._name, key)
+            logger.info(
+                "[%s]: no cached result for [%s], calculating new one", self._name, key
+            )
             d = run_in_background(callback, *args, **kwargs)
             result = self.set(key, d)
         elif not isinstance(result, defer.Deferred) or result.called:
-            logger.info("[%s]: using completed cached result for [%s]",
-                        self._name, key)
+            logger.info("[%s]: using completed cached result for [%s]", self._name, key)
         else:
-            logger.info("[%s]: using incomplete cached result for [%s]",
-                        self._name, key)
+            logger.info(
+                "[%s]: using incomplete cached result for [%s]", self._name, key
+            )
         return make_deferred_yieldable(result)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 625aedc940..235f64049c 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -77,9 +77,8 @@ class StreamChangeCache(object):
 
         if stream_pos >= self._earliest_known_stream_pos:
             changed_entities = {
-                self._cache[k] for k in self._cache.islice(
-                    start=self._cache.bisect_right(stream_pos),
-                )
+                self._cache[k]
+                for k in self._cache.islice(start=self._cache.bisect_right(stream_pos))
             }
 
             result = changed_entities.intersection(entities)
@@ -114,8 +113,10 @@ class StreamChangeCache(object):
         assert type(stream_pos) is int
 
         if stream_pos >= self._earliest_known_stream_pos:
-            return [self._cache[k] for k in self._cache.islice(
-                start=self._cache.bisect_right(stream_pos))]
+            return [
+                self._cache[k]
+                for k in self._cache.islice(start=self._cache.bisect_right(stream_pos))
+            ]
         else:
             return None
 
@@ -136,7 +137,7 @@ class StreamChangeCache(object):
             while len(self._cache) > self._max_size:
                 k, r = self._cache.popitem(0)
                 self._earliest_known_stream_pos = max(
-                    k, self._earliest_known_stream_pos,
+                    k, self._earliest_known_stream_pos
                 )
                 self._entity_to_key.pop(r, None)
 
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index dd4c9e6067..9a72218d85 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -9,6 +9,7 @@ class TreeCache(object):
     efficiently.
     Keys must be tuples.
     """
+
     def __init__(self):
         self.size = 0
         self.root = {}
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 5ba1862506..2af8ca43b1 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -155,6 +155,7 @@ class TTLCache(object):
 @attr.s(frozen=True, slots=True)
 class _CacheEntry(object):
     """TTLCache entry"""
+
     # expiry_time is the first attribute, so that entries are sorted by expiry.
     expiry_time = attr.ib()
     key = attr.ib()