summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17044.misc1
-rw-r--r--synapse/storage/databases/main/events.py108
-rw-r--r--synapse/storage/schema/__init__.py8
-rw-r--r--tests/storage/test_event_chain.py104
4 files changed, 117 insertions, 104 deletions
diff --git a/changelog.d/17044.misc b/changelog.d/17044.misc
new file mode 100644
index 0000000000..a1439752d3
--- /dev/null
+++ b/changelog.d/17044.misc
@@ -0,0 +1 @@
+Refactor auth chain fetching to reduce duplication.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index a6fda3f43c..1e731d56bd 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -19,6 +19,7 @@
 # [This file includes modifications made by New Vector Limited]
 #
 #
+import collections
 import itertools
 import logging
 from collections import OrderedDict
@@ -53,6 +54,7 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
+from synapse.storage.databases.main.event_federation import EventFederationStore
 from synapse.storage.databases.main.events_worker import EventCacheEntry
 from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.engines import PostgresEngine
@@ -768,40 +770,26 @@ class PersistEventsStore:
         #      that have the same chain ID as the event.
         #   2. For each retained auth event we:
         #       a. Add a link from the event's to the auth event's chain
-        #          ID/sequence number; and
-        #       b. Add a link from the event to every chain reachable by the
-        #          auth event.
+        #          ID/sequence number
 
         # Step 1, fetch all existing links from all the chains we've seen
         # referenced.
         chain_links = _LinkMap()
-        auth_chain_rows = cast(
-            List[Tuple[int, int, int, int]],
-            db_pool.simple_select_many_txn(
-                txn,
-                table="event_auth_chain_links",
-                column="origin_chain_id",
-                iterable={chain_id for chain_id, _ in chain_map.values()},
-                keyvalues={},
-                retcols=(
-                    "origin_chain_id",
-                    "origin_sequence_number",
-                    "target_chain_id",
-                    "target_sequence_number",
-                ),
-            ),
-        )
-        for (
-            origin_chain_id,
-            origin_sequence_number,
-            target_chain_id,
-            target_sequence_number,
-        ) in auth_chain_rows:
-            chain_links.add_link(
-                (origin_chain_id, origin_sequence_number),
-                (target_chain_id, target_sequence_number),
-                new=False,
-            )
+
+        for links in EventFederationStore._get_chain_links(
+            txn, {chain_id for chain_id, _ in chain_map.values()}
+        ):
+            for origin_chain_id, inner_links in links.items():
+                for (
+                    origin_sequence_number,
+                    target_chain_id,
+                    target_sequence_number,
+                ) in inner_links:
+                    chain_links.add_link(
+                        (origin_chain_id, origin_sequence_number),
+                        (target_chain_id, target_sequence_number),
+                        new=False,
+                    )
 
         # We do this in toplogical order to avoid adding redundant links.
         for event_id in sorted_topologically(
@@ -836,18 +824,6 @@ class PersistEventsStore:
                     (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
                 )
 
-                # Step 2b, add a link to chains reachable from the auth
-                # event.
-                for target_id, target_seq in chain_links.get_links_from(
-                    (auth_chain_id, auth_sequence_number)
-                ):
-                    if target_id == chain_id:
-                        continue
-
-                    chain_links.add_link(
-                        (chain_id, sequence_number), (target_id, target_seq)
-                    )
-
         db_pool.simple_insert_many_txn(
             txn,
             table="event_auth_chain_links",
@@ -2451,31 +2427,6 @@ class _LinkMap:
         current_links[src_seq] = target_seq
         return True
 
-    def get_links_from(
-        self, src_tuple: Tuple[int, int]
-    ) -> Generator[Tuple[int, int], None, None]:
-        """Gets the chains reachable from the given chain/sequence number.
-
-        Yields:
-            The chain ID and sequence number the link points to.
-        """
-        src_chain, src_seq = src_tuple
-        for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
-            for link_src_seq, target_seq in sequence_numbers.items():
-                if link_src_seq <= src_seq:
-                    yield target_id, target_seq
-
-    def get_links_between(
-        self, source_chain: int, target_chain: int
-    ) -> Generator[Tuple[int, int], None, None]:
-        """Gets the links between two chains.
-
-        Yields:
-            The source and target sequence numbers.
-        """
-
-        yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()
-
     def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
         """Gets any newly added links.
 
@@ -2502,9 +2453,24 @@ class _LinkMap:
         if src_chain == target_chain:
             return target_seq <= src_seq
 
-        links = self.get_links_between(src_chain, target_chain)
-        for link_start_seq, link_end_seq in links:
-            if link_start_seq <= src_seq and target_seq <= link_end_seq:
-                return True
+        # We have to graph traverse the links to check for indirect paths.
+        visited_chains = collections.Counter()
+        search = [(src_chain, src_seq)]
+        while search:
+            chain, seq = search.pop()
+            visited_chains[chain] = max(seq, visited_chains[chain])
+            for tc, links in self.maps.get(chain, {}).items():
+                for ss, ts in links.items():
+                    # Don't revisit chains we've already seen, unless the target
+                    # sequence number is higher than last time.
+                    if ts <= visited_chains.get(tc, 0):
+                        continue
+
+                    if ss <= seq:
+                        if tc == target_chain:
+                            if target_seq <= ts:
+                                return True
+                        else:
+                            search.append((tc, ts))
 
         return False
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index c0b925444f..039aa91b92 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -132,12 +132,16 @@ Changes in SCHEMA_VERSION = 82
 
 Changes in SCHEMA_VERSION = 83
     - The event_txn_id is no longer used.
+
+Changes in SCHEMA_VERSION = 84
+    - No longer assumes that `event_auth_chain_links` holds transitive links, and
+      so read operations must do graph traversal.
 """
 
 
 SCHEMA_COMPAT_VERSION = (
-    # The event_txn_id table and tables from MSC2716 no longer exist.
-    83
+    # Transitive links are no longer written to `event_auth_chain_links`
+    84
 )
 """Limit on how far the synapse codebase can be rolled back without breaking db compat
 
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 9e4e73832e..27d5b0125f 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -21,6 +21,8 @@
 
 from typing import Dict, List, Set, Tuple, cast
 
+from parameterized import parameterized
+
 from twisted.test.proto_helpers import MemoryReactor
 from twisted.trial import unittest
 
@@ -45,7 +47,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
         self.store = hs.get_datastores().main
         self._next_stream_ordering = 1
 
-    def test_simple(self) -> None:
+    @parameterized.expand([(False,), (True,)])
+    def test_simple(self, batched: bool) -> None:
         """Test that the example in `docs/auth_chain_difference_algorithm.md`
         works.
         """
@@ -53,6 +56,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
         event_factory = self.hs.get_event_builder_factory()
         bob = "@creator:test"
         alice = "@alice:test"
+        charlie = "@charlie:test"
         room_id = "!room:test"
 
         # Ensure that we have a rooms entry so that we generate the chain index.
@@ -191,6 +195,26 @@ class EventChainStoreTestCase(HomeserverTestCase):
             )
         )
 
+        charlie_invite = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": charlie,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "charlie_invite"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[
+                    create.event_id,
+                    alice_join2.event_id,
+                    power_2.event_id,
+                ],
+            )
+        )
+
         events = [
             create,
             bob_join,
@@ -200,33 +224,41 @@ class EventChainStoreTestCase(HomeserverTestCase):
             bob_join_2,
             power_2,
             alice_join2,
+            charlie_invite,
         ]
 
         expected_links = [
             (bob_join, create),
-            (power, create),
             (power, bob_join),
-            (alice_invite, create),
             (alice_invite, power),
-            (alice_invite, bob_join),
             (bob_join_2, power),
             (alice_join2, power_2),
+            (charlie_invite, alice_join2),
         ]
 
-        self.persist(events)
+        # We either persist as a batch or one-by-one depending on test
+        # parameter.
+        if batched:
+            self.persist(events)
+        else:
+            for event in events:
+                self.persist([event])
+
         chain_map, link_map = self.fetch_chains(events)
 
         # Check that the expected links and only the expected links have been
         # added.
-        self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
-
-        for start, end in expected_links:
-            start_id, start_seq = chain_map[start.event_id]
-            end_id, end_seq = chain_map[end.event_id]
+        event_map = {e.event_id: e for e in events}
+        reverse_chain_map = {v: event_map[k] for k, v in chain_map.items()}
 
-            self.assertIn(
-                (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
-            )
+        self.maxDiff = None
+        self.assertCountEqual(
+            expected_links,
+            [
+                (reverse_chain_map[(s1, s2)], reverse_chain_map[(t1, t2)])
+                for s1, s2, t1, t2 in link_map.get_additions()
+            ],
+        )
 
         # Test that everything can reach the create event, but the create event
         # can't reach anything.
@@ -368,24 +400,23 @@ class EventChainStoreTestCase(HomeserverTestCase):
 
         expected_links = [
             (bob_join, create),
-            (power, create),
             (power, bob_join),
-            (alice_invite, create),
             (alice_invite, power),
-            (alice_invite, bob_join),
         ]
 
         # Check that the expected links and only the expected links have been
         # added.
-        self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+        event_map = {e.event_id: e for e in events}
+        reverse_chain_map = {v: event_map[k] for k, v in chain_map.items()}
 
-        for start, end in expected_links:
-            start_id, start_seq = chain_map[start.event_id]
-            end_id, end_seq = chain_map[end.event_id]
-
-            self.assertIn(
-                (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
-            )
+        self.maxDiff = None
+        self.assertCountEqual(
+            expected_links,
+            [
+                (reverse_chain_map[(s1, s2)], reverse_chain_map[(t1, t2)])
+                for s1, s2, t1, t2 in link_map.get_additions()
+            ],
+        )
 
     def persist(
         self,
@@ -489,8 +520,6 @@ class LinkMapTestCase(unittest.TestCase):
         link_map = _LinkMap()
 
         link_map.add_link((1, 1), (2, 1), new=False)
-        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
-        self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
         self.assertCountEqual(link_map.get_additions(), [])
         self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
         self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
@@ -499,18 +528,31 @@ class LinkMapTestCase(unittest.TestCase):
 
         # Attempting to add a redundant link is ignored.
         self.assertFalse(link_map.add_link((1, 4), (2, 1)))
-        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+        self.assertCountEqual(link_map.get_additions(), [])
 
         # Adding new non-redundant links works
         self.assertTrue(link_map.add_link((1, 3), (2, 3)))
-        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+        self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3)])
 
         self.assertTrue(link_map.add_link((2, 5), (1, 3)))
-        self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
-        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
-
         self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
 
+    def test_exists_path_from(self) -> None:
+        "Check that `exists_path_from` can handle non-direct links"
+        link_map = _LinkMap()
+
+        link_map.add_link((1, 1), (2, 1), new=False)
+        link_map.add_link((2, 1), (3, 1), new=False)
+
+        self.assertTrue(link_map.exists_path_from((1, 4), (3, 1)))
+        self.assertFalse(link_map.exists_path_from((1, 4), (3, 2)))
+
+        link_map.add_link((1, 5), (2, 3), new=False)
+        link_map.add_link((2, 2), (3, 3), new=False)
+
+        self.assertTrue(link_map.exists_path_from((1, 6), (3, 2)))
+        self.assertFalse(link_map.exists_path_from((1, 4), (3, 2)))
+
 
 class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
     servlets = [