summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/push_rule.py2
-rw-r--r--synapse/util/caches/descriptors.py35
-rw-r--r--tests/storage/test__base.py4
3 files changed, 30 insertions, 11 deletions
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index ca929bc239..247dd15694 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -134,7 +134,7 @@ class PushRuleStore(SQLBaseStore):
 
         return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
 
-    @cachedInlineCallbacks(num_args=2)
+    @cachedInlineCallbacks(num_args=2, cache_context=True)
     def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
                                       cache_context):
         # We don't use `state_group`, its there so that we can cache based
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index c38f01ead0..e7a74d3da8 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -146,7 +146,7 @@ class CacheDescriptor(object):
     invalidated) by adding a special "cache_context" argument to the function
     and passing that as a kwarg to all caches called. For example::
 
-        @cachedInlineCallbacks()
+        @cachedInlineCallbacks(cache_context=True)
         def foo(self, key, cache_context):
             r1 = yield self.bar1(key, cache_context=cache_context)
             r2 = yield self.bar2(key, cache_context=cache_context)
@@ -154,7 +154,7 @@ class CacheDescriptor(object):
 
     """
     def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
-                 inlineCallbacks=False):
+                 inlineCallbacks=False, cache_context=False):
         max_entries = int(max_entries * CACHE_SIZE_FACTOR)
 
         self.orig = orig
@@ -171,15 +171,28 @@ class CacheDescriptor(object):
         all_args = inspect.getargspec(orig)
         self.arg_names = all_args.args[1:num_args + 1]
 
-        if "cache_context" in self.arg_names:
-            self.arg_names.remove("cache_context")
+        if "cache_context" in all_args.args:
+            if not cache_context:
+                raise ValueError(
+                    "Cannot have a 'cache_context' arg without setting"
+                    " cache_context=True"
+                )
+            try:
+                self.arg_names.remove("cache_context")
+            except ValueError:
+                pass
+        elif cache_context:
+            raise ValueError(
+                "Cannot have cache_context=True without having an arg"
+                " named `cache_context`"
+            )
 
-        self.add_cache_context = "cache_context" in all_args.args
+        self.add_cache_context = cache_context
 
         if len(self.arg_names) < self.num_args:
             raise Exception(
                 "Not enough explicit positional arguments to key off of for %r."
-                " (@cached cannot key off of *args or **kwars)"
+                " (@cached cannot key off of *args or **kwargs)"
                 % (orig.__name__,)
             )
 
@@ -193,12 +206,16 @@ class CacheDescriptor(object):
 
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
+            # If we're passed a cache_context then we'll want to call its invalidate()
+            # whenever we are invalidated
             cache_context = kwargs.pop("cache_context", None)
             if cache_context:
                 context_callback = cache_context.invalidate
             else:
                 context_callback = None
 
+            # Add our own `cache_context` to argument list if the wrapped function
+            # has asked for one
             self_context = _CacheContext(cache, None)
             if self.add_cache_context:
                 kwargs["cache_context"] = self_context
@@ -414,22 +431,24 @@ class _CacheContext(object):
         self.cache.invalidate(self.key)
 
 
-def cached(max_entries=1000, num_args=1, tree=False):
+def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
         tree=tree,
+        cache_context=cache_context,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
         tree=tree,
         inlineCallbacks=True,
+        cache_context=cache_context,
     )
 
 
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index ed074ce9ec..eab0c8d219 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -211,7 +211,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
                 callcount[0] += 1
                 return key
 
-            @cached()
+            @cached(cache_context=True)
             def func2(self, key, cache_context):
                 callcount2[0] += 1
                 return self.func(key, cache_context=cache_context)
@@ -244,7 +244,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
                 callcount[0] += 1
                 return key
 
-            @cached()
+            @cached(cache_context=True)
             def func2(self, key, cache_context):
                 callcount2[0] += 1
                 return self.func(key, cache_context=cache_context)