summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erikj@element.io>2024-04-02 15:33:56 +0100
committerGitHub <noreply@github.com>2024-04-02 15:33:56 +0100
commitec174d047005e4ac976311f4d3730452b2c5710f (patch)
tree936ab8ef3e9534be64000b3ce7264c9f1ded4377
parentFixups to new push stream (#17038) (diff)
downloadsynapse-ec174d047005e4ac976311f4d3730452b2c5710f.tar.xz
Refactor chain fetching (#17044)
Since these queries are duplicated in two places.
-rw-r--r--changelog.d/17044.misc1
-rw-r--r--synapse/storage/databases/main/event_federation.py162
2 files changed, 67 insertions, 96 deletions
diff --git a/changelog.d/17044.misc b/changelog.d/17044.misc
new file mode 100644
index 0000000000..a1439752d3
--- /dev/null
+++ b/changelog.d/17044.misc
@@ -0,0 +1 @@
+Refactor auth chain fetching to reduce duplication.
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 846c3f363a..fb132ef090 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -27,6 +27,7 @@ from typing import (
     Collection,
     Dict,
     FrozenSet,
+    Generator,
     Iterable,
     List,
     Optional,
@@ -279,64 +280,16 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         # Now we look up all links for the chains we have, adding chains that
         # are reachable from any event.
-        #
-        # This query is structured to first get all chain IDs reachable, and
-        # then pull out all links from those chains. This does pull out more
-        # rows than is strictly necessary, however there isn't a way of
-        # structuring the recursive part of query to pull out the links without
-        # also returning large quantities of redundant data (which can make it a
-        # lot slower).
-        sql = """
-            WITH RECURSIVE links(chain_id) AS (
-                SELECT
-                    DISTINCT origin_chain_id
-                FROM event_auth_chain_links WHERE %s
-                UNION
-                SELECT
-                    target_chain_id
-                FROM event_auth_chain_links
-                INNER JOIN links ON (chain_id = origin_chain_id)
-            )
-            SELECT
-                origin_chain_id, origin_sequence_number,
-                target_chain_id, target_sequence_number
-            FROM links
-            INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
-        """
 
         # A map from chain ID to max sequence number *reachable* from any event ID.
         chains: Dict[int, int] = {}
-
-        # Add all linked chains reachable from initial set of chains.
-        chains_to_fetch = set(event_chains.keys())
-        while chains_to_fetch:
-            batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
-            chains_to_fetch.difference_update(batch2)
-            clause, args = make_in_list_sql_clause(
-                txn.database_engine, "origin_chain_id", batch2
-            )
-            txn.execute(sql % (clause,), args)
-
-            links: Dict[int, List[Tuple[int, int, int]]] = {}
-
-            for (
-                origin_chain_id,
-                origin_sequence_number,
-                target_chain_id,
-                target_sequence_number,
-            ) in txn:
-                links.setdefault(origin_chain_id, []).append(
-                    (origin_sequence_number, target_chain_id, target_sequence_number)
-                )
-
+        for links in self._get_chain_links(txn, set(event_chains.keys())):
             for chain_id in links:
                 if chain_id not in event_chains:
                     continue
 
                 _materialize(chain_id, event_chains[chain_id], links, chains)
 
-            chains_to_fetch.difference_update(chains)
-
         # Add the initial set of chains, excluding the sequence corresponding to
         # initial event.
         for chain_id, seq_no in event_chains.items():
@@ -380,6 +333,68 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         return results
 
+    @classmethod
+    def _get_chain_links(
+        cls, txn: LoggingTransaction, chains_to_fetch: Set[int]
+    ) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]:
+        """Fetch all auth chain links from the given set of chains, and all
+        links from those chains, recursively.
+
+        Note: This may return links that are not reachable from the given
+        chains.
+
+        Returns a generator that produces dicts from origin chain ID to 3-tuple
+        of origin sequence number, target chain ID and target sequence number.
+        """
+
+        # This query is structured to first get all chain IDs reachable, and
+        # then pull out all links from those chains. This does pull out more
+        # rows than is strictly necessary, however there isn't a way of
+        # structuring the recursive part of query to pull out the links without
+        # also returning large quantities of redundant data (which can make it a
+        # lot slower).
+        sql = """
+            WITH RECURSIVE links(chain_id) AS (
+                SELECT
+                    DISTINCT origin_chain_id
+                FROM event_auth_chain_links WHERE %s
+                UNION
+                SELECT
+                    target_chain_id
+                FROM event_auth_chain_links
+                INNER JOIN links ON (chain_id = origin_chain_id)
+            )
+            SELECT
+                origin_chain_id, origin_sequence_number,
+                target_chain_id, target_sequence_number
+            FROM links
+            INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
+        """
+
+        while chains_to_fetch:
+            batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
+            chains_to_fetch.difference_update(batch2)
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "origin_chain_id", batch2
+            )
+            txn.execute(sql % (clause,), args)
+
+            links: Dict[int, List[Tuple[int, int, int]]] = {}
+
+            for (
+                origin_chain_id,
+                origin_sequence_number,
+                target_chain_id,
+                target_sequence_number,
+            ) in txn:
+                links.setdefault(origin_chain_id, []).append(
+                    (origin_sequence_number, target_chain_id, target_sequence_number)
+                )
+
+            chains_to_fetch.difference_update(links)
+
+            yield links
+
     def _get_auth_chain_ids_txn(
         self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
     ) -> Set[str]:
@@ -564,53 +579,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         # Now we look up all links for the chains we have, adding chains that
         # are reachable from any event.
-        #
-        # This query is structured to first get all chain IDs reachable, and
-        # then pull out all links from those chains. This does pull out more
-        # rows than is strictly necessary, however there isn't a way of
-        # structuring the recursive part of query to pull out the links without
-        # also returning large quantities of redundant data (which can make it a
-        # lot slower).
-        sql = """
-            WITH RECURSIVE links(chain_id) AS (
-                SELECT
-                    DISTINCT origin_chain_id
-                FROM event_auth_chain_links WHERE %s
-                UNION
-                SELECT
-                    target_chain_id
-                FROM event_auth_chain_links
-                INNER JOIN links ON (chain_id = origin_chain_id)
-            )
-            SELECT
-                origin_chain_id, origin_sequence_number,
-                target_chain_id, target_sequence_number
-            FROM links
-            INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
-        """
-
-        # (We need to take a copy of `seen_chains` as we want to mutate it in
-        # the loop)
-        chains_to_fetch = set(seen_chains)
-        while chains_to_fetch:
-            batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
-            clause, args = make_in_list_sql_clause(
-                txn.database_engine, "origin_chain_id", batch2
-            )
-            txn.execute(sql % (clause,), args)
-
-            links: Dict[int, List[Tuple[int, int, int]]] = {}
-
-            for (
-                origin_chain_id,
-                origin_sequence_number,
-                target_chain_id,
-                target_sequence_number,
-            ) in txn:
-                links.setdefault(origin_chain_id, []).append(
-                    (origin_sequence_number, target_chain_id, target_sequence_number)
-                )
 
+        # (We need to take a copy of `seen_chains` as the function mutates it)
+        for links in self._get_chain_links(txn, set(seen_chains)):
             for chains in set_to_chain:
                 for chain_id in links:
                     if chain_id not in chains:
@@ -618,7 +589,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
                     _materialize(chain_id, chains[chain_id], links, chains)
 
-                chains_to_fetch.difference_update(chains)
                 seen_chains.update(chains)
 
         # Now for each chain we figure out the maximum sequence number reachable