diff options
author | Erik Johnston <erik@matrix.org> | 2017-05-22 15:04:42 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2017-05-22 15:04:42 +0100 |
commit | e3417a06e23c532e6502bdcdcaedac826e231d69 (patch) | |
tree | 607902ecd541f10f20d70dd83910af7ca938488a /synapse/util | |
parent | Merge pull request #2236 from matrix-org/erikj/invalidation (diff) | |
download | synapse-e3417a06e23c532e6502bdcdcaedac826e231d69.tar.xz |
Update list cache to handle one arg case
We update the normal cache descriptors to handle caches with a single argument specially so that the key wasn't a 1-tuple. We need to update the cache list to be aware of this.
Diffstat (limited to 'synapse/util')
-rw-r--r-- | synapse/util/caches/descriptors.py | 50 |
1 files changed, 33 insertions, 17 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 48dcbafeef..77a0d8e35d 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -404,6 +404,7 @@ class CacheDescriptor(_CacheDescriptorBase): wrapped.invalidate_all = cache.invalidate_all wrapped.cache = cache + wrapped.num_args = self.num_args obj.__dict__[self.orig.__name__] = wrapped @@ -451,8 +452,9 @@ class CacheListDescriptor(_CacheDescriptorBase): ) def __get__(self, obj, objtype=None): - - cache = getattr(obj, self.cached_method_name).cache + cached_method = getattr(obj, self.cached_method_name) + cache = cached_method.cache + num_args = cached_method.num_args @functools.wraps(self.orig) def wrapped(*args, **kwargs): @@ -470,11 +472,14 @@ class CacheListDescriptor(_CacheDescriptorBase): cached_defers = {} missing = [] for arg in list_args: - key = list(keyargs) - key[self.list_pos] = arg - try: - res = cache.get(tuple(key), callback=invalidate_callback) + if num_args == 1: + res = cache.get(arg, callback=invalidate_callback) + else: + key = list(keyargs) + key[self.list_pos] = arg + res = cache.get(tuple(key), callback=invalidate_callback) + if not isinstance(res, ObservableDeferred): results[arg] = res elif not res.has_succeeded(): @@ -505,17 +510,28 @@ class CacheListDescriptor(_CacheDescriptorBase): observer = ObservableDeferred(observer) - key = list(keyargs) - key[self.list_pos] = arg - cache.set( - tuple(key), observer, - callback=invalidate_callback - ) - - def invalidate(f, key): - cache.invalidate(key) - return f - observer.addErrback(invalidate, tuple(key)) + if num_args == 1: + cache.set( + arg, observer, + callback=invalidate_callback + ) + + def invalidate(f, key): + cache.invalidate(key) + return f + observer.addErrback(invalidate, arg) + else: + key = list(keyargs) + key[self.list_pos] = arg + cache.set( + tuple(key), observer, + callback=invalidate_callback + ) + + def invalidate(f, key): + cache.invalidate(key) + return f + observer.addErrback(invalidate, tuple(key)) res = observer.observe() res.addCallback(lambda r, arg: (arg, r), arg) |