summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-06-03 14:45:17 +0100
committerErik Johnston <erik@matrix.org>2015-06-03 14:45:17 +0100
commitd8866d72771fa04bd58a77a03aded45bf3ff2293 (patch)
tree4225c55da7f1f5af6db4c8a31ba1ab1f40c8b0aa
parentLog where a request came from in federation (diff)
downloadsynapse-d8866d72771fa04bd58a77a03aded45bf3ff2293.tar.xz
Caches should be bound to instances.
Before, caches were global and so different instances of the stores
would share caches. This caused problems in the unit tests.
Diffstat (limited to '')
-rw-r--r--synapse/storage/_base.py45
-rw-r--r--tests/storage/test__base.py84
-rw-r--r--tests/storage/test_registration.py2
3 files changed, 81 insertions, 50 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 39884c2afe..8d33def6c6 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -127,7 +127,7 @@ class Cache(object):
         self.cache.clear()
 
 
-def cached(max_entries=1000, num_args=1, lru=False):
+class CacheDescriptor(object):
     """ A method decorator that applies a memoizing cache around the function.
 
     The function is presumed to take zero or more arguments, which are used in
@@ -141,25 +141,32 @@ def cached(max_entries=1000, num_args=1, lru=False):
     which can be used to insert values into the cache specifically, without
     calling the calculation function.
     """
-    def wrap(orig):
+    def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
+        self.orig = orig
+
+        self.max_entries = max_entries
+        self.num_args = num_args
+        self.lru = lru
+
+    def __get__(self, obj, objtype=None):
         cache = Cache(
-            name=orig.__name__,
-            max_entries=max_entries,
-            keylen=num_args,
-            lru=lru,
+            name=self.orig.__name__,
+            max_entries=self.max_entries,
+            keylen=self.num_args,
+            lru=self.lru,
         )
 
-        @functools.wraps(orig)
+        @functools.wraps(self.orig)
         @defer.inlineCallbacks
-        def wrapped(self, *keyargs):
+        def wrapped(*keyargs):
             try:
-                cached_result = cache.get(*keyargs)
+                cached_result = cache.get(*keyargs[:self.num_args])
                 if DEBUG_CACHES:
-                    actual_result = yield orig(self, *keyargs)
+                    actual_result = yield self.orig(obj, *keyargs)
                     if actual_result != cached_result:
                         logger.error(
                             "Stale cache entry %s%r: cached: %r, actual %r",
-                            orig.__name__, keyargs,
+                            self.orig.__name__, keyargs,
                             cached_result, actual_result,
                         )
                         raise ValueError("Stale cache entry")
@@ -170,18 +177,28 @@ def cached(max_entries=1000, num_args=1, lru=False):
                 # while the SELECT is executing (SYN-369)
                 sequence = cache.sequence
 
-                ret = yield orig(self, *keyargs)
+                ret = yield self.orig(obj, *keyargs)
 
-                cache.update(sequence, *keyargs + (ret,))
+                cache.update(sequence, *keyargs[:self.num_args] + (ret,))
 
                 defer.returnValue(ret)
 
         wrapped.invalidate = cache.invalidate
         wrapped.invalidate_all = cache.invalidate_all
         wrapped.prefill = cache.prefill
+
+        obj.__dict__[self.orig.__name__] = wrapped
+
         return wrapped
 
-    return wrap
+
+def cached(max_entries=1000, num_args=1, lru=False):
+    return lambda orig: CacheDescriptor(
+        orig,
+        max_entries=max_entries,
+        num_args=num_args,
+        lru=lru
+    )
 
 
 class LoggingTransaction(object):
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 96caf8c4c1..8c3d2952bd 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -96,73 +96,84 @@ class CacheDecoratorTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_passthrough(self):
-        @cached()
-        def func(self, key):
-            return key
+        class A(object):
+            @cached()
+            def func(self, key):
+                return key
 
-        self.assertEquals((yield func(self, "foo")), "foo")
-        self.assertEquals((yield func(self, "bar")), "bar")
+        a = A()
+
+        self.assertEquals((yield a.func("foo")), "foo")
+        self.assertEquals((yield a.func("bar")), "bar")
 
     @defer.inlineCallbacks
     def test_hit(self):
         callcount = [0]
 
-        @cached()
-        def func(self, key):
-            callcount[0] += 1
-            return key
+        class A(object):
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
 
-        yield func(self, "foo")
+        a = A()
+        yield a.func("foo")
 
         self.assertEquals(callcount[0], 1)
 
-        self.assertEquals((yield func(self, "foo")), "foo")
+        self.assertEquals((yield a.func("foo")), "foo")
         self.assertEquals(callcount[0], 1)
 
     @defer.inlineCallbacks
     def test_invalidate(self):
         callcount = [0]
 
-        @cached()
-        def func(self, key):
-            callcount[0] += 1
-            return key
+        class A(object):
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
 
-        yield func(self, "foo")
+        a = A()
+        yield a.func("foo")
 
         self.assertEquals(callcount[0], 1)
 
-        func.invalidate("foo")
+        a.func.invalidate("foo")
 
-        yield func(self, "foo")
+        yield a.func("foo")
 
         self.assertEquals(callcount[0], 2)
 
     def test_invalidate_missing(self):
-        @cached()
-        def func(self, key):
-            return key
+        class A(object):
+            @cached()
+            def func(self, key):
+                return key
 
-        func.invalidate("what")
+        A().func.invalidate("what")
 
     @defer.inlineCallbacks
     def test_max_entries(self):
         callcount = [0]
 
-        @cached(max_entries=10)
-        def func(self, key):
-            callcount[0] += 1
-            return key
+        class A(object):
+            @cached(max_entries=10)
+            def func(self, key):
+                callcount[0] += 1
+                return key
 
-        for k in range(0,12):
-            yield func(self, k)
+        a = A()
+
+        for k in range(0, 12):
+            yield a.func(k)
 
         self.assertEquals(callcount[0], 12)
 
         # There must have been at least 2 evictions, meaning if we calculate
         # all 12 values again, we must get called at least 2 more times
         for k in range(0,12):
-            yield func(self, k)
+            yield a.func(k)
 
         self.assertTrue(callcount[0] >= 14,
             msg="Expected callcount >= 14, got %d" % (callcount[0]))
@@ -171,12 +182,15 @@ class CacheDecoratorTestCase(unittest.TestCase):
     def test_prefill(self):
         callcount = [0]
 
-        @cached()
-        def func(self, key):
-            callcount[0] += 1
-            return key
+        class A(object):
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+        a = A()
 
-        func.prefill("foo", 123)
+        a.func.prefill("foo", 123)
 
-        self.assertEquals((yield func(self, "foo")), 123)
+        self.assertEquals((yield a.func("foo")), 123)
         self.assertEquals(callcount[0], 0)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 78f6004204..2702291178 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -46,7 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
             (yield self.store.get_user_by_id(self.user_id))
         )
 
-        result = yield self.store.get_user_by_token(self.tokens[1])
+        result = yield self.store.get_user_by_token(self.tokens[0])
 
         self.assertDictContainsSubset(
             {