summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test_event_federation.py76
1 files changed, 73 insertions, 3 deletions
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 06000f81a6..d597d712d6 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -118,8 +118,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
         self.assertTrue(r == [room2] or r == [room3])
 
-    @parameterized.expand([(True,), (False,)])
-    def test_auth_difference(self, use_chain_cover_index: bool):
+    def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
         room_id = "@ROOM:local"
 
         # The silly auth graph we use to test the auth difference algorithm,
@@ -165,7 +164,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             "j": 1,
         }
 
-        # Mark the room as not having a cover index
+        # Mark the room as maybe having a cover index.
 
         def store_room(txn):
             self.store.db_pool.simple_insert_txn(
@@ -222,6 +221,77 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             )
         )
 
+        return room_id
+
+    @parameterized.expand([(True,), (False,)])
+    def test_auth_chain_ids(self, use_chain_cover_index: bool):
+        room_id = self._setup_auth_chain(use_chain_cover_index)
+
+        # a and b have the same auth chain.
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"]))
+        self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"]))
+        self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+        auth_chain_ids = self.get_success(
+            self.store.get_auth_chain_ids(room_id, ["a", "b"])
+        )
+        self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"]))
+        self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
+
+        # d and e have the same auth chain.
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"]))
+        self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"]))
+        self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
+
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"]))
+        self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
+
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"]))
+        self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])
+
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"]))
+        self.assertEqual(auth_chain_ids, ["k"])
+
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"]))
+        self.assertEqual(auth_chain_ids, ["j"])
+
+        # j and k have no parents.
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"]))
+        self.assertEqual(auth_chain_ids, [])
+        auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"]))
+        self.assertEqual(auth_chain_ids, [])
+
+        # More complex input sequences.
+        auth_chain_ids = self.get_success(
+            self.store.get_auth_chain_ids(room_id, ["b", "c", "d"])
+        )
+        self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+        auth_chain_ids = self.get_success(
+            self.store.get_auth_chain_ids(room_id, ["h", "i"])
+        )
+        self.assertCountEqual(auth_chain_ids, ["k", "j"])
+
+        # e gets returned even though include_given is false, but it is in the
+        # auth chain of b.
+        auth_chain_ids = self.get_success(
+            self.store.get_auth_chain_ids(room_id, ["b", "e"])
+        )
+        self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+        # Test include_given.
+        auth_chain_ids = self.get_success(
+            self.store.get_auth_chain_ids(room_id, ["i"], include_given=True)
+        )
+        self.assertCountEqual(auth_chain_ids, ["i", "j"])
+
+    @parameterized.expand([(True,), (False,)])
+    def test_auth_difference(self, use_chain_cover_index: bool):
+        room_id = self._setup_auth_chain(use_chain_cover_index)
+
         # Now actually test that various combinations give the right result:
 
         difference = self.get_success(