diff options
author | Erik Johnston <erik@matrix.org> | 2024-05-17 10:34:39 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2024-05-17 10:35:03 +0100 |
commit | fe9fa90af4a3c3fbcbd6454c06434ddda2b4b3d6 (patch) | |
tree | 0d29c4945bd2f4d7211d0ad9778bdcb90f4a88e9 | |
parent | Up batch size (diff) | |
download | synapse-fe9fa90af4a3c3fbcbd6454c06434ddda2b4b3d6.tar.xz |
Add a cache to auth links
-rw-r--r-- | synapse/storage/databases/main/event_federation.py | 82 |
1 files changed, 79 insertions, 3 deletions
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index f66492b02e..a5d9a25d29 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -120,6 +120,11 @@ class BackfillQueueNavigationItem: type: str +@attr.s(frozen=True, slots=True, auto_attribs=True) +class _ChainLinksCacheEntry: + links: List[Tuple[int, int, int, "_ChainLinksCacheEntry"]] = attr.Factory(list) + + class _NoChainCoverIndex(Exception): def __init__(self, room_id: str): super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,)) @@ -140,6 +145,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas self.hs = hs + self._chain_links_cache: LruCache[int, _ChainLinksCacheEntry] = LruCache( + max_size=10000, cache_name="chain_links_cache" + ) + if hs.config.worker.run_background_tasks: hs.get_clock().looping_call( self._delete_old_forward_extrem_cache, 60 * 60 * 1000 @@ -285,7 +294,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # A map from chain ID to max sequence number *reachable* from any event ID. chains: Dict[int, int] = {} - for links in self._get_chain_links(txn, event_chains.keys()): + for links in self._get_chain_links( + txn, event_chains.keys(), self._chain_links_cache + ): for chain_id in links: if chain_id not in event_chains: continue @@ -337,7 +348,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas @classmethod def _get_chain_links( - cls, txn: LoggingTransaction, chains_to_fetch: Collection[int] + cls, + txn: LoggingTransaction, + chains_to_fetch: Collection[int], + cache: Optional[LruCache[int, _ChainLinksCacheEntry]] = None, ) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]: """Fetch all auth chain links from the given set of chains, and all links from those chains, recursively. @@ -349,6 +363,44 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas of origin sequence number, target chain ID and target sequence number. """ + found_cached_chains = set() + if cache: + entries: Dict[int, _ChainLinksCacheEntry] = {} + for chain_id in chains_to_fetch: + entry = cache.get(chain_id) + if entry: + entries[chain_id] = entry + + cached_links: Dict[int, List[Tuple[int, int, int]]] = {} + while entries: + origin_chain_id, entry = entries.popitem() + + for ( + origin_sequence_number, + target_chain_id, + target_sequence_number, + target_entry, + ) in entry.links: + if target_chain_id in found_cached_chains: + continue + + found_cached_chains.add(target_chain_id) + + cache.get(chain_id) + + entries[chain_id] = target_entry + cached_links.setdefault(origin_chain_id, []).append( + ( + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) + ) + + yield cached_links + + logger.info("CHAINS: Found cached chain links %d", len(found_cached_chains)) + # This query is structured to first get all chain IDs reachable, and # then pull out all links from those chains. This does pull out more # rows than is strictly necessary, however there isn't a way of @@ -385,6 +437,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # likely to depend on lower chain IDs than vice versa). BATCH_SIZE = 5000 chains_to_fetch_sorted = SortedSet(chains_to_fetch) + chains_to_fetch_sorted.difference_update(found_cached_chains) logger.info("CHAINS: Fetching chain links %d", len(chains_to_fetch_sorted)) @@ -406,6 +459,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas links: Dict[int, List[Tuple[int, int, int]]] = {} + cache_entries: Dict[int, _ChainLinksCacheEntry] = {} + for ( origin_chain_id, origin_sequence_number, @@ -416,6 +471,27 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas (origin_sequence_number, target_chain_id, target_sequence_number) ) + if cache: + origin_entry = cache_entries.setdefault( + origin_chain_id, _ChainLinksCacheEntry() + ) + target_entry = cache_entries.setdefault( + target_chain_id, _ChainLinksCacheEntry() + ) + origin_entry.links.append( + ( + origin_sequence_number, + target_chain_id, + target_sequence_number, + target_entry, + ) + ) + + if cache: + for chain_id, entry in cache_entries.items(): + if chain_id not in cache: + cache[chain_id] = entry + chains_to_fetch_sorted.difference_update(links) logger.info("CHAINS: returned %d", len(links)) @@ -614,7 +690,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # (We need to take a copy of `seen_chains` as the function mutates it) logger.info("CHAINS: for room %s", room_id) - for links in self._get_chain_links(txn, seen_chains): + for links in self._get_chain_links(txn, seen_chains, self._chain_links_cache): for chains in set_to_chain: for chain_id in links: if chain_id not in chains: |