From b770435389a9c827582884912b0a2761d0eed812 Mon Sep 17 00:00:00 2001
From: Erik Johnston <erik@matrix.org>
Date: Fri, 19 Aug 2016 10:19:29 +0100
Subject: Make get_new_events_for_appservice use indices

---
 synapse/storage/appservice.py | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

(limited to 'synapse/storage')

diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index b496b918b7..a854a87eab 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -366,8 +366,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
         def get_new_events_for_appservice_txn(txn):
             sql = (
                 "SELECT e.stream_ordering, e.event_id"
-                " FROM events AS e, appservice_stream_position AS a"
-                " WHERE a.stream_ordering < e.stream_ordering AND e.stream_ordering <= ?"
+                " FROM events AS e"
+                " WHERE"
+                " (SELECT stream_ordering FROM appservice_stream_position)"
+                "     < e.stream_ordering"
+                " AND e.stream_ordering <= ?"
                 " ORDER BY e.stream_ordering ASC"
                 " LIMIT ?"
             )
-- 
cgit 1.5.1


From ba214a5e325adbf8ab430cb15f55d2c7544eba8b Mon Sep 17 00:00:00 2001
From: Erik Johnston <erik@matrix.org>
Date: Fri, 19 Aug 2016 11:59:29 +0100
Subject: Remove lru option

---
 synapse/storage/_base.py              |  2 +-
 synapse/storage/event_push_actions.py |  2 +-
 synapse/storage/push_rule.py          |  4 ++--
 synapse/storage/pusher.py             |  2 +-
 synapse/storage/receipts.py           |  2 +-
 synapse/storage/signatures.py         |  2 +-
 synapse/storage/state.py              |  4 ++--
 synapse/util/caches/descriptors.py    | 31 ++++++++-----------------------
 tests/storage/test__base.py           |  2 +-
 9 files changed, 18 insertions(+), 33 deletions(-)

(limited to 'synapse/storage')

diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 029f6612e6..49fa8614f2 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -166,7 +166,7 @@ class SQLBaseStore(object):
         self._txn_perf_counters = PerformanceCounters()
         self._get_event_counters = PerformanceCounters()
 
-        self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
+        self._get_event_cache = Cache("*getEvent*", keylen=3,
                                       max_entries=hs.config.event_cache_size)
 
         self._state_group_cache = DictionaryCache(
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index df4000d0da..c65c9c9c47 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -56,7 +56,7 @@ class EventPushActionsStore(SQLBaseStore):
             )
         self._simple_insert_many_txn(txn, "event_push_actions", values)
 
-    @cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000)
+    @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
     def get_unread_event_push_actions_by_room_for_user(
             self, room_id, user_id, last_read_event_id
     ):
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 8183b7f1b0..86e4a3a81d 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -48,7 +48,7 @@ def _load_rules(rawrules, enabled_map):
 
 
 class PushRuleStore(SQLBaseStore):
-    @cachedInlineCallbacks(lru=True)
+    @cachedInlineCallbacks()
     def get_push_rules_for_user(self, user_id):
         rows = yield self._simple_select_list(
             table="push_rules",
@@ -72,7 +72,7 @@ class PushRuleStore(SQLBaseStore):
 
         defer.returnValue(rules)
 
-    @cachedInlineCallbacks(lru=True)
+    @cachedInlineCallbacks()
     def get_push_rules_enabled_for_user(self, user_id):
         results = yield self._simple_select_list(
             table="push_rules_enable",
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index a7d7c54d7e..8f5f8f24a9 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -135,7 +135,7 @@ class PusherStore(SQLBaseStore):
             "get_all_updated_pushers", get_all_updated_pushers_txn
         )
 
-    @cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000)
+    @cachedInlineCallbacks(num_args=1, max_entries=15000)
     def get_if_user_has_pusher(self, user_id):
         result = yield self._simple_select_many_batch(
             table='pushers',
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 8c26f39fbb..3ad916103f 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -120,7 +120,7 @@ class ReceiptsStore(SQLBaseStore):
 
         defer.returnValue([ev for res in results.values() for ev in res])
 
-    @cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True)
+    @cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True)
     def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
         """Get receipts for a single room for sending to clients.
 
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index ea6823f18d..e1dca927d7 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -25,7 +25,7 @@ from synapse.util.caches.descriptors import cached, cachedList
 class SignatureStore(SQLBaseStore):
     """Persistence for event signatures and hashes"""
 
-    @cached(lru=True)
+    @cached()
     def get_event_reference_hash(self, event_id):
         return self._get_event_reference_hashes_txn(event_id)
 
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 5b743db67a..0e8fa93e1f 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -174,7 +174,7 @@ class StateStore(SQLBaseStore):
             return [r[0] for r in results]
         return self.runInteraction("get_current_state_for_key", f)
 
-    @cached(num_args=2, lru=True, max_entries=1000)
+    @cached(num_args=2, max_entries=1000)
     def _get_state_group_from_group(self, group, types):
         raise NotImplementedError()
 
@@ -272,7 +272,7 @@ class StateStore(SQLBaseStore):
         state_map = yield self.get_state_for_events([event_id], types)
         defer.returnValue(state_map[event_id])
 
-    @cached(num_args=2, lru=True, max_entries=10000)
+    @cached(num_args=2, max_entries=10000)
     def _get_state_group_for_event(self, room_id, event_id):
         return self._simple_select_one_onecol(
             table="event_to_state_groups",
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 5cd277f2f2..c38f01ead0 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -26,8 +26,6 @@ from . import DEBUG_CACHES, register_cache
 
 from twisted.internet import defer
 
-from collections import OrderedDict
-
 import os
 import functools
 import inspect
@@ -54,16 +52,11 @@ class Cache(object):
         "metrics",
     )
 
-    def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
-        if True:
-            cache_type = TreeCache if tree else dict
-            self.cache = LruCache(
-                max_size=max_entries, keylen=keylen, cache_type=cache_type
-            )
-            self.max_entries = None
-        else:
-            self.cache = OrderedDict()
-            self.max_entries = max_entries
+    def __init__(self, name, max_entries=1000, keylen=1, tree=False):
+        cache_type = TreeCache if tree else dict
+        self.cache = LruCache(
+            max_size=max_entries, keylen=keylen, cache_type=cache_type
+        )
 
         self.name = name
         self.keylen = keylen
@@ -102,10 +95,6 @@ class Cache(object):
             self.prefill(key, value, callback=callback)
 
     def prefill(self, key, value, callback=None):
-        if self.max_entries is not None:
-            while len(self.cache) >= self.max_entries:
-                self.cache.popitem(last=False, callback=None)
-
         self.cache.set(key, value, callback=callback)
 
     def invalidate(self, key):
@@ -164,7 +153,7 @@ class CacheDescriptor(object):
             defer.returnValue(r1 + r2)
 
     """
-    def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
+    def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
                  inlineCallbacks=False):
         max_entries = int(max_entries * CACHE_SIZE_FACTOR)
 
@@ -177,7 +166,6 @@ class CacheDescriptor(object):
 
         self.max_entries = max_entries
         self.num_args = num_args
-        self.lru = lru
         self.tree = tree
 
         all_args = inspect.getargspec(orig)
@@ -200,7 +188,6 @@ class CacheDescriptor(object):
             name=self.orig.__name__,
             max_entries=self.max_entries,
             keylen=self.num_args,
-            lru=self.lru,
             tree=self.tree,
         )
 
@@ -427,22 +414,20 @@ class _CacheContext(object):
         self.cache.invalidate(self.key)
 
 
-def cached(max_entries=1000, num_args=1, lru=True, tree=False):
+def cached(max_entries=1000, num_args=1, tree=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
-        lru=lru,
         tree=tree,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
-        lru=lru,
         tree=tree,
         inlineCallbacks=True,
     )
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 9d99eea8d0..ed074ce9ec 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -72,7 +72,7 @@ class CacheTestCase(unittest.TestCase):
         cache.get(3)
 
     def test_eviction_lru(self):
-        cache = Cache("test", max_entries=2, lru=True)
+        cache = Cache("test", max_entries=2)
 
         cache.prefill(1, "one")
         cache.prefill(2, "two")
-- 
cgit 1.5.1


From f164fd922024308e702269a881328f7de980e9eb Mon Sep 17 00:00:00 2001
From: Erik Johnston <erik@matrix.org>
Date: Fri, 19 Aug 2016 14:07:27 +0100
Subject: Move _bulk_get_push_rules_for_room to storage layer

---
 synapse/push/action_generator.py         |  2 +-
 synapse/push/bulk_push_rule_evaluator.py | 41 +++++------------------
 synapse/storage/push_rule.py             | 56 ++++++++++++++++++++++++++++++++
 3 files changed, 65 insertions(+), 34 deletions(-)

(limited to 'synapse/storage')

diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index b2c94bfaac..ed2ccc4dfb 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -40,7 +40,7 @@ class ActionGenerator:
     def handle_push_actions_for_event(self, event, context):
         with Measure(self.clock, "evaluator_for_event"):
             bulk_evaluator = yield evaluator_for_event(
-                event, self.hs, self.store, context.current_state
+                event, self.hs, self.store, context.state_group, context.current_state
             )
 
         with Measure(self.clock, "action_for_event_by_user"):
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 756e5da513..004eded61f 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -36,35 +36,11 @@ def _get_rules(room_id, user_ids, store):
 
 
 @defer.inlineCallbacks
-def evaluator_for_event(event, hs, store, current_state):
-    room_id = event.room_id
-    # We also will want to generate notifs for other people in the room so
-    # their unread countss are correct in the event stream, but to avoid
-    # generating them for bot / AS users etc, we only do so for people who've
-    # sent a read receipt into the room.
-
-    local_users_in_room = set(
-        e.state_key for e in current_state.values()
-        if e.type == EventTypes.Member and e.membership == Membership.JOIN
-        and hs.is_mine_id(e.state_key)
+def evaluator_for_event(event, hs, store, state_group, current_state):
+    rules_by_user = yield store.bulk_get_push_rules_for_room(
+        event.room_id, state_group, current_state
     )
 
-    # users in the room who have pushers need to get push rules run because
-    # that's how their pushers work
-    if_users_with_pushers = yield store.get_if_users_have_pushers(
-        local_users_in_room
-    )
-    user_ids = set(
-        uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
-    )
-
-    users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
-
-    # any users with pushers must be ours: they have pushers
-    for uid in users_with_receipts:
-        if uid in local_users_in_room:
-            user_ids.add(uid)
-
     # if this event is an invite event, we may need to run rules for the user
     # who's been invited, otherwise they won't get told they've been invited
     if event.type == 'm.room.member' and event.content['membership'] == 'invite':
@@ -72,12 +48,12 @@ def evaluator_for_event(event, hs, store, current_state):
         if invited_user and hs.is_mine_id(invited_user):
             has_pusher = yield store.user_has_pusher(invited_user)
             if has_pusher:
-                user_ids.add(invited_user)
-
-    rules_by_user = yield _get_rules(room_id, user_ids, store)
+                rules_by_user[invited_user] = yield store.get_push_rules_for_user(
+                    invited_user
+                )
 
     defer.returnValue(BulkPushRuleEvaluator(
-        room_id, rules_by_user, user_ids, store
+        event.room_id, rules_by_user, store
     ))
 
 
@@ -90,10 +66,9 @@ class BulkPushRuleEvaluator:
     the same logic to run the actual rules, but could be optimised further
     (see https://matrix.org/jira/browse/SYN-562)
     """
-    def __init__(self, room_id, rules_by_user, users_in_room, store):
+    def __init__(self, room_id, rules_by_user, store):
         self.room_id = room_id
         self.rules_by_user = rules_by_user
-        self.users_in_room = users_in_room
         self.store = store
 
     @defer.inlineCallbacks
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 86e4a3a81d..ca929bc239 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -16,6 +16,7 @@
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
 from synapse.push.baserules import list_with_base_rules
+from synapse.api.constants import EventTypes, Membership
 from twisted.internet import defer
 
 import logging
@@ -123,6 +124,61 @@ class PushRuleStore(SQLBaseStore):
 
         defer.returnValue(results)
 
+    def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
+
+    @cachedInlineCallbacks(num_args=2)
+    def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
+                                      cache_context):
+        # We don't use `state_group`, its there so that we can cache based
+        # on it. However, its important that its never None, since two current_state's
+        # with a state_group of None are likely to be different.
+        # See bulk_get_push_rules_for_room for how we work around this.
+        assert state_group is not None
+
+        # We also will want to generate notifs for other people in the room so
+        # their unread countss are correct in the event stream, but to avoid
+        # generating them for bot / AS users etc, we only do so for people who've
+        # sent a read receipt into the room.
+        local_users_in_room = set(
+            e.state_key for e in current_state.values()
+            if e.type == EventTypes.Member and e.membership == Membership.JOIN
+            and self.hs.is_mine_id(e.state_key)
+        )
+
+        # users in the room who have pushers need to get push rules run because
+        # that's how their pushers work
+        if_users_with_pushers = yield self.get_if_users_have_pushers(
+            local_users_in_room, cache_context=cache_context,
+        )
+        user_ids = set(
+            uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
+        )
+
+        users_with_receipts = yield self.get_users_with_read_receipts_in_room(
+            room_id, cache_context=cache_context,
+        )
+
+        # any users with pushers must be ours: they have pushers
+        for uid in users_with_receipts:
+            if uid in local_users_in_room:
+                user_ids.add(uid)
+
+        rules_by_user = yield self.bulk_get_push_rules(
+            user_ids, cache_context=cache_context
+        )
+
+        rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
+
+        defer.returnValue(rules_by_user)
+
     @cachedList(cached_method_name="get_push_rules_enabled_for_user",
                 list_name="user_ids", num_args=1, inlineCallbacks=True)
     def bulk_get_push_rules_enabled(self, user_ids):
-- 
cgit 1.5.1


From dc76a3e909535d99f0b6b4a76279a14685324dc4 Mon Sep 17 00:00:00 2001
From: Erik Johnston <erik@matrix.org>
Date: Fri, 19 Aug 2016 15:02:38 +0100
Subject: Make cache_context an explicit option

---
 synapse/storage/push_rule.py       |  2 +-
 synapse/util/caches/descriptors.py | 35 +++++++++++++++++++++++++++--------
 tests/storage/test__base.py        |  4 ++--
 3 files changed, 30 insertions(+), 11 deletions(-)

(limited to 'synapse/storage')

diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index ca929bc239..247dd15694 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -134,7 +134,7 @@ class PushRuleStore(SQLBaseStore):
 
         return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
 
-    @cachedInlineCallbacks(num_args=2)
+    @cachedInlineCallbacks(num_args=2, cache_context=True)
     def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
                                       cache_context):
         # We don't use `state_group`, its there so that we can cache based
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index c38f01ead0..e7a74d3da8 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -146,7 +146,7 @@ class CacheDescriptor(object):
     invalidated) by adding a special "cache_context" argument to the function
     and passing that as a kwarg to all caches called. For example::
 
-        @cachedInlineCallbacks()
+        @cachedInlineCallbacks(cache_context=True)
         def foo(self, key, cache_context):
             r1 = yield self.bar1(key, cache_context=cache_context)
             r2 = yield self.bar2(key, cache_context=cache_context)
@@ -154,7 +154,7 @@ class CacheDescriptor(object):
 
     """
     def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
-                 inlineCallbacks=False):
+                 inlineCallbacks=False, cache_context=False):
         max_entries = int(max_entries * CACHE_SIZE_FACTOR)
 
         self.orig = orig
@@ -171,15 +171,28 @@ class CacheDescriptor(object):
         all_args = inspect.getargspec(orig)
         self.arg_names = all_args.args[1:num_args + 1]
 
-        if "cache_context" in self.arg_names:
-            self.arg_names.remove("cache_context")
+        if "cache_context" in all_args.args:
+            if not cache_context:
+                raise ValueError(
+                    "Cannot have a 'cache_context' arg without setting"
+                    " cache_context=True"
+                )
+            try:
+                self.arg_names.remove("cache_context")
+            except ValueError:
+                pass
+        elif cache_context:
+            raise ValueError(
+                "Cannot have cache_context=True without having an arg"
+                " named `cache_context`"
+            )
 
-        self.add_cache_context = "cache_context" in all_args.args
+        self.add_cache_context = cache_context
 
         if len(self.arg_names) < self.num_args:
             raise Exception(
                 "Not enough explicit positional arguments to key off of for %r."
-                " (@cached cannot key off of *args or **kwars)"
+                " (@cached cannot key off of *args or **kwargs)"
                 % (orig.__name__,)
             )
 
@@ -193,12 +206,16 @@ class CacheDescriptor(object):
 
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
+            # If we're passed a cache_context then we'll want to call its invalidate()
+            # whenever we are invalidated
             cache_context = kwargs.pop("cache_context", None)
             if cache_context:
                 context_callback = cache_context.invalidate
             else:
                 context_callback = None
 
+            # Add our own `cache_context` to argument list if the wrapped function
+            # has asked for one
             self_context = _CacheContext(cache, None)
             if self.add_cache_context:
                 kwargs["cache_context"] = self_context
@@ -414,22 +431,24 @@ class _CacheContext(object):
         self.cache.invalidate(self.key)
 
 
-def cached(max_entries=1000, num_args=1, tree=False):
+def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
         tree=tree,
+        cache_context=cache_context,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
         tree=tree,
         inlineCallbacks=True,
+        cache_context=cache_context,
     )
 
 
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index ed074ce9ec..eab0c8d219 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -211,7 +211,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
                 callcount[0] += 1
                 return key
 
-            @cached()
+            @cached(cache_context=True)
             def func2(self, key, cache_context):
                 callcount2[0] += 1
                 return self.func(key, cache_context=cache_context)
@@ -244,7 +244,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
                 callcount[0] += 1
                 return key
 
-            @cached()
+            @cached(cache_context=True)
             def func2(self, key, cache_context):
                 callcount2[0] += 1
                 return self.func(key, cache_context=cache_context)
-- 
cgit 1.5.1


From c0d7d9d6429584f51a8174a331e72a894009f3c8 Mon Sep 17 00:00:00 2001
From: Erik Johnston <erik@matrix.org>
Date: Fri, 19 Aug 2016 15:13:58 +0100
Subject: Rename to on_invalidate

---
 synapse/storage/push_rule.py       |  6 +++---
 synapse/util/caches/descriptors.py | 26 ++++++++++----------------
 tests/storage/test__base.py        |  4 ++--
 3 files changed, 15 insertions(+), 21 deletions(-)

(limited to 'synapse/storage')

diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 247dd15694..78334a98cf 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -156,14 +156,14 @@ class PushRuleStore(SQLBaseStore):
         # users in the room who have pushers need to get push rules run because
         # that's how their pushers work
         if_users_with_pushers = yield self.get_if_users_have_pushers(
-            local_users_in_room, cache_context=cache_context,
+            local_users_in_room, on_invalidate=cache_context.invalidate,
         )
         user_ids = set(
             uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
         )
 
         users_with_receipts = yield self.get_users_with_read_receipts_in_room(
-            room_id, cache_context=cache_context,
+            room_id, on_invalidate=cache_context.invalidate,
         )
 
         # any users with pushers must be ours: they have pushers
@@ -172,7 +172,7 @@ class PushRuleStore(SQLBaseStore):
                 user_ids.add(uid)
 
         rules_by_user = yield self.bulk_get_push_rules(
-            user_ids, cache_context=cache_context
+            user_ids, on_invalidate=cache_context.invalidate,
         )
 
         rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index e7a74d3da8..e93ff40dc0 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -148,8 +148,8 @@ class CacheDescriptor(object):
 
         @cachedInlineCallbacks(cache_context=True)
         def foo(self, key, cache_context):
-            r1 = yield self.bar1(key, cache_context=cache_context)
-            r2 = yield self.bar2(key, cache_context=cache_context)
+            r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
+            r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
             defer.returnValue(r1 + r2)
 
     """
@@ -208,11 +208,7 @@ class CacheDescriptor(object):
         def wrapped(*args, **kwargs):
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
-            cache_context = kwargs.pop("cache_context", None)
-            if cache_context:
-                context_callback = cache_context.invalidate
-            else:
-                context_callback = None
+            invalidate_callback = kwargs.pop("on_invalidate", None)
 
             # Add our own `cache_context` to argument list if the wrapped function
             # has asked for one
@@ -226,7 +222,7 @@ class CacheDescriptor(object):
             self_context.key = cache_key
 
             try:
-                cached_result_d = cache.get(cache_key, callback=context_callback)
+                cached_result_d = cache.get(cache_key, callback=invalidate_callback)
 
                 observer = cached_result_d.observe()
                 if DEBUG_CACHES:
@@ -263,7 +259,7 @@ class CacheDescriptor(object):
                 ret.addErrback(onErr)
 
                 ret = ObservableDeferred(ret, consumeErrors=True)
-                cache.update(sequence, cache_key, ret, callback=context_callback)
+                cache.update(sequence, cache_key, ret, callback=invalidate_callback)
 
                 return preserve_context_over_deferred(ret.observe())
 
@@ -332,11 +328,9 @@ class CacheListDescriptor(object):
 
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
-            cache_context = kwargs.pop("cache_context", None)
-            if cache_context:
-                context_callback = cache_context.invalidate
-            else:
-                context_callback = None
+            # If we're passed a cache_context then we'll want to call its invalidate()
+            # whenever we are invalidated
+            invalidate_callback = kwargs.pop("on_invalidate", None)
 
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
             keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
@@ -352,7 +346,7 @@ class CacheListDescriptor(object):
                 key[self.list_pos] = arg
 
                 try:
-                    res = cache.get(tuple(key), callback=context_callback)
+                    res = cache.get(tuple(key), callback=invalidate_callback)
                     if not res.has_succeeded():
                         res = res.observe()
                         res.addCallback(lambda r, arg: (arg, r), arg)
@@ -388,7 +382,7 @@ class CacheListDescriptor(object):
                     key[self.list_pos] = arg
                     cache.update(
                         sequence, tuple(key), observer,
-                        callback=context_callback
+                        callback=invalidate_callback
                     )
 
                     def invalidate(f, key):
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index eab0c8d219..4fc3639de0 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -214,7 +214,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
             @cached(cache_context=True)
             def func2(self, key, cache_context):
                 callcount2[0] += 1
-                return self.func(key, cache_context=cache_context)
+                return self.func(key, on_invalidate=cache_context.invalidate)
 
         a = A()
         yield a.func2("foo")
@@ -247,7 +247,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
             @cached(cache_context=True)
             def func2(self, key, cache_context):
                 callcount2[0] += 1
-                return self.func(key, cache_context=cache_context)
+                return self.func(key, on_invalidate=cache_context.invalidate)
 
         a = A()
         yield a.func2("foo")
-- 
cgit 1.5.1