diff options
author | Erik Johnston <erikj@element.io> | 2024-01-23 11:26:27 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-23 11:26:27 +0000 |
commit | 14c725f73b1e4da37566def0491670009d718539 (patch) | |
tree | 6ec594ef8162a4d1a465733f5a99fb18ac31bbea /synapse/storage/databases/main | |
parent | Add a `--generate-only` option to the Complement launcher. (#16828) (diff) | |
download | synapse-14c725f73b1e4da37566def0491670009d718539.tar.xz |
Preparatory work for tweaking performance of auth chain lookups (#16833)
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r-- | synapse/storage/databases/main/event_federation.py | 153 |
1 files changed, 127 insertions, 26 deletions
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ddc2baf95d..e00e7ebf76 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -159,6 +159,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas unique_columns=("event_id", "room_id"), ) + self.db_pool.updates.register_background_index_update( + update_name="event_auth_chain_links_origin_index", + index_name="event_auth_chain_links_origin_index", + table="event_auth_chain_links", + columns=("origin_chain_id", "origin_sequence_number"), + ) + async def get_auth_chain( self, room_id: str, event_ids: Collection[str], include_given: bool = False ) -> List[EventBase]: @@ -271,38 +278,63 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Now we look up all links for the chains we have, adding chains that # are reachable from any event. + # + # This query is structured to first get all chain IDs reachable, and + # then pull out all links from those chains. This does pull out more + # rows than is strictly necessary, however there isn't a way of + # structuring the recursive part of query to pull out the links without + # also returning large quantities of redundant data (which can make it a + # lot slower). sql = """ + WITH RECURSIVE links(chain_id) AS ( + SELECT + DISTINCT origin_chain_id + FROM event_auth_chain_links WHERE %s + UNION + SELECT + target_chain_id + FROM event_auth_chain_links + INNER JOIN links ON (chain_id = origin_chain_id) + ) SELECT origin_chain_id, origin_sequence_number, target_chain_id, target_sequence_number - FROM event_auth_chain_links - WHERE %s + FROM links + INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id) """ # A map from chain ID to max sequence number *reachable* from any event ID. chains: Dict[int, int] = {} # Add all linked chains reachable from initial set of chains. - for batch2 in batch_iter(event_chains, 1000): + chains_to_fetch = set(event_chains.keys()) + while chains_to_fetch: + batch2 = tuple(itertools.islice(chains_to_fetch, 100)) + chains_to_fetch.difference_update(batch2) clause, args = make_in_list_sql_clause( txn.database_engine, "origin_chain_id", batch2 ) txn.execute(sql % (clause,), args) + links: Dict[int, List[Tuple[int, int, int]]] = {} + for ( origin_chain_id, origin_sequence_number, target_chain_id, target_sequence_number, ) in txn: - # chains are only reachable if the origin sequence number of - # the link is less than the max sequence number in the - # origin chain. - if origin_sequence_number <= event_chains.get(origin_chain_id, 0): - chains[target_chain_id] = max( - target_sequence_number, - chains.get(target_chain_id, 0), - ) + links.setdefault(origin_chain_id, []).append( + (origin_sequence_number, target_chain_id, target_sequence_number) + ) + + for chain_id in links: + if chain_id not in event_chains: + continue + + _materialize(chain_id, event_chains[chain_id], links, chains) + + chains_to_fetch.difference_update(chains) # Add the initial set of chains, excluding the sequence corresponding to # initial event. @@ -529,41 +561,64 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas chains[chain_id] = max(seq_no, chains.get(chain_id, 0)) - # Now we look up all links for the chains we have, adding chains to - # set_to_chain that are reachable from each set. + # Now we look up all links for the chains we have, adding chains that + # are reachable from any event. + # + # This query is structured to first get all chain IDs reachable, and + # then pull out all links from those chains. This does pull out more + # rows than is strictly necessary, however there isn't a way of + # structuring the recursive part of query to pull out the links without + # also returning large quantities of redundant data (which can make it a + # lot slower). sql = """ + WITH RECURSIVE links(chain_id) AS ( + SELECT + DISTINCT origin_chain_id + FROM event_auth_chain_links WHERE %s + UNION + SELECT + target_chain_id + FROM event_auth_chain_links + INNER JOIN links ON (chain_id = origin_chain_id) + ) SELECT origin_chain_id, origin_sequence_number, target_chain_id, target_sequence_number - FROM event_auth_chain_links - WHERE %s + FROM links + INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id) """ # (We need to take a copy of `seen_chains` as we want to mutate it in # the loop) - for batch2 in batch_iter(set(seen_chains), 1000): + chains_to_fetch = set(seen_chains) + while chains_to_fetch: + batch2 = tuple(itertools.islice(chains_to_fetch, 100)) clause, args = make_in_list_sql_clause( txn.database_engine, "origin_chain_id", batch2 ) txn.execute(sql % (clause,), args) + links: Dict[int, List[Tuple[int, int, int]]] = {} + for ( origin_chain_id, origin_sequence_number, target_chain_id, target_sequence_number, ) in txn: - for chains in set_to_chain: - # chains are only reachable if the origin sequence number of - # the link is less than the max sequence number in the - # origin chain. - if origin_sequence_number <= chains.get(origin_chain_id, 0): - chains[target_chain_id] = max( - target_sequence_number, - chains.get(target_chain_id, 0), - ) + links.setdefault(origin_chain_id, []).append( + (origin_sequence_number, target_chain_id, target_sequence_number) + ) + + for chains in set_to_chain: + for chain_id in links: + if chain_id not in chains: + continue - seen_chains.add(target_chain_id) + _materialize(chain_id, chains[chain_id], links, chains) + + chains_to_fetch.difference_update(chains) + seen_chains.update(chains) # Now for each chain we figure out the maximum sequence number reachable # from *any* state set and the minimum sequence number reachable from @@ -2103,3 +2158,49 @@ class EventFederationStore(EventFederationWorkerStore): ) return batch_size + + +def _materialize( + origin_chain_id: int, + origin_sequence_number: int, + links: Dict[int, List[Tuple[int, int, int]]], + materialized: Dict[int, int], +) -> None: + """Helper function for fetching auth chain links. For a given origin chain + ID / sequence number and a dictionary of links, updates the materialized + dict with the reachable chains. + + To get a dict of all chains reachable from a set of chains this function can + be called in a loop, once per origin chain with the same links and + materialized args. The materialized dict will the result. + + Args: + origin_chain_id, origin_sequence_number + links: map of the links between chains as a dict from origin chain ID + to list of 3-tuples of origin sequence number, target chain ID and + target sequence number. + materialized: dict to update with new reachability information, as a + map from chain ID to max sequence number reachable. + """ + + # Do a standard graph traversal. + stack = [(origin_chain_id, origin_sequence_number)] + + while stack: + c, s = stack.pop() + + chain_links = links.get(c, []) + for ( + sequence_number, + target_chain_id, + target_sequence_number, + ) in chain_links: + # Ignore any links that are higher up the chain + if sequence_number > s: + continue + + # Check if we have already visited the target chain before, if so we + # can skip it. + if materialized.get(target_chain_id, 0) < target_sequence_number: + stack.append((target_chain_id, target_sequence_number)) + materialized[target_chain_id] = target_sequence_number |