diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 9e4e73832e..27d5b0125f 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -21,6 +21,8 @@
from typing import Dict, List, Set, Tuple, cast
+from parameterized import parameterized
+
from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest
@@ -45,7 +47,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
self.store = hs.get_datastores().main
self._next_stream_ordering = 1
- def test_simple(self) -> None:
+ @parameterized.expand([(False,), (True,)])
+ def test_simple(self, batched: bool) -> None:
"""Test that the example in `docs/auth_chain_difference_algorithm.md`
works.
"""
@@ -53,6 +56,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
event_factory = self.hs.get_event_builder_factory()
bob = "@creator:test"
alice = "@alice:test"
+ charlie = "@charlie:test"
room_id = "!room:test"
# Ensure that we have a rooms entry so that we generate the chain index.
@@ -191,6 +195,26 @@ class EventChainStoreTestCase(HomeserverTestCase):
)
)
+ charlie_invite = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": charlie,
+ "sender": alice,
+ "room_id": room_id,
+ "content": {"tag": "charlie_invite"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[
+ create.event_id,
+ alice_join2.event_id,
+ power_2.event_id,
+ ],
+ )
+ )
+
events = [
create,
bob_join,
@@ -200,33 +224,41 @@ class EventChainStoreTestCase(HomeserverTestCase):
bob_join_2,
power_2,
alice_join2,
+ charlie_invite,
]
expected_links = [
(bob_join, create),
- (power, create),
(power, bob_join),
- (alice_invite, create),
(alice_invite, power),
- (alice_invite, bob_join),
(bob_join_2, power),
(alice_join2, power_2),
+ (charlie_invite, alice_join2),
]
- self.persist(events)
+ # We either persist as a batch or one-by-one depending on test
+ # parameter.
+ if batched:
+ self.persist(events)
+ else:
+ for event in events:
+ self.persist([event])
+
chain_map, link_map = self.fetch_chains(events)
# Check that the expected links and only the expected links have been
# added.
- self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
-
- for start, end in expected_links:
- start_id, start_seq = chain_map[start.event_id]
- end_id, end_seq = chain_map[end.event_id]
+ event_map = {e.event_id: e for e in events}
+ reverse_chain_map = {v: event_map[k] for k, v in chain_map.items()}
- self.assertIn(
- (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
- )
+ self.maxDiff = None
+ self.assertCountEqual(
+ expected_links,
+ [
+ (reverse_chain_map[(s1, s2)], reverse_chain_map[(t1, t2)])
+ for s1, s2, t1, t2 in link_map.get_additions()
+ ],
+ )
# Test that everything can reach the create event, but the create event
# can't reach anything.
@@ -368,24 +400,23 @@ class EventChainStoreTestCase(HomeserverTestCase):
expected_links = [
(bob_join, create),
- (power, create),
(power, bob_join),
- (alice_invite, create),
(alice_invite, power),
- (alice_invite, bob_join),
]
# Check that the expected links and only the expected links have been
# added.
- self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+ event_map = {e.event_id: e for e in events}
+ reverse_chain_map = {v: event_map[k] for k, v in chain_map.items()}
- for start, end in expected_links:
- start_id, start_seq = chain_map[start.event_id]
- end_id, end_seq = chain_map[end.event_id]
-
- self.assertIn(
- (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
- )
+ self.maxDiff = None
+ self.assertCountEqual(
+ expected_links,
+ [
+ (reverse_chain_map[(s1, s2)], reverse_chain_map[(t1, t2)])
+ for s1, s2, t1, t2 in link_map.get_additions()
+ ],
+ )
def persist(
self,
@@ -489,8 +520,6 @@ class LinkMapTestCase(unittest.TestCase):
link_map = _LinkMap()
link_map.add_link((1, 1), (2, 1), new=False)
- self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
- self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
self.assertCountEqual(link_map.get_additions(), [])
self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
@@ -499,18 +528,31 @@ class LinkMapTestCase(unittest.TestCase):
# Attempting to add a redundant link is ignored.
self.assertFalse(link_map.add_link((1, 4), (2, 1)))
- self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+ self.assertCountEqual(link_map.get_additions(), [])
# Adding new non-redundant links works
self.assertTrue(link_map.add_link((1, 3), (2, 3)))
- self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+ self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3)])
self.assertTrue(link_map.add_link((2, 5), (1, 3)))
- self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
- self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
-
self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
+ def test_exists_path_from(self) -> None:
+ "Check that `exists_path_from` can handle non-direct links"
+ link_map = _LinkMap()
+
+ link_map.add_link((1, 1), (2, 1), new=False)
+ link_map.add_link((2, 1), (3, 1), new=False)
+
+ self.assertTrue(link_map.exists_path_from((1, 4), (3, 1)))
+ self.assertFalse(link_map.exists_path_from((1, 4), (3, 2)))
+
+ link_map.add_link((1, 5), (2, 3), new=False)
+ link_map.add_link((2, 2), (3, 3), new=False)
+
+ self.assertTrue(link_map.exists_path_from((1, 6), (3, 2)))
+ self.assertFalse(link_map.exists_path_from((1, 4), (3, 2)))
+
class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
servlets = [
|