diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 06000f81a6..d597d712d6 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -118,8 +118,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
self.assertTrue(r == [room2] or r == [room3])
- @parameterized.expand([(True,), (False,)])
- def test_auth_difference(self, use_chain_cover_index: bool):
+ def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
room_id = "@ROOM:local"
# The silly auth graph we use to test the auth difference algorithm,
@@ -165,7 +164,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"j": 1,
}
- # Mark the room as not having a cover index
+ # Mark the room as maybe having a cover index.
def store_room(txn):
self.store.db_pool.simple_insert_txn(
@@ -222,6 +221,77 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
)
+ return room_id
+
+ @parameterized.expand([(True,), (False,)])
+ def test_auth_chain_ids(self, use_chain_cover_index: bool):
+ room_id = self._setup_auth_chain(use_chain_cover_index)
+
+ # a and b have the same auth chain.
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"]))
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"]))
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["a", "b"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"]))
+ self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
+
+ # d and e have the same auth chain.
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"]))
+ self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"]))
+ self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"]))
+ self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"]))
+ self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"]))
+ self.assertEqual(auth_chain_ids, ["k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"]))
+ self.assertEqual(auth_chain_ids, ["j"])
+
+ # j and k have no parents.
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"]))
+ self.assertEqual(auth_chain_ids, [])
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"]))
+ self.assertEqual(auth_chain_ids, [])
+
+ # More complex input sequences.
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["b", "c", "d"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["h", "i"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["k", "j"])
+
+ # e gets returned even though include_given is false, but it is in the
+ # auth chain of b.
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["b", "e"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+ # Test include_given.
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["i"], include_given=True)
+ )
+ self.assertCountEqual(auth_chain_ids, ["i", "j"])
+
+ @parameterized.expand([(True,), (False,)])
+ def test_auth_difference(self, use_chain_cover_index: bool):
+ room_id = self._setup_auth_chain(use_chain_cover_index)
+
# Now actually test that various combinations give the right result:
difference = self.get_success(
|