summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
authorErik Johnston <erikj@element.io>2024-01-23 11:26:27 +0000
committerGitHub <noreply@github.com>2024-01-23 11:26:27 +0000
commit14c725f73b1e4da37566def0491670009d718539 (patch)
tree6ec594ef8162a4d1a465733f5a99fb18ac31bbea /synapse/storage/databases/main
parentAdd a `--generate-only` option to the Complement launcher. (#16828) (diff)
downloadsynapse-14c725f73b1e4da37566def0491670009d718539.tar.xz
Preparatory work for tweaking performance of auth chain lookups (#16833)
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/event_federation.py153
1 files changed, 127 insertions, 26 deletions
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ddc2baf95d..e00e7ebf76 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -159,6 +159,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 unique_columns=("event_id", "room_id"),
             )
 
+        self.db_pool.updates.register_background_index_update(
+            update_name="event_auth_chain_links_origin_index",
+            index_name="event_auth_chain_links_origin_index",
+            table="event_auth_chain_links",
+            columns=("origin_chain_id", "origin_sequence_number"),
+        )
+
     async def get_auth_chain(
         self, room_id: str, event_ids: Collection[str], include_given: bool = False
     ) -> List[EventBase]:
@@ -271,38 +278,63 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         # Now we look up all links for the chains we have, adding chains that
         # are reachable from any event.
+        #
+        # This query is structured to first get all chain IDs reachable, and
+        # then pull out all links from those chains. This does pull out more
+        # rows than is strictly necessary, however there isn't a way of
+        # structuring the recursive part of query to pull out the links without
+        # also returning large quantities of redundant data (which can make it a
+        # lot slower).
         sql = """
+            WITH RECURSIVE links(chain_id) AS (
+                SELECT
+                    DISTINCT origin_chain_id
+                FROM event_auth_chain_links WHERE %s
+                UNION
+                SELECT
+                    target_chain_id
+                FROM event_auth_chain_links
+                INNER JOIN links ON (chain_id = origin_chain_id)
+            )
             SELECT
                 origin_chain_id, origin_sequence_number,
                 target_chain_id, target_sequence_number
-            FROM event_auth_chain_links
-            WHERE %s
+            FROM links
+            INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
         """
 
         # A map from chain ID to max sequence number *reachable* from any event ID.
         chains: Dict[int, int] = {}
 
         # Add all linked chains reachable from initial set of chains.
-        for batch2 in batch_iter(event_chains, 1000):
+        chains_to_fetch = set(event_chains.keys())
+        while chains_to_fetch:
+            batch2 = tuple(itertools.islice(chains_to_fetch, 100))
+            chains_to_fetch.difference_update(batch2)
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "origin_chain_id", batch2
             )
             txn.execute(sql % (clause,), args)
 
+            links: Dict[int, List[Tuple[int, int, int]]] = {}
+
             for (
                 origin_chain_id,
                 origin_sequence_number,
                 target_chain_id,
                 target_sequence_number,
             ) in txn:
-                # chains are only reachable if the origin sequence number of
-                # the link is less than the max sequence number in the
-                # origin chain.
-                if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
-                    chains[target_chain_id] = max(
-                        target_sequence_number,
-                        chains.get(target_chain_id, 0),
-                    )
+                links.setdefault(origin_chain_id, []).append(
+                    (origin_sequence_number, target_chain_id, target_sequence_number)
+                )
+
+            for chain_id in links:
+                if chain_id not in event_chains:
+                    continue
+
+                _materialize(chain_id, event_chains[chain_id], links, chains)
+
+            chains_to_fetch.difference_update(chains)
 
         # Add the initial set of chains, excluding the sequence corresponding to
         # initial event.
@@ -529,41 +561,64 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
                 chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
 
-        # Now we look up all links for the chains we have, adding chains to
-        # set_to_chain that are reachable from each set.
+        # Now we look up all links for the chains we have, adding chains that
+        # are reachable from any event.
+        #
+        # This query is structured to first get all chain IDs reachable, and
+        # then pull out all links from those chains. This does pull out more
+        # rows than is strictly necessary, however there isn't a way of
+        # structuring the recursive part of query to pull out the links without
+        # also returning large quantities of redundant data (which can make it a
+        # lot slower).
         sql = """
+            WITH RECURSIVE links(chain_id) AS (
+                SELECT
+                    DISTINCT origin_chain_id
+                FROM event_auth_chain_links WHERE %s
+                UNION
+                SELECT
+                    target_chain_id
+                FROM event_auth_chain_links
+                INNER JOIN links ON (chain_id = origin_chain_id)
+            )
             SELECT
                 origin_chain_id, origin_sequence_number,
                 target_chain_id, target_sequence_number
-            FROM event_auth_chain_links
-            WHERE %s
+            FROM links
+            INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
         """
 
         # (We need to take a copy of `seen_chains` as we want to mutate it in
         # the loop)
-        for batch2 in batch_iter(set(seen_chains), 1000):
+        chains_to_fetch = set(seen_chains)
+        while chains_to_fetch:
+            batch2 = tuple(itertools.islice(chains_to_fetch, 100))
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "origin_chain_id", batch2
             )
             txn.execute(sql % (clause,), args)
 
+            links: Dict[int, List[Tuple[int, int, int]]] = {}
+
             for (
                 origin_chain_id,
                 origin_sequence_number,
                 target_chain_id,
                 target_sequence_number,
             ) in txn:
-                for chains in set_to_chain:
-                    # chains are only reachable if the origin sequence number of
-                    # the link is less than the max sequence number in the
-                    # origin chain.
-                    if origin_sequence_number <= chains.get(origin_chain_id, 0):
-                        chains[target_chain_id] = max(
-                            target_sequence_number,
-                            chains.get(target_chain_id, 0),
-                        )
+                links.setdefault(origin_chain_id, []).append(
+                    (origin_sequence_number, target_chain_id, target_sequence_number)
+                )
+
+            for chains in set_to_chain:
+                for chain_id in links:
+                    if chain_id not in chains:
+                        continue
 
-                seen_chains.add(target_chain_id)
+                    _materialize(chain_id, chains[chain_id], links, chains)
+
+                chains_to_fetch.difference_update(chains)
+                seen_chains.update(chains)
 
         # Now for each chain we figure out the maximum sequence number reachable
         # from *any* state set and the minimum sequence number reachable from
@@ -2103,3 +2158,49 @@ class EventFederationStore(EventFederationWorkerStore):
             )
 
         return batch_size
+
+
+def _materialize(
+    origin_chain_id: int,
+    origin_sequence_number: int,
+    links: Dict[int, List[Tuple[int, int, int]]],
+    materialized: Dict[int, int],
+) -> None:
+    """Helper function for fetching auth chain links. For a given origin chain
+    ID / sequence number and a dictionary of links, updates the materialized
+    dict with the reachable chains.
+
+    To get a dict of all chains reachable from a set of chains this function can
+    be called in a loop, once per origin chain with the same links and
+    materialized args. The materialized dict will the result.
+
+    Args:
+        origin_chain_id, origin_sequence_number
+        links: map of the links between chains as a dict from origin chain ID
+            to list of 3-tuples of origin sequence number, target chain ID and
+            target sequence number.
+        materialized: dict to update with new reachability information, as a
+            map from chain ID to max sequence number reachable.
+    """
+
+    # Do a standard graph traversal.
+    stack = [(origin_chain_id, origin_sequence_number)]
+
+    while stack:
+        c, s = stack.pop()
+
+        chain_links = links.get(c, [])
+        for (
+            sequence_number,
+            target_chain_id,
+            target_sequence_number,
+        ) in chain_links:
+            # Ignore any links that are higher up the chain
+            if sequence_number > s:
+                continue
+
+            # Check if we have already visited the target chain before, if so we
+            # can skip it.
+            if materialized.get(target_chain_id, 0) < target_sequence_number:
+                stack.append((target_chain_id, target_sequence_number))
+                materialized[target_chain_id] = target_sequence_number