summary refs log tree commit diff
path: root/synapse/storage/databases/main/events_worker.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/events_worker.py')
-rw-r--r--synapse/storage/databases/main/events_worker.py34
1 files changed, 24 insertions, 10 deletions
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 621f92e238..f3935bfead 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -79,7 +79,7 @@ from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
 from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.lrucache import AsyncLruCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -238,7 +238,9 @@ class EventsWorkerStore(SQLBaseStore):
                 5 * 60 * 1000,
             )
 
-        self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
+        self._get_event_cache: AsyncLruCache[
+            Tuple[str], EventCacheEntry
+        ] = AsyncLruCache(
             cache_name="*getEvent*",
             max_size=hs.config.caches.event_cache_size,
         )
@@ -598,7 +600,7 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             map from event id to result
         """
-        event_entry_map = self._get_events_from_cache(
+        event_entry_map = await self._get_events_from_cache(
             event_ids,
         )
 
@@ -710,12 +712,22 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_entry_map
 
-    def _invalidate_get_event_cache(self, event_id: str) -> None:
-        self._get_event_cache.invalidate((event_id,))
+    async def _invalidate_get_event_cache(self, event_id: str) -> None:
+        # First we invalidate the asynchronous cache instance. This may include
+        # out-of-process caches such as Redis/memcache. Once complete we can
+        # invalidate any in memory cache. The ordering is important here to
+        # ensure we don't pull in any remote invalid value after we invalidate
+        # the in-memory cache.
+        await self._get_event_cache.invalidate((event_id,))
         self._event_ref.pop(event_id, None)
         self._current_event_fetches.pop(event_id, None)
 
-    def _get_events_from_cache(
+    def _invalidate_local_get_event_cache(self, event_id: str) -> None:
+        self._get_event_cache.invalidate_local((event_id,))
+        self._event_ref.pop(event_id, None)
+        self._current_event_fetches.pop(event_id, None)
+
+    async def _get_events_from_cache(
         self, events: Iterable[str], update_metrics: bool = True
     ) -> Dict[str, EventCacheEntry]:
         """Fetch events from the caches.
@@ -730,7 +742,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         for event_id in events:
             # First check if it's in the event cache
-            ret = self._get_event_cache.get(
+            ret = await self._get_event_cache.get(
                 (event_id,), None, update_metrics=update_metrics
             )
             if ret:
@@ -752,7 +764,7 @@ class EventsWorkerStore(SQLBaseStore):
 
                 # We add the entry back into the cache as we want to keep
                 # recently queried events in the cache.
-                self._get_event_cache.set((event_id,), cache_entry)
+                await self._get_event_cache.set((event_id,), cache_entry)
 
         return event_map
 
@@ -1129,7 +1141,7 @@ class EventsWorkerStore(SQLBaseStore):
                 event=original_ev, redacted_event=redacted_event
             )
 
-            self._get_event_cache.set((event_id,), cache_entry)
+            await self._get_event_cache.set((event_id,), cache_entry)
             result_map[event_id] = cache_entry
 
             if not redacted_event:
@@ -1363,7 +1375,9 @@ class EventsWorkerStore(SQLBaseStore):
         # if the event cache contains the event, obviously we've seen it.
 
         cache_results = {
-            (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,))
+            (rid, eid)
+            for (rid, eid) in keys
+            if await self._get_event_cache.contains((eid,))
         }
         results = dict.fromkeys(cache_results, True)
         remaining = [k for k in keys if k not in cache_results]