diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index c38f01ead0..e7a74d3da8 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -146,7 +146,7 @@ class CacheDescriptor(object):
invalidated) by adding a special "cache_context" argument to the function
and passing that as a kwarg to all caches called. For example::
- @cachedInlineCallbacks()
+ @cachedInlineCallbacks(cache_context=True)
def foo(self, key, cache_context):
r1 = yield self.bar1(key, cache_context=cache_context)
r2 = yield self.bar2(key, cache_context=cache_context)
@@ -154,7 +154,7 @@ class CacheDescriptor(object):
"""
def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
- inlineCallbacks=False):
+ inlineCallbacks=False, cache_context=False):
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig
@@ -171,15 +171,28 @@ class CacheDescriptor(object):
all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1]
- if "cache_context" in self.arg_names:
- self.arg_names.remove("cache_context")
+ if "cache_context" in all_args.args:
+ if not cache_context:
+ raise ValueError(
+ "Cannot have a 'cache_context' arg without setting"
+ " cache_context=True"
+ )
+ try:
+ self.arg_names.remove("cache_context")
+ except ValueError:
+ pass
+ elif cache_context:
+ raise ValueError(
+ "Cannot have cache_context=True without having an arg"
+ " named `cache_context`"
+ )
- self.add_cache_context = "cache_context" in all_args.args
+ self.add_cache_context = cache_context
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
- " (@cached cannot key off of *args or **kwars)"
+ " (@cached cannot key off of *args or **kwargs)"
% (orig.__name__,)
)
@@ -193,12 +206,16 @@ class CacheDescriptor(object):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
+ # If we're passed a cache_context then we'll want to call its invalidate()
+ # whenever we are invalidated
cache_context = kwargs.pop("cache_context", None)
if cache_context:
context_callback = cache_context.invalidate
else:
context_callback = None
+ # Add our own `cache_context` to argument list if the wrapped function
+ # has asked for one
self_context = _CacheContext(cache, None)
if self.add_cache_context:
kwargs["cache_context"] = self_context
@@ -414,22 +431,24 @@ class _CacheContext(object):
self.cache.invalidate(self.key)
-def cached(max_entries=1000, num_args=1, tree=False):
+def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
tree=tree,
+ cache_context=cache_context,
)
-def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
tree=tree,
inlineCallbacks=True,
+ cache_context=cache_context,
)
|