diff options
Diffstat (limited to 'tests/storage/test_state.py')
-rw-r--r-- | tests/storage/test_state.py | 145 |
1 files changed, 40 insertions, 105 deletions
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 2471f1267d..f06b452fa9 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,24 +15,18 @@ import logging -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.storage.state import StateFilter from synapse.types import RoomID, UserID -import tests.unittest -import tests.utils +from tests.unittest import HomeserverTestCase logger = logging.getLogger(__name__) -class StateStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver(self.addCleanup) - +class StateStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.storage = hs.get_storage() self.state_datastore = self.storage.state.stores.state @@ -44,7 +38,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room = RoomID.from_string("!abc123:test") - yield defer.ensureDeferred( + self.get_success( self.store.store_room( self.room.to_string(), room_creator_user_id="@creator:text", @@ -53,7 +47,6 @@ class StateStoreTestCase(tests.unittest.TestCase): ) ) - @defer.inlineCallbacks def inject_state_event(self, room, sender, typ, state_key, content): builder = self.event_builder_factory.for_room_version( RoomVersions.V1, @@ -66,13 +59,11 @@ class StateStoreTestCase(tests.unittest.TestCase): }, ) - event, context = yield defer.ensureDeferred( + event, context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) - yield defer.ensureDeferred( - self.storage.persistence.persist_event(event, context) - ) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -82,16 +73,13 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(s1[t].event_id, s2[t].event_id) self.assertEqual(len(s1), len(s2)) - @defer.inlineCallbacks def test_get_state_groups_ids(self): - e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, "", {} - ) - e2 = yield self.inject_state_event( + e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) + e2 = self.inject_state_event( self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield defer.ensureDeferred( + state_group_map = self.get_success( self.storage.state.get_state_groups_ids(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) @@ -101,16 +89,13 @@ class StateStoreTestCase(tests.unittest.TestCase): {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id}, ) - @defer.inlineCallbacks def test_get_state_groups(self): - e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, "", {} - ) - e2 = yield self.inject_state_event( + e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) + e2 = self.inject_state_event( self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield defer.ensureDeferred( + state_group_map = self.get_success( self.storage.state.get_state_groups(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) @@ -118,32 +103,29 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id}) - @defer.inlineCallbacks def test_get_state_for_event(self): # this defaults to a linear DAG as each new injection defaults to whatever # forward extremities are currently in the DB for this room. - e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, "", {} - ) - e2 = yield self.inject_state_event( + e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) + e2 = self.inject_state_event( self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - e3 = yield self.inject_state_event( + e3 = self.inject_state_event( self.room, self.u_alice, EventTypes.Member, self.u_alice.to_string(), {"membership": Membership.JOIN}, ) - e4 = yield self.inject_state_event( + e4 = self.inject_state_event( self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), {"membership": Membership.JOIN}, ) - e5 = yield self.inject_state_event( + e5 = self.inject_state_event( self.room, self.u_bob, EventTypes.Member, @@ -152,9 +134,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield defer.ensureDeferred( - self.storage.state.get_state_for_event(e5.event_id) - ) + state = self.get_success(self.storage.state.get_state_for_event(e5.event_id)) self.assertIsNotNone(e4) @@ -170,7 +150,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we can filter to the m.room.name event (with a '' state key) - state = yield defer.ensureDeferred( + state = self.get_success( self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) ) @@ -179,7 +159,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) - state = yield defer.ensureDeferred( + state = self.get_success( self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) ) @@ -188,7 +168,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) - state = yield defer.ensureDeferred( + state = self.get_success( self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) ) @@ -200,7 +180,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can grab a specific room member without filtering out the # other event types - state = yield defer.ensureDeferred( + state = self.get_success( self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( @@ -220,7 +200,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check that we can grab everything except members - state = yield defer.ensureDeferred( + state = self.get_success( self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( @@ -238,17 +218,14 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### room_id = self.room.to_string() - group_ids = yield defer.ensureDeferred( + group_ids = self.get_success( self.storage.state.get_state_groups_ids(room_id, [e5.event_id]) ) group = list(group_ids.keys())[0] # test _get_state_for_group_using_cache correctly filters out members # with types=[] - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -265,10 +242,7 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -281,10 +255,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with wildcard types - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -301,10 +272,7 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -324,10 +292,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -344,10 +309,7 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -360,10 +322,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -413,10 +372,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] room_id = self.room.to_string() - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -428,10 +384,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) room_id = self.room.to_string() - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -444,10 +397,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # wildcard types - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -458,10 +408,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -480,10 +427,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -494,10 +438,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -510,10 +451,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -524,10 +462,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( |