summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-05-06 14:42:42 +0100
committerErik Johnston <erik@matrix.org>2021-05-06 14:42:42 +0100
commit9d1118dde8781e5a7389289fb12af6b3357777de (patch)
tree0213fa7f13b00408c7542937ee4054d6343f85a7
parentMerge remote-tracking branch 'origin/master' into develop (diff)
downloadsynapse-9d1118dde8781e5a7389289fb12af6b3357777de.tar.xz
Ensure we only have one copy of an event in memory at a time
This ensures that if the get event cache overflows we don't end up with
multiple copies of the event in RAM at the same time (which could lead
to memory bloat)
-rw-r--r--synapse/storage/databases/main/censor_events.py2
-rw-r--r--synapse/storage/databases/main/events_worker.py26
2 files changed, 23 insertions, 5 deletions
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index f22c1f241b..b41948b0c0 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -181,7 +181,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
             # changed its content in the database. We can't call
             # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
             # right type.
-            txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+            txn.call_after(self._invalidate_get_event_cache, event.event_id)
             # Send that invalidation to replication so that other workers also invalidate
             # the event cache.
             self._send_invalidation_to_replication(
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 2c823e09cf..66eaf946d7 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -14,7 +14,6 @@
 
 import logging
 import threading
-from collections import namedtuple
 from typing import (
     Collection,
     Container,
@@ -25,7 +24,9 @@ from typing import (
     Tuple,
     overload,
 )
+from weakref import WeakValueDictionary
 
+import attr
 from constantly import NamedConstant, Names
 from typing_extensions import Literal
 
@@ -73,7 +74,10 @@ EVENT_QUEUE_ITERATIONS = 3  # No. times we block waiting for requests for events
 EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
 
 
-_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventCacheEntry:
+    event: EventBase
+    redacted_event: Optional[EventBase]
 
 
 class EventRedactBehaviour(Names):
@@ -157,9 +161,14 @@ class EventsWorkerStore(SQLBaseStore):
 
         self._get_event_cache = LruCache(
             cache_name="*getEvent*",
-            keylen=3,
             max_size=hs.config.caches.event_cache_size,
         )
+        # We seperately track which events we have in memory. This is mainly to
+        # guard against loading the same event into memory multiple times when
+        # `_get_event_cache` overflows.
+        self._in_memory_events = (
+            WeakValueDictionary()
+        )  # type: WeakValueDictionary[str, _EventCacheEntry]
 
         self._event_fetch_lock = threading.Condition()
         self._event_fetch_list = []
@@ -519,6 +528,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     def _invalidate_get_event_cache(self, event_id):
         self._get_event_cache.invalidate((event_id,))
+        self._in_memory_events.pop(event_id, None)
 
     def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
         """Fetch events from the caches
@@ -540,6 +550,9 @@ class EventsWorkerStore(SQLBaseStore):
                 (event_id,), None, update_metrics=update_metrics
             )
             if not ret:
+                ret = self._in_memory_events.get(event_id)
+
+            if not ret:
                 continue
 
             if allow_rejected or not ret.event.rejected_reason:
@@ -825,6 +838,7 @@ class EventsWorkerStore(SQLBaseStore):
             )
 
             self._get_event_cache.set((event_id,), cache_entry)
+            self._in_memory_events[event_id] = cache_entry
             result_map[event_id] = cache_entry
 
         return result_map
@@ -1056,7 +1070,11 @@ class EventsWorkerStore(SQLBaseStore):
             set[str]: The events we have already seen.
         """
         # if the event cache contains the event, obviously we've seen it.
-        results = {x for x in event_ids if self._get_event_cache.contains(x)}
+        results = {
+            x
+            for x in event_ids
+            if self._get_event_cache.contains((x,)) or x in self._in_memory_events
+        }
 
         def have_seen_events_txn(txn, chunk):
             sql = "SELECT event_id FROM events as e WHERE "