summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2017-05-22 15:04:42 +0100
committerErik Johnston <erik@matrix.org>2017-05-22 15:04:42 +0100
commite3417a06e23c532e6502bdcdcaedac826e231d69 (patch)
tree607902ecd541f10f20d70dd83910af7ca938488a /synapse/util/caches/descriptors.py
parentMerge pull request #2236 from matrix-org/erikj/invalidation (diff)
downloadsynapse-e3417a06e23c532e6502bdcdcaedac826e231d69.tar.xz
Update list cache to handle one arg case
We update the normal cache descriptors to handle caches with a single
argument specially so that the key wasn't a 1-tuple. We need to update
the cache list to be aware of this.
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py50
1 files changed, 33 insertions, 17 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 48dcbafeef..77a0d8e35d 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -404,6 +404,7 @@ class CacheDescriptor(_CacheDescriptorBase):
 
         wrapped.invalidate_all = cache.invalidate_all
         wrapped.cache = cache
+        wrapped.num_args = self.num_args
 
         obj.__dict__[self.orig.__name__] = wrapped
 
@@ -451,8 +452,9 @@ class CacheListDescriptor(_CacheDescriptorBase):
             )
 
     def __get__(self, obj, objtype=None):
-
-        cache = getattr(obj, self.cached_method_name).cache
+        cached_method = getattr(obj, self.cached_method_name)
+        cache = cached_method.cache
+        num_args = cached_method.num_args
 
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
@@ -470,11 +472,14 @@ class CacheListDescriptor(_CacheDescriptorBase):
             cached_defers = {}
             missing = []
             for arg in list_args:
-                key = list(keyargs)
-                key[self.list_pos] = arg
-
                 try:
-                    res = cache.get(tuple(key), callback=invalidate_callback)
+                    if num_args == 1:
+                        res = cache.get(arg, callback=invalidate_callback)
+                    else:
+                        key = list(keyargs)
+                        key[self.list_pos] = arg
+                        res = cache.get(tuple(key), callback=invalidate_callback)
+
                     if not isinstance(res, ObservableDeferred):
                         results[arg] = res
                     elif not res.has_succeeded():
@@ -505,17 +510,28 @@ class CacheListDescriptor(_CacheDescriptorBase):
 
                     observer = ObservableDeferred(observer)
 
-                    key = list(keyargs)
-                    key[self.list_pos] = arg
-                    cache.set(
-                        tuple(key), observer,
-                        callback=invalidate_callback
-                    )
-
-                    def invalidate(f, key):
-                        cache.invalidate(key)
-                        return f
-                    observer.addErrback(invalidate, tuple(key))
+                    if num_args == 1:
+                        cache.set(
+                            arg, observer,
+                            callback=invalidate_callback
+                        )
+
+                        def invalidate(f, key):
+                            cache.invalidate(key)
+                            return f
+                        observer.addErrback(invalidate, arg)
+                    else:
+                        key = list(keyargs)
+                        key[self.list_pos] = arg
+                        cache.set(
+                            tuple(key), observer,
+                            callback=invalidate_callback
+                        )
+
+                        def invalidate(f, key):
+                            cache.invalidate(key)
+                            return f
+                        observer.addErrback(invalidate, tuple(key))
 
                     res = observer.observe()
                     res.addCallback(lambda r, arg: (arg, r), arg)