summary refs log tree commit diff
path: root/tests/test_state.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_state.py')
-rw-r--r--tests/test_state.py72
1 files changed, 44 insertions, 28 deletions
diff --git a/tests/test_state.py b/tests/test_state.py
index 66f22f6813..4858e8fc59 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -97,17 +97,19 @@ class StateGroupStore(object):
 
         self._group_to_state[state_group] = dict(current_state_ids)
 
-        return state_group
+        return defer.succeed(state_group)
 
     def get_events(self, event_ids, **kwargs):
-        return {
-            e_id: self._event_id_to_event[e_id]
-            for e_id in event_ids
-            if e_id in self._event_id_to_event
-        }
+        return defer.succeed(
+            {
+                e_id: self._event_id_to_event[e_id]
+                for e_id in event_ids
+                if e_id in self._event_id_to_event
+            }
+        )
 
     def get_state_group_delta(self, name):
-        return None, None
+        return defer.succeed((None, None))
 
     def register_events(self, events):
         for e in events:
@@ -120,7 +122,7 @@ class StateGroupStore(object):
         self._event_to_state_group[event_id] = state_group
 
     def get_room_version_id(self, room_id):
-        return RoomVersions.V1.identifier
+        return defer.succeed(RoomVersions.V1.identifier)
 
 
 class DictObj(dict):
@@ -202,7 +204,9 @@ class StateTestCase(unittest.TestCase):
         context_store = {}  # type: dict[str, EventContext]
 
         for event in graph.walk():
-            context = yield self.state.compute_event_context(event)
+            context = yield defer.ensureDeferred(
+                self.state.compute_event_context(event)
+            )
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
@@ -244,7 +248,9 @@ class StateTestCase(unittest.TestCase):
         context_store = {}
 
         for event in graph.walk():
-            context = yield self.state.compute_event_context(event)
+            context = yield defer.ensureDeferred(
+                self.state.compute_event_context(event)
+            )
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
@@ -300,7 +306,9 @@ class StateTestCase(unittest.TestCase):
         context_store = {}
 
         for event in graph.walk():
-            context = yield self.state.compute_event_context(event)
+            context = yield defer.ensureDeferred(
+                self.state.compute_event_context(event)
+            )
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
@@ -373,7 +381,9 @@ class StateTestCase(unittest.TestCase):
         context_store = {}
 
         for event in graph.walk():
-            context = yield self.state.compute_event_context(event)
+            context = yield defer.ensureDeferred(
+                self.state.compute_event_context(event)
+            )
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
@@ -411,12 +421,14 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        context = yield self.state.compute_event_context(event, old_state=old_state)
+        context = yield defer.ensureDeferred(
+            self.state.compute_event_context(event, old_state=old_state)
+        )
 
         prev_state_ids = yield context.get_prev_state_ids()
         self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
         self.assertCountEqual(
             (e.event_id for e in old_state), current_state_ids.values()
         )
@@ -434,12 +446,14 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        context = yield self.state.compute_event_context(event, old_state=old_state)
+        context = yield defer.ensureDeferred(
+            self.state.compute_event_context(event, old_state=old_state)
+        )
 
         prev_state_ids = yield context.get_prev_state_ids()
         self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
         self.assertCountEqual(
             (e.event_id for e in old_state + [event]), current_state_ids.values()
         )
@@ -462,7 +476,7 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        group_name = self.store.store_state_group(
+        group_name = yield self.store.store_state_group(
             prev_event_id,
             event.room_id,
             None,
@@ -471,9 +485,9 @@ class StateTestCase(unittest.TestCase):
         )
         self.store.register_event_id_state_group(prev_event_id, group_name)
 
-        context = yield self.state.compute_event_context(event)
+        context = yield defer.ensureDeferred(self.state.compute_event_context(event))
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(
             {e.event_id for e in old_state}, set(current_state_ids.values())
@@ -494,7 +508,7 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        group_name = self.store.store_state_group(
+        group_name = yield self.store.store_state_group(
             prev_event_id,
             event.room_id,
             None,
@@ -503,7 +517,7 @@ class StateTestCase(unittest.TestCase):
         )
         self.store.register_event_id_state_group(prev_event_id, group_name)
 
-        context = yield self.state.compute_event_context(event)
+        context = yield defer.ensureDeferred(self.state.compute_event_context(event))
 
         prev_state_ids = yield context.get_prev_state_ids()
 
@@ -544,7 +558,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(len(current_state_ids), 6)
 
@@ -586,7 +600,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(len(current_state_ids), 6)
 
@@ -641,7 +655,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
 
@@ -669,14 +683,15 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
 
+    @defer.inlineCallbacks
     def _get_context(
         self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
     ):
-        sg1 = self.store.store_state_group(
+        sg1 = yield self.store.store_state_group(
             prev_event_id_1,
             event.room_id,
             None,
@@ -685,7 +700,7 @@ class StateTestCase(unittest.TestCase):
         )
         self.store.register_event_id_state_group(prev_event_id_1, sg1)
 
-        sg2 = self.store.store_state_group(
+        sg2 = yield self.store.store_state_group(
             prev_event_id_2,
             event.room_id,
             None,
@@ -694,4 +709,5 @@ class StateTestCase(unittest.TestCase):
         )
         self.store.register_event_id_state_group(prev_event_id_2, sg2)
 
-        return self.state.compute_event_context(event)
+        result = yield defer.ensureDeferred(self.state.compute_event_context(event))
+        return result