summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/databases/main/event_federation.py82
1 files changed, 70 insertions, 12 deletions
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index a6279a6c13..2e07c37340 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -26,6 +26,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.signatures import SignatureWorkerStore
 from synapse.types import Collection
 from synapse.util.caches.descriptors import cached
+from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
 
 logger = logging.getLogger(__name__)
@@ -40,6 +41,11 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
                 self._delete_old_forward_extrem_cache, 60 * 60 * 1000
             )
 
+        # Cache of event ID to list of auth event IDs and their depths.
+        self._event_auth_cache = LruCache(
+            500000, "_event_auth_cache", size_callback=len
+        )  # type: LruCache[str, List[Tuple[str, int]]]
+
     async def get_auth_chain(
         self, event_ids: Collection[str], include_given: bool = False
     ) -> List[EventBase]:
@@ -84,17 +90,45 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         else:
             results = set()
 
-        base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
+        # We pull out the depth simply so that we can populate the
+        # `_event_auth_cache` cache.
+        base_sql = """
+            SELECT a.event_id, auth_id, depth
+            FROM event_auth AS a
+            INNER JOIN events AS e ON (e.event_id = a.auth_id)
+            WHERE
+        """
 
         front = set(event_ids)
         while front:
             new_front = set()
             for chunk in batch_iter(front, 100):
-                clause, args = make_in_list_sql_clause(
-                    txn.database_engine, "event_id", chunk
-                )
-                txn.execute(base_sql + clause, args)
-                new_front.update(r[0] for r in txn)
+                # Pull the auth events either from the cache or DB.
+                to_fetch = []  # Event IDs to fetch from DB  # type: List[str]
+                for event_id in chunk:
+                    res = self._event_auth_cache.get(event_id)
+                    if res is None:
+                        to_fetch.append(event_id)
+                    else:
+                        new_front.update(auth_id for auth_id, depth in res)
+
+                if to_fetch:
+                    clause, args = make_in_list_sql_clause(
+                        txn.database_engine, "a.event_id", to_fetch
+                    )
+                    txn.execute(base_sql + clause, args)
+
+                    # Note we need to batch up the results by event ID before
+                    # adding to the cache.
+                    to_cache = {}
+                    for event_id, auth_event_id, auth_event_depth in txn:
+                        to_cache.setdefault(event_id, []).append(
+                            (auth_event_id, auth_event_depth)
+                        )
+                        new_front.add(auth_event_id)
+
+                    for event_id, auth_events in to_cache.items():
+                        self._event_auth_cache.set(event_id, auth_events)
 
             new_front -= results
 
@@ -213,14 +247,38 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
                 break
 
             # Fetch the auth events and their depths of the N last events we're
-            # currently walking
+            # currently walking, either from cache or DB.
             search, chunk = search[:-100], search[-100:]
-            clause, args = make_in_list_sql_clause(
-                txn.database_engine, "a.event_id", [e_id for _, e_id in chunk]
-            )
-            txn.execute(base_sql + clause, args)
 
-            for event_id, auth_event_id, auth_event_depth in txn:
+            found = []  # Results found  # type: List[Tuple[str, str, int]]
+            to_fetch = []  # Event IDs to fetch from DB  # type: List[str]
+            for _, event_id in chunk:
+                res = self._event_auth_cache.get(event_id)
+                if res is None:
+                    to_fetch.append(event_id)
+                else:
+                    found.extend((event_id, auth_id, depth) for auth_id, depth in res)
+
+            if to_fetch:
+                clause, args = make_in_list_sql_clause(
+                    txn.database_engine, "a.event_id", to_fetch
+                )
+                txn.execute(base_sql + clause, args)
+
+                # We parse the results and add the to the `found` set and the
+                # cache (note we need to batch up the results by event ID before
+                # adding to the cache).
+                to_cache = {}
+                for event_id, auth_event_id, auth_event_depth in txn:
+                    to_cache.setdefault(event_id, []).append(
+                        (auth_event_id, auth_event_depth)
+                    )
+                    found.append((event_id, auth_event_id, auth_event_depth))
+
+                for event_id, auth_events in to_cache.items():
+                    self._event_auth_cache.set(event_id, auth_events)
+
+            for event_id, auth_event_id, auth_event_depth in found:
                 event_to_auth_events.setdefault(event_id, set()).add(auth_event_id)
 
                 sets = event_to_missing_sets.get(auth_event_id)