summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2024-05-17 10:34:39 +0100
committerErik Johnston <erik@matrix.org>2024-05-17 10:35:03 +0100
commitfe9fa90af4a3c3fbcbd6454c06434ddda2b4b3d6 (patch)
tree0d29c4945bd2f4d7211d0ad9778bdcb90f4a88e9
parentUp batch size (diff)
downloadsynapse-fe9fa90af4a3c3fbcbd6454c06434ddda2b4b3d6.tar.xz
Add a cache to auth links
-rw-r--r--synapse/storage/databases/main/event_federation.py82
1 files changed, 79 insertions, 3 deletions
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index f66492b02e..a5d9a25d29 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -120,6 +120,11 @@ class BackfillQueueNavigationItem:
     type: str
 
 
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class _ChainLinksCacheEntry:
+    links: List[Tuple[int, int, int, "_ChainLinksCacheEntry"]] = attr.Factory(list)
+
+
 class _NoChainCoverIndex(Exception):
     def __init__(self, room_id: str):
         super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
@@ -140,6 +145,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         self.hs = hs
 
+        self._chain_links_cache: LruCache[int, _ChainLinksCacheEntry] = LruCache(
+            max_size=10000, cache_name="chain_links_cache"
+        )
+
         if hs.config.worker.run_background_tasks:
             hs.get_clock().looping_call(
                 self._delete_old_forward_extrem_cache, 60 * 60 * 1000
@@ -285,7 +294,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         # A map from chain ID to max sequence number *reachable* from any event ID.
         chains: Dict[int, int] = {}
-        for links in self._get_chain_links(txn, event_chains.keys()):
+        for links in self._get_chain_links(
+            txn, event_chains.keys(), self._chain_links_cache
+        ):
             for chain_id in links:
                 if chain_id not in event_chains:
                     continue
@@ -337,7 +348,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
     @classmethod
     def _get_chain_links(
-        cls, txn: LoggingTransaction, chains_to_fetch: Collection[int]
+        cls,
+        txn: LoggingTransaction,
+        chains_to_fetch: Collection[int],
+        cache: Optional[LruCache[int, _ChainLinksCacheEntry]] = None,
     ) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]:
         """Fetch all auth chain links from the given set of chains, and all
         links from those chains, recursively.
@@ -349,6 +363,44 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         of origin sequence number, target chain ID and target sequence number.
         """
 
+        found_cached_chains = set()
+        if cache:
+            entries: Dict[int, _ChainLinksCacheEntry] = {}
+            for chain_id in chains_to_fetch:
+                entry = cache.get(chain_id)
+                if entry:
+                    entries[chain_id] = entry
+
+            cached_links: Dict[int, List[Tuple[int, int, int]]] = {}
+            while entries:
+                origin_chain_id, entry = entries.popitem()
+
+                for (
+                    origin_sequence_number,
+                    target_chain_id,
+                    target_sequence_number,
+                    target_entry,
+                ) in entry.links:
+                    if target_chain_id in found_cached_chains:
+                        continue
+
+                    found_cached_chains.add(target_chain_id)
+
+                    cache.get(chain_id)
+
+                    entries[chain_id] = target_entry
+                    cached_links.setdefault(origin_chain_id, []).append(
+                        (
+                            origin_sequence_number,
+                            target_chain_id,
+                            target_sequence_number,
+                        )
+                    )
+
+            yield cached_links
+
+        logger.info("CHAINS: Found cached chain links %d", len(found_cached_chains))
+
         # This query is structured to first get all chain IDs reachable, and
         # then pull out all links from those chains. This does pull out more
         # rows than is strictly necessary, however there isn't a way of
@@ -385,6 +437,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # likely to depend on lower chain IDs than vice versa).
         BATCH_SIZE = 5000
         chains_to_fetch_sorted = SortedSet(chains_to_fetch)
+        chains_to_fetch_sorted.difference_update(found_cached_chains)
 
         logger.info("CHAINS: Fetching chain links %d", len(chains_to_fetch_sorted))
 
@@ -406,6 +459,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
             links: Dict[int, List[Tuple[int, int, int]]] = {}
 
+            cache_entries: Dict[int, _ChainLinksCacheEntry] = {}
+
             for (
                 origin_chain_id,
                 origin_sequence_number,
@@ -416,6 +471,27 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                     (origin_sequence_number, target_chain_id, target_sequence_number)
                 )
 
+                if cache:
+                    origin_entry = cache_entries.setdefault(
+                        origin_chain_id, _ChainLinksCacheEntry()
+                    )
+                    target_entry = cache_entries.setdefault(
+                        target_chain_id, _ChainLinksCacheEntry()
+                    )
+                    origin_entry.links.append(
+                        (
+                            origin_sequence_number,
+                            target_chain_id,
+                            target_sequence_number,
+                            target_entry,
+                        )
+                    )
+
+            if cache:
+                for chain_id, entry in cache_entries.items():
+                    if chain_id not in cache:
+                        cache[chain_id] = entry
+
             chains_to_fetch_sorted.difference_update(links)
 
             logger.info("CHAINS: returned %d", len(links))
@@ -614,7 +690,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         # (We need to take a copy of `seen_chains` as the function mutates it)
         logger.info("CHAINS: for room %s", room_id)
-        for links in self._get_chain_links(txn, seen_chains):
+        for links in self._get_chain_links(txn, seen_chains, self._chain_links_cache):
             for chains in set_to_chain:
                 for chain_id in links:
                     if chain_id not in chains: