summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/storage/test__base.py23
-rw-r--r--tests/test_distributor.py4
-rw-r--r--tests/test_state.py2
-rw-r--r--tests/util/test_dict_cache.py101
-rw-r--r--tests/util/test_lrucache.py4
5 files changed, 118 insertions, 16 deletions
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 8c3d2952bd..e72cace8ff 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -17,7 +17,9 @@
 from tests import unittest
 from twisted.internet import defer
 
-from synapse.storage._base import Cache, cached
+from synapse.util.async import ObservableDeferred
+
+from synapse.util.caches.descriptors import Cache, cached
 
 
 class CacheTestCase(unittest.TestCase):
@@ -40,12 +42,12 @@ class CacheTestCase(unittest.TestCase):
         self.assertEquals(self.cache.get("foo"), 123)
 
     def test_invalidate(self):
-        self.cache.prefill("foo", 123)
-        self.cache.invalidate("foo")
+        self.cache.prefill(("foo",), 123)
+        self.cache.invalidate(("foo",))
 
         failed = False
         try:
-            self.cache.get("foo")
+            self.cache.get(("foo",))
         except KeyError:
             failed = True
 
@@ -139,7 +141,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
 
         self.assertEquals(callcount[0], 1)
 
-        a.func.invalidate("foo")
+        a.func.invalidate(("foo",))
 
         yield a.func("foo")
 
@@ -151,7 +153,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
             def func(self, key):
                 return key
 
-        A().func.invalidate("what")
+        A().func.invalidate(("what",))
 
     @defer.inlineCallbacks
     def test_max_entries(self):
@@ -178,19 +180,20 @@ class CacheDecoratorTestCase(unittest.TestCase):
         self.assertTrue(callcount[0] >= 14,
             msg="Expected callcount >= 14, got %d" % (callcount[0]))
 
-    @defer.inlineCallbacks
     def test_prefill(self):
         callcount = [0]
 
+        d = defer.succeed(123)
+
         class A(object):
             @cached()
             def func(self, key):
                 callcount[0] += 1
-                return key
+                return d
 
         a = A()
 
-        a.func.prefill("foo", 123)
+        a.func.prefill(("foo",), ObservableDeferred(d))
 
-        self.assertEquals((yield a.func("foo")), 123)
+        self.assertEquals(a.func("foo").result, d.result)
         self.assertEquals(callcount[0], 0)
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 6a0095d850..8ed48cfb70 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -73,8 +73,8 @@ class DistributorTestCase(unittest.TestCase):
             yield d
             self.assertTrue(d.called)
 
-            observers[0].assert_called_once("Go")
-            observers[1].assert_called_once("Go")
+            observers[0].assert_called_once_with("Go")
+            observers[1].assert_called_once_with("Go")
 
             self.assertEquals(mock_logger.warning.call_count, 1)
             self.assertIsInstance(mock_logger.warning.call_args[0][0],
diff --git a/tests/test_state.py b/tests/test_state.py
index fea25f7021..5845358754 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -69,7 +69,7 @@ class StateGroupStore(object):
 
         self._next_group = 1
 
-    def get_state_groups(self, event_ids):
+    def get_state_groups(self, room_id, event_ids):
         groups = {}
         for event_id in event_ids:
             group = self._event_to_state_group.get(event_id)
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
new file mode 100644
index 0000000000..54ff26cd97
--- /dev/null
+++ b/tests/util/test_dict_cache.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+from tests import unittest
+
+from synapse.util.caches.dictionary_cache import DictionaryCache
+
+
+class DictCacheTestCase(unittest.TestCase):
+
+    def setUp(self):
+        self.cache = DictionaryCache("foobar")
+
+    def test_simple_cache_hit_full(self):
+        key = "test_simple_cache_hit_full"
+
+        v = self.cache.get(key)
+        self.assertEqual((False, {}), v)
+
+        seq = self.cache.sequence
+        test_value = {"test": "test_simple_cache_hit_full"}
+        self.cache.update(seq, key, test_value, full=True)
+
+        c = self.cache.get(key)
+        self.assertEqual(test_value, c.value)
+
+    def test_simple_cache_hit_partial(self):
+        key = "test_simple_cache_hit_partial"
+
+        seq = self.cache.sequence
+        test_value = {
+            "test": "test_simple_cache_hit_partial"
+        }
+        self.cache.update(seq, key, test_value, full=True)
+
+        c = self.cache.get(key, ["test"])
+        self.assertEqual(test_value, c.value)
+
+    def test_simple_cache_miss_partial(self):
+        key = "test_simple_cache_miss_partial"
+
+        seq = self.cache.sequence
+        test_value = {
+            "test": "test_simple_cache_miss_partial"
+        }
+        self.cache.update(seq, key, test_value, full=True)
+
+        c = self.cache.get(key, ["test2"])
+        self.assertEqual({}, c.value)
+
+    def test_simple_cache_hit_miss_partial(self):
+        key = "test_simple_cache_hit_miss_partial"
+
+        seq = self.cache.sequence
+        test_value = {
+            "test": "test_simple_cache_hit_miss_partial",
+            "test2": "test_simple_cache_hit_miss_partial2",
+            "test3": "test_simple_cache_hit_miss_partial3",
+        }
+        self.cache.update(seq, key, test_value, full=True)
+
+        c = self.cache.get(key, ["test2"])
+        self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
+
+    def test_multi_insert(self):
+        key = "test_simple_cache_hit_miss_partial"
+
+        seq = self.cache.sequence
+        test_value_1 = {
+            "test": "test_simple_cache_hit_miss_partial",
+        }
+        self.cache.update(seq, key, test_value_1, full=False)
+
+        seq = self.cache.sequence
+        test_value_2 = {
+            "test2": "test_simple_cache_hit_miss_partial2",
+        }
+        self.cache.update(seq, key, test_value_2, full=False)
+
+        c = self.cache.get(key)
+        self.assertEqual(
+            {
+                "test": "test_simple_cache_hit_miss_partial",
+                "test2": "test_simple_cache_hit_miss_partial2",
+            },
+            c.value
+        )
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index ab934bf928..fc5a904323 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -16,7 +16,7 @@
 
 from .. import unittest
 
-from synapse.util.lrucache import LruCache
+from synapse.util.caches.lrucache import LruCache
 
 class LruCacheTestCase(unittest.TestCase):
 
@@ -52,5 +52,3 @@ class LruCacheTestCase(unittest.TestCase):
         cache["key"] = 1
         self.assertEquals(cache.pop("key"), 1)
         self.assertEquals(cache.pop("key"), None)
-
-