diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
new file mode 100644
index 0000000000..83c377824b
--- /dev/null
+++ b/tests/storage/test_event_chain.py
@@ -0,0 +1,472 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List, Tuple
+
+from twisted.trial import unittest
+
+from synapse.api.constants import EventTypes
+from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
+from synapse.storage.databases.main.events import _LinkMap
+
+from tests.unittest import HomeserverTestCase
+
+
+class EventChainStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self._next_stream_ordering = 1
+
+ def test_simple(self):
+ """Test that the example in `docs/auth_chain_difference_algorithm.md`
+ works.
+ """
+
+ event_factory = self.hs.get_event_builder_factory()
+ bob = "@creator:test"
+ alice = "@alice:test"
+ room_id = "!room:test"
+
+ # Ensure that we have a rooms entry so that we generate the chain index.
+ self.get_success(
+ self.store.store_room(
+ room_id=room_id,
+ room_creator_user_id="",
+ is_public=True,
+ room_version=RoomVersions.V6,
+ )
+ )
+
+ create = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Create,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "create"},
+ },
+ ).build(prev_event_ids=[], auth_event_ids=[])
+ )
+
+ bob_join = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": bob,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "bob_join"},
+ },
+ ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
+ )
+
+ power = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.PowerLevels,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "power"},
+ },
+ ).build(
+ prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+ )
+ )
+
+ alice_invite = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "alice_invite"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+ )
+ )
+
+ alice_join = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": alice,
+ "room_id": room_id,
+ "content": {"tag": "alice_join"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
+ )
+ )
+
+ power_2 = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.PowerLevels,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "power_2"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+ )
+ )
+
+ bob_join_2 = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": bob,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "bob_join_2"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+ )
+ )
+
+ alice_join2 = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": alice,
+ "room_id": room_id,
+ "content": {"tag": "alice_join2"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[
+ create.event_id,
+ alice_join.event_id,
+ power_2.event_id,
+ ],
+ )
+ )
+
+ events = [
+ create,
+ bob_join,
+ power,
+ alice_invite,
+ alice_join,
+ bob_join_2,
+ power_2,
+ alice_join2,
+ ]
+
+ expected_links = [
+ (bob_join, create),
+ (power, create),
+ (power, bob_join),
+ (alice_invite, create),
+ (alice_invite, power),
+ (alice_invite, bob_join),
+ (bob_join_2, power),
+ (alice_join2, power_2),
+ ]
+
+ self.persist(events)
+ chain_map, link_map = self.fetch_chains(events)
+
+ # Check that the expected links and only the expected links have been
+ # added.
+ self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+
+ for start, end in expected_links:
+ start_id, start_seq = chain_map[start.event_id]
+ end_id, end_seq = chain_map[end.event_id]
+
+ self.assertIn(
+ (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+ )
+
+ # Test that everything can reach the create event, but the create event
+ # can't reach anything.
+ for event in events[1:]:
+ self.assertTrue(
+ link_map.exists_path_from(
+ chain_map[event.event_id], chain_map[create.event_id]
+ ),
+ )
+
+ self.assertFalse(
+ link_map.exists_path_from(
+ chain_map[create.event_id], chain_map[event.event_id],
+ ),
+ )
+
+ def test_out_of_order_events(self):
+ """Test that we handle persisting events that we don't have the full
+ auth chain for yet (which should only happen for out of band memberships).
+ """
+ event_factory = self.hs.get_event_builder_factory()
+ bob = "@creator:test"
+ alice = "@alice:test"
+ room_id = "!room:test"
+
+ # Ensure that we have a rooms entry so that we generate the chain index.
+ self.get_success(
+ self.store.store_room(
+ room_id=room_id,
+ room_creator_user_id="",
+ is_public=True,
+ room_version=RoomVersions.V6,
+ )
+ )
+
+ # First persist the base room.
+ create = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Create,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "create"},
+ },
+ ).build(prev_event_ids=[], auth_event_ids=[])
+ )
+
+ bob_join = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": bob,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "bob_join"},
+ },
+ ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
+ )
+
+ power = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.PowerLevels,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "power"},
+ },
+ ).build(
+ prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+ )
+ )
+
+ self.persist([create, bob_join, power])
+
+ # Now persist an invite and a couple of memberships out of order.
+ alice_invite = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "alice_invite"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+ )
+ )
+
+ alice_join = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": alice,
+ "room_id": room_id,
+ "content": {"tag": "alice_join"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
+ )
+ )
+
+ alice_join2 = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": alice,
+ "room_id": room_id,
+ "content": {"tag": "alice_join2"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, alice_join.event_id, power.event_id],
+ )
+ )
+
+ self.persist([alice_join])
+ self.persist([alice_join2])
+ self.persist([alice_invite])
+
+ # The end result should be sane.
+ events = [create, bob_join, power, alice_invite, alice_join]
+
+ chain_map, link_map = self.fetch_chains(events)
+
+ expected_links = [
+ (bob_join, create),
+ (power, create),
+ (power, bob_join),
+ (alice_invite, create),
+ (alice_invite, power),
+ (alice_invite, bob_join),
+ ]
+
+ # Check that the expected links and only the expected links have been
+ # added.
+ self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+
+ for start, end in expected_links:
+ start_id, start_seq = chain_map[start.event_id]
+ end_id, end_seq = chain_map[end.event_id]
+
+ self.assertIn(
+ (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+ )
+
+ def persist(
+ self, events: List[EventBase],
+ ):
+ """Persist the given events and check that the links generated match
+ those given.
+ """
+
+ persist_events_store = self.hs.get_datastores().persist_events
+
+ for e in events:
+ e.internal_metadata.stream_ordering = self._next_stream_ordering
+ self._next_stream_ordering += 1
+
+ def _persist(txn):
+ # We need to persist the events to the events and state_events
+ # tables.
+ persist_events_store._store_event_txn(txn, [(e, {}) for e in events])
+
+ # Actually call the function that calculates the auth chain stuff.
+ persist_events_store._persist_event_auth_chain_txn(txn, events)
+
+ self.get_success(
+ persist_events_store.db_pool.runInteraction("_persist", _persist,)
+ )
+
+ def fetch_chains(
+ self, events: List[EventBase]
+ ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
+
+ # Fetch the map from event ID -> (chain ID, sequence number)
+ rows = self.get_success(
+ self.store.db_pool.simple_select_many_batch(
+ table="event_auth_chains",
+ column="event_id",
+ iterable=[e.event_id for e in events],
+ retcols=("event_id", "chain_id", "sequence_number"),
+ keyvalues={},
+ )
+ )
+
+ chain_map = {
+ row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
+ }
+
+ # Fetch all the links and pass them to the _LinkMap.
+ rows = self.get_success(
+ self.store.db_pool.simple_select_many_batch(
+ table="event_auth_chain_links",
+ column="origin_chain_id",
+ iterable=[chain_id for chain_id, _ in chain_map.values()],
+ retcols=(
+ "origin_chain_id",
+ "origin_sequence_number",
+ "target_chain_id",
+ "target_sequence_number",
+ ),
+ keyvalues={},
+ )
+ )
+
+ link_map = _LinkMap()
+ for row in rows:
+ added = link_map.add_link(
+ (row["origin_chain_id"], row["origin_sequence_number"]),
+ (row["target_chain_id"], row["target_sequence_number"]),
+ )
+
+ # We shouldn't have persisted any redundant links
+ self.assertTrue(added)
+
+ return chain_map, link_map
+
+
+class LinkMapTestCase(unittest.TestCase):
+ def test_simple(self):
+ """Basic tests for the LinkMap.
+ """
+ link_map = _LinkMap()
+
+ link_map.add_link((1, 1), (2, 1), new=False)
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+ self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
+ self.assertCountEqual(link_map.get_additions(), [])
+ self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
+ self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
+ self.assertTrue(link_map.exists_path_from((1, 5), (1, 1)))
+ self.assertFalse(link_map.exists_path_from((1, 1), (1, 5)))
+
+ # Attempting to add a redundant link is ignored.
+ self.assertFalse(link_map.add_link((1, 4), (2, 1)))
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+
+ # Adding new non-redundant links works
+ self.assertTrue(link_map.add_link((1, 3), (2, 3)))
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+
+ self.assertTrue(link_map.add_link((2, 5), (1, 3)))
+ self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+
+ self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 482506d731..9d04a066d8 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,6 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import attr
+from parameterized import parameterized
+
+from synapse.events import _EventInternalMetadata
+
import tests.unittest
import tests.utils
@@ -113,7 +118,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
self.assertTrue(r == [room2] or r == [room3])
- def test_auth_difference(self):
+ @parameterized.expand([(True,), (False,)])
+ def test_auth_difference(self, use_chain_cover_index: bool):
room_id = "@ROOM:local"
# The silly auth graph we use to test the auth difference algorithm,
@@ -159,46 +165,223 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"j": 1,
}
+ # Mark the room as not having a cover index
+
+ def store_room(txn):
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ "rooms",
+ {
+ "room_id": room_id,
+ "creator": "room_creator_user_id",
+ "is_public": True,
+ "room_version": "6",
+ "has_auth_chain_index": use_chain_cover_index,
+ },
+ )
+
+ self.get_success(self.store.db_pool.runInteraction("store_room", store_room))
+
# We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly.
- def insert_event(txn, event_id, stream_ordering):
+ def insert_event(txn):
+ stream_ordering = 0
+
+ for event_id in auth_graph:
+ stream_ordering += 1
+ depth = depth_map[event_id]
+
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="events",
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ "depth": depth,
+ "topological_ordering": depth,
+ "type": "m.test",
+ "processed": True,
+ "outlier": False,
+ "stream_ordering": stream_ordering,
+ },
+ )
+
+ self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ txn,
+ [
+ FakeEvent(event_id, room_id, auth_graph[event_id])
+ for event_id in auth_graph
+ ],
+ )
+
+ self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+
+ # Now actually test that various combinations give the right result:
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "d", "e"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}])
+ )
+ self.assertSetEqual(difference, set())
+
+ def test_auth_difference_partial_cover(self):
+ """Test that we correctly handle rooms where not all events have a chain
+ cover calculated. This can happen in some obscure edge cases, including
+ during the background update that calculates the chain cover for old
+ rooms.
+ """
+
+ room_id = "@ROOM:local"
+
+ # The silly auth graph we use to test the auth difference algorithm,
+ # where the top are the most recent events.
+ #
+ # A B
+ # \ /
+ # D E
+ # \ |
+ # ` F C
+ # | /|
+ # G ยด |
+ # | \ |
+ # H I
+ # | |
+ # K J
+
+ auth_graph = {
+ "a": ["e"],
+ "b": ["e"],
+ "c": ["g", "i"],
+ "d": ["f"],
+ "e": ["f"],
+ "f": ["g"],
+ "g": ["h", "i"],
+ "h": ["k"],
+ "i": ["j"],
+ "k": [],
+ "j": [],
+ }
+
+ depth_map = {
+ "a": 7,
+ "b": 7,
+ "c": 4,
+ "d": 6,
+ "e": 6,
+ "f": 5,
+ "g": 3,
+ "h": 2,
+ "i": 2,
+ "k": 1,
+ "j": 1,
+ }
- depth = depth_map[event_id]
+ # We rudely fiddle with the appropriate tables directly, as that's much
+ # easier than constructing events properly.
+ def insert_event(txn):
+ # First insert the room and mark it as having a chain cover.
self.store.db_pool.simple_insert_txn(
txn,
- table="events",
- values={
- "event_id": event_id,
+ "rooms",
+ {
"room_id": room_id,
- "depth": depth,
- "topological_ordering": depth,
- "type": "m.test",
- "processed": True,
- "outlier": False,
- "stream_ordering": stream_ordering,
+ "creator": "room_creator_user_id",
+ "is_public": True,
+ "room_version": "6",
+ "has_auth_chain_index": True,
},
)
- self.store.db_pool.simple_insert_many_txn(
+ stream_ordering = 0
+
+ for event_id in auth_graph:
+ stream_ordering += 1
+ depth = depth_map[event_id]
+
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="events",
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ "depth": depth,
+ "topological_ordering": depth,
+ "type": "m.test",
+ "processed": True,
+ "outlier": False,
+ "stream_ordering": stream_ordering,
+ },
+ )
+
+ # Insert all events apart from 'B'
+ self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn,
- table="event_auth",
- values=[
- {"event_id": event_id, "room_id": room_id, "auth_id": a}
- for a in auth_graph[event_id]
+ [
+ FakeEvent(event_id, room_id, auth_graph[event_id])
+ for event_id in auth_graph
+ if event_id != "b"
],
)
- next_stream_ordering = 0
- for event_id in auth_graph:
- next_stream_ordering += 1
- self.get_success(
- self.store.db_pool.runInteraction(
- "insert", insert_event, event_id, next_stream_ordering
- )
+ # Now we insert the event 'B' without a chain cover, by temporarily
+ # pretending the room doesn't have a chain cover.
+
+ self.store.db_pool.simple_update_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"has_auth_chain_index": False},
+ )
+
+ self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ txn, [FakeEvent("b", room_id, auth_graph["b"])],
+ )
+
+ self.store.db_pool.simple_update_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"has_auth_chain_index": True},
)
+ self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+
# Now actually test that various combinations give the right result:
difference = self.get_success(
@@ -240,3 +423,21 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.store.get_auth_chain_difference(room_id, [{"a"}])
)
self.assertSetEqual(difference, set())
+
+
+@attr.s
+class FakeEvent:
+ event_id = attr.ib()
+ room_id = attr.ib()
+ auth_events = attr.ib()
+
+ type = "foo"
+ state_key = "foo"
+
+ internal_metadata = _EventInternalMetadata({})
+
+ def auth_event_ids(self):
+ return self.auth_events
+
+ def is_state(self):
+ return True
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 0ab0a91483..1184cea5a3 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.iterutils import chunk_seq
+from typing import Dict, List
+
+from synapse.util.iterutils import chunk_seq, sorted_topologically
from tests.unittest import TestCase
@@ -45,3 +47,40 @@ class ChunkSeqTests(TestCase):
self.assertEqual(
list(parts), [],
)
+
+
+class SortTopologically(TestCase):
+ def test_empty(self):
+ "Test that an empty graph works correctly"
+
+ graph = {} # type: Dict[int, List[int]]
+ self.assertEqual(list(sorted_topologically([], graph)), [])
+
+ def test_disconnected(self):
+ "Test that a graph with no edges work"
+
+ graph = {1: [], 2: []} # type: Dict[int, List[int]]
+
+ # For disconnected nodes the output is simply sorted.
+ self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
+
+ def test_linear(self):
+ "Test that a simple `4 -> 3 -> 2 -> 1` graph works"
+
+ graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]]
+
+ self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
+
+ def test_subset(self):
+ "Test that only sorting a subset of the graph works"
+ graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]]
+
+ self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
+
+ def test_fork(self):
+ "Test that a forked graph works"
+ graph = {1: [], 2: [1], 3: [1], 4: [2, 3]} # type: Dict[int, List[int]]
+
+ # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should
+ # always get the same one.
+ self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
|