diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 105e1228bb..f430cce931 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -226,11 +226,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
context = EventContext()
context.current_state_ids = state_ids
context.prev_state_ids = state_ids
- elif not backfill:
+ else:
state_handler = self.hs.get_state_handler()
context = yield state_handler.compute_event_context(event)
- else:
- context = EventContext()
context.push_actions = push_actions
diff --git a/tests/test_state.py b/tests/test_state.py
index feb84f3d48..253aa62f2a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -99,6 +99,10 @@ class StateGroupStore(object):
for e in events:
self._event_id_to_event[e.event_id] = e
+ def store_state_group(self, *args, **kwargs):
+ self._next_group += 1
+ return self._next_group
+
class DictObj(dict):
def __init__(self, **kwargs):
@@ -144,6 +148,7 @@ class StateTestCase(unittest.TestCase):
"get_events",
"get_next_state_group",
"get_state_group_delta",
+ "store_state_group",
]
)
hs = Mock(spec_set=[
@@ -316,6 +321,7 @@ class StateTestCase(unittest.TestCase):
store = StateGroupStore()
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
self.store.get_events = store.get_events
+ self.store.store_state_group = store.store_state_group
store.register_events(graph.walk())
context_store = {}
@@ -399,6 +405,7 @@ class StateTestCase(unittest.TestCase):
store = StateGroupStore()
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
self.store.get_events = store.get_events
+ self.store.store_state_group = store.store_state_group
store.register_events(graph.walk())
context_store = {}
|