summary refs log tree commit diff
path: root/tests/storage/test_state.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_state.py')
-rw-r--r--tests/storage/test_state.py145
1 files changed, 40 insertions, 105 deletions
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 2471f1267d..f06b452fa9 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,24 +15,18 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.storage.state import StateFilter
 from synapse.types import RoomID, UserID
 
-import tests.unittest
-import tests.utils
+from tests.unittest import HomeserverTestCase
 
 logger = logging.getLogger(__name__)
 
 
-class StateStoreTestCase(tests.unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
-
+class StateStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
         self.storage = hs.get_storage()
         self.state_datastore = self.storage.state.stores.state
@@ -44,7 +38,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         self.room = RoomID.from_string("!abc123:test")
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_room(
                 self.room.to_string(),
                 room_creator_user_id="@creator:text",
@@ -53,7 +47,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
             )
         )
 
-    @defer.inlineCallbacks
     def inject_state_event(self, room, sender, typ, state_key, content):
         builder = self.event_builder_factory.for_room_version(
             RoomVersions.V1,
@@ -66,13 +59,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
             },
         )
 
-        event, context = yield defer.ensureDeferred(
+        event, context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        yield defer.ensureDeferred(
-            self.storage.persistence.persist_event(event, context)
-        )
+        self.get_success(self.storage.persistence.persist_event(event, context))
 
         return event
 
@@ -82,16 +73,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.assertEqual(s1[t].event_id, s2[t].event_id)
         self.assertEqual(len(s1), len(s2))
 
-    @defer.inlineCallbacks
     def test_get_state_groups_ids(self):
-        e1 = yield self.inject_state_event(
-            self.room, self.u_alice, EventTypes.Create, "", {}
-        )
-        e2 = yield self.inject_state_event(
+        e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+        e2 = self.inject_state_event(
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield defer.ensureDeferred(
+        state_group_map = self.get_success(
             self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
         )
         self.assertEqual(len(state_group_map), 1)
@@ -101,16 +89,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
             {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
         )
 
-    @defer.inlineCallbacks
     def test_get_state_groups(self):
-        e1 = yield self.inject_state_event(
-            self.room, self.u_alice, EventTypes.Create, "", {}
-        )
-        e2 = yield self.inject_state_event(
+        e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+        e2 = self.inject_state_event(
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield defer.ensureDeferred(
+        state_group_map = self.get_success(
             self.storage.state.get_state_groups(self.room, [e2.event_id])
         )
         self.assertEqual(len(state_group_map), 1)
@@ -118,32 +103,29 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
 
-    @defer.inlineCallbacks
     def test_get_state_for_event(self):
 
         # this defaults to a linear DAG as each new injection defaults to whatever
         # forward extremities are currently in the DB for this room.
-        e1 = yield self.inject_state_event(
-            self.room, self.u_alice, EventTypes.Create, "", {}
-        )
-        e2 = yield self.inject_state_event(
+        e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+        e2 = self.inject_state_event(
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
-        e3 = yield self.inject_state_event(
+        e3 = self.inject_state_event(
             self.room,
             self.u_alice,
             EventTypes.Member,
             self.u_alice.to_string(),
             {"membership": Membership.JOIN},
         )
-        e4 = yield self.inject_state_event(
+        e4 = self.inject_state_event(
             self.room,
             self.u_bob,
             EventTypes.Member,
             self.u_bob.to_string(),
             {"membership": Membership.JOIN},
         )
-        e5 = yield self.inject_state_event(
+        e5 = self.inject_state_event(
             self.room,
             self.u_bob,
             EventTypes.Member,
@@ -152,9 +134,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we get the full state as of the final event
-        state = yield defer.ensureDeferred(
-            self.storage.state.get_state_for_event(e5.event_id)
-        )
+        state = self.get_success(self.storage.state.get_state_for_event(e5.event_id))
 
         self.assertIsNotNone(e4)
 
@@ -170,7 +150,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we can filter to the m.room.name event (with a '' state key)
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.storage.state.get_state_for_event(
                 e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
             )
@@ -179,7 +159,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can filter to the m.room.name event (with a wildcard None state key)
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.storage.state.get_state_for_event(
                 e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
             )
@@ -188,7 +168,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can grab the m.room.member events (with a wildcard None state key)
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.storage.state.get_state_for_event(
                 e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
             )
@@ -200,7 +180,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # check we can grab a specific room member without filtering out the
         # other event types
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.storage.state.get_state_for_event(
                 e5.event_id,
                 state_filter=StateFilter(
@@ -220,7 +200,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check that we can grab everything except members
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.storage.state.get_state_for_event(
                 e5.event_id,
                 state_filter=StateFilter(
@@ -238,17 +218,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
         #######################################################
 
         room_id = self.room.to_string()
-        group_ids = yield defer.ensureDeferred(
+        group_ids = self.get_success(
             self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
         )
         group = list(group_ids.keys())[0]
 
         # test _get_state_for_group_using_cache correctly filters out members
         # with types=[]
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -265,10 +242,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -281,10 +255,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with wildcard types
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -301,10 +272,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -324,10 +292,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -344,10 +309,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -360,10 +322,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -413,10 +372,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         # test _get_state_for_group_using_cache correctly filters out members
         # with types=[]
         room_id = self.room.to_string()
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -428,10 +384,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
         room_id = self.room.to_string()
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -444,10 +397,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # wildcard types
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -458,10 +408,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -480,10 +427,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -494,10 +438,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -510,10 +451,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -524,10 +462,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({}, state_dict)
 
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(