summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/_base.py45
-rw-r--r--tests/storage/test__base.py84
-rw-r--r--tests/storage/test_registration.py2
3 files changed, 81 insertions, 50 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 39884c2afe..8d33def6c6 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -127,7 +127,7 @@ class Cache(object):
         self.cache.clear()
 
 
-def cached(max_entries=1000, num_args=1, lru=False):
+class CacheDescriptor(object):
     """ A method decorator that applies a memoizing cache around the function.
 
     The function is presumed to take zero or more arguments, which are used in
@@ -141,25 +141,32 @@ def cached(max_entries=1000, num_args=1, lru=False):
     which can be used to insert values into the cache specifically, without
     calling the calculation function.
     """
-    def wrap(orig):
+    def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
+        self.orig = orig
+
+        self.max_entries = max_entries
+        self.num_args = num_args
+        self.lru = lru
+
+    def __get__(self, obj, objtype=None):
         cache = Cache(
-            name=orig.__name__,
-            max_entries=max_entries,
-            keylen=num_args,
-            lru=lru,
+            name=self.orig.__name__,
+            max_entries=self.max_entries,
+            keylen=self.num_args,
+            lru=self.lru,
         )
 
-        @functools.wraps(orig)
+        @functools.wraps(self.orig)
         @defer.inlineCallbacks
-        def wrapped(self, *keyargs):
+        def wrapped(*keyargs):
             try:
-                cached_result = cache.get(*keyargs)
+                cached_result = cache.get(*keyargs[:self.num_args])
                 if DEBUG_CACHES:
-                    actual_result = yield orig(self, *keyargs)
+                    actual_result = yield self.orig(obj, *keyargs)
                     if actual_result != cached_result:
                         logger.error(
                             "Stale cache entry %s%r: cached: %r, actual %r",
-                            orig.__name__, keyargs,
+                            self.orig.__name__, keyargs,
                             cached_result, actual_result,
                         )
                         raise ValueError("Stale cache entry")
@@ -170,18 +177,28 @@ def cached(max_entries=1000, num_args=1, lru=False):
                 # while the SELECT is executing (SYN-369)
                 sequence = cache.sequence
 
-                ret = yield orig(self, *keyargs)
+                ret = yield self.orig(obj, *keyargs)
 
-                cache.update(sequence, *keyargs + (ret,))
+                cache.update(sequence, *keyargs[:self.num_args] + (ret,))
 
                 defer.returnValue(ret)
 
         wrapped.invalidate = cache.invalidate
         wrapped.invalidate_all = cache.invalidate_all
         wrapped.prefill = cache.prefill
+
+        obj.__dict__[self.orig.__name__] = wrapped
+
         return wrapped
 
-    return wrap
+
+def cached(max_entries=1000, num_args=1, lru=False):
+    return lambda orig: CacheDescriptor(
+        orig,
+        max_entries=max_entries,
+        num_args=num_args,
+        lru=lru
+    )
 
 
 class LoggingTransaction(object):
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 96caf8c4c1..8c3d2952bd 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -96,73 +96,84 @@ class CacheDecoratorTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_passthrough(self):
-        @cached()
-        def func(self, key):
-            return key
+        class A(object):
+            @cached()
+            def func(self, key):
+                return key
 
-        self.assertEquals((yield func(self, "foo")), "foo")
-        self.assertEquals((yield func(self, "bar")), "bar")
+        a = A()
+
+        self.assertEquals((yield a.func("foo")), "foo")
+        self.assertEquals((yield a.func("bar")), "bar")
 
     @defer.inlineCallbacks
     def test_hit(self):
         callcount = [0]
 
-        @cached()
-        def func(self, key):
-            callcount[0] += 1
-            return key
+        class A(object):
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
 
-        yield func(self, "foo")
+        a = A()
+        yield a.func("foo")
 
         self.assertEquals(callcount[0], 1)
 
-        self.assertEquals((yield func(self, "foo")), "foo")
+        self.assertEquals((yield a.func("foo")), "foo")
         self.assertEquals(callcount[0], 1)
 
     @defer.inlineCallbacks
     def test_invalidate(self):
         callcount = [0]
 
-        @cached()
-        def func(self, key):
-            callcount[0] += 1
-            return key
+        class A(object):
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
 
-        yield func(self, "foo")
+        a = A()
+        yield a.func("foo")
 
         self.assertEquals(callcount[0], 1)
 
-        func.invalidate("foo")
+        a.func.invalidate("foo")
 
-        yield func(self, "foo")
+        yield a.func("foo")
 
         self.assertEquals(callcount[0], 2)
 
     def test_invalidate_missing(self):
-        @cached()
-        def func(self, key):
-            return key
+        class A(object):
+            @cached()
+            def func(self, key):
+                return key
 
-        func.invalidate("what")
+        A().func.invalidate("what")
 
     @defer.inlineCallbacks
     def test_max_entries(self):
         callcount = [0]
 
-        @cached(max_entries=10)
-        def func(self, key):
-            callcount[0] += 1
-            return key
+        class A(object):
+            @cached(max_entries=10)
+            def func(self, key):
+                callcount[0] += 1
+                return key
 
-        for k in range(0,12):
-            yield func(self, k)
+        a = A()
+
+        for k in range(0, 12):
+            yield a.func(k)
 
         self.assertEquals(callcount[0], 12)
 
         # There must have been at least 2 evictions, meaning if we calculate
         # all 12 values again, we must get called at least 2 more times
         for k in range(0,12):
-            yield func(self, k)
+            yield a.func(k)
 
         self.assertTrue(callcount[0] >= 14,
             msg="Expected callcount >= 14, got %d" % (callcount[0]))
@@ -171,12 +182,15 @@ class CacheDecoratorTestCase(unittest.TestCase):
     def test_prefill(self):
         callcount = [0]
 
-        @cached()
-        def func(self, key):
-            callcount[0] += 1
-            return key
+        class A(object):
+            @cached()
+            def func(self, key):
+                callcount[0] += 1
+                return key
+
+        a = A()
 
-        func.prefill("foo", 123)
+        a.func.prefill("foo", 123)
 
-        self.assertEquals((yield func(self, "foo")), 123)
+        self.assertEquals((yield a.func("foo")), 123)
         self.assertEquals(callcount[0], 0)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 78f6004204..2702291178 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -46,7 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
             (yield self.store.get_user_by_id(self.user_id))
         )
 
-        result = yield self.store.get_user_by_token(self.tokens[1])
+        result = yield self.store.get_user_by_token(self.tokens[0])
 
         self.assertDictContainsSubset(
             {