summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/replication/slave/storage/test_events.py5
-rw-r--r--tests/storage/test_event_push_actions.py10
2 files changed, 10 insertions, 5 deletions
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index f430cce931..4780f2ab72 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -230,7 +230,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             state_handler = self.hs.get_state_handler()
             context = yield state_handler.compute_event_context(event)
 
-        context.push_actions = push_actions
+        for user_id, actions in push_actions:
+            yield self.master_store.add_push_actions_to_staging(
+                event.event_id, user_id, actions,
+            )
 
         ordering = None
         if backfill:
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 3135488353..d483e7cf9e 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -62,6 +62,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
                 {"notify_count": noitf_count, "highlight_count": highlight_count}
             )
 
+        @defer.inlineCallbacks
         def _inject_actions(stream, action):
             event = Mock()
             event.room_id = room_id
@@ -69,11 +70,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             event.internal_metadata.stream_ordering = stream
             event.depth = stream
 
-            tuples = [(user_id, action)]
-
-            return self.store.runInteraction(
+            yield self.store.add_push_actions_to_staging(
+                event.event_id, user_id, action,
+            )
+            yield self.store.runInteraction(
                 "", self.store._set_push_actions_for_event_and_users_txn,
-                event, tuples
+                event,
             )
 
         def _rotate(stream):