summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/push_rule.py6
-rw-r--r--synapse/util/caches/descriptors.py26
-rw-r--r--tests/storage/test__base.py4
3 files changed, 15 insertions, 21 deletions
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 247dd15694..78334a98cf 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -156,14 +156,14 @@ class PushRuleStore(SQLBaseStore):
         # users in the room who have pushers need to get push rules run because
         # that's how their pushers work
         if_users_with_pushers = yield self.get_if_users_have_pushers(
-            local_users_in_room, cache_context=cache_context,
+            local_users_in_room, on_invalidate=cache_context.invalidate,
         )
         user_ids = set(
             uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
         )
 
         users_with_receipts = yield self.get_users_with_read_receipts_in_room(
-            room_id, cache_context=cache_context,
+            room_id, on_invalidate=cache_context.invalidate,
         )
 
         # any users with pushers must be ours: they have pushers
@@ -172,7 +172,7 @@ class PushRuleStore(SQLBaseStore):
                 user_ids.add(uid)
 
         rules_by_user = yield self.bulk_get_push_rules(
-            user_ids, cache_context=cache_context
+            user_ids, on_invalidate=cache_context.invalidate,
         )
 
         rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index e7a74d3da8..e93ff40dc0 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -148,8 +148,8 @@ class CacheDescriptor(object):
 
         @cachedInlineCallbacks(cache_context=True)
         def foo(self, key, cache_context):
-            r1 = yield self.bar1(key, cache_context=cache_context)
-            r2 = yield self.bar2(key, cache_context=cache_context)
+            r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
+            r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
             defer.returnValue(r1 + r2)
 
     """
@@ -208,11 +208,7 @@ class CacheDescriptor(object):
         def wrapped(*args, **kwargs):
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
-            cache_context = kwargs.pop("cache_context", None)
-            if cache_context:
-                context_callback = cache_context.invalidate
-            else:
-                context_callback = None
+            invalidate_callback = kwargs.pop("on_invalidate", None)
 
             # Add our own `cache_context` to argument list if the wrapped function
             # has asked for one
@@ -226,7 +222,7 @@ class CacheDescriptor(object):
             self_context.key = cache_key
 
             try:
-                cached_result_d = cache.get(cache_key, callback=context_callback)
+                cached_result_d = cache.get(cache_key, callback=invalidate_callback)
 
                 observer = cached_result_d.observe()
                 if DEBUG_CACHES:
@@ -263,7 +259,7 @@ class CacheDescriptor(object):
                 ret.addErrback(onErr)
 
                 ret = ObservableDeferred(ret, consumeErrors=True)
-                cache.update(sequence, cache_key, ret, callback=context_callback)
+                cache.update(sequence, cache_key, ret, callback=invalidate_callback)
 
                 return preserve_context_over_deferred(ret.observe())
 
@@ -332,11 +328,9 @@ class CacheListDescriptor(object):
 
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
-            cache_context = kwargs.pop("cache_context", None)
-            if cache_context:
-                context_callback = cache_context.invalidate
-            else:
-                context_callback = None
+            # If we're passed a cache_context then we'll want to call its invalidate()
+            # whenever we are invalidated
+            invalidate_callback = kwargs.pop("on_invalidate", None)
 
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
             keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
@@ -352,7 +346,7 @@ class CacheListDescriptor(object):
                 key[self.list_pos] = arg
 
                 try:
-                    res = cache.get(tuple(key), callback=context_callback)
+                    res = cache.get(tuple(key), callback=invalidate_callback)
                     if not res.has_succeeded():
                         res = res.observe()
                         res.addCallback(lambda r, arg: (arg, r), arg)
@@ -388,7 +382,7 @@ class CacheListDescriptor(object):
                     key[self.list_pos] = arg
                     cache.update(
                         sequence, tuple(key), observer,
-                        callback=context_callback
+                        callback=invalidate_callback
                     )
 
                     def invalidate(f, key):
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index eab0c8d219..4fc3639de0 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -214,7 +214,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
             @cached(cache_context=True)
             def func2(self, key, cache_context):
                 callcount2[0] += 1
-                return self.func(key, cache_context=cache_context)
+                return self.func(key, on_invalidate=cache_context.invalidate)
 
         a = A()
         yield a.func2("foo")
@@ -247,7 +247,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
             @cached(cache_context=True)
             def func2(self, key, cache_context):
                 callcount2[0] += 1
-                return self.func(key, cache_context=cache_context)
+                return self.func(key, on_invalidate=cache_context.invalidate)
 
         a = A()
         yield a.func2("foo")