diff options
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r-- | synapse/storage/databases/main/events.py | 108 |
1 files changed, 37 insertions, 71 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index a6fda3f43c..1e731d56bd 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -19,6 +19,7 @@ # [This file includes modifications made by New Vector Limited] # # +import collections import itertools import logging from collections import OrderedDict @@ -53,6 +54,7 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) +from synapse.storage.databases.main.event_federation import EventFederationStore from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.search import SearchEntry from synapse.storage.engines import PostgresEngine @@ -768,40 +770,26 @@ class PersistEventsStore: # that have the same chain ID as the event. # 2. For each retained auth event we: # a. Add a link from the event's to the auth event's chain - # ID/sequence number; and - # b. Add a link from the event to every chain reachable by the - # auth event. + # ID/sequence number # Step 1, fetch all existing links from all the chains we've seen # referenced. chain_links = _LinkMap() - auth_chain_rows = cast( - List[Tuple[int, int, int, int]], - db_pool.simple_select_many_txn( - txn, - table="event_auth_chain_links", - column="origin_chain_id", - iterable={chain_id for chain_id, _ in chain_map.values()}, - keyvalues={}, - retcols=( - "origin_chain_id", - "origin_sequence_number", - "target_chain_id", - "target_sequence_number", - ), - ), - ) - for ( - origin_chain_id, - origin_sequence_number, - target_chain_id, - target_sequence_number, - ) in auth_chain_rows: - chain_links.add_link( - (origin_chain_id, origin_sequence_number), - (target_chain_id, target_sequence_number), - new=False, - ) + + for links in EventFederationStore._get_chain_links( + txn, {chain_id for chain_id, _ in chain_map.values()} + ): + for origin_chain_id, inner_links in links.items(): + for ( + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) in inner_links: + chain_links.add_link( + (origin_chain_id, origin_sequence_number), + (target_chain_id, target_sequence_number), + new=False, + ) # We do this in toplogical order to avoid adding redundant links. for event_id in sorted_topologically( @@ -836,18 +824,6 @@ class PersistEventsStore: (chain_id, sequence_number), (auth_chain_id, auth_sequence_number) ) - # Step 2b, add a link to chains reachable from the auth - # event. - for target_id, target_seq in chain_links.get_links_from( - (auth_chain_id, auth_sequence_number) - ): - if target_id == chain_id: - continue - - chain_links.add_link( - (chain_id, sequence_number), (target_id, target_seq) - ) - db_pool.simple_insert_many_txn( txn, table="event_auth_chain_links", @@ -2451,31 +2427,6 @@ class _LinkMap: current_links[src_seq] = target_seq return True - def get_links_from( - self, src_tuple: Tuple[int, int] - ) -> Generator[Tuple[int, int], None, None]: - """Gets the chains reachable from the given chain/sequence number. - - Yields: - The chain ID and sequence number the link points to. - """ - src_chain, src_seq = src_tuple - for target_id, sequence_numbers in self.maps.get(src_chain, {}).items(): - for link_src_seq, target_seq in sequence_numbers.items(): - if link_src_seq <= src_seq: - yield target_id, target_seq - - def get_links_between( - self, source_chain: int, target_chain: int - ) -> Generator[Tuple[int, int], None, None]: - """Gets the links between two chains. - - Yields: - The source and target sequence numbers. - """ - - yield from self.maps.get(source_chain, {}).get(target_chain, {}).items() - def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]: """Gets any newly added links. @@ -2502,9 +2453,24 @@ class _LinkMap: if src_chain == target_chain: return target_seq <= src_seq - links = self.get_links_between(src_chain, target_chain) - for link_start_seq, link_end_seq in links: - if link_start_seq <= src_seq and target_seq <= link_end_seq: - return True + # We have to graph traverse the links to check for indirect paths. + visited_chains = collections.Counter() + search = [(src_chain, src_seq)] + while search: + chain, seq = search.pop() + visited_chains[chain] = max(seq, visited_chains[chain]) + for tc, links in self.maps.get(chain, {}).items(): + for ss, ts in links.items(): + # Don't revisit chains we've already seen, unless the target + # sequence number is higher than last time. + if ts <= visited_chains.get(tc, 0): + continue + + if ss <= seq: + if tc == target_chain: + if target_seq <= ts: + return True + else: + search.append((tc, ts)) return False |