summary refs log tree commit diff
path: root/tests/test_state.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/test_state.py86
1 files changed, 46 insertions, 40 deletions
diff --git a/tests/test_state.py b/tests/test_state.py
index b5c3667d2a..56ba0fecf5 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -80,16 +80,16 @@ class StateGroupStore(object):
 
         self._next_group = 1
 
-    def get_state_groups_ids(self, room_id, event_ids):
+    async def get_state_groups_ids(self, room_id, event_ids):
         groups = {}
         for event_id in event_ids:
             group = self._event_to_state_group.get(event_id)
             if group:
                 groups[group] = self._group_to_state[group]
 
-        return defer.succeed(groups)
+        return groups
 
-    def store_state_group(
+    async def store_state_group(
         self, event_id, room_id, prev_group, delta_ids, current_state_ids
     ):
         state_group = self._next_group
@@ -97,19 +97,17 @@ class StateGroupStore(object):
 
         self._group_to_state[state_group] = dict(current_state_ids)
 
-        return defer.succeed(state_group)
+        return state_group
 
-    def get_events(self, event_ids, **kwargs):
-        return defer.succeed(
-            {
-                e_id: self._event_id_to_event[e_id]
-                for e_id in event_ids
-                if e_id in self._event_id_to_event
-            }
-        )
+    async def get_events(self, event_ids, **kwargs):
+        return {
+            e_id: self._event_id_to_event[e_id]
+            for e_id in event_ids
+            if e_id in self._event_id_to_event
+        }
 
-    def get_state_group_delta(self, name):
-        return defer.succeed((None, None))
+    async def get_state_group_delta(self, name):
+        return (None, None)
 
     def register_events(self, events):
         for e in events:
@@ -121,8 +119,8 @@ class StateGroupStore(object):
     def register_event_id_state_group(self, event_id, state_group):
         self._event_to_state_group[event_id] = state_group
 
-    def get_room_version_id(self, room_id):
-        return defer.succeed(RoomVersions.V1.identifier)
+    async def get_room_version_id(self, room_id):
+        return RoomVersions.V1.identifier
 
 
 class DictObj(dict):
@@ -476,12 +474,14 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        group_name = yield self.store.store_state_group(
-            prev_event_id,
-            event.room_id,
-            None,
-            None,
-            {(e.type, e.state_key): e.event_id for e in old_state},
+        group_name = yield defer.ensureDeferred(
+            self.store.store_state_group(
+                prev_event_id,
+                event.room_id,
+                None,
+                None,
+                {(e.type, e.state_key): e.event_id for e in old_state},
+            )
         )
         self.store.register_event_id_state_group(prev_event_id, group_name)
 
@@ -508,12 +508,14 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        group_name = yield self.store.store_state_group(
-            prev_event_id,
-            event.room_id,
-            None,
-            None,
-            {(e.type, e.state_key): e.event_id for e in old_state},
+        group_name = yield defer.ensureDeferred(
+            self.store.store_state_group(
+                prev_event_id,
+                event.room_id,
+                None,
+                None,
+                {(e.type, e.state_key): e.event_id for e in old_state},
+            )
         )
         self.store.register_event_id_state_group(prev_event_id, group_name)
 
@@ -691,21 +693,25 @@ class StateTestCase(unittest.TestCase):
     def _get_context(
         self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
     ):
-        sg1 = yield self.store.store_state_group(
-            prev_event_id_1,
-            event.room_id,
-            None,
-            None,
-            {(e.type, e.state_key): e.event_id for e in old_state_1},
+        sg1 = yield defer.ensureDeferred(
+            self.store.store_state_group(
+                prev_event_id_1,
+                event.room_id,
+                None,
+                None,
+                {(e.type, e.state_key): e.event_id for e in old_state_1},
+            )
         )
         self.store.register_event_id_state_group(prev_event_id_1, sg1)
 
-        sg2 = yield self.store.store_state_group(
-            prev_event_id_2,
-            event.room_id,
-            None,
-            None,
-            {(e.type, e.state_key): e.event_id for e in old_state_2},
+        sg2 = yield defer.ensureDeferred(
+            self.store.store_state_group(
+                prev_event_id_2,
+                event.room_id,
+                None,
+                None,
+                {(e.type, e.state_key): e.event_id for e in old_state_2},
+            )
         )
         self.store.register_event_id_state_group(prev_event_id_2, sg2)