summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-08-04 13:54:51 +0100
committerGitHub <noreply@github.com>2021-08-04 13:54:51 +0100
commitc37dad67ab04980ac934554399f52a27e54292ab (patch)
treea985ffee098a29c8c894b056b9437143a9d961e7
parentFix `could not serialize access` errors for `claim_e2e_one_time_keys` (#10504) (diff)
downloadsynapse-c37dad67ab04980ac934554399f52a27e54292ab.tar.xz
Improve event caching code (#10119)
Ensure we only load an event from the DB once when the same event is requested multiple times at once.
-rw-r--r--changelog.d/10119.misc1
-rw-r--r--synapse/storage/databases/main/events_worker.py144
-rw-r--r--synapse/storage/databases/main/roommember.py6
-rw-r--r--tests/storage/databases/main/test_events_worker.py50
4 files changed, 158 insertions, 43 deletions
diff --git a/changelog.d/10119.misc b/changelog.d/10119.misc
new file mode 100644
index 0000000000..f70dc6496f
--- /dev/null
+++ b/changelog.d/10119.misc
@@ -0,0 +1 @@
+Improve event caching mechanism to avoid having multiple copies of an event in memory at a time.
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 3c86adab56..375463e4e9 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -14,7 +14,6 @@
 
 import logging
 import threading
-from collections import namedtuple
 from typing import (
     Collection,
     Container,
@@ -27,6 +26,7 @@ from typing import (
     overload,
 )
 
+import attr
 from constantly import NamedConstant, Names
 from typing_extensions import Literal
 
@@ -42,7 +42,11 @@ from synapse.api.room_versions import (
 from synapse.events import EventBase, make_event_from_dict
 from synapse.events.snapshot import EventContext
 from synapse.events.utils import prune_event
-from synapse.logging.context import PreserveLoggingContext, current_context
+from synapse.logging.context import (
+    PreserveLoggingContext,
+    current_context,
+    make_deferred_yieldable,
+)
 from synapse.metrics.background_process_metrics import (
     run_as_background_process,
     wrap_as_background_process,
@@ -56,6 +60,8 @@ from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
+from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
@@ -74,7 +80,10 @@ EVENT_QUEUE_ITERATIONS = 3  # No. times we block waiting for requests for events
 EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
 
 
-_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+@attr.s(slots=True, auto_attribs=True)
+class _EventCacheEntry:
+    event: EventBase
+    redacted_event: Optional[EventBase]
 
 
 class EventRedactBehaviour(Names):
@@ -161,6 +170,13 @@ class EventsWorkerStore(SQLBaseStore):
             max_size=hs.config.caches.event_cache_size,
         )
 
+        # Map from event ID to a deferred that will result in a map from event
+        # ID to cache entry. Note that the returned dict may not have the
+        # requested event in it if the event isn't in the DB.
+        self._current_event_fetches: Dict[
+            str, ObservableDeferred[Dict[str, _EventCacheEntry]]
+        ] = {}
+
         self._event_fetch_lock = threading.Condition()
         self._event_fetch_list = []
         self._event_fetch_ongoing = 0
@@ -476,7 +492,9 @@ class EventsWorkerStore(SQLBaseStore):
 
         return events
 
-    async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+    async def _get_events_from_cache_or_db(
+        self, event_ids: Iterable[str], allow_rejected: bool = False
+    ) -> Dict[str, _EventCacheEntry]:
         """Fetch a bunch of events from the cache or the database.
 
         If events are pulled from the database, they will be cached for future lookups.
@@ -485,53 +503,107 @@ class EventsWorkerStore(SQLBaseStore):
 
         Args:
 
-            event_ids (Iterable[str]): The event_ids of the events to fetch
+            event_ids: The event_ids of the events to fetch
 
-            allow_rejected (bool): Whether to include rejected events. If False,
+            allow_rejected: Whether to include rejected events. If False,
                 rejected events are omitted from the response.
 
         Returns:
-            Dict[str, _EventCacheEntry]:
-                map from event id to result
+            map from event id to result
         """
         event_entry_map = self._get_events_from_cache(
-            event_ids, allow_rejected=allow_rejected
+            event_ids,
         )
 
-        missing_events_ids = [e for e in event_ids if e not in event_entry_map]
+        missing_events_ids = {e for e in event_ids if e not in event_entry_map}
+
+        # We now look up if we're already fetching some of the events in the DB,
+        # if so we wait for those lookups to finish instead of pulling the same
+        # events out of the DB multiple times.
+        already_fetching: Dict[str, defer.Deferred] = {}
+
+        for event_id in missing_events_ids:
+            deferred = self._current_event_fetches.get(event_id)
+            if deferred is not None:
+                # We're already pulling the event out of the DB. Add the deferred
+                # to the collection of deferreds to wait on.
+                already_fetching[event_id] = deferred.observe()
+
+        missing_events_ids.difference_update(already_fetching)
 
         if missing_events_ids:
             log_ctx = current_context()
             log_ctx.record_event_fetch(len(missing_events_ids))
 
+            # Add entries to `self._current_event_fetches` for each event we're
+            # going to pull from the DB. We use a single deferred that resolves
+            # to all the events we pulled from the DB (this will result in this
+            # function returning more events than requested, but that can happen
+            # already due to `_get_events_from_db`).
+            fetching_deferred: ObservableDeferred[
+                Dict[str, _EventCacheEntry]
+            ] = ObservableDeferred(defer.Deferred())
+            for event_id in missing_events_ids:
+                self._current_event_fetches[event_id] = fetching_deferred
+
             # Note that _get_events_from_db is also responsible for turning db rows
             # into FrozenEvents (via _get_event_from_row), which involves seeing if
             # the events have been redacted, and if so pulling the redaction event out
             # of the database to check it.
             #
-            missing_events = await self._get_events_from_db(
-                missing_events_ids, allow_rejected=allow_rejected
-            )
+            try:
+                missing_events = await self._get_events_from_db(
+                    missing_events_ids,
+                )
 
-            event_entry_map.update(missing_events)
+                event_entry_map.update(missing_events)
+            except Exception as e:
+                with PreserveLoggingContext():
+                    fetching_deferred.errback(e)
+                raise e
+            finally:
+                # Ensure that we mark these events as no longer being fetched.
+                for event_id in missing_events_ids:
+                    self._current_event_fetches.pop(event_id, None)
+
+            with PreserveLoggingContext():
+                fetching_deferred.callback(missing_events)
+
+        if already_fetching:
+            # Wait for the other event requests to finish and add their results
+            # to ours.
+            results = await make_deferred_yieldable(
+                defer.gatherResults(
+                    already_fetching.values(),
+                    consumeErrors=True,
+                )
+            ).addErrback(unwrapFirstError)
+
+            for result in results:
+                event_entry_map.update(result)
+
+        if not allow_rejected:
+            event_entry_map = {
+                event_id: entry
+                for event_id, entry in event_entry_map.items()
+                if not entry.event.rejected_reason
+            }
 
         return event_entry_map
 
     def _invalidate_get_event_cache(self, event_id):
         self._get_event_cache.invalidate((event_id,))
 
-    def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
-        """Fetch events from the caches
+    def _get_events_from_cache(
+        self, events: Iterable[str], update_metrics: bool = True
+    ) -> Dict[str, _EventCacheEntry]:
+        """Fetch events from the caches.
 
-        Args:
-            events (Iterable[str]): list of event_ids to fetch
-            allow_rejected (bool): Whether to return events that were rejected
-            update_metrics (bool): Whether to update the cache hit ratio metrics
+        May return rejected events.
 
-        Returns:
-            dict of event_id -> _EventCacheEntry for each event_id in cache. If
-            allow_rejected is `False` then there will still be an entry but it
-            will be `None`
+        Args:
+            events: list of event_ids to fetch
+            update_metrics: Whether to update the cache hit ratio metrics
         """
         event_map = {}
 
@@ -542,10 +614,7 @@ class EventsWorkerStore(SQLBaseStore):
             if not ret:
                 continue
 
-            if allow_rejected or not ret.event.rejected_reason:
-                event_map[event_id] = ret
-            else:
-                event_map[event_id] = None
+            event_map[event_id] = ret
 
         return event_map
 
@@ -672,23 +741,23 @@ class EventsWorkerStore(SQLBaseStore):
                 with PreserveLoggingContext():
                     self.hs.get_reactor().callFromThread(fire, event_list, e)
 
-    async def _get_events_from_db(self, event_ids, allow_rejected=False):
+    async def _get_events_from_db(
+        self, event_ids: Iterable[str]
+    ) -> Dict[str, _EventCacheEntry]:
         """Fetch a bunch of events from the database.
 
+        May return rejected events.
+
         Returned events will be added to the cache for future lookups.
 
         Unknown events are omitted from the response.
 
         Args:
-            event_ids (Iterable[str]): The event_ids of the events to fetch
-
-            allow_rejected (bool): Whether to include rejected events. If False,
-                rejected events are omitted from the response.
+            event_ids: The event_ids of the events to fetch
 
         Returns:
-            Dict[str, _EventCacheEntry]:
-                map from event id to result. May return extra events which
-                weren't asked for.
+            map from event id to result. May return extra events which
+            weren't asked for.
         """
         fetched_events = {}
         events_to_fetch = event_ids
@@ -717,9 +786,6 @@ class EventsWorkerStore(SQLBaseStore):
 
             rejected_reason = row["rejected_reason"]
 
-            if not allow_rejected and rejected_reason:
-                continue
-
             # If the event or metadata cannot be parsed, log the error and act
             # as if the event is unknown.
             try:
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 68f1b40ea6..e8157ba3d4 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -629,14 +629,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # We don't update the event cache hit ratio as it completely throws off
         # the hit ratio counts. After all, we don't populate the cache if we
         # miss it here
-        event_map = self._get_events_from_cache(
-            member_event_ids, allow_rejected=False, update_metrics=False
-        )
+        event_map = self._get_events_from_cache(member_event_ids, update_metrics=False)
 
         missing_member_event_ids = []
         for event_id in member_event_ids:
             ev_entry = event_map.get(event_id)
-            if ev_entry:
+            if ev_entry and not ev_entry.event.rejected_reason:
                 if ev_entry.event.membership == Membership.JOIN:
                     users_in_room[ev_entry.event.state_key] = ProfileInfo(
                         display_name=ev_entry.event.content.get("displayname", None),
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 932970fd9a..d05d367685 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -14,7 +14,10 @@
 import json
 
 from synapse.logging.context import LoggingContext
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util.async_helpers import yieldable_gather_results
 
 from tests import unittest
 
@@ -94,3 +97,50 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
             res = self.get_success(self.store.have_seen_events("room1", ["event10"]))
             self.assertEquals(res, {"event10"})
             self.assertEquals(ctx.get_resource_usage().db_txn_count, 0)
+
+
+class EventCacheTestCase(unittest.HomeserverTestCase):
+    """Test that the various layers of event cache works."""
+
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store: EventsWorkerStore = hs.get_datastore()
+
+        self.user = self.register_user("user", "pass")
+        self.token = self.login(self.user, "pass")
+
+        self.room = self.helper.create_room_as(self.user, tok=self.token)
+
+        res = self.helper.send(self.room, tok=self.token)
+        self.event_id = res["event_id"]
+
+        # Reset the event cache so the tests start with it empty
+        self.store._get_event_cache.clear()
+
+    def test_simple(self):
+        """Test that we cache events that we pull from the DB."""
+
+        with LoggingContext("test") as ctx:
+            self.get_success(self.store.get_event(self.event_id))
+
+            # We should have fetched the event from the DB
+            self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+
+    def test_dedupe(self):
+        """Test that if we request the same event multiple times we only pull it
+        out once.
+        """
+
+        with LoggingContext("test") as ctx:
+            d = yieldable_gather_results(
+                self.store.get_event, [self.event_id, self.event_id]
+            )
+            self.get_success(d)
+
+            # We should have fetched the event from the DB
+            self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)