summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/events.py108
1 files changed, 37 insertions, 71 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index a6fda3f43c..1e731d56bd 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -19,6 +19,7 @@
 # [This file includes modifications made by New Vector Limited]
 #
 #
+import collections
 import itertools
 import logging
 from collections import OrderedDict
@@ -53,6 +54,7 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
+from synapse.storage.databases.main.event_federation import EventFederationStore
 from synapse.storage.databases.main.events_worker import EventCacheEntry
 from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.engines import PostgresEngine
@@ -768,40 +770,26 @@ class PersistEventsStore:
         #      that have the same chain ID as the event.
         #   2. For each retained auth event we:
         #       a. Add a link from the event's to the auth event's chain
-        #          ID/sequence number; and
-        #       b. Add a link from the event to every chain reachable by the
-        #          auth event.
+        #          ID/sequence number
 
         # Step 1, fetch all existing links from all the chains we've seen
         # referenced.
         chain_links = _LinkMap()
-        auth_chain_rows = cast(
-            List[Tuple[int, int, int, int]],
-            db_pool.simple_select_many_txn(
-                txn,
-                table="event_auth_chain_links",
-                column="origin_chain_id",
-                iterable={chain_id for chain_id, _ in chain_map.values()},
-                keyvalues={},
-                retcols=(
-                    "origin_chain_id",
-                    "origin_sequence_number",
-                    "target_chain_id",
-                    "target_sequence_number",
-                ),
-            ),
-        )
-        for (
-            origin_chain_id,
-            origin_sequence_number,
-            target_chain_id,
-            target_sequence_number,
-        ) in auth_chain_rows:
-            chain_links.add_link(
-                (origin_chain_id, origin_sequence_number),
-                (target_chain_id, target_sequence_number),
-                new=False,
-            )
+
+        for links in EventFederationStore._get_chain_links(
+            txn, {chain_id for chain_id, _ in chain_map.values()}
+        ):
+            for origin_chain_id, inner_links in links.items():
+                for (
+                    origin_sequence_number,
+                    target_chain_id,
+                    target_sequence_number,
+                ) in inner_links:
+                    chain_links.add_link(
+                        (origin_chain_id, origin_sequence_number),
+                        (target_chain_id, target_sequence_number),
+                        new=False,
+                    )
 
         # We do this in toplogical order to avoid adding redundant links.
         for event_id in sorted_topologically(
@@ -836,18 +824,6 @@ class PersistEventsStore:
                     (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
                 )
 
-                # Step 2b, add a link to chains reachable from the auth
-                # event.
-                for target_id, target_seq in chain_links.get_links_from(
-                    (auth_chain_id, auth_sequence_number)
-                ):
-                    if target_id == chain_id:
-                        continue
-
-                    chain_links.add_link(
-                        (chain_id, sequence_number), (target_id, target_seq)
-                    )
-
         db_pool.simple_insert_many_txn(
             txn,
             table="event_auth_chain_links",
@@ -2451,31 +2427,6 @@ class _LinkMap:
         current_links[src_seq] = target_seq
         return True
 
-    def get_links_from(
-        self, src_tuple: Tuple[int, int]
-    ) -> Generator[Tuple[int, int], None, None]:
-        """Gets the chains reachable from the given chain/sequence number.
-
-        Yields:
-            The chain ID and sequence number the link points to.
-        """
-        src_chain, src_seq = src_tuple
-        for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
-            for link_src_seq, target_seq in sequence_numbers.items():
-                if link_src_seq <= src_seq:
-                    yield target_id, target_seq
-
-    def get_links_between(
-        self, source_chain: int, target_chain: int
-    ) -> Generator[Tuple[int, int], None, None]:
-        """Gets the links between two chains.
-
-        Yields:
-            The source and target sequence numbers.
-        """
-
-        yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()
-
     def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
         """Gets any newly added links.
 
@@ -2502,9 +2453,24 @@ class _LinkMap:
         if src_chain == target_chain:
             return target_seq <= src_seq
 
-        links = self.get_links_between(src_chain, target_chain)
-        for link_start_seq, link_end_seq in links:
-            if link_start_seq <= src_seq and target_seq <= link_end_seq:
-                return True
+        # We have to graph traverse the links to check for indirect paths.
+        visited_chains = collections.Counter()
+        search = [(src_chain, src_seq)]
+        while search:
+            chain, seq = search.pop()
+            visited_chains[chain] = max(seq, visited_chains[chain])
+            for tc, links in self.maps.get(chain, {}).items():
+                for ss, ts in links.items():
+                    # Don't revisit chains we've already seen, unless the target
+                    # sequence number is higher than last time.
+                    if ts <= visited_chains.get(tc, 0):
+                        continue
+
+                    if ss <= seq:
+                        if tc == target_chain:
+                            if target_seq <= ts:
+                                return True
+                        else:
+                            search.append((tc, ts))
 
         return False