summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erikj@element.io>2024-01-23 11:26:27 +0000
committerGitHub <noreply@github.com>2024-01-23 11:26:27 +0000
commit14c725f73b1e4da37566def0491670009d718539 (patch)
tree6ec594ef8162a4d1a465733f5a99fb18ac31bbea
parentAdd a `--generate-only` option to the Complement launcher. (#16828) (diff)
downloadsynapse-14c725f73b1e4da37566def0491670009d718539.tar.xz
Preparatory work for tweaking performance of auth chain lookups (#16833)
-rw-r--r--changelog.d/16833.misc1
-rw-r--r--synapse/storage/databases/main/event_federation.py153
-rw-r--r--synapse/storage/schema/__init__.py2
-rw-r--r--synapse/storage/schema/main/delta/84/01_auth_links_stats.sql.postgres18
-rw-r--r--synapse/storage/schema/main/delta/84/02_auth_links_index.sql16
5 files changed, 163 insertions, 27 deletions
diff --git a/changelog.d/16833.misc b/changelog.d/16833.misc
new file mode 100644
index 0000000000..9714c97a7d
--- /dev/null
+++ b/changelog.d/16833.misc
@@ -0,0 +1 @@
+Preparatory work for tweaking performance of auth chain lookups.
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ddc2baf95d..e00e7ebf76 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -159,6 +159,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 unique_columns=("event_id", "room_id"),
             )
 
+        self.db_pool.updates.register_background_index_update(
+            update_name="event_auth_chain_links_origin_index",
+            index_name="event_auth_chain_links_origin_index",
+            table="event_auth_chain_links",
+            columns=("origin_chain_id", "origin_sequence_number"),
+        )
+
     async def get_auth_chain(
         self, room_id: str, event_ids: Collection[str], include_given: bool = False
     ) -> List[EventBase]:
@@ -271,38 +278,63 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         # Now we look up all links for the chains we have, adding chains that
         # are reachable from any event.
+        #
+        # This query is structured to first get all chain IDs reachable, and
+        # then pull out all links from those chains. This does pull out more
+        # rows than is strictly necessary, however there isn't a way of
+        # structuring the recursive part of query to pull out the links without
+        # also returning large quantities of redundant data (which can make it a
+        # lot slower).
         sql = """
+            WITH RECURSIVE links(chain_id) AS (
+                SELECT
+                    DISTINCT origin_chain_id
+                FROM event_auth_chain_links WHERE %s
+                UNION
+                SELECT
+                    target_chain_id
+                FROM event_auth_chain_links
+                INNER JOIN links ON (chain_id = origin_chain_id)
+            )
             SELECT
                 origin_chain_id, origin_sequence_number,
                 target_chain_id, target_sequence_number
-            FROM event_auth_chain_links
-            WHERE %s
+            FROM links
+            INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
         """
 
         # A map from chain ID to max sequence number *reachable* from any event ID.
         chains: Dict[int, int] = {}
 
         # Add all linked chains reachable from initial set of chains.
-        for batch2 in batch_iter(event_chains, 1000):
+        chains_to_fetch = set(event_chains.keys())
+        while chains_to_fetch:
+            batch2 = tuple(itertools.islice(chains_to_fetch, 100))
+            chains_to_fetch.difference_update(batch2)
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "origin_chain_id", batch2
             )
             txn.execute(sql % (clause,), args)
 
+            links: Dict[int, List[Tuple[int, int, int]]] = {}
+
             for (
                 origin_chain_id,
                 origin_sequence_number,
                 target_chain_id,
                 target_sequence_number,
             ) in txn:
-                # chains are only reachable if the origin sequence number of
-                # the link is less than the max sequence number in the
-                # origin chain.
-                if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
-                    chains[target_chain_id] = max(
-                        target_sequence_number,
-                        chains.get(target_chain_id, 0),
-                    )
+                links.setdefault(origin_chain_id, []).append(
+                    (origin_sequence_number, target_chain_id, target_sequence_number)
+                )
+
+            for chain_id in links:
+                if chain_id not in event_chains:
+                    continue
+
+                _materialize(chain_id, event_chains[chain_id], links, chains)
+
+            chains_to_fetch.difference_update(chains)
 
         # Add the initial set of chains, excluding the sequence corresponding to
         # initial event.
@@ -529,41 +561,64 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
                 chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
 
-        # Now we look up all links for the chains we have, adding chains to
-        # set_to_chain that are reachable from each set.
+        # Now we look up all links for the chains we have, adding chains that
+        # are reachable from any event.
+        #
+        # This query is structured to first get all chain IDs reachable, and
+        # then pull out all links from those chains. This does pull out more
+        # rows than is strictly necessary, however there isn't a way of
+        # structuring the recursive part of query to pull out the links without
+        # also returning large quantities of redundant data (which can make it a
+        # lot slower).
         sql = """
+            WITH RECURSIVE links(chain_id) AS (
+                SELECT
+                    DISTINCT origin_chain_id
+                FROM event_auth_chain_links WHERE %s
+                UNION
+                SELECT
+                    target_chain_id
+                FROM event_auth_chain_links
+                INNER JOIN links ON (chain_id = origin_chain_id)
+            )
             SELECT
                 origin_chain_id, origin_sequence_number,
                 target_chain_id, target_sequence_number
-            FROM event_auth_chain_links
-            WHERE %s
+            FROM links
+            INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
         """
 
         # (We need to take a copy of `seen_chains` as we want to mutate it in
         # the loop)
-        for batch2 in batch_iter(set(seen_chains), 1000):
+        chains_to_fetch = set(seen_chains)
+        while chains_to_fetch:
+            batch2 = tuple(itertools.islice(chains_to_fetch, 100))
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "origin_chain_id", batch2
             )
             txn.execute(sql % (clause,), args)
 
+            links: Dict[int, List[Tuple[int, int, int]]] = {}
+
             for (
                 origin_chain_id,
                 origin_sequence_number,
                 target_chain_id,
                 target_sequence_number,
             ) in txn:
-                for chains in set_to_chain:
-                    # chains are only reachable if the origin sequence number of
-                    # the link is less than the max sequence number in the
-                    # origin chain.
-                    if origin_sequence_number <= chains.get(origin_chain_id, 0):
-                        chains[target_chain_id] = max(
-                            target_sequence_number,
-                            chains.get(target_chain_id, 0),
-                        )
+                links.setdefault(origin_chain_id, []).append(
+                    (origin_sequence_number, target_chain_id, target_sequence_number)
+                )
+
+            for chains in set_to_chain:
+                for chain_id in links:
+                    if chain_id not in chains:
+                        continue
 
-                seen_chains.add(target_chain_id)
+                    _materialize(chain_id, chains[chain_id], links, chains)
+
+                chains_to_fetch.difference_update(chains)
+                seen_chains.update(chains)
 
         # Now for each chain we figure out the maximum sequence number reachable
         # from *any* state set and the minimum sequence number reachable from
@@ -2103,3 +2158,49 @@ class EventFederationStore(EventFederationWorkerStore):
             )
 
         return batch_size
+
+
+def _materialize(
+    origin_chain_id: int,
+    origin_sequence_number: int,
+    links: Dict[int, List[Tuple[int, int, int]]],
+    materialized: Dict[int, int],
+) -> None:
+    """Helper function for fetching auth chain links. For a given origin chain
+    ID / sequence number and a dictionary of links, updates the materialized
+    dict with the reachable chains.
+
+    To get a dict of all chains reachable from a set of chains this function can
+    be called in a loop, once per origin chain with the same links and
+    materialized args. The materialized dict will the result.
+
+    Args:
+        origin_chain_id, origin_sequence_number
+        links: map of the links between chains as a dict from origin chain ID
+            to list of 3-tuples of origin sequence number, target chain ID and
+            target sequence number.
+        materialized: dict to update with new reachability information, as a
+            map from chain ID to max sequence number reachable.
+    """
+
+    # Do a standard graph traversal.
+    stack = [(origin_chain_id, origin_sequence_number)]
+
+    while stack:
+        c, s = stack.pop()
+
+        chain_links = links.get(c, [])
+        for (
+            sequence_number,
+            target_chain_id,
+            target_sequence_number,
+        ) in chain_links:
+            # Ignore any links that are higher up the chain
+            if sequence_number > s:
+                continue
+
+            # Check if we have already visited the target chain before, if so we
+            # can skip it.
+            if materialized.get(target_chain_id, 0) < target_sequence_number:
+                stack.append((target_chain_id, target_sequence_number))
+                materialized[target_chain_id] = target_sequence_number
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 132c781f51..ebdf6d95af 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -18,7 +18,7 @@
 #
 #
 
-SCHEMA_VERSION = 83  # remember to update the list below when updating
+SCHEMA_VERSION = 84  # remember to update the list below when updating
 """Represents the expectations made by the codebase about the database schema
 
 This should be incremented whenever the codebase changes its requirements on the
diff --git a/synapse/storage/schema/main/delta/84/01_auth_links_stats.sql.postgres b/synapse/storage/schema/main/delta/84/01_auth_links_stats.sql.postgres
new file mode 100644
index 0000000000..b0b41bd106
--- /dev/null
+++ b/synapse/storage/schema/main/delta/84/01_auth_links_stats.sql.postgres
@@ -0,0 +1,18 @@
+--
+-- This file is licensed under the Affero General Public License (AGPL) version 3.
+--
+-- Copyright (C) 2023 New Vector, Ltd
+--
+-- This program is free software: you can redistribute it and/or modify
+-- it under the terms of the GNU Affero General Public License as
+-- published by the Free Software Foundation, either version 3 of the
+-- License, or (at your option) any later version.
+--
+-- See the GNU Affero General Public License for more details:
+-- <https://www.gnu.org/licenses/agpl-3.0.html>.
+
+-- Force the statistics for these tables to show that the number of distinct
+-- chain IDs are proportional to the total rows, as postgres has trouble
+-- figuring that out by itself.
+ALTER TABLE event_auth_chain_links ALTER origin_chain_id SET (n_distinct = -0.5);
+ALTER TABLE event_auth_chain_links ALTER target_chain_id SET (n_distinct = -0.5);
diff --git a/synapse/storage/schema/main/delta/84/02_auth_links_index.sql b/synapse/storage/schema/main/delta/84/02_auth_links_index.sql
new file mode 100644
index 0000000000..6936e3d05b
--- /dev/null
+++ b/synapse/storage/schema/main/delta/84/02_auth_links_index.sql
@@ -0,0 +1,16 @@
+--
+-- This file is licensed under the Affero General Public License (AGPL) version 3.
+--
+-- Copyright (C) 2023 New Vector, Ltd
+--
+-- This program is free software: you can redistribute it and/or modify
+-- it under the terms of the GNU Affero General Public License as
+-- published by the Free Software Foundation, either version 3 of the
+-- License, or (at your option) any later version.
+--
+-- See the GNU Affero General Public License for more details:
+-- <https://www.gnu.org/licenses/agpl-3.0.html>.
+
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+    (8402, 'event_auth_chain_links_origin_index', '{}');