diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ddc2baf95d..e00e7ebf76 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -159,6 +159,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
unique_columns=("event_id", "room_id"),
)
+ self.db_pool.updates.register_background_index_update(
+ update_name="event_auth_chain_links_origin_index",
+ index_name="event_auth_chain_links_origin_index",
+ table="event_auth_chain_links",
+ columns=("origin_chain_id", "origin_sequence_number"),
+ )
+
async def get_auth_chain(
self, room_id: str, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
@@ -271,38 +278,63 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
+ #
+ # This query is structured to first get all chain IDs reachable, and
+ # then pull out all links from those chains. This does pull out more
+ # rows than is strictly necessary, however there isn't a way of
+ # structuring the recursive part of query to pull out the links without
+ # also returning large quantities of redundant data (which can make it a
+ # lot slower).
sql = """
+ WITH RECURSIVE links(chain_id) AS (
+ SELECT
+ DISTINCT origin_chain_id
+ FROM event_auth_chain_links WHERE %s
+ UNION
+ SELECT
+ target_chain_id
+ FROM event_auth_chain_links
+ INNER JOIN links ON (chain_id = origin_chain_id)
+ )
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
- FROM event_auth_chain_links
- WHERE %s
+ FROM links
+ INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""
# A map from chain ID to max sequence number *reachable* from any event ID.
chains: Dict[int, int] = {}
# Add all linked chains reachable from initial set of chains.
- for batch2 in batch_iter(event_chains, 1000):
+ chains_to_fetch = set(event_chains.keys())
+ while chains_to_fetch:
+ batch2 = tuple(itertools.islice(chains_to_fetch, 100))
+ chains_to_fetch.difference_update(batch2)
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
+ links: Dict[int, List[Tuple[int, int, int]]] = {}
+
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
- # chains are only reachable if the origin sequence number of
- # the link is less than the max sequence number in the
- # origin chain.
- if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
- chains[target_chain_id] = max(
- target_sequence_number,
- chains.get(target_chain_id, 0),
- )
+ links.setdefault(origin_chain_id, []).append(
+ (origin_sequence_number, target_chain_id, target_sequence_number)
+ )
+
+ for chain_id in links:
+ if chain_id not in event_chains:
+ continue
+
+ _materialize(chain_id, event_chains[chain_id], links, chains)
+
+ chains_to_fetch.difference_update(chains)
# Add the initial set of chains, excluding the sequence corresponding to
# initial event.
@@ -529,41 +561,64 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
- # Now we look up all links for the chains we have, adding chains to
- # set_to_chain that are reachable from each set.
+ # Now we look up all links for the chains we have, adding chains that
+ # are reachable from any event.
+ #
+ # This query is structured to first get all chain IDs reachable, and
+ # then pull out all links from those chains. This does pull out more
+ # rows than is strictly necessary, however there isn't a way of
+ # structuring the recursive part of query to pull out the links without
+ # also returning large quantities of redundant data (which can make it a
+ # lot slower).
sql = """
+ WITH RECURSIVE links(chain_id) AS (
+ SELECT
+ DISTINCT origin_chain_id
+ FROM event_auth_chain_links WHERE %s
+ UNION
+ SELECT
+ target_chain_id
+ FROM event_auth_chain_links
+ INNER JOIN links ON (chain_id = origin_chain_id)
+ )
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
- FROM event_auth_chain_links
- WHERE %s
+ FROM links
+ INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""
# (We need to take a copy of `seen_chains` as we want to mutate it in
# the loop)
- for batch2 in batch_iter(set(seen_chains), 1000):
+ chains_to_fetch = set(seen_chains)
+ while chains_to_fetch:
+ batch2 = tuple(itertools.islice(chains_to_fetch, 100))
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
+ links: Dict[int, List[Tuple[int, int, int]]] = {}
+
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
- for chains in set_to_chain:
- # chains are only reachable if the origin sequence number of
- # the link is less than the max sequence number in the
- # origin chain.
- if origin_sequence_number <= chains.get(origin_chain_id, 0):
- chains[target_chain_id] = max(
- target_sequence_number,
- chains.get(target_chain_id, 0),
- )
+ links.setdefault(origin_chain_id, []).append(
+ (origin_sequence_number, target_chain_id, target_sequence_number)
+ )
+
+ for chains in set_to_chain:
+ for chain_id in links:
+ if chain_id not in chains:
+ continue
- seen_chains.add(target_chain_id)
+ _materialize(chain_id, chains[chain_id], links, chains)
+
+ chains_to_fetch.difference_update(chains)
+ seen_chains.update(chains)
# Now for each chain we figure out the maximum sequence number reachable
# from *any* state set and the minimum sequence number reachable from
@@ -2103,3 +2158,49 @@ class EventFederationStore(EventFederationWorkerStore):
)
return batch_size
+
+
+def _materialize(
+ origin_chain_id: int,
+ origin_sequence_number: int,
+ links: Dict[int, List[Tuple[int, int, int]]],
+ materialized: Dict[int, int],
+) -> None:
+ """Helper function for fetching auth chain links. For a given origin chain
+ ID / sequence number and a dictionary of links, updates the materialized
+ dict with the reachable chains.
+
+ To get a dict of all chains reachable from a set of chains this function can
+ be called in a loop, once per origin chain with the same links and
+ materialized args. The materialized dict will the result.
+
+ Args:
+ origin_chain_id, origin_sequence_number
+ links: map of the links between chains as a dict from origin chain ID
+ to list of 3-tuples of origin sequence number, target chain ID and
+ target sequence number.
+ materialized: dict to update with new reachability information, as a
+ map from chain ID to max sequence number reachable.
+ """
+
+ # Do a standard graph traversal.
+ stack = [(origin_chain_id, origin_sequence_number)]
+
+ while stack:
+ c, s = stack.pop()
+
+ chain_links = links.get(c, [])
+ for (
+ sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ ) in chain_links:
+ # Ignore any links that are higher up the chain
+ if sequence_number > s:
+ continue
+
+ # Check if we have already visited the target chain before, if so we
+ # can skip it.
+ if materialized.get(target_chain_id, 0) < target_sequence_number:
+ stack.append((target_chain_id, target_sequence_number))
+ materialized[target_chain_id] = target_sequence_number
|