diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 8043bdbde2..5564161750 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -369,8 +369,8 @@ class StateStoreTestCase(HomeserverTestCase):
state_dict_ids = cache_entry.value
self.assertEqual(cache_entry.full, False)
- self.assertEqual(cache_entry.known_absent, {(e1.type, e1.state_key)})
- self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
+ self.assertEqual(cache_entry.known_absent, set())
+ self.assertDictEqual(state_dict_ids, {})
############################################
# test that things work with a partial cache
@@ -387,7 +387,7 @@ class StateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(is_all, False)
- self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+ self.assertDictEqual({}, state_dict)
room_id = self.room.to_string()
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
@@ -412,7 +412,7 @@ class StateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(is_all, False)
- self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+ self.assertDictEqual({}, state_dict)
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
@@ -443,7 +443,7 @@ class StateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(is_all, False)
- self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+ self.assertDictEqual({}, state_dict)
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index bee66dee43..e8b6246ab5 100644
--- a/tests/util/test_dict_cache.py
+++ b/tests/util/test_dict_cache.py
@@ -20,7 +20,7 @@ from tests import unittest
class DictCacheTestCase(unittest.TestCase):
def setUp(self):
- self.cache = DictionaryCache("foobar")
+ self.cache = DictionaryCache("foobar", max_entries=10)
def test_simple_cache_hit_full(self):
key = "test_simple_cache_hit_full"
@@ -76,13 +76,13 @@ class DictCacheTestCase(unittest.TestCase):
seq = self.cache.sequence
test_value_1 = {"test": "test_simple_cache_hit_miss_partial"}
- self.cache.update(seq, key, test_value_1, fetched_keys=set("test"))
+ self.cache.update(seq, key, test_value_1, fetched_keys={"test"})
seq = self.cache.sequence
test_value_2 = {"test2": "test_simple_cache_hit_miss_partial2"}
- self.cache.update(seq, key, test_value_2, fetched_keys=set("test2"))
+ self.cache.update(seq, key, test_value_2, fetched_keys={"test2"})
- c = self.cache.get(key)
+ c = self.cache.get(key, dict_keys=["test", "test2"])
self.assertEqual(
{
"test": "test_simple_cache_hit_miss_partial",
@@ -90,3 +90,30 @@ class DictCacheTestCase(unittest.TestCase):
},
c.value,
)
+ self.assertEqual(c.full, False)
+
+ def test_invalidation(self):
+ """Test that the partial dict and full dicts get invalidated
+ separately.
+ """
+ key = "some_key"
+
+ seq = self.cache.sequence
+ # start by populating a "full dict" entry
+ self.cache.update(seq, key, {"a": "b", "c": "d"})
+
+ # add a bunch of individual entries, also keeping the individual
+ # entry for "a" warm.
+ for i in range(20):
+ self.cache.get(key, ["a"])
+ self.cache.update(seq, f"key{i}", {1: 2})
+
+ # We should have evicted the full dict...
+ r = self.cache.get(key)
+ self.assertFalse(r.full)
+ self.assertTrue("c" not in r.value)
+
+ # ... but kept the "a" entry that we kept querying.
+ r = self.cache.get(key, dict_keys=["a"])
+ self.assertFalse(r.full)
+ self.assertEqual(r.value, {"a": "b"})
|