diff --git a/tests/test_state.py b/tests/test_state.py
index 4858e8fc59..80b0ccbc40 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -71,7 +71,7 @@ def create_event(
return event
-class StateGroupStore(object):
+class StateGroupStore:
def __init__(self):
self._event_to_state_group = {}
self._group_to_state = {}
@@ -80,16 +80,16 @@ class StateGroupStore(object):
self._next_group = 1
- def get_state_groups_ids(self, room_id, event_ids):
+ async def get_state_groups_ids(self, room_id, event_ids):
groups = {}
for event_id in event_ids:
group = self._event_to_state_group.get(event_id)
if group:
groups[group] = self._group_to_state[group]
- return defer.succeed(groups)
+ return groups
- def store_state_group(
+ async def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
):
state_group = self._next_group
@@ -97,19 +97,17 @@ class StateGroupStore(object):
self._group_to_state[state_group] = dict(current_state_ids)
- return defer.succeed(state_group)
+ return state_group
- def get_events(self, event_ids, **kwargs):
- return defer.succeed(
- {
- e_id: self._event_id_to_event[e_id]
- for e_id in event_ids
- if e_id in self._event_id_to_event
- }
- )
+ async def get_events(self, event_ids, **kwargs):
+ return {
+ e_id: self._event_id_to_event[e_id]
+ for e_id in event_ids
+ if e_id in self._event_id_to_event
+ }
- def get_state_group_delta(self, name):
- return defer.succeed((None, None))
+ async def get_state_group_delta(self, name):
+ return (None, None)
def register_events(self, events):
for e in events:
@@ -121,17 +119,17 @@ class StateGroupStore(object):
def register_event_id_state_group(self, event_id, state_group):
self._event_to_state_group[event_id] = state_group
- def get_room_version_id(self, room_id):
- return defer.succeed(RoomVersions.V1.identifier)
+ async def get_room_version_id(self, room_id):
+ return RoomVersions.V1.identifier
class DictObj(dict):
def __init__(self, **kwargs):
- super(DictObj, self).__init__(kwargs)
+ super().__init__(kwargs)
self.__dict__ = self
-class Graph(object):
+class Graph:
def __init__(self, nodes, edges):
events = {}
clobbered = set(events.keys())
@@ -213,7 +211,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertEqual(2, len(prev_state_ids))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -259,7 +257,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -318,7 +316,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_e = context_store["E"]
- prev_state_ids = yield ctx_e.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@@ -393,7 +391,7 @@ class StateTestCase(unittest.TestCase):
ctx_b = context_store["B"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
@@ -425,7 +423,7 @@ class StateTestCase(unittest.TestCase):
self.state.compute_event_context(event, old_state=old_state)
)
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
@@ -450,7 +448,7 @@ class StateTestCase(unittest.TestCase):
self.state.compute_event_context(event, old_state=old_state)
)
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
@@ -476,12 +474,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = yield self.store.store_state_group(
- prev_event_id,
- event.room_id,
- None,
- None,
- {(e.type, e.state_key): e.event_id for e in old_state},
+ group_name = yield defer.ensureDeferred(
+ self.store.store_state_group(
+ prev_event_id,
+ event.room_id,
+ None,
+ None,
+ {(e.type, e.state_key): e.event_id for e in old_state},
+ )
)
self.store.register_event_id_state_group(prev_event_id, group_name)
@@ -508,18 +508,20 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = yield self.store.store_state_group(
- prev_event_id,
- event.room_id,
- None,
- None,
- {(e.type, e.state_key): e.event_id for e in old_state},
+ group_name = yield defer.ensureDeferred(
+ self.store.store_state_group(
+ prev_event_id,
+ event.room_id,
+ None,
+ None,
+ {(e.type, e.state_key): e.event_id for e in old_state},
+ )
)
self.store.register_event_id_state_group(prev_event_id, group_name)
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
@@ -691,21 +693,25 @@ class StateTestCase(unittest.TestCase):
def _get_context(
self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
):
- sg1 = yield self.store.store_state_group(
- prev_event_id_1,
- event.room_id,
- None,
- None,
- {(e.type, e.state_key): e.event_id for e in old_state_1},
+ sg1 = yield defer.ensureDeferred(
+ self.store.store_state_group(
+ prev_event_id_1,
+ event.room_id,
+ None,
+ None,
+ {(e.type, e.state_key): e.event_id for e in old_state_1},
+ )
)
self.store.register_event_id_state_group(prev_event_id_1, sg1)
- sg2 = yield self.store.store_state_group(
- prev_event_id_2,
- event.room_id,
- None,
- None,
- {(e.type, e.state_key): e.event_id for e in old_state_2},
+ sg2 = yield defer.ensureDeferred(
+ self.store.store_state_group(
+ prev_event_id_2,
+ event.room_id,
+ None,
+ None,
+ {(e.type, e.state_key): e.event_id for e in old_state_2},
+ )
)
self.store.register_event_id_state_group(prev_event_id_2, sg2)
|