diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 81feb3ec29..c4e216c308 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -447,7 +447,14 @@ class EventChainStoreTestCase(HomeserverTestCase):
)
# Actually call the function that calculates the auth chain stuff.
- persist_events_store._persist_event_auth_chain_txn(txn, events)
+ new_event_links = (
+ persist_events_store.calculate_chain_cover_index_for_events_txn(
+ txn, events[0].room_id, [e for e in events if e.is_state()]
+ )
+ )
+ persist_events_store._persist_event_auth_chain_txn(
+ txn, events, new_event_links
+ )
self.get_success(
persist_events_store.db_pool.runInteraction(
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 0a6253e22c..1832a23714 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -365,12 +365,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
},
)
+ events = [
+ cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
+ for event_id in AUTH_GRAPH
+ ]
+ new_event_links = (
+ self.persist_events.calculate_chain_cover_index_for_events_txn(
+ txn, room_id, [e for e in events if e.is_state()]
+ )
+ )
self.persist_events._persist_event_auth_chain_txn(
txn,
- [
- cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
- for event_id in AUTH_GRAPH
- ],
+ events,
+ new_event_links,
)
self.get_success(
@@ -628,13 +635,20 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
# Insert all events apart from 'B'
+ events = [
+ cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
+ for event_id in auth_graph
+ if event_id != "b"
+ ]
+ new_event_links = (
+ self.persist_events.calculate_chain_cover_index_for_events_txn(
+ txn, room_id, [e for e in events if e.is_state()]
+ )
+ )
self.persist_events._persist_event_auth_chain_txn(
txn,
- [
- cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
- for event_id in auth_graph
- if event_id != "b"
- ],
+ events,
+ new_event_links,
)
# Now we insert the event 'B' without a chain cover, by temporarily
@@ -647,9 +661,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
updatevalues={"has_auth_chain_index": False},
)
+ events = [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))]
+ new_event_links = (
+ self.persist_events.calculate_chain_cover_index_for_events_txn(
+ txn, room_id, [e for e in events if e.is_state()]
+ )
+ )
self.persist_events._persist_event_auth_chain_txn(
- txn,
- [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
+ txn, events, new_event_links
)
self.store.db_pool.simple_update_txn(
|