diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index b2f314e9db..b7797c45e8 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -159,6 +159,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
(e3.type, e3.state_key): e3,
}, state)
+ # check that types=[], filtered_types=[EventTypes.Member]
+ # doesn't return all members
state = yield self.store.get_state_for_event(
e5.event_id, [], filtered_types=[EventTypes.Member],
)
@@ -167,3 +169,151 @@ class StateStoreTestCase(tests.unittest.TestCase):
(e1.type, e1.state_key): e1,
(e2.type, e2.state_key): e2,
}, state)
+
+ ##################################
+ # _get_some_state_from_cache tests
+ ##################################
+
+ room_id = self.room.to_string()
+ group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
+ group = group_ids.keys()[0]
+
+ # test that _get_some_state_from_cache correctly filters out members with types=[]
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ group, [], filtered_types=[EventTypes.Member]
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual({
+ (e1.type, e1.state_key): e1.event_id,
+ (e2.type, e2.state_key): e2.event_id,
+ }, state_dict)
+
+ # test that _get_some_state_from_cache correctly filters out members with types=wildcard
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual({
+ (e1.type, e1.state_key): e1.event_id,
+ (e2.type, e2.state_key): e2.event_id,
+ (e3.type, e3.state_key): e3.event_id,
+ # e4 is overwritten by e5
+ (e5.type, e5.state_key): e5.event_id,
+ }, state_dict)
+
+ # test that _get_some_state_from_cache correctly filters out members with types=specific
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member]
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual({
+ (e1.type, e1.state_key): e1.event_id,
+ (e2.type, e2.state_key): e2.event_id,
+ (e5.type, e5.state_key): e5.event_id,
+ }, state_dict)
+
+ # test that _get_some_state_from_cache correctly filters out members with types=specific
+ # and no filtered_types
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ group, [(EventTypes.Member, e5.state_key)], filtered_types=None
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual({
+ (e5.type, e5.state_key): e5.event_id,
+ }, state_dict)
+
+ #######################################################
+ # deliberately remove e2 (room name) from the _state_group_cache
+
+ (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group)
+
+ self.assertEqual(is_all, True)
+ self.assertEqual(known_absent, set())
+ self.assertDictEqual(state_dict_ids, {
+ (e1.type, e1.state_key): e1.event_id,
+ (e2.type, e2.state_key): e2.event_id,
+ (e3.type, e3.state_key): e3.event_id,
+ # e4 is overwritten by e5
+ (e5.type, e5.state_key): e5.event_id,
+ })
+
+ state_dict_ids.pop((e2.type, e2.state_key))
+ self.store._state_group_cache.invalidate(group)
+ self.store._state_group_cache.update(
+ sequence=self.store._state_group_cache.sequence,
+ key=group,
+ value=state_dict_ids,
+ # list fetched keys so it knows it's partial
+ fetched_keys=(
+ (e1.type, e1.state_key),
+ (e3.type, e3.state_key),
+ (e5.type, e5.state_key),
+ )
+ )
+
+ (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group)
+
+ self.assertEqual(is_all, False)
+ self.assertEqual(known_absent, set([
+ (e1.type, e1.state_key),
+ (e3.type, e3.state_key),
+ (e5.type, e5.state_key),
+ ]))
+ self.assertDictEqual(state_dict_ids, {
+ (e1.type, e1.state_key): e1.event_id,
+ (e3.type, e3.state_key): e3.event_id,
+ (e5.type, e5.state_key): e5.event_id,
+ })
+
+ ###################################################
+ # test that things work with a partial cache
+
+ # test that _get_some_state_from_cache correctly filters out members with types=[]
+ room_id = self.room.to_string()
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ group, [], filtered_types=[EventTypes.Member]
+ )
+
+ self.assertEqual(is_all, False)
+ self.assertDictEqual({
+ (e1.type, e1.state_key): e1.event_id,
+ }, state_dict)
+
+ # test that _get_some_state_from_cache correctly filters out members with types=wildcard
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
+ )
+
+ self.assertEqual(is_all, False)
+ self.assertDictEqual({
+ (e1.type, e1.state_key): e1.event_id,
+ (e3.type, e3.state_key): e3.event_id,
+ # e4 is overwritten by e5
+ (e5.type, e5.state_key): e5.event_id,
+ }, state_dict)
+
+ # test that _get_some_state_from_cache correctly filters out members with types=specific
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member]
+ )
+
+ self.assertEqual(is_all, False)
+ self.assertDictEqual({
+ (e1.type, e1.state_key): e1.event_id,
+ (e5.type, e5.state_key): e5.event_id,
+ }, state_dict)
+
+ # test that _get_some_state_from_cache correctly filters out members with types=specific
+ # and no filtered_types
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ group, [(EventTypes.Member, e5.state_key)], filtered_types=None
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual({
+ (e5.type, e5.state_key): e5.event_id,
+ }, state_dict)
|