summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2024-05-08 16:04:35 +0100
committerErik Johnston <erik@matrix.org>2024-05-08 16:04:35 +0100
commitdb25e30a256047454a9b09b579856d6cce0a6a7b (patch)
tree55b82e2f6583dc406bc218f5655a05f241e242bd
parentOptional whitespace support in Authorization (#1350) (#17145) (diff)
downloadsynapse-db25e30a256047454a9b09b579856d6cce0a6a7b.tar.xz
Perf improvement to getting auth chains
-rw-r--r--synapse/storage/databases/main/event_federation.py33
1 files changed, 26 insertions, 7 deletions
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index fb132ef090..68f30d893c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -283,7 +283,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         # A map from chain ID to max sequence number *reachable* from any event ID.
         chains: Dict[int, int] = {}
-        for links in self._get_chain_links(txn, set(event_chains.keys())):
+        for links in self._get_chain_links(txn, event_chains.keys()):
             for chain_id in links:
                 if chain_id not in event_chains:
                     continue
@@ -335,7 +335,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
     @classmethod
     def _get_chain_links(
-        cls, txn: LoggingTransaction, chains_to_fetch: Set[int]
+        cls, txn: LoggingTransaction, chains_to_fetch: Collection[int]
     ) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]:
         """Fetch all auth chain links from the given set of chains, and all
         links from those chains, recursively.
@@ -371,9 +371,27 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
         """
 
-        while chains_to_fetch:
-            batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
-            chains_to_fetch.difference_update(batch2)
+        # We fetch the links in batches. Separate batches will likely fetch the
+        # same set of links (e.g. they'll always pull in the links to create
+        # event). To try and minimize the amount of redundant links, we sort the
+        # chain IDs in reverse, as there will be a correlation between the order
+        # of chain IDs and links (i.e., higher chain IDs are more likely to
+        # depend on lower chain IDs than vice versa).
+        BATCH_SIZE = 1000
+        chains_to_fetch_list = list(chains_to_fetch)
+        chains_to_fetch_list.sort(reverse=True)
+
+        seen_chains: Set[int] = set()
+        while chains_to_fetch_list:
+            batch2 = [
+                c for c in chains_to_fetch_list[-BATCH_SIZE:] if c not in seen_chains
+            ]
+            chains_to_fetch_list = chains_to_fetch_list[:-BATCH_SIZE]
+            while len(batch2) < BATCH_SIZE and chains_to_fetch_list:
+                chain_id = chains_to_fetch_list.pop()
+                if chain_id not in seen_chains:
+                    batch2.append(chain_id)
+
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "origin_chain_id", batch2
             )
@@ -391,7 +409,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                     (origin_sequence_number, target_chain_id, target_sequence_number)
                 )
 
-            chains_to_fetch.difference_update(links)
+            seen_chains.update(links)
+            seen_chains.update(batch2)
 
             yield links
 
@@ -581,7 +600,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # are reachable from any event.
 
         # (We need to take a copy of `seen_chains` as the function mutates it)
-        for links in self._get_chain_links(txn, set(seen_chains)):
+        for links in self._get_chain_links(txn, seen_chains):
             for chains in set_to_chain:
                 for chain_id in links:
                     if chain_id not in chains: