diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 6687c28e8f..b5e42f9600 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -101,8 +101,7 @@ class TestEventContext(unittest.HomeserverTestCase):
self.assertEqual(
context.state_group_before_event, d_context.state_group_before_event
)
- self.assertEqual(context.prev_group, d_context.prev_group)
- self.assertEqual(context.delta_ids, d_context.delta_ids)
+ self.assertEqual(context.state_group_deltas, d_context.state_group_deltas)
self.assertEqual(context.app_service, d_context.app_service)
self.assertEqual(
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index e39b63edac..48ebfadaab 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -401,7 +401,10 @@ class EventChainStoreTestCase(HomeserverTestCase):
assert persist_events_store is not None
persist_events_store._store_event_txn(
txn,
- [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
+ [
+ (e, EventContext(self.hs.get_storage_controllers(), {}))
+ for e in events
+ ],
)
# Actually call the function that calculates the auth chain stuff.
diff --git a/tests/test_state.py b/tests/test_state.py
index 7a49b87953..eded38c766 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -555,10 +555,15 @@ class StateTestCase(unittest.TestCase):
(e.event_id for e in old_state + [event]), current_state_ids.values()
)
- self.assertIsNotNone(context.state_group_before_event)
+ assert context.state_group_before_event is not None
+ assert context.state_group is not None
+ self.assertEqual(
+ context.state_group_deltas.get(
+ (context.state_group_before_event, context.state_group)
+ ),
+ {(event.type, event.state_key): event.event_id},
+ )
self.assertNotEqual(context.state_group_before_event, context.state_group)
- self.assertEqual(context.state_group_before_event, context.prev_group)
- self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
@defer.inlineCallbacks
def test_trivial_annotate_message(
|