summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erikj@jki.re>2016-08-19 16:29:58 +0100
committerGitHub <noreply@github.com>2016-08-19 16:29:58 +0100
commite6784daf0777c24a8efd2fe0c2d72f730369ad2b (patch)
treebdf2f54eaedc0d8ec3c5ecb2efde6e35521ae179
parentMerge pull request #1029 from matrix-org/erikj/appservice_stream (diff)
parentEnsure invalidation list does not grow unboundedly (diff)
downloadsynapse-e6784daf0777c24a8efd2fe0c2d72f730369ad2b.tar.xz
Merge pull request #1030 from matrix-org/erikj/cache_contexts
Add concept of cache contexts
-rw-r--r--synapse/push/action_generator.py2
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py41
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/event_push_actions.py2
-rw-r--r--synapse/storage/push_rule.py60
-rw-r--r--synapse/storage/pusher.py2
-rw-r--r--synapse/storage/receipts.py2
-rw-r--r--synapse/storage/signatures.py2
-rw-r--r--synapse/storage/state.py4
-rw-r--r--synapse/util/caches/descriptors.py117
-rw-r--r--synapse/util/caches/lrucache.py39
-rw-r--r--synapse/util/caches/treecache.py3
-rw-r--r--tests/storage/test__base.py116
-rw-r--r--tests/util/test_lrucache.py153
14 files changed, 458 insertions, 87 deletions
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index b2c94bfaac..ed2ccc4dfb 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -40,7 +40,7 @@ class ActionGenerator:
     def handle_push_actions_for_event(self, event, context):
         with Measure(self.clock, "evaluator_for_event"):
             bulk_evaluator = yield evaluator_for_event(
-                event, self.hs, self.store, context.current_state
+                event, self.hs, self.store, context.state_group, context.current_state
             )
 
         with Measure(self.clock, "action_for_event_by_user"):
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 756e5da513..004eded61f 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -36,35 +36,11 @@ def _get_rules(room_id, user_ids, store):
 
 
 @defer.inlineCallbacks
-def evaluator_for_event(event, hs, store, current_state):
-    room_id = event.room_id
-    # We also will want to generate notifs for other people in the room so
-    # their unread countss are correct in the event stream, but to avoid
-    # generating them for bot / AS users etc, we only do so for people who've
-    # sent a read receipt into the room.
-
-    local_users_in_room = set(
-        e.state_key for e in current_state.values()
-        if e.type == EventTypes.Member and e.membership == Membership.JOIN
-        and hs.is_mine_id(e.state_key)
+def evaluator_for_event(event, hs, store, state_group, current_state):
+    rules_by_user = yield store.bulk_get_push_rules_for_room(
+        event.room_id, state_group, current_state
     )
 
-    # users in the room who have pushers need to get push rules run because
-    # that's how their pushers work
-    if_users_with_pushers = yield store.get_if_users_have_pushers(
-        local_users_in_room
-    )
-    user_ids = set(
-        uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
-    )
-
-    users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
-
-    # any users with pushers must be ours: they have pushers
-    for uid in users_with_receipts:
-        if uid in local_users_in_room:
-            user_ids.add(uid)
-
     # if this event is an invite event, we may need to run rules for the user
     # who's been invited, otherwise they won't get told they've been invited
     if event.type == 'm.room.member' and event.content['membership'] == 'invite':
@@ -72,12 +48,12 @@ def evaluator_for_event(event, hs, store, current_state):
         if invited_user and hs.is_mine_id(invited_user):
             has_pusher = yield store.user_has_pusher(invited_user)
             if has_pusher:
-                user_ids.add(invited_user)
-
-    rules_by_user = yield _get_rules(room_id, user_ids, store)
+                rules_by_user[invited_user] = yield store.get_push_rules_for_user(
+                    invited_user
+                )
 
     defer.returnValue(BulkPushRuleEvaluator(
-        room_id, rules_by_user, user_ids, store
+        event.room_id, rules_by_user, store
     ))
 
 
@@ -90,10 +66,9 @@ class BulkPushRuleEvaluator:
     the same logic to run the actual rules, but could be optimised further
     (see https://matrix.org/jira/browse/SYN-562)
     """
-    def __init__(self, room_id, rules_by_user, users_in_room, store):
+    def __init__(self, room_id, rules_by_user, store):
         self.room_id = room_id
         self.rules_by_user = rules_by_user
-        self.users_in_room = users_in_room
         self.store = store
 
     @defer.inlineCallbacks
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 029f6612e6..49fa8614f2 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -166,7 +166,7 @@ class SQLBaseStore(object):
         self._txn_perf_counters = PerformanceCounters()
         self._get_event_counters = PerformanceCounters()
 
-        self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
+        self._get_event_cache = Cache("*getEvent*", keylen=3,
                                       max_entries=hs.config.event_cache_size)
 
         self._state_group_cache = DictionaryCache(
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index df4000d0da..c65c9c9c47 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -56,7 +56,7 @@ class EventPushActionsStore(SQLBaseStore):
             )
         self._simple_insert_many_txn(txn, "event_push_actions", values)
 
-    @cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000)
+    @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
     def get_unread_event_push_actions_by_room_for_user(
             self, room_id, user_id, last_read_event_id
     ):
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 8183b7f1b0..78334a98cf 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -16,6 +16,7 @@
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
 from synapse.push.baserules import list_with_base_rules
+from synapse.api.constants import EventTypes, Membership
 from twisted.internet import defer
 
 import logging
@@ -48,7 +49,7 @@ def _load_rules(rawrules, enabled_map):
 
 
 class PushRuleStore(SQLBaseStore):
-    @cachedInlineCallbacks(lru=True)
+    @cachedInlineCallbacks()
     def get_push_rules_for_user(self, user_id):
         rows = yield self._simple_select_list(
             table="push_rules",
@@ -72,7 +73,7 @@ class PushRuleStore(SQLBaseStore):
 
         defer.returnValue(rules)
 
-    @cachedInlineCallbacks(lru=True)
+    @cachedInlineCallbacks()
     def get_push_rules_enabled_for_user(self, user_id):
         results = yield self._simple_select_list(
             table="push_rules_enable",
@@ -123,6 +124,61 @@ class PushRuleStore(SQLBaseStore):
 
         defer.returnValue(results)
 
+    def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
+
+    @cachedInlineCallbacks(num_args=2, cache_context=True)
+    def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
+                                      cache_context):
+        # We don't use `state_group`, its there so that we can cache based
+        # on it. However, its important that its never None, since two current_state's
+        # with a state_group of None are likely to be different.
+        # See bulk_get_push_rules_for_room for how we work around this.
+        assert state_group is not None
+
+        # We also will want to generate notifs for other people in the room so
+        # their unread countss are correct in the event stream, but to avoid
+        # generating them for bot / AS users etc, we only do so for people who've
+        # sent a read receipt into the room.
+        local_users_in_room = set(
+            e.state_key for e in current_state.values()
+            if e.type == EventTypes.Member and e.membership == Membership.JOIN
+            and self.hs.is_mine_id(e.state_key)
+        )
+
+        # users in the room who have pushers need to get push rules run because
+        # that's how their pushers work
+        if_users_with_pushers = yield self.get_if_users_have_pushers(
+            local_users_in_room, on_invalidate=cache_context.invalidate,
+        )
+        user_ids = set(
+            uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
+        )
+
+        users_with_receipts = yield self.get_users_with_read_receipts_in_room(
+            room_id, on_invalidate=cache_context.invalidate,
+        )
+
+        # any users with pushers must be ours: they have pushers
+        for uid in users_with_receipts:
+            if uid in local_users_in_room:
+                user_ids.add(uid)
+
+        rules_by_user = yield self.bulk_get_push_rules(
+            user_ids, on_invalidate=cache_context.invalidate,
+        )
+
+        rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
+
+        defer.returnValue(rules_by_user)
+
     @cachedList(cached_method_name="get_push_rules_enabled_for_user",
                 list_name="user_ids", num_args=1, inlineCallbacks=True)
     def bulk_get_push_rules_enabled(self, user_ids):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index a7d7c54d7e..8f5f8f24a9 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -135,7 +135,7 @@ class PusherStore(SQLBaseStore):
             "get_all_updated_pushers", get_all_updated_pushers_txn
         )
 
-    @cachedInlineCallbacks(lru=True, num_args=1, max_entries=15000)
+    @cachedInlineCallbacks(num_args=1, max_entries=15000)
     def get_if_user_has_pusher(self, user_id):
         result = yield self._simple_select_many_batch(
             table='pushers',
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 8c26f39fbb..3ad916103f 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -120,7 +120,7 @@ class ReceiptsStore(SQLBaseStore):
 
         defer.returnValue([ev for res in results.values() for ev in res])
 
-    @cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True)
+    @cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True)
     def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
         """Get receipts for a single room for sending to clients.
 
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index ea6823f18d..e1dca927d7 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -25,7 +25,7 @@ from synapse.util.caches.descriptors import cached, cachedList
 class SignatureStore(SQLBaseStore):
     """Persistence for event signatures and hashes"""
 
-    @cached(lru=True)
+    @cached()
     def get_event_reference_hash(self, event_id):
         return self._get_event_reference_hashes_txn(event_id)
 
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 5b743db67a..0e8fa93e1f 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -174,7 +174,7 @@ class StateStore(SQLBaseStore):
             return [r[0] for r in results]
         return self.runInteraction("get_current_state_for_key", f)
 
-    @cached(num_args=2, lru=True, max_entries=1000)
+    @cached(num_args=2, max_entries=1000)
     def _get_state_group_from_group(self, group, types):
         raise NotImplementedError()
 
@@ -272,7 +272,7 @@ class StateStore(SQLBaseStore):
         state_map = yield self.get_state_for_events([event_id], types)
         defer.returnValue(state_map[event_id])
 
-    @cached(num_args=2, lru=True, max_entries=10000)
+    @cached(num_args=2, max_entries=10000)
     def _get_state_group_for_event(self, room_id, event_id):
         return self._simple_select_one_onecol(
             table="event_to_state_groups",
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index f31dfb22b7..8dba61d49f 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -25,8 +25,7 @@ from synapse.util.logcontext import (
 from . import DEBUG_CACHES, register_cache
 
 from twisted.internet import defer
-
-from collections import OrderedDict
+from collections import namedtuple
 
 import os
 import functools
@@ -54,16 +53,11 @@ class Cache(object):
         "metrics",
     )
 
-    def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
-        if lru:
-            cache_type = TreeCache if tree else dict
-            self.cache = LruCache(
-                max_size=max_entries, keylen=keylen, cache_type=cache_type
-            )
-            self.max_entries = None
-        else:
-            self.cache = OrderedDict()
-            self.max_entries = max_entries
+    def __init__(self, name, max_entries=1000, keylen=1, tree=False):
+        cache_type = TreeCache if tree else dict
+        self.cache = LruCache(
+            max_size=max_entries, keylen=keylen, cache_type=cache_type
+        )
 
         self.name = name
         self.keylen = keylen
@@ -81,8 +75,8 @@ class Cache(object):
                     "Cache objects can only be accessed from the main thread"
                 )
 
-    def get(self, key, default=_CacheSentinel):
-        val = self.cache.get(key, _CacheSentinel)
+    def get(self, key, default=_CacheSentinel, callback=None):
+        val = self.cache.get(key, _CacheSentinel, callback=callback)
         if val is not _CacheSentinel:
             self.metrics.inc_hits()
             return val
@@ -94,19 +88,15 @@ class Cache(object):
         else:
             return default
 
-    def update(self, sequence, key, value):
+    def update(self, sequence, key, value, callback=None):
         self.check_thread()
         if self.sequence == sequence:
             # Only update the cache if the caches sequence number matches the
             # number that the cache had before the SELECT was started (SYN-369)
-            self.prefill(key, value)
-
-    def prefill(self, key, value):
-        if self.max_entries is not None:
-            while len(self.cache) >= self.max_entries:
-                self.cache.popitem(last=False)
+            self.prefill(key, value, callback=callback)
 
-        self.cache[key] = value
+    def prefill(self, key, value, callback=None):
+        self.cache.set(key, value, callback=callback)
 
     def invalidate(self, key):
         self.check_thread()
@@ -151,9 +141,21 @@ class CacheDescriptor(object):
     The wrapped function has another additional callable, called "prefill",
     which can be used to insert values into the cache specifically, without
     calling the calculation function.
+
+    Cached functions can be "chained" (i.e. a cached function can call other cached
+    functions and get appropriately invalidated when they called caches are
+    invalidated) by adding a special "cache_context" argument to the function
+    and passing that as a kwarg to all caches called. For example::
+
+        @cachedInlineCallbacks(cache_context=True)
+        def foo(self, key, cache_context):
+            r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
+            r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
+            defer.returnValue(r1 + r2)
+
     """
-    def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
-                 inlineCallbacks=False):
+    def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
+                 inlineCallbacks=False, cache_context=False):
         max_entries = int(max_entries * CACHE_SIZE_FACTOR)
 
         self.orig = orig
@@ -165,15 +167,33 @@ class CacheDescriptor(object):
 
         self.max_entries = max_entries
         self.num_args = num_args
-        self.lru = lru
         self.tree = tree
 
-        self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
+        all_args = inspect.getargspec(orig)
+        self.arg_names = all_args.args[1:num_args + 1]
+
+        if "cache_context" in all_args.args:
+            if not cache_context:
+                raise ValueError(
+                    "Cannot have a 'cache_context' arg without setting"
+                    " cache_context=True"
+                )
+            try:
+                self.arg_names.remove("cache_context")
+            except ValueError:
+                pass
+        elif cache_context:
+            raise ValueError(
+                "Cannot have cache_context=True without having an arg"
+                " named `cache_context`"
+            )
+
+        self.add_cache_context = cache_context
 
         if len(self.arg_names) < self.num_args:
             raise Exception(
                 "Not enough explicit positional arguments to key off of for %r."
-                " (@cached cannot key off of *args or **kwars)"
+                " (@cached cannot key off of *args or **kwargs)"
                 % (orig.__name__,)
             )
 
@@ -182,16 +202,29 @@ class CacheDescriptor(object):
             name=self.orig.__name__,
             max_entries=self.max_entries,
             keylen=self.num_args,
-            lru=self.lru,
             tree=self.tree,
         )
 
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
+            # If we're passed a cache_context then we'll want to call its invalidate()
+            # whenever we are invalidated
+            invalidate_callback = kwargs.pop("on_invalidate", None)
+
+            # Add temp cache_context so inspect.getcallargs doesn't explode
+            if self.add_cache_context:
+                kwargs["cache_context"] = None
+
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
             cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+
+            # Add our own `cache_context` to argument list if the wrapped function
+            # has asked for one
+            if self.add_cache_context:
+                kwargs["cache_context"] = _CacheContext(cache, cache_key)
+
             try:
-                cached_result_d = cache.get(cache_key)
+                cached_result_d = cache.get(cache_key, callback=invalidate_callback)
 
                 observer = cached_result_d.observe()
                 if DEBUG_CACHES:
@@ -228,7 +261,7 @@ class CacheDescriptor(object):
                 ret.addErrback(onErr)
 
                 ret = ObservableDeferred(ret, consumeErrors=True)
-                cache.update(sequence, cache_key, ret)
+                cache.update(sequence, cache_key, ret, callback=invalidate_callback)
 
                 return preserve_context_over_deferred(ret.observe())
 
@@ -297,6 +330,10 @@ class CacheListDescriptor(object):
 
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
+            # If we're passed a cache_context then we'll want to call its invalidate()
+            # whenever we are invalidated
+            invalidate_callback = kwargs.pop("on_invalidate", None)
+
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
             keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
             list_args = arg_dict[self.list_name]
@@ -311,7 +348,7 @@ class CacheListDescriptor(object):
                 key[self.list_pos] = arg
 
                 try:
-                    res = cache.get(tuple(key))
+                    res = cache.get(tuple(key), callback=invalidate_callback)
                     if not res.has_succeeded():
                         res = res.observe()
                         res.addCallback(lambda r, arg: (arg, r), arg)
@@ -345,7 +382,10 @@ class CacheListDescriptor(object):
 
                     key = list(keyargs)
                     key[self.list_pos] = arg
-                    cache.update(sequence, tuple(key), observer)
+                    cache.update(
+                        sequence, tuple(key), observer,
+                        callback=invalidate_callback
+                    )
 
                     def invalidate(f, key):
                         cache.invalidate(key)
@@ -376,24 +416,29 @@ class CacheListDescriptor(object):
         return wrapped
 
 
-def cached(max_entries=1000, num_args=1, lru=True, tree=False):
+class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
+    def invalidate(self):
+        self.cache.invalidate(self.key)
+
+
+def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
-        lru=lru,
         tree=tree,
+        cache_context=cache_context,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
         num_args=num_args,
-        lru=lru,
         tree=tree,
         inlineCallbacks=True,
+        cache_context=cache_context,
     )
 
 
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index f9df445a8d..9c4c679175 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -30,13 +30,14 @@ def enumerate_leaves(node, depth):
 
 
 class _Node(object):
-    __slots__ = ["prev_node", "next_node", "key", "value"]
+    __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
 
-    def __init__(self, prev_node, next_node, key, value):
+    def __init__(self, prev_node, next_node, key, value, callbacks=set()):
         self.prev_node = prev_node
         self.next_node = next_node
         self.key = key
         self.value = value
+        self.callbacks = callbacks
 
 
 class LruCache(object):
@@ -44,6 +45,9 @@ class LruCache(object):
     Least-recently-used cache.
     Supports del_multi only if cache_type=TreeCache
     If cache_type=TreeCache, all keys must be tuples.
+
+    Can also set callbacks on objects when getting/setting which are fired
+    when that key gets invalidated/evicted.
     """
     def __init__(self, max_size, keylen=1, cache_type=dict):
         cache = cache_type()
@@ -62,10 +66,10 @@ class LruCache(object):
 
             return inner
 
-        def add_node(key, value):
+        def add_node(key, value, callbacks=set()):
             prev_node = list_root
             next_node = prev_node.next_node
-            node = _Node(prev_node, next_node, key, value)
+            node = _Node(prev_node, next_node, key, value, callbacks)
             prev_node.next_node = node
             next_node.prev_node = node
             cache[key] = node
@@ -88,23 +92,41 @@ class LruCache(object):
             prev_node.next_node = next_node
             next_node.prev_node = prev_node
 
+            for cb in node.callbacks:
+                cb()
+            node.callbacks.clear()
+
         @synchronized
-        def cache_get(key, default=None):
+        def cache_get(key, default=None, callback=None):
             node = cache.get(key, None)
             if node is not None:
                 move_node_to_front(node)
+                if callback:
+                    node.callbacks.add(callback)
                 return node.value
             else:
                 return default
 
         @synchronized
-        def cache_set(key, value):
+        def cache_set(key, value, callback=None):
             node = cache.get(key, None)
             if node is not None:
+                if value != node.value:
+                    for cb in node.callbacks:
+                        cb()
+                    node.callbacks.clear()
+
+                if callback:
+                    node.callbacks.add(callback)
+
                 move_node_to_front(node)
                 node.value = value
             else:
-                add_node(key, value)
+                if callback:
+                    callbacks = set([callback])
+                else:
+                    callbacks = set()
+                add_node(key, value, callbacks)
                 if len(cache) > max_size:
                     todelete = list_root.prev_node
                     delete_node(todelete)
@@ -148,6 +170,9 @@ class LruCache(object):
         def cache_clear():
             list_root.next_node = list_root
             list_root.prev_node = list_root
+            for node in cache.values():
+                for cb in node.callbacks:
+                    cb()
             cache.clear()
 
         @synchronized
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 03bc1401b7..c31585aea3 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -64,6 +64,9 @@ class TreeCache(object):
         self.size -= cnt
         return popped
 
+    def values(self):
+        return [e.value for e in self.root.values()]
+
     def __len__(self):
         return self.size
 
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 96b7dba5fe..ab6095564a 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -17,6 +17,8 @@
 from tests import unittest
 from twisted.internet import defer
 
+from mock import Mock
+
 from synapse.util.async import ObservableDeferred
 
 from synapse.util.caches.descriptors import Cache, cached
@@ -72,7 +74,7 @@ class CacheTestCase(unittest.TestCase):
         cache.get(3)
 
     def test_eviction_lru(self):
-        cache = Cache("test", max_entries=2, lru=True)
+        cache = Cache("test", max_entries=2)
 
         cache.prefill(1, "one")
         cache.prefill(2, "two")
@@ -199,3 +201,115 @@ class CacheDecoratorTestCase(unittest.TestCase):
 
         self.assertEquals(a.func("foo").result, d.result)
         self.assertEquals(callcount[0], 0)
+
+    @defer.inlineCallbacks
+    def test_invalidate_context(self):
+        callcount = [0]
+        callcount2 = [0]
+
+        class A(object):
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+            @cached(cache_context=True)
+            def func2(self, key, cache_context):
+                callcount2[0] += 1
+                return self.func(key, on_invalidate=cache_context.invalidate)
+
+        a = A()
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 1)
+
+        a.func.invalidate(("foo",))
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 1)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+    @defer.inlineCallbacks
+    def test_eviction_context(self):
+        callcount = [0]
+        callcount2 = [0]
+
+        class A(object):
+            @cached(max_entries=2)
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+            @cached(cache_context=True)
+            def func2(self, key, cache_context):
+                callcount2[0] += 1
+                return self.func(key, on_invalidate=cache_context.invalidate)
+
+        a = A()
+        yield a.func2("foo")
+        yield a.func2("foo2")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func("foo3")
+
+        self.assertEquals(callcount[0], 3)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 4)
+        self.assertEquals(callcount2[0], 3)
+
+    @defer.inlineCallbacks
+    def test_double_get(self):
+        callcount = [0]
+        callcount2 = [0]
+
+        class A(object):
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+            @cached(cache_context=True)
+            def func2(self, key, cache_context):
+                callcount2[0] += 1
+                return self.func(key, on_invalidate=cache_context.invalidate)
+
+        a = A()
+        a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 1)
+
+        a.func2.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
+
+        yield a.func2("foo")
+        a.func2.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 2)
+
+        a.func.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 3)
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index bab366fb7f..1eba5b535e 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -19,6 +19,8 @@ from .. import unittest
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache
 
+from mock import Mock
+
 
 class LruCacheTestCase(unittest.TestCase):
 
@@ -48,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase):
         self.assertEquals(cache.get("key"), 1)
         self.assertEquals(cache.setdefault("key", 2), 1)
         self.assertEquals(cache.get("key"), 1)
+        cache["key"] = 2  # Make sure overriding works.
+        self.assertEquals(cache.get("key"), 2)
 
     def test_pop(self):
         cache = LruCache(1)
@@ -79,3 +83,152 @@ class LruCacheTestCase(unittest.TestCase):
         cache["key"] = 1
         cache.clear()
         self.assertEquals(len(cache), 0)
+
+
+class LruCacheCallbacksTestCase(unittest.TestCase):
+    def test_get(self):
+        m = Mock()
+        cache = LruCache(1)
+
+        cache.set("key", "value")
+        self.assertFalse(m.called)
+
+        cache.get("key", callback=m)
+        self.assertFalse(m.called)
+
+        cache.get("key", "value")
+        self.assertFalse(m.called)
+
+        cache.set("key", "value2")
+        self.assertEquals(m.call_count, 1)
+
+        cache.set("key", "value")
+        self.assertEquals(m.call_count, 1)
+
+    def test_multi_get(self):
+        m = Mock()
+        cache = LruCache(1)
+
+        cache.set("key", "value")
+        self.assertFalse(m.called)
+
+        cache.get("key", callback=m)
+        self.assertFalse(m.called)
+
+        cache.get("key", callback=m)
+        self.assertFalse(m.called)
+
+        cache.set("key", "value2")
+        self.assertEquals(m.call_count, 1)
+
+        cache.set("key", "value")
+        self.assertEquals(m.call_count, 1)
+
+    def test_set(self):
+        m = Mock()
+        cache = LruCache(1)
+
+        cache.set("key", "value", m)
+        self.assertFalse(m.called)
+
+        cache.set("key", "value")
+        self.assertFalse(m.called)
+
+        cache.set("key", "value2")
+        self.assertEquals(m.call_count, 1)
+
+        cache.set("key", "value")
+        self.assertEquals(m.call_count, 1)
+
+    def test_pop(self):
+        m = Mock()
+        cache = LruCache(1)
+
+        cache.set("key", "value", m)
+        self.assertFalse(m.called)
+
+        cache.pop("key")
+        self.assertEquals(m.call_count, 1)
+
+        cache.set("key", "value")
+        self.assertEquals(m.call_count, 1)
+
+        cache.pop("key")
+        self.assertEquals(m.call_count, 1)
+
+    def test_del_multi(self):
+        m1 = Mock()
+        m2 = Mock()
+        m3 = Mock()
+        m4 = Mock()
+        cache = LruCache(4, 2, cache_type=TreeCache)
+
+        cache.set(("a", "1"), "value", m1)
+        cache.set(("a", "2"), "value", m2)
+        cache.set(("b", "1"), "value", m3)
+        cache.set(("b", "2"), "value", m4)
+
+        self.assertEquals(m1.call_count, 0)
+        self.assertEquals(m2.call_count, 0)
+        self.assertEquals(m3.call_count, 0)
+        self.assertEquals(m4.call_count, 0)
+
+        cache.del_multi(("a",))
+
+        self.assertEquals(m1.call_count, 1)
+        self.assertEquals(m2.call_count, 1)
+        self.assertEquals(m3.call_count, 0)
+        self.assertEquals(m4.call_count, 0)
+
+    def test_clear(self):
+        m1 = Mock()
+        m2 = Mock()
+        cache = LruCache(5)
+
+        cache.set("key1", "value", m1)
+        cache.set("key2", "value", m2)
+
+        self.assertEquals(m1.call_count, 0)
+        self.assertEquals(m2.call_count, 0)
+
+        cache.clear()
+
+        self.assertEquals(m1.call_count, 1)
+        self.assertEquals(m2.call_count, 1)
+
+    def test_eviction(self):
+        m1 = Mock(name="m1")
+        m2 = Mock(name="m2")
+        m3 = Mock(name="m3")
+        cache = LruCache(2)
+
+        cache.set("key1", "value", m1)
+        cache.set("key2", "value", m2)
+
+        self.assertEquals(m1.call_count, 0)
+        self.assertEquals(m2.call_count, 0)
+        self.assertEquals(m3.call_count, 0)
+
+        cache.set("key3", "value", m3)
+
+        self.assertEquals(m1.call_count, 1)
+        self.assertEquals(m2.call_count, 0)
+        self.assertEquals(m3.call_count, 0)
+
+        cache.set("key3", "value")
+
+        self.assertEquals(m1.call_count, 1)
+        self.assertEquals(m2.call_count, 0)
+        self.assertEquals(m3.call_count, 0)
+
+        cache.get("key2")
+
+        self.assertEquals(m1.call_count, 1)
+        self.assertEquals(m2.call_count, 0)
+        self.assertEquals(m3.call_count, 0)
+
+        cache.set("key1", "value", m1)
+
+        self.assertEquals(m1.call_count, 1)
+        self.assertEquals(m2.call_count, 0)
+        self.assertEquals(m3.call_count, 1)