summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_federation.py6
-rw-r--r--tests/storage/test_events.py43
-rw-r--r--tests/test_state.py14
3 files changed, 45 insertions, 18 deletions
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index bef6c2b776..ec00900621 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -276,7 +276,11 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
             # federation handler wanting to backfill the fake event.
             self.get_success(
                 federation_event_handler._process_received_pdu(
-                    self.OTHER_SERVER_NAME, event, state=current_state
+                    self.OTHER_SERVER_NAME,
+                    event,
+                    state_ids={
+                        (e.type, e.state_key): e.event_id for e in current_state
+                    },
                 )
             )
 
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index ef5e25873c..aaa3189b16 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -69,7 +69,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
     def persist_event(self, event, state=None):
         """Persist the event, with optional state"""
         context = self.get_success(
-            self.state.compute_event_context(event, old_state=state)
+            self.state.compute_event_context(event, state_ids_before_event=state)
         )
         self.get_success(self.persistence.persist_event(event, context))
 
@@ -103,9 +103,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+        state_before_gap = self.get_success(
+            self.state.get_current_state_ids(self.room_id)
+        )
 
-        self.persist_event(remote_event_2, state=state_before_gap.values())
+        self.persist_event(remote_event_2, state=state_before_gap)
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([remote_event_2.event_id])
@@ -135,13 +137,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
         # setting. The state resolution across the old and new event will then
         # include it, and so the resolved state won't match the new state.
         state_before_gap = dict(
-            self.get_success(self.state.get_current_state(self.room_id))
+            self.get_success(self.state.get_current_state_ids(self.room_id))
         )
         state_before_gap.pop(("m.room.history_visibility", ""))
 
         context = self.get_success(
             self.state.compute_event_context(
-                remote_event_2, old_state=state_before_gap.values()
+                remote_event_2,
+                state_ids_before_event=state_before_gap,
             )
         )
 
@@ -177,9 +180,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+        state_before_gap = self.get_success(
+            self.state.get_current_state_ids(self.room_id)
+        )
 
-        self.persist_event(remote_event_2, state=state_before_gap.values())
+        self.persist_event(remote_event_2, state=state_before_gap)
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([remote_event_2.event_id])
@@ -207,9 +212,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+        state_before_gap = self.get_success(
+            self.state.get_current_state_ids(self.room_id)
+        )
 
-        self.persist_event(remote_event_2, state=state_before_gap.values())
+        self.persist_event(remote_event_2, state=state_before_gap)
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@@ -247,9 +254,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+        state_before_gap = self.get_success(
+            self.state.get_current_state_ids(self.room_id)
+        )
 
-        self.persist_event(remote_event_2, state=state_before_gap.values())
+        self.persist_event(remote_event_2, state=state_before_gap)
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([remote_event_2.event_id])
@@ -289,9 +298,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+        state_before_gap = self.get_success(
+            self.state.get_current_state_ids(self.room_id)
+        )
 
-        self.persist_event(remote_event_2, state=state_before_gap.values())
+        self.persist_event(remote_event_2, state=state_before_gap)
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([remote_event_2.event_id, local_message_event_id])
@@ -323,9 +334,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+        state_before_gap = self.get_success(
+            self.state.get_current_state_ids(self.room_id)
+        )
 
-        self.persist_event(remote_event_2, state=state_before_gap.values())
+        self.persist_event(remote_event_2, state=state_before_gap)
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([local_message_event_id, remote_event_2.event_id])
diff --git a/tests/test_state.py b/tests/test_state.py
index c6baea3d76..84694d368d 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -442,7 +442,12 @@ class StateTestCase(unittest.TestCase):
         ]
 
         context = yield defer.ensureDeferred(
-            self.state.compute_event_context(event, old_state=old_state)
+            self.state.compute_event_context(
+                event,
+                state_ids_before_event={
+                    (e.type, e.state_key): e.event_id for e in old_state
+                },
+            )
         )
 
         prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
@@ -467,7 +472,12 @@ class StateTestCase(unittest.TestCase):
         ]
 
         context = yield defer.ensureDeferred(
-            self.state.compute_event_context(event, old_state=old_state)
+            self.state.compute_event_context(
+                event,
+                state_ids_before_event={
+                    (e.type, e.state_key): e.event_id for e in old_state
+                },
+            )
         )
 
         prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())