summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/storage/test_state.py10
-rw-r--r--tests/util/test_dict_cache.py35
2 files changed, 36 insertions, 9 deletions
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 8043bdbde2..5564161750 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -369,8 +369,8 @@ class StateStoreTestCase(HomeserverTestCase):
         state_dict_ids = cache_entry.value
 
         self.assertEqual(cache_entry.full, False)
-        self.assertEqual(cache_entry.known_absent, {(e1.type, e1.state_key)})
-        self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
+        self.assertEqual(cache_entry.known_absent, set())
+        self.assertDictEqual(state_dict_ids, {})
 
         ############################################
         # test that things work with a partial cache
@@ -387,7 +387,7 @@ class StateStoreTestCase(HomeserverTestCase):
         )
 
         self.assertEqual(is_all, False)
-        self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+        self.assertDictEqual({}, state_dict)
 
         room_id = self.room.to_string()
         (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
@@ -412,7 +412,7 @@ class StateStoreTestCase(HomeserverTestCase):
         )
 
         self.assertEqual(is_all, False)
-        self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+        self.assertDictEqual({}, state_dict)
 
         (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
@@ -443,7 +443,7 @@ class StateStoreTestCase(HomeserverTestCase):
         )
 
         self.assertEqual(is_all, False)
-        self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+        self.assertDictEqual({}, state_dict)
 
         (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index bee66dee43..e8b6246ab5 100644
--- a/tests/util/test_dict_cache.py
+++ b/tests/util/test_dict_cache.py
@@ -20,7 +20,7 @@ from tests import unittest
 
 class DictCacheTestCase(unittest.TestCase):
     def setUp(self):
-        self.cache = DictionaryCache("foobar")
+        self.cache = DictionaryCache("foobar", max_entries=10)
 
     def test_simple_cache_hit_full(self):
         key = "test_simple_cache_hit_full"
@@ -76,13 +76,13 @@ class DictCacheTestCase(unittest.TestCase):
 
         seq = self.cache.sequence
         test_value_1 = {"test": "test_simple_cache_hit_miss_partial"}
-        self.cache.update(seq, key, test_value_1, fetched_keys=set("test"))
+        self.cache.update(seq, key, test_value_1, fetched_keys={"test"})
 
         seq = self.cache.sequence
         test_value_2 = {"test2": "test_simple_cache_hit_miss_partial2"}
-        self.cache.update(seq, key, test_value_2, fetched_keys=set("test2"))
+        self.cache.update(seq, key, test_value_2, fetched_keys={"test2"})
 
-        c = self.cache.get(key)
+        c = self.cache.get(key, dict_keys=["test", "test2"])
         self.assertEqual(
             {
                 "test": "test_simple_cache_hit_miss_partial",
@@ -90,3 +90,30 @@ class DictCacheTestCase(unittest.TestCase):
             },
             c.value,
         )
+        self.assertEqual(c.full, False)
+
+    def test_invalidation(self):
+        """Test that the partial dict and full dicts get invalidated
+        separately.
+        """
+        key = "some_key"
+
+        seq = self.cache.sequence
+        # start by populating a "full dict" entry
+        self.cache.update(seq, key, {"a": "b", "c": "d"})
+
+        # add a bunch of individual entries, also keeping the individual
+        # entry for "a" warm.
+        for i in range(20):
+            self.cache.get(key, ["a"])
+            self.cache.update(seq, f"key{i}", {1: 2})
+
+        # We should have evicted the full dict...
+        r = self.cache.get(key)
+        self.assertFalse(r.full)
+        self.assertTrue("c" not in r.value)
+
+        # ... but kept the "a" entry that we kept querying.
+        r = self.cache.get(key, dict_keys=["a"])
+        self.assertFalse(r.full)
+        self.assertEqual(r.value, {"a": "b"})