summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-01-11 16:09:22 +0000
committerGitHub <noreply@github.com>2021-01-11 16:09:22 +0000
commit1315a2e8be702a513d49c1142e9e52b642286635 (patch)
tree2c9aca9e27a2fd4ac1dda844015cefb26a021939 /tests
parentClean up exception handling in the startup code (#9059) (diff)
downloadsynapse-1315a2e8be702a513d49c1142e9e52b642286635.tar.xz
Use a chain cover index to efficiently calculate auth chain difference (#8868)
Diffstat (limited to 'tests')
-rw-r--r--tests/storage/test_event_chain.py472
-rw-r--r--tests/storage/test_event_federation.py249
-rw-r--r--tests/util/test_itertools.py41
3 files changed, 737 insertions, 25 deletions
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
new file mode 100644
index 0000000000..83c377824b
--- /dev/null
+++ b/tests/storage/test_event_chain.py
@@ -0,0 +1,472 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List, Tuple
+
+from twisted.trial import unittest
+
+from synapse.api.constants import EventTypes
+from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
+from synapse.storage.databases.main.events import _LinkMap
+
+from tests.unittest import HomeserverTestCase
+
+
+class EventChainStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self._next_stream_ordering = 1
+
+    def test_simple(self):
+        """Test that the example in `docs/auth_chain_difference_algorithm.md`
+        works.
+        """
+
+        event_factory = self.hs.get_event_builder_factory()
+        bob = "@creator:test"
+        alice = "@alice:test"
+        room_id = "!room:test"
+
+        # Ensure that we have a rooms entry so that we generate the chain index.
+        self.get_success(
+            self.store.store_room(
+                room_id=room_id,
+                room_creator_user_id="",
+                is_public=True,
+                room_version=RoomVersions.V6,
+            )
+        )
+
+        create = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Create,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "create"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[])
+        )
+
+        bob_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": bob,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "bob_join"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
+        )
+
+        power = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.PowerLevels,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "power"},
+                },
+            ).build(
+                prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+            )
+        )
+
+        alice_invite = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_invite"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        alice_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
+            )
+        )
+
+        power_2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.PowerLevels,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "power_2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        bob_join_2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": bob,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "bob_join_2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        alice_join2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[
+                    create.event_id,
+                    alice_join.event_id,
+                    power_2.event_id,
+                ],
+            )
+        )
+
+        events = [
+            create,
+            bob_join,
+            power,
+            alice_invite,
+            alice_join,
+            bob_join_2,
+            power_2,
+            alice_join2,
+        ]
+
+        expected_links = [
+            (bob_join, create),
+            (power, create),
+            (power, bob_join),
+            (alice_invite, create),
+            (alice_invite, power),
+            (alice_invite, bob_join),
+            (bob_join_2, power),
+            (alice_join2, power_2),
+        ]
+
+        self.persist(events)
+        chain_map, link_map = self.fetch_chains(events)
+
+        # Check that the expected links and only the expected links have been
+        # added.
+        self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+
+        for start, end in expected_links:
+            start_id, start_seq = chain_map[start.event_id]
+            end_id, end_seq = chain_map[end.event_id]
+
+            self.assertIn(
+                (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+            )
+
+        # Test that everything can reach the create event, but the create event
+        # can't reach anything.
+        for event in events[1:]:
+            self.assertTrue(
+                link_map.exists_path_from(
+                    chain_map[event.event_id], chain_map[create.event_id]
+                ),
+            )
+
+            self.assertFalse(
+                link_map.exists_path_from(
+                    chain_map[create.event_id], chain_map[event.event_id],
+                ),
+            )
+
+    def test_out_of_order_events(self):
+        """Test that we handle persisting events that we don't have the full
+        auth chain for yet (which should only happen for out of band memberships).
+        """
+        event_factory = self.hs.get_event_builder_factory()
+        bob = "@creator:test"
+        alice = "@alice:test"
+        room_id = "!room:test"
+
+        # Ensure that we have a rooms entry so that we generate the chain index.
+        self.get_success(
+            self.store.store_room(
+                room_id=room_id,
+                room_creator_user_id="",
+                is_public=True,
+                room_version=RoomVersions.V6,
+            )
+        )
+
+        # First persist the base room.
+        create = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Create,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "create"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[])
+        )
+
+        bob_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": bob,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "bob_join"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
+        )
+
+        power = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.PowerLevels,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "power"},
+                },
+            ).build(
+                prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+            )
+        )
+
+        self.persist([create, bob_join, power])
+
+        # Now persist an invite and a couple of memberships out of order.
+        alice_invite = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_invite"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        alice_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
+            )
+        )
+
+        alice_join2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, alice_join.event_id, power.event_id],
+            )
+        )
+
+        self.persist([alice_join])
+        self.persist([alice_join2])
+        self.persist([alice_invite])
+
+        # The end result should be sane.
+        events = [create, bob_join, power, alice_invite, alice_join]
+
+        chain_map, link_map = self.fetch_chains(events)
+
+        expected_links = [
+            (bob_join, create),
+            (power, create),
+            (power, bob_join),
+            (alice_invite, create),
+            (alice_invite, power),
+            (alice_invite, bob_join),
+        ]
+
+        # Check that the expected links and only the expected links have been
+        # added.
+        self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+
+        for start, end in expected_links:
+            start_id, start_seq = chain_map[start.event_id]
+            end_id, end_seq = chain_map[end.event_id]
+
+            self.assertIn(
+                (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+            )
+
+    def persist(
+        self, events: List[EventBase],
+    ):
+        """Persist the given events and check that the links generated match
+        those given.
+        """
+
+        persist_events_store = self.hs.get_datastores().persist_events
+
+        for e in events:
+            e.internal_metadata.stream_ordering = self._next_stream_ordering
+            self._next_stream_ordering += 1
+
+        def _persist(txn):
+            # We need to persist the events to the events and state_events
+            # tables.
+            persist_events_store._store_event_txn(txn, [(e, {}) for e in events])
+
+            # Actually call the function that calculates the auth chain stuff.
+            persist_events_store._persist_event_auth_chain_txn(txn, events)
+
+        self.get_success(
+            persist_events_store.db_pool.runInteraction("_persist", _persist,)
+        )
+
+    def fetch_chains(
+        self, events: List[EventBase]
+    ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
+
+        # Fetch the map from event ID -> (chain ID, sequence number)
+        rows = self.get_success(
+            self.store.db_pool.simple_select_many_batch(
+                table="event_auth_chains",
+                column="event_id",
+                iterable=[e.event_id for e in events],
+                retcols=("event_id", "chain_id", "sequence_number"),
+                keyvalues={},
+            )
+        )
+
+        chain_map = {
+            row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
+        }
+
+        # Fetch all the links and pass them to the _LinkMap.
+        rows = self.get_success(
+            self.store.db_pool.simple_select_many_batch(
+                table="event_auth_chain_links",
+                column="origin_chain_id",
+                iterable=[chain_id for chain_id, _ in chain_map.values()],
+                retcols=(
+                    "origin_chain_id",
+                    "origin_sequence_number",
+                    "target_chain_id",
+                    "target_sequence_number",
+                ),
+                keyvalues={},
+            )
+        )
+
+        link_map = _LinkMap()
+        for row in rows:
+            added = link_map.add_link(
+                (row["origin_chain_id"], row["origin_sequence_number"]),
+                (row["target_chain_id"], row["target_sequence_number"]),
+            )
+
+            # We shouldn't have persisted any redundant links
+            self.assertTrue(added)
+
+        return chain_map, link_map
+
+
+class LinkMapTestCase(unittest.TestCase):
+    def test_simple(self):
+        """Basic tests for the LinkMap.
+        """
+        link_map = _LinkMap()
+
+        link_map.add_link((1, 1), (2, 1), new=False)
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+        self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
+        self.assertCountEqual(link_map.get_additions(), [])
+        self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
+        self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
+        self.assertTrue(link_map.exists_path_from((1, 5), (1, 1)))
+        self.assertFalse(link_map.exists_path_from((1, 1), (1, 5)))
+
+        # Attempting to add a redundant link is ignored.
+        self.assertFalse(link_map.add_link((1, 4), (2, 1)))
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+
+        # Adding new non-redundant links works
+        self.assertTrue(link_map.add_link((1, 3), (2, 3)))
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+
+        self.assertTrue(link_map.add_link((2, 5), (1, 3)))
+        self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+
+        self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 482506d731..9d04a066d8 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,6 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import attr
+from parameterized import parameterized
+
+from synapse.events import _EventInternalMetadata
+
 import tests.unittest
 import tests.utils
 
@@ -113,7 +118,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
         self.assertTrue(r == [room2] or r == [room3])
 
-    def test_auth_difference(self):
+    @parameterized.expand([(True,), (False,)])
+    def test_auth_difference(self, use_chain_cover_index: bool):
         room_id = "@ROOM:local"
 
         # The silly auth graph we use to test the auth difference algorithm,
@@ -159,46 +165,223 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             "j": 1,
         }
 
+        # Mark the room as not having a cover index
+
+        def store_room(txn):
+            self.store.db_pool.simple_insert_txn(
+                txn,
+                "rooms",
+                {
+                    "room_id": room_id,
+                    "creator": "room_creator_user_id",
+                    "is_public": True,
+                    "room_version": "6",
+                    "has_auth_chain_index": use_chain_cover_index,
+                },
+            )
+
+        self.get_success(self.store.db_pool.runInteraction("store_room", store_room))
+
         # We rudely fiddle with the appropriate tables directly, as that's much
         # easier than constructing events properly.
 
-        def insert_event(txn, event_id, stream_ordering):
+        def insert_event(txn):
+            stream_ordering = 0
+
+            for event_id in auth_graph:
+                stream_ordering += 1
+                depth = depth_map[event_id]
+
+                self.store.db_pool.simple_insert_txn(
+                    txn,
+                    table="events",
+                    values={
+                        "event_id": event_id,
+                        "room_id": room_id,
+                        "depth": depth,
+                        "topological_ordering": depth,
+                        "type": "m.test",
+                        "processed": True,
+                        "outlier": False,
+                        "stream_ordering": stream_ordering,
+                    },
+                )
+
+            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+                txn,
+                [
+                    FakeEvent(event_id, room_id, auth_graph[event_id])
+                    for event_id in auth_graph
+                ],
+            )
+
+        self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+
+        # Now actually test that various combinations give the right result:
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "d", "e"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}])
+        )
+        self.assertSetEqual(difference, set())
+
+    def test_auth_difference_partial_cover(self):
+        """Test that we correctly handle rooms where not all events have a chain
+        cover calculated. This can happen in some obscure edge cases, including
+        during the background update that calculates the chain cover for old
+        rooms.
+        """
+
+        room_id = "@ROOM:local"
+
+        # The silly auth graph we use to test the auth difference algorithm,
+        # where the top are the most recent events.
+        #
+        #   A   B
+        #    \ /
+        #  D  E
+        #  \  |
+        #   ` F   C
+        #     |  /|
+        #     G ยด |
+        #     | \ |
+        #     H   I
+        #     |   |
+        #     K   J
+
+        auth_graph = {
+            "a": ["e"],
+            "b": ["e"],
+            "c": ["g", "i"],
+            "d": ["f"],
+            "e": ["f"],
+            "f": ["g"],
+            "g": ["h", "i"],
+            "h": ["k"],
+            "i": ["j"],
+            "k": [],
+            "j": [],
+        }
+
+        depth_map = {
+            "a": 7,
+            "b": 7,
+            "c": 4,
+            "d": 6,
+            "e": 6,
+            "f": 5,
+            "g": 3,
+            "h": 2,
+            "i": 2,
+            "k": 1,
+            "j": 1,
+        }
 
-            depth = depth_map[event_id]
+        # We rudely fiddle with the appropriate tables directly, as that's much
+        # easier than constructing events properly.
 
+        def insert_event(txn):
+            # First insert the room and mark it as having a chain cover.
             self.store.db_pool.simple_insert_txn(
                 txn,
-                table="events",
-                values={
-                    "event_id": event_id,
+                "rooms",
+                {
                     "room_id": room_id,
-                    "depth": depth,
-                    "topological_ordering": depth,
-                    "type": "m.test",
-                    "processed": True,
-                    "outlier": False,
-                    "stream_ordering": stream_ordering,
+                    "creator": "room_creator_user_id",
+                    "is_public": True,
+                    "room_version": "6",
+                    "has_auth_chain_index": True,
                 },
             )
 
-            self.store.db_pool.simple_insert_many_txn(
+            stream_ordering = 0
+
+            for event_id in auth_graph:
+                stream_ordering += 1
+                depth = depth_map[event_id]
+
+                self.store.db_pool.simple_insert_txn(
+                    txn,
+                    table="events",
+                    values={
+                        "event_id": event_id,
+                        "room_id": room_id,
+                        "depth": depth,
+                        "topological_ordering": depth,
+                        "type": "m.test",
+                        "processed": True,
+                        "outlier": False,
+                        "stream_ordering": stream_ordering,
+                    },
+                )
+
+            # Insert all events apart from 'B'
+            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
                 txn,
-                table="event_auth",
-                values=[
-                    {"event_id": event_id, "room_id": room_id, "auth_id": a}
-                    for a in auth_graph[event_id]
+                [
+                    FakeEvent(event_id, room_id, auth_graph[event_id])
+                    for event_id in auth_graph
+                    if event_id != "b"
                 ],
             )
 
-        next_stream_ordering = 0
-        for event_id in auth_graph:
-            next_stream_ordering += 1
-            self.get_success(
-                self.store.db_pool.runInteraction(
-                    "insert", insert_event, event_id, next_stream_ordering
-                )
+            # Now we insert the event 'B' without a chain cover, by temporarily
+            # pretending the room doesn't have a chain cover.
+
+            self.store.db_pool.simple_update_txn(
+                txn,
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"has_auth_chain_index": False},
+            )
+
+            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+                txn, [FakeEvent("b", room_id, auth_graph["b"])],
+            )
+
+            self.store.db_pool.simple_update_txn(
+                txn,
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"has_auth_chain_index": True},
             )
 
+        self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+
         # Now actually test that various combinations give the right result:
 
         difference = self.get_success(
@@ -240,3 +423,21 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             self.store.get_auth_chain_difference(room_id, [{"a"}])
         )
         self.assertSetEqual(difference, set())
+
+
+@attr.s
+class FakeEvent:
+    event_id = attr.ib()
+    room_id = attr.ib()
+    auth_events = attr.ib()
+
+    type = "foo"
+    state_key = "foo"
+
+    internal_metadata = _EventInternalMetadata({})
+
+    def auth_event_ids(self):
+        return self.auth_events
+
+    def is_state(self):
+        return True
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 0ab0a91483..1184cea5a3 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -12,7 +12,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from synapse.util.iterutils import chunk_seq
+from typing import Dict, List
+
+from synapse.util.iterutils import chunk_seq, sorted_topologically
 
 from tests.unittest import TestCase
 
@@ -45,3 +47,40 @@ class ChunkSeqTests(TestCase):
         self.assertEqual(
             list(parts), [],
         )
+
+
+class SortTopologically(TestCase):
+    def test_empty(self):
+        "Test that an empty graph works correctly"
+
+        graph = {}  # type: Dict[int, List[int]]
+        self.assertEqual(list(sorted_topologically([], graph)), [])
+
+    def test_disconnected(self):
+        "Test that a graph with no edges work"
+
+        graph = {1: [], 2: []}  # type: Dict[int, List[int]]
+
+        # For disconnected nodes the output is simply sorted.
+        self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
+
+    def test_linear(self):
+        "Test that a simple `4 -> 3 -> 2 -> 1` graph works"
+
+        graph = {1: [], 2: [1], 3: [2], 4: [3]}  # type: Dict[int, List[int]]
+
+        self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
+
+    def test_subset(self):
+        "Test that only sorting a subset of the graph works"
+        graph = {1: [], 2: [1], 3: [2], 4: [3]}  # type: Dict[int, List[int]]
+
+        self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
+
+    def test_fork(self):
+        "Test that a forked graph works"
+        graph = {1: [], 2: [1], 3: [1], 4: [2, 3]}  # type: Dict[int, List[int]]
+
+        # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should
+        # always get the same one.
+        self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])