summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/util/caches/descriptors.py20
-rw-r--r--synapse/util/caches/lrucache.py16
-rw-r--r--tests/storage/test__base.py48
-rw-r--r--tests/util/test_lrucache.py40
4 files changed, 104 insertions, 20 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index e93ff40dc0..8dba61d49f 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -25,6 +25,7 @@ from synapse.util.logcontext import (
 from . import DEBUG_CACHES, register_cache
 
 from twisted.internet import defer
+from collections import namedtuple
 
 import os
 import functools
@@ -210,16 +211,17 @@ class CacheDescriptor(object):
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
 
-            # Add our own `cache_context` to argument list if the wrapped function
-            # has asked for one
-            self_context = _CacheContext(cache, None)
+            # Add temp cache_context so inspect.getcallargs doesn't explode
             if self.add_cache_context:
-                kwargs["cache_context"] = self_context
+                kwargs["cache_context"] = None
 
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
             cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
 
-            self_context.key = cache_key
+            # Add our own `cache_context` to argument list if the wrapped function
+            # has asked for one
+            if self.add_cache_context:
+                kwargs["cache_context"] = _CacheContext(cache, cache_key)
 
             try:
                 cached_result_d = cache.get(cache_key, callback=invalidate_callback)
@@ -414,13 +416,7 @@ class CacheListDescriptor(object):
         return wrapped
 
 
-class _CacheContext(object):
-    __slots__ = ["cache", "key"]
-
-    def __init__(self, cache, key):
-        self.cache = cache
-        self.key = key
-
+class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
     def invalidate(self):
         self.cache.invalidate(self.key)
 
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index a5a827b4d1..9c4c679175 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -32,7 +32,7 @@ def enumerate_leaves(node, depth):
 class _Node(object):
     __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
 
-    def __init__(self, prev_node, next_node, key, value, callbacks=[]):
+    def __init__(self, prev_node, next_node, key, value, callbacks=set()):
         self.prev_node = prev_node
         self.next_node = next_node
         self.key = key
@@ -66,7 +66,7 @@ class LruCache(object):
 
             return inner
 
-        def add_node(key, value, callbacks=[]):
+        def add_node(key, value, callbacks=set()):
             prev_node = list_root
             next_node = prev_node.next_node
             node = _Node(prev_node, next_node, key, value, callbacks)
@@ -94,7 +94,7 @@ class LruCache(object):
 
             for cb in node.callbacks:
                 cb()
-            node.callbacks = []
+            node.callbacks.clear()
 
         @synchronized
         def cache_get(key, default=None, callback=None):
@@ -102,7 +102,7 @@ class LruCache(object):
             if node is not None:
                 move_node_to_front(node)
                 if callback:
-                    node.callbacks.append(callback)
+                    node.callbacks.add(callback)
                 return node.value
             else:
                 return default
@@ -114,18 +114,18 @@ class LruCache(object):
                 if value != node.value:
                     for cb in node.callbacks:
                         cb()
-                    node.callbacks = []
+                    node.callbacks.clear()
 
                 if callback:
-                    node.callbacks.append(callback)
+                    node.callbacks.add(callback)
 
                 move_node_to_front(node)
                 node.value = value
             else:
                 if callback:
-                    callbacks = [callback]
+                    callbacks = set([callback])
                 else:
-                    callbacks = []
+                    callbacks = set()
                 add_node(key, value, callbacks)
                 if len(cache) > max_size:
                     todelete = list_root.prev_node
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 4fc3639de0..ab6095564a 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -17,6 +17,8 @@
 from tests import unittest
 from twisted.internet import defer
 
+from mock import Mock
+
 from synapse.util.async import ObservableDeferred
 
 from synapse.util.caches.descriptors import Cache, cached
@@ -265,3 +267,49 @@ class CacheDecoratorTestCase(unittest.TestCase):
 
         self.assertEquals(callcount[0], 4)
         self.assertEquals(callcount2[0], 3)
+
+    @defer.inlineCallbacks
+    def test_double_get(self):
+        callcount = [0]
+        callcount2 = [0]
+
+        class A(object):
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+            @cached(cache_context=True)
+            def func2(self, key, cache_context):
+                callcount2[0] += 1
+                return self.func(key, on_invalidate=cache_context.invalidate)
+
+        a = A()
+        a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 1)
+
+        a.func2.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
+
+        yield a.func2("foo")
+        a.func2.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
+
+        self.assertEquals(callcount[0], 1)
+        self.assertEquals(callcount2[0], 2)
+
+        a.func.invalidate(("foo",))
+        self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
+        yield a.func("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 2)
+
+        yield a.func2("foo")
+
+        self.assertEquals(callcount[0], 2)
+        self.assertEquals(callcount2[0], 3)
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index bacec2f465..1eba5b535e 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -50,6 +50,8 @@ class LruCacheTestCase(unittest.TestCase):
         self.assertEquals(cache.get("key"), 1)
         self.assertEquals(cache.setdefault("key", 2), 1)
         self.assertEquals(cache.get("key"), 1)
+        cache["key"] = 2  # Make sure overriding works.
+        self.assertEquals(cache.get("key"), 2)
 
     def test_pop(self):
         cache = LruCache(1)
@@ -84,6 +86,44 @@ class LruCacheTestCase(unittest.TestCase):
 
 
 class LruCacheCallbacksTestCase(unittest.TestCase):
+    def test_get(self):
+        m = Mock()
+        cache = LruCache(1)
+
+        cache.set("key", "value")
+        self.assertFalse(m.called)
+
+        cache.get("key", callback=m)
+        self.assertFalse(m.called)
+
+        cache.get("key", "value")
+        self.assertFalse(m.called)
+
+        cache.set("key", "value2")
+        self.assertEquals(m.call_count, 1)
+
+        cache.set("key", "value")
+        self.assertEquals(m.call_count, 1)
+
+    def test_multi_get(self):
+        m = Mock()
+        cache = LruCache(1)
+
+        cache.set("key", "value")
+        self.assertFalse(m.called)
+
+        cache.get("key", callback=m)
+        self.assertFalse(m.called)
+
+        cache.get("key", callback=m)
+        self.assertFalse(m.called)
+
+        cache.set("key", "value2")
+        self.assertEquals(m.call_count, 1)
+
+        cache.set("key", "value")
+        self.assertEquals(m.call_count, 1)
+
     def test_set(self):
         m = Mock()
         cache = LruCache(1)