diff --git a/tests/test_state.py b/tests/test_state.py
index 66f22f6813..4858e8fc59 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -97,17 +97,19 @@ class StateGroupStore(object):
self._group_to_state[state_group] = dict(current_state_ids)
- return state_group
+ return defer.succeed(state_group)
def get_events(self, event_ids, **kwargs):
- return {
- e_id: self._event_id_to_event[e_id]
- for e_id in event_ids
- if e_id in self._event_id_to_event
- }
+ return defer.succeed(
+ {
+ e_id: self._event_id_to_event[e_id]
+ for e_id in event_ids
+ if e_id in self._event_id_to_event
+ }
+ )
def get_state_group_delta(self, name):
- return None, None
+ return defer.succeed((None, None))
def register_events(self, events):
for e in events:
@@ -120,7 +122,7 @@ class StateGroupStore(object):
self._event_to_state_group[event_id] = state_group
def get_room_version_id(self, room_id):
- return RoomVersions.V1.identifier
+ return defer.succeed(RoomVersions.V1.identifier)
class DictObj(dict):
@@ -202,7 +204,9 @@ class StateTestCase(unittest.TestCase):
context_store = {} # type: dict[str, EventContext]
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -244,7 +248,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -300,7 +306,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -373,7 +381,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -411,12 +421,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- context = yield self.state.compute_event_context(event, old_state=old_state)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event, old_state=old_state)
+ )
prev_state_ids = yield context.get_prev_state_ids()
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual(
(e.event_id for e in old_state), current_state_ids.values()
)
@@ -434,12 +446,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- context = yield self.state.compute_event_context(event, old_state=old_state)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event, old_state=old_state)
+ )
prev_state_ids = yield context.get_prev_state_ids()
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual(
(e.event_id for e in old_state + [event]), current_state_ids.values()
)
@@ -462,7 +476,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = self.store.store_state_group(
+ group_name = yield self.store.store_state_group(
prev_event_id,
event.room_id,
None,
@@ -471,9 +485,9 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id, group_name)
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(self.state.compute_event_context(event))
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(
{e.event_id for e in old_state}, set(current_state_ids.values())
@@ -494,7 +508,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = self.store.store_state_group(
+ group_name = yield self.store.store_state_group(
prev_event_id,
event.room_id,
None,
@@ -503,7 +517,7 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id, group_name)
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(self.state.compute_event_context(event))
prev_state_ids = yield context.get_prev_state_ids()
@@ -544,7 +558,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6)
@@ -586,7 +600,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6)
@@ -641,7 +655,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
@@ -669,14 +683,15 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
+ @defer.inlineCallbacks
def _get_context(
self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
):
- sg1 = self.store.store_state_group(
+ sg1 = yield self.store.store_state_group(
prev_event_id_1,
event.room_id,
None,
@@ -685,7 +700,7 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id_1, sg1)
- sg2 = self.store.store_state_group(
+ sg2 = yield self.store.store_state_group(
prev_event_id_2,
event.room_id,
None,
@@ -694,4 +709,5 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id_2, sg2)
- return self.state.compute_event_context(event)
+ result = yield defer.ensureDeferred(self.state.compute_event_context(event))
+ return result
|