diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 1a88c7fb80..0b5204654c 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -366,7 +366,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
state_handler = self.hs.get_state_handler()
context = self.get_success(state_handler.compute_event_context(event))
- self.master_store.add_push_actions_to_staging(
- event.event_id, {user_id: actions for user_id, actions in push_actions}
+ self.get_success(
+ self.master_store.add_push_actions_to_staging(
+ event.event_id, {user_id: actions for user_id, actions in push_actions}
+ )
)
return event, context
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index b45bc9c115..43dbeb42c5 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -72,8 +72,10 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.internal_metadata.stream_ordering = stream
event.depth = stream
- yield self.store.add_push_actions_to_staging(
- event.event_id, {user_id: action}
+ yield defer.ensureDeferred(
+ self.store.add_push_actions_to_staging(
+ event.event_id, {user_id: action}
+ )
)
yield self.store.db.runInteraction(
"",
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index db3667dc43..0f0e1cd09b 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -237,7 +237,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def build(self, prev_event_ids):
- built_event = yield self._base_builder.build(prev_event_ids)
+ built_event = yield defer.ensureDeferred(
+ self._base_builder.build(prev_event_ids)
+ )
built_event._event_id = self._event_id
built_event._dict["event_id"] = self._event_id
diff --git a/tests/test_state.py b/tests/test_state.py
index 4858e8fc59..b5c3667d2a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -213,7 +213,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertEqual(2, len(prev_state_ids))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -259,7 +259,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -318,7 +318,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_e = context_store["E"]
- prev_state_ids = yield ctx_e.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@@ -393,7 +393,7 @@ class StateTestCase(unittest.TestCase):
ctx_b = context_store["B"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
@@ -425,7 +425,7 @@ class StateTestCase(unittest.TestCase):
self.state.compute_event_context(event, old_state=old_state)
)
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
@@ -450,7 +450,7 @@ class StateTestCase(unittest.TestCase):
self.state.compute_event_context(event, old_state=old_state)
)
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
@@ -519,7 +519,7 @@ class StateTestCase(unittest.TestCase):
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index b371efc0df..a7a36174ea 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -64,8 +64,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
evt = yield self.inject_room_member(user, extra_content={"a": "b"})
events_to_filter.append(evt)
- filtered = yield filter_events_for_server(
- self.storage, "test_server", events_to_filter
+ filtered = yield defer.ensureDeferred(
+ filter_events_for_server(self.storage, "test_server", events_to_filter)
)
# the result should be 5 redacted events, and 5 unredacted events.
@@ -102,8 +102,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
# ... and the filtering happens.
- filtered = yield filter_events_for_server(
- self.storage, "test_server", events_to_filter
+ filtered = yield defer.ensureDeferred(
+ filter_events_for_server(self.storage, "test_server", events_to_filter)
)
for i in range(0, len(events_to_filter)):
@@ -265,8 +265,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
storage.main = test_store
storage.state = test_store
- filtered = yield filter_events_for_server(
- test_store, "test_server", events_to_filter
+ filtered = yield defer.ensureDeferred(
+ filter_events_for_server(test_store, "test_server", events_to_filter)
)
logger.info("Filtering took %f seconds", time.time() - start)
|