diff --git a/changelog.d/17044.misc b/changelog.d/17044.misc
new file mode 100644
index 0000000000..a1439752d3
--- /dev/null
+++ b/changelog.d/17044.misc
@@ -0,0 +1 @@
+Refactor auth chain fetching to reduce duplication.
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 846c3f363a..fb132ef090 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -27,6 +27,7 @@ from typing import (
Collection,
Dict,
FrozenSet,
+ Generator,
Iterable,
List,
Optional,
@@ -279,64 +280,16 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
- #
- # This query is structured to first get all chain IDs reachable, and
- # then pull out all links from those chains. This does pull out more
- # rows than is strictly necessary, however there isn't a way of
- # structuring the recursive part of query to pull out the links without
- # also returning large quantities of redundant data (which can make it a
- # lot slower).
- sql = """
- WITH RECURSIVE links(chain_id) AS (
- SELECT
- DISTINCT origin_chain_id
- FROM event_auth_chain_links WHERE %s
- UNION
- SELECT
- target_chain_id
- FROM event_auth_chain_links
- INNER JOIN links ON (chain_id = origin_chain_id)
- )
- SELECT
- origin_chain_id, origin_sequence_number,
- target_chain_id, target_sequence_number
- FROM links
- INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
- """
# A map from chain ID to max sequence number *reachable* from any event ID.
chains: Dict[int, int] = {}
-
- # Add all linked chains reachable from initial set of chains.
- chains_to_fetch = set(event_chains.keys())
- while chains_to_fetch:
- batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
- chains_to_fetch.difference_update(batch2)
- clause, args = make_in_list_sql_clause(
- txn.database_engine, "origin_chain_id", batch2
- )
- txn.execute(sql % (clause,), args)
-
- links: Dict[int, List[Tuple[int, int, int]]] = {}
-
- for (
- origin_chain_id,
- origin_sequence_number,
- target_chain_id,
- target_sequence_number,
- ) in txn:
- links.setdefault(origin_chain_id, []).append(
- (origin_sequence_number, target_chain_id, target_sequence_number)
- )
-
+ for links in self._get_chain_links(txn, set(event_chains.keys())):
for chain_id in links:
if chain_id not in event_chains:
continue
_materialize(chain_id, event_chains[chain_id], links, chains)
- chains_to_fetch.difference_update(chains)
-
# Add the initial set of chains, excluding the sequence corresponding to
# initial event.
for chain_id, seq_no in event_chains.items():
@@ -380,6 +333,68 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return results
+ @classmethod
+ def _get_chain_links(
+ cls, txn: LoggingTransaction, chains_to_fetch: Set[int]
+ ) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]:
+ """Fetch all auth chain links from the given set of chains, and all
+ links from those chains, recursively.
+
+ Note: This may return links that are not reachable from the given
+ chains.
+
+ Returns a generator that produces dicts from origin chain ID to 3-tuple
+ of origin sequence number, target chain ID and target sequence number.
+ """
+
+ # This query is structured to first get all chain IDs reachable, and
+ # then pull out all links from those chains. This does pull out more
+ # rows than is strictly necessary, however there isn't a way of
+ # structuring the recursive part of query to pull out the links without
+ # also returning large quantities of redundant data (which can make it a
+ # lot slower).
+ sql = """
+ WITH RECURSIVE links(chain_id) AS (
+ SELECT
+ DISTINCT origin_chain_id
+ FROM event_auth_chain_links WHERE %s
+ UNION
+ SELECT
+ target_chain_id
+ FROM event_auth_chain_links
+ INNER JOIN links ON (chain_id = origin_chain_id)
+ )
+ SELECT
+ origin_chain_id, origin_sequence_number,
+ target_chain_id, target_sequence_number
+ FROM links
+ INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
+ """
+
+ while chains_to_fetch:
+ batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
+ chains_to_fetch.difference_update(batch2)
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "origin_chain_id", batch2
+ )
+ txn.execute(sql % (clause,), args)
+
+ links: Dict[int, List[Tuple[int, int, int]]] = {}
+
+ for (
+ origin_chain_id,
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ ) in txn:
+ links.setdefault(origin_chain_id, []).append(
+ (origin_sequence_number, target_chain_id, target_sequence_number)
+ )
+
+ chains_to_fetch.difference_update(links)
+
+ yield links
+
def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> Set[str]:
@@ -564,53 +579,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
- #
- # This query is structured to first get all chain IDs reachable, and
- # then pull out all links from those chains. This does pull out more
- # rows than is strictly necessary, however there isn't a way of
- # structuring the recursive part of query to pull out the links without
- # also returning large quantities of redundant data (which can make it a
- # lot slower).
- sql = """
- WITH RECURSIVE links(chain_id) AS (
- SELECT
- DISTINCT origin_chain_id
- FROM event_auth_chain_links WHERE %s
- UNION
- SELECT
- target_chain_id
- FROM event_auth_chain_links
- INNER JOIN links ON (chain_id = origin_chain_id)
- )
- SELECT
- origin_chain_id, origin_sequence_number,
- target_chain_id, target_sequence_number
- FROM links
- INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
- """
-
- # (We need to take a copy of `seen_chains` as we want to mutate it in
- # the loop)
- chains_to_fetch = set(seen_chains)
- while chains_to_fetch:
- batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
- clause, args = make_in_list_sql_clause(
- txn.database_engine, "origin_chain_id", batch2
- )
- txn.execute(sql % (clause,), args)
-
- links: Dict[int, List[Tuple[int, int, int]]] = {}
-
- for (
- origin_chain_id,
- origin_sequence_number,
- target_chain_id,
- target_sequence_number,
- ) in txn:
- links.setdefault(origin_chain_id, []).append(
- (origin_sequence_number, target_chain_id, target_sequence_number)
- )
+ # (We need to take a copy of `seen_chains` as the function mutates it)
+ for links in self._get_chain_links(txn, set(seen_chains)):
for chains in set_to_chain:
for chain_id in links:
if chain_id not in chains:
@@ -618,7 +589,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
_materialize(chain_id, chains[chain_id], links, chains)
- chains_to_fetch.difference_update(chains)
seen_chains.update(chains)
# Now for each chain we figure out the maximum sequence number reachable
|