summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/events/test_snapshot.py3
-rw-r--r--tests/storage/test_event_chain.py5
-rw-r--r--tests/test_state.py11
3 files changed, 13 insertions, 6 deletions
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 6687c28e8f..b5e42f9600 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -101,8 +101,7 @@ class TestEventContext(unittest.HomeserverTestCase):
         self.assertEqual(
             context.state_group_before_event, d_context.state_group_before_event
         )
-        self.assertEqual(context.prev_group, d_context.prev_group)
-        self.assertEqual(context.delta_ids, d_context.delta_ids)
+        self.assertEqual(context.state_group_deltas, d_context.state_group_deltas)
         self.assertEqual(context.app_service, d_context.app_service)
 
         self.assertEqual(
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index e39b63edac..48ebfadaab 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -401,7 +401,10 @@ class EventChainStoreTestCase(HomeserverTestCase):
             assert persist_events_store is not None
             persist_events_store._store_event_txn(
                 txn,
-                [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
+                [
+                    (e, EventContext(self.hs.get_storage_controllers(), {}))
+                    for e in events
+                ],
             )
 
             # Actually call the function that calculates the auth chain stuff.
diff --git a/tests/test_state.py b/tests/test_state.py
index 7a49b87953..eded38c766 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -555,10 +555,15 @@ class StateTestCase(unittest.TestCase):
             (e.event_id for e in old_state + [event]), current_state_ids.values()
         )
 
-        self.assertIsNotNone(context.state_group_before_event)
+        assert context.state_group_before_event is not None
+        assert context.state_group is not None
+        self.assertEqual(
+            context.state_group_deltas.get(
+                (context.state_group_before_event, context.state_group)
+            ),
+            {(event.type, event.state_key): event.event_id},
+        )
         self.assertNotEqual(context.state_group_before_event, context.state_group)
-        self.assertEqual(context.state_group_before_event, context.prev_group)
-        self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
 
     @defer.inlineCallbacks
     def test_trivial_annotate_message(