summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test_event_chain.py9
-rw-r--r--tests/storage/test_event_federation.py44
2 files changed, 41 insertions, 12 deletions
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 81feb3ec29..c4e216c308 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -447,7 +447,14 @@ class EventChainStoreTestCase(HomeserverTestCase):
             )
 
             # Actually call the function that calculates the auth chain stuff.
-            persist_events_store._persist_event_auth_chain_txn(txn, events)
+            new_event_links = (
+                persist_events_store.calculate_chain_cover_index_for_events_txn(
+                    txn, events[0].room_id, [e for e in events if e.is_state()]
+                )
+            )
+            persist_events_store._persist_event_auth_chain_txn(
+                txn, events, new_event_links
+            )
 
         self.get_success(
             persist_events_store.db_pool.runInteraction(
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 0a6253e22c..088f0d24f9 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -365,12 +365,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                     },
                 )
 
+            events = [
+                cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
+                for event_id in AUTH_GRAPH
+            ]
+            new_event_links = (
+                self.persist_events.calculate_chain_cover_index_for_events_txn(
+                    txn, room_id, [e for e in events if e.is_state()]
+                )
+            )
             self.persist_events._persist_event_auth_chain_txn(
                 txn,
-                [
-                    cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
-                    for event_id in AUTH_GRAPH
-                ],
+                events,
+                new_event_links,
             )
 
         self.get_success(
@@ -544,6 +551,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         rooms.
         """
 
+        # We allow partial covers for this test
+        self.hs.get_datastores().main.tests_allow_no_chain_cover_index = True
+
         room_id = "@ROOM:local"
 
         # The silly auth graph we use to test the auth difference algorithm,
@@ -628,13 +638,20 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 )
 
             # Insert all events apart from 'B'
+            events = [
+                cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
+                for event_id in auth_graph
+                if event_id != "b"
+            ]
+            new_event_links = (
+                self.persist_events.calculate_chain_cover_index_for_events_txn(
+                    txn, room_id, [e for e in events if e.is_state()]
+                )
+            )
             self.persist_events._persist_event_auth_chain_txn(
                 txn,
-                [
-                    cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
-                    for event_id in auth_graph
-                    if event_id != "b"
-                ],
+                events,
+                new_event_links,
             )
 
             # Now we insert the event 'B' without a chain cover, by temporarily
@@ -647,9 +664,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 updatevalues={"has_auth_chain_index": False},
             )
 
+            events = [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))]
+            new_event_links = (
+                self.persist_events.calculate_chain_cover_index_for_events_txn(
+                    txn, room_id, [e for e in events if e.is_state()]
+                )
+            )
             self.persist_events._persist_event_auth_chain_txn(
-                txn,
-                [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
+                txn, events, new_event_links
             )
 
             self.store.db_pool.simple_update_txn(