summary refs log tree commit diff
path: root/tests/storage/test_state.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_state.py')
-rw-r--r--tests/storage/test_state.py64
1 files changed, 39 insertions, 25 deletions
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index a0e133cd4a..6a48b9d3b3 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -68,7 +68,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        yield self.storage.persistence.persist_event(event, context)
+        yield defer.ensureDeferred(
+            self.storage.persistence.persist_event(event, context)
+        )
 
         return event
 
@@ -87,8 +89,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield self.storage.state.get_state_groups_ids(
-            self.room, [e2.event_id]
+        state_group_map = yield defer.ensureDeferred(
+            self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
         )
         self.assertEqual(len(state_group_map), 1)
         state_map = list(state_group_map.values())[0]
@@ -106,8 +108,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield self.storage.state.get_state_groups(
-            self.room, [e2.event_id]
+        state_group_map = yield defer.ensureDeferred(
+            self.storage.state.get_state_groups(self.room, [e2.event_id])
         )
         self.assertEqual(len(state_group_map), 1)
         state_list = list(state_group_map.values())[0]
@@ -148,7 +150,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we get the full state as of the final event
-        state = yield self.storage.state.get_state_for_event(e5.event_id)
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(e5.event_id)
+        )
 
         self.assertIsNotNone(e4)
 
@@ -164,22 +168,28 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we can filter to the m.room.name event (with a '' state key)
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
+            )
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can filter to the m.room.name event (with a wildcard None state key)
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
+            )
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can grab the m.room.member events (with a wildcard None state key)
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
+            )
         )
 
         self.assertStateMapEqual(
@@ -188,12 +198,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # check we can grab a specific room member without filtering out the
         # other event types
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id,
-            state_filter=StateFilter(
-                types={EventTypes.Member: {self.u_alice.to_string()}},
-                include_others=True,
-            ),
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id,
+                state_filter=StateFilter(
+                    types={EventTypes.Member: {self.u_alice.to_string()}},
+                    include_others=True,
+                ),
+            )
         )
 
         self.assertStateMapEqual(
@@ -206,11 +218,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check that we can grab everything except members
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id,
-            state_filter=StateFilter(
-                types={EventTypes.Member: set()}, include_others=True
-            ),
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id,
+                state_filter=StateFilter(
+                    types={EventTypes.Member: set()}, include_others=True
+                ),
+            )
         )
 
         self.assertStateMapEqual(
@@ -222,8 +236,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
         #######################################################
 
         room_id = self.room.to_string()
-        group_ids = yield self.storage.state.get_state_groups_ids(
-            room_id, [e5.event_id]
+        group_ids = yield defer.ensureDeferred(
+            self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
         )
         group = list(group_ids.keys())[0]