diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index e7a74d3da8..e93ff40dc0 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -148,8 +148,8 @@ class CacheDescriptor(object):
@cachedInlineCallbacks(cache_context=True)
def foo(self, key, cache_context):
- r1 = yield self.bar1(key, cache_context=cache_context)
- r2 = yield self.bar2(key, cache_context=cache_context)
+ r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
+ r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
defer.returnValue(r1 + r2)
"""
@@ -208,11 +208,7 @@ class CacheDescriptor(object):
def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
- cache_context = kwargs.pop("cache_context", None)
- if cache_context:
- context_callback = cache_context.invalidate
- else:
- context_callback = None
+ invalidate_callback = kwargs.pop("on_invalidate", None)
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
@@ -226,7 +222,7 @@ class CacheDescriptor(object):
self_context.key = cache_key
try:
- cached_result_d = cache.get(cache_key, callback=context_callback)
+ cached_result_d = cache.get(cache_key, callback=invalidate_callback)
observer = cached_result_d.observe()
if DEBUG_CACHES:
@@ -263,7 +259,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
- cache.update(sequence, cache_key, ret, callback=context_callback)
+ cache.update(sequence, cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe())
@@ -332,11 +328,9 @@ class CacheListDescriptor(object):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
- cache_context = kwargs.pop("cache_context", None)
- if cache_context:
- context_callback = cache_context.invalidate
- else:
- context_callback = None
+ # If we're passed a cache_context then we'll want to call its invalidate()
+ # whenever we are invalidated
+ invalidate_callback = kwargs.pop("on_invalidate", None)
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
@@ -352,7 +346,7 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg
try:
- res = cache.get(tuple(key), callback=context_callback)
+ res = cache.get(tuple(key), callback=invalidate_callback)
if not res.has_succeeded():
res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
@@ -388,7 +382,7 @@ class CacheListDescriptor(object):
key[self.list_pos] = arg
cache.update(
sequence, tuple(key), observer,
- callback=context_callback
+ callback=invalidate_callback
)
def invalidate(f, key):
|