diff options
Diffstat (limited to 'tests/storage/test_event_chain.py')
-rw-r--r-- | tests/storage/test_event_chain.py | 104 |
1 files changed, 73 insertions, 31 deletions
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 9e4e73832e..27d5b0125f 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -21,6 +21,8 @@ from typing import Dict, List, Set, Tuple, cast +from parameterized import parameterized + from twisted.test.proto_helpers import MemoryReactor from twisted.trial import unittest @@ -45,7 +47,8 @@ class EventChainStoreTestCase(HomeserverTestCase): self.store = hs.get_datastores().main self._next_stream_ordering = 1 - def test_simple(self) -> None: + @parameterized.expand([(False,), (True,)]) + def test_simple(self, batched: bool) -> None: """Test that the example in `docs/auth_chain_difference_algorithm.md` works. """ @@ -53,6 +56,7 @@ class EventChainStoreTestCase(HomeserverTestCase): event_factory = self.hs.get_event_builder_factory() bob = "@creator:test" alice = "@alice:test" + charlie = "@charlie:test" room_id = "!room:test" # Ensure that we have a rooms entry so that we generate the chain index. @@ -191,6 +195,26 @@ class EventChainStoreTestCase(HomeserverTestCase): ) ) + charlie_invite = self.get_success( + event_factory.for_room_version( + RoomVersions.V6, + { + "type": EventTypes.Member, + "state_key": charlie, + "sender": alice, + "room_id": room_id, + "content": {"tag": "charlie_invite"}, + }, + ).build( + prev_event_ids=[], + auth_event_ids=[ + create.event_id, + alice_join2.event_id, + power_2.event_id, + ], + ) + ) + events = [ create, bob_join, @@ -200,33 +224,41 @@ class EventChainStoreTestCase(HomeserverTestCase): bob_join_2, power_2, alice_join2, + charlie_invite, ] expected_links = [ (bob_join, create), - (power, create), (power, bob_join), - (alice_invite, create), (alice_invite, power), - (alice_invite, bob_join), (bob_join_2, power), (alice_join2, power_2), + (charlie_invite, alice_join2), ] - self.persist(events) + # We either persist as a batch or one-by-one depending on test + # parameter. + if batched: + self.persist(events) + else: + for event in events: + self.persist([event]) + chain_map, link_map = self.fetch_chains(events) # Check that the expected links and only the expected links have been # added. - self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) - - for start, end in expected_links: - start_id, start_seq = chain_map[start.event_id] - end_id, end_seq = chain_map[end.event_id] + event_map = {e.event_id: e for e in events} + reverse_chain_map = {v: event_map[k] for k, v in chain_map.items()} - self.assertIn( - (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) - ) + self.maxDiff = None + self.assertCountEqual( + expected_links, + [ + (reverse_chain_map[(s1, s2)], reverse_chain_map[(t1, t2)]) + for s1, s2, t1, t2 in link_map.get_additions() + ], + ) # Test that everything can reach the create event, but the create event # can't reach anything. @@ -368,24 +400,23 @@ class EventChainStoreTestCase(HomeserverTestCase): expected_links = [ (bob_join, create), - (power, create), (power, bob_join), - (alice_invite, create), (alice_invite, power), - (alice_invite, bob_join), ] # Check that the expected links and only the expected links have been # added. - self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) + event_map = {e.event_id: e for e in events} + reverse_chain_map = {v: event_map[k] for k, v in chain_map.items()} - for start, end in expected_links: - start_id, start_seq = chain_map[start.event_id] - end_id, end_seq = chain_map[end.event_id] - - self.assertIn( - (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) - ) + self.maxDiff = None + self.assertCountEqual( + expected_links, + [ + (reverse_chain_map[(s1, s2)], reverse_chain_map[(t1, t2)]) + for s1, s2, t1, t2 in link_map.get_additions() + ], + ) def persist( self, @@ -489,8 +520,6 @@ class LinkMapTestCase(unittest.TestCase): link_map = _LinkMap() link_map.add_link((1, 1), (2, 1), new=False) - self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) - self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)]) self.assertCountEqual(link_map.get_additions(), []) self.assertTrue(link_map.exists_path_from((1, 5), (2, 1))) self.assertFalse(link_map.exists_path_from((1, 5), (2, 2))) @@ -499,18 +528,31 @@ class LinkMapTestCase(unittest.TestCase): # Attempting to add a redundant link is ignored. self.assertFalse(link_map.add_link((1, 4), (2, 1))) - self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) + self.assertCountEqual(link_map.get_additions(), []) # Adding new non-redundant links works self.assertTrue(link_map.add_link((1, 3), (2, 3))) - self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) + self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3)]) self.assertTrue(link_map.add_link((2, 5), (1, 3))) - self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)]) - self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) - self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)]) + def test_exists_path_from(self) -> None: + "Check that `exists_path_from` can handle non-direct links" + link_map = _LinkMap() + + link_map.add_link((1, 1), (2, 1), new=False) + link_map.add_link((2, 1), (3, 1), new=False) + + self.assertTrue(link_map.exists_path_from((1, 4), (3, 1))) + self.assertFalse(link_map.exists_path_from((1, 4), (3, 2))) + + link_map.add_link((1, 5), (2, 3), new=False) + link_map.add_link((2, 2), (3, 3), new=False) + + self.assertTrue(link_map.exists_path_from((1, 6), (3, 2))) + self.assertFalse(link_map.exists_path_from((1, 4), (3, 2))) + class EventChainBackgroundUpdateTestCase(HomeserverTestCase): servlets = [ |