diff --git a/changelog.d/10119.misc b/changelog.d/10119.misc
new file mode 100644
index 0000000000..f70dc6496f
--- /dev/null
+++ b/changelog.d/10119.misc
@@ -0,0 +1 @@
+Improve event caching mechanism to avoid having multiple copies of an event in memory at a time.
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 3c86adab56..375463e4e9 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -14,7 +14,6 @@
import logging
import threading
-from collections import namedtuple
from typing import (
Collection,
Container,
@@ -27,6 +26,7 @@ from typing import (
overload,
)
+import attr
from constantly import NamedConstant, Names
from typing_extensions import Literal
@@ -42,7 +42,11 @@ from synapse.api.room_versions import (
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event
-from synapse.logging.context import PreserveLoggingContext, current_context
+from synapse.logging.context import (
+ PreserveLoggingContext,
+ current_context,
+ make_deferred_yieldable,
+)
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
@@ -56,6 +60,8 @@ from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
+from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
@@ -74,7 +80,10 @@ EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
-_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+@attr.s(slots=True, auto_attribs=True)
+class _EventCacheEntry:
+ event: EventBase
+ redacted_event: Optional[EventBase]
class EventRedactBehaviour(Names):
@@ -161,6 +170,13 @@ class EventsWorkerStore(SQLBaseStore):
max_size=hs.config.caches.event_cache_size,
)
+ # Map from event ID to a deferred that will result in a map from event
+ # ID to cache entry. Note that the returned dict may not have the
+ # requested event in it if the event isn't in the DB.
+ self._current_event_fetches: Dict[
+ str, ObservableDeferred[Dict[str, _EventCacheEntry]]
+ ] = {}
+
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
@@ -476,7 +492,9 @@ class EventsWorkerStore(SQLBaseStore):
return events
- async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+ async def _get_events_from_cache_or_db(
+ self, event_ids: Iterable[str], allow_rejected: bool = False
+ ) -> Dict[str, _EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@@ -485,53 +503,107 @@ class EventsWorkerStore(SQLBaseStore):
Args:
- event_ids (Iterable[str]): The event_ids of the events to fetch
+ event_ids: The event_ids of the events to fetch
- allow_rejected (bool): Whether to include rejected events. If False,
+ allow_rejected: Whether to include rejected events. If False,
rejected events are omitted from the response.
Returns:
- Dict[str, _EventCacheEntry]:
- map from event id to result
+ map from event id to result
"""
event_entry_map = self._get_events_from_cache(
- event_ids, allow_rejected=allow_rejected
+ event_ids,
)
- missing_events_ids = [e for e in event_ids if e not in event_entry_map]
+ missing_events_ids = {e for e in event_ids if e not in event_entry_map}
+
+ # We now look up if we're already fetching some of the events in the DB,
+ # if so we wait for those lookups to finish instead of pulling the same
+ # events out of the DB multiple times.
+ already_fetching: Dict[str, defer.Deferred] = {}
+
+ for event_id in missing_events_ids:
+ deferred = self._current_event_fetches.get(event_id)
+ if deferred is not None:
+ # We're already pulling the event out of the DB. Add the deferred
+ # to the collection of deferreds to wait on.
+ already_fetching[event_id] = deferred.observe()
+
+ missing_events_ids.difference_update(already_fetching)
if missing_events_ids:
log_ctx = current_context()
log_ctx.record_event_fetch(len(missing_events_ids))
+ # Add entries to `self._current_event_fetches` for each event we're
+ # going to pull from the DB. We use a single deferred that resolves
+ # to all the events we pulled from the DB (this will result in this
+ # function returning more events than requested, but that can happen
+ # already due to `_get_events_from_db`).
+ fetching_deferred: ObservableDeferred[
+ Dict[str, _EventCacheEntry]
+ ] = ObservableDeferred(defer.Deferred())
+ for event_id in missing_events_ids:
+ self._current_event_fetches[event_id] = fetching_deferred
+
# Note that _get_events_from_db is also responsible for turning db rows
# into FrozenEvents (via _get_event_from_row), which involves seeing if
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
- missing_events = await self._get_events_from_db(
- missing_events_ids, allow_rejected=allow_rejected
- )
+ try:
+ missing_events = await self._get_events_from_db(
+ missing_events_ids,
+ )
- event_entry_map.update(missing_events)
+ event_entry_map.update(missing_events)
+ except Exception as e:
+ with PreserveLoggingContext():
+ fetching_deferred.errback(e)
+ raise e
+ finally:
+ # Ensure that we mark these events as no longer being fetched.
+ for event_id in missing_events_ids:
+ self._current_event_fetches.pop(event_id, None)
+
+ with PreserveLoggingContext():
+ fetching_deferred.callback(missing_events)
+
+ if already_fetching:
+ # Wait for the other event requests to finish and add their results
+ # to ours.
+ results = await make_deferred_yieldable(
+ defer.gatherResults(
+ already_fetching.values(),
+ consumeErrors=True,
+ )
+ ).addErrback(unwrapFirstError)
+
+ for result in results:
+ event_entry_map.update(result)
+
+ if not allow_rejected:
+ event_entry_map = {
+ event_id: entry
+ for event_id, entry in event_entry_map.items()
+ if not entry.event.rejected_reason
+ }
return event_entry_map
def _invalidate_get_event_cache(self, event_id):
self._get_event_cache.invalidate((event_id,))
- def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
- """Fetch events from the caches
+ def _get_events_from_cache(
+ self, events: Iterable[str], update_metrics: bool = True
+ ) -> Dict[str, _EventCacheEntry]:
+ """Fetch events from the caches.
- Args:
- events (Iterable[str]): list of event_ids to fetch
- allow_rejected (bool): Whether to return events that were rejected
- update_metrics (bool): Whether to update the cache hit ratio metrics
+ May return rejected events.
- Returns:
- dict of event_id -> _EventCacheEntry for each event_id in cache. If
- allow_rejected is `False` then there will still be an entry but it
- will be `None`
+ Args:
+ events: list of event_ids to fetch
+ update_metrics: Whether to update the cache hit ratio metrics
"""
event_map = {}
@@ -542,10 +614,7 @@ class EventsWorkerStore(SQLBaseStore):
if not ret:
continue
- if allow_rejected or not ret.event.rejected_reason:
- event_map[event_id] = ret
- else:
- event_map[event_id] = None
+ event_map[event_id] = ret
return event_map
@@ -672,23 +741,23 @@ class EventsWorkerStore(SQLBaseStore):
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)
- async def _get_events_from_db(self, event_ids, allow_rejected=False):
+ async def _get_events_from_db(
+ self, event_ids: Iterable[str]
+ ) -> Dict[str, _EventCacheEntry]:
"""Fetch a bunch of events from the database.
+ May return rejected events.
+
Returned events will be added to the cache for future lookups.
Unknown events are omitted from the response.
Args:
- event_ids (Iterable[str]): The event_ids of the events to fetch
-
- allow_rejected (bool): Whether to include rejected events. If False,
- rejected events are omitted from the response.
+ event_ids: The event_ids of the events to fetch
Returns:
- Dict[str, _EventCacheEntry]:
- map from event id to result. May return extra events which
- weren't asked for.
+ map from event id to result. May return extra events which
+ weren't asked for.
"""
fetched_events = {}
events_to_fetch = event_ids
@@ -717,9 +786,6 @@ class EventsWorkerStore(SQLBaseStore):
rejected_reason = row["rejected_reason"]
- if not allow_rejected and rejected_reason:
- continue
-
# If the event or metadata cannot be parsed, log the error and act
# as if the event is unknown.
try:
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 68f1b40ea6..e8157ba3d4 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -629,14 +629,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# We don't update the event cache hit ratio as it completely throws off
# the hit ratio counts. After all, we don't populate the cache if we
# miss it here
- event_map = self._get_events_from_cache(
- member_event_ids, allow_rejected=False, update_metrics=False
- )
+ event_map = self._get_events_from_cache(member_event_ids, update_metrics=False)
missing_member_event_ids = []
for event_id in member_event_ids:
ev_entry = event_map.get(event_id)
- if ev_entry:
+ if ev_entry and not ev_entry.event.rejected_reason:
if ev_entry.event.membership == Membership.JOIN:
users_in_room[ev_entry.event.state_key] = ProfileInfo(
display_name=ev_entry.event.content.get("displayname", None),
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 932970fd9a..d05d367685 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -14,7 +14,10 @@
import json
from synapse.logging.context import LoggingContext
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util.async_helpers import yieldable_gather_results
from tests import unittest
@@ -94,3 +97,50 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
res = self.get_success(self.store.have_seen_events("room1", ["event10"]))
self.assertEquals(res, {"event10"})
self.assertEquals(ctx.get_resource_usage().db_txn_count, 0)
+
+
+class EventCacheTestCase(unittest.HomeserverTestCase):
+ """Test that the various layers of event cache works."""
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store: EventsWorkerStore = hs.get_datastore()
+
+ self.user = self.register_user("user", "pass")
+ self.token = self.login(self.user, "pass")
+
+ self.room = self.helper.create_room_as(self.user, tok=self.token)
+
+ res = self.helper.send(self.room, tok=self.token)
+ self.event_id = res["event_id"]
+
+ # Reset the event cache so the tests start with it empty
+ self.store._get_event_cache.clear()
+
+ def test_simple(self):
+ """Test that we cache events that we pull from the DB."""
+
+ with LoggingContext("test") as ctx:
+ self.get_success(self.store.get_event(self.event_id))
+
+ # We should have fetched the event from the DB
+ self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+
+ def test_dedupe(self):
+ """Test that if we request the same event multiple times we only pull it
+ out once.
+ """
+
+ with LoggingContext("test") as ctx:
+ d = yieldable_gather_results(
+ self.store.get_event, [self.event_id, self.event_id]
+ )
+ self.get_success(d)
+
+ # We should have fetched the event from the DB
+ self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
|