diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index f66492b02e..a5d9a25d29 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -120,6 +120,11 @@ class BackfillQueueNavigationItem:
type: str
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class _ChainLinksCacheEntry:
+ links: List[Tuple[int, int, int, "_ChainLinksCacheEntry"]] = attr.Factory(list)
+
+
class _NoChainCoverIndex(Exception):
def __init__(self, room_id: str):
super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
@@ -140,6 +145,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
self.hs = hs
+ self._chain_links_cache: LruCache[int, _ChainLinksCacheEntry] = LruCache(
+ max_size=10000, cache_name="chain_links_cache"
+ )
+
if hs.config.worker.run_background_tasks:
hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
@@ -285,7 +294,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# A map from chain ID to max sequence number *reachable* from any event ID.
chains: Dict[int, int] = {}
- for links in self._get_chain_links(txn, event_chains.keys()):
+ for links in self._get_chain_links(
+ txn, event_chains.keys(), self._chain_links_cache
+ ):
for chain_id in links:
if chain_id not in event_chains:
continue
@@ -337,7 +348,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@classmethod
def _get_chain_links(
- cls, txn: LoggingTransaction, chains_to_fetch: Collection[int]
+ cls,
+ txn: LoggingTransaction,
+ chains_to_fetch: Collection[int],
+ cache: Optional[LruCache[int, _ChainLinksCacheEntry]] = None,
) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]:
"""Fetch all auth chain links from the given set of chains, and all
links from those chains, recursively.
@@ -349,6 +363,44 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
of origin sequence number, target chain ID and target sequence number.
"""
+ found_cached_chains = set()
+ if cache:
+ entries: Dict[int, _ChainLinksCacheEntry] = {}
+ for chain_id in chains_to_fetch:
+ entry = cache.get(chain_id)
+ if entry:
+ entries[chain_id] = entry
+
+ cached_links: Dict[int, List[Tuple[int, int, int]]] = {}
+ while entries:
+ origin_chain_id, entry = entries.popitem()
+
+ for (
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ target_entry,
+ ) in entry.links:
+ if target_chain_id in found_cached_chains:
+ continue
+
+ found_cached_chains.add(target_chain_id)
+
+ cache.get(chain_id)
+
+ entries[chain_id] = target_entry
+ cached_links.setdefault(origin_chain_id, []).append(
+ (
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ )
+ )
+
+ yield cached_links
+
+ logger.info("CHAINS: Found cached chain links %d", len(found_cached_chains))
+
# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
@@ -385,6 +437,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# likely to depend on lower chain IDs than vice versa).
BATCH_SIZE = 5000
chains_to_fetch_sorted = SortedSet(chains_to_fetch)
+ chains_to_fetch_sorted.difference_update(found_cached_chains)
logger.info("CHAINS: Fetching chain links %d", len(chains_to_fetch_sorted))
@@ -406,6 +459,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
links: Dict[int, List[Tuple[int, int, int]]] = {}
+ cache_entries: Dict[int, _ChainLinksCacheEntry] = {}
+
for (
origin_chain_id,
origin_sequence_number,
@@ -416,6 +471,27 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
(origin_sequence_number, target_chain_id, target_sequence_number)
)
+ if cache:
+ origin_entry = cache_entries.setdefault(
+ origin_chain_id, _ChainLinksCacheEntry()
+ )
+ target_entry = cache_entries.setdefault(
+ target_chain_id, _ChainLinksCacheEntry()
+ )
+ origin_entry.links.append(
+ (
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ target_entry,
+ )
+ )
+
+ if cache:
+ for chain_id, entry in cache_entries.items():
+ if chain_id not in cache:
+ cache[chain_id] = entry
+
chains_to_fetch_sorted.difference_update(links)
logger.info("CHAINS: returned %d", len(links))
@@ -614,7 +690,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# (We need to take a copy of `seen_chains` as the function mutates it)
logger.info("CHAINS: for room %s", room_id)
- for links in self._get_chain_links(txn, seen_chains):
+ for links in self._get_chain_links(txn, seen_chains, self._chain_links_cache):
for chains in set_to_chain:
for chain_id in links:
if chain_id not in chains:
|