summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9379.feature1
-rw-r--r--docs/sample_config.yaml7
-rwxr-xr-xscripts/synapse_port_db9
-rw-r--r--stubs/txredisapi.pyi2
-rw-r--r--synapse/config/cache.py15
-rw-r--r--synapse/replication/tcp/external_cache.py42
-rw-r--r--synapse/storage/databases/main/events_worker.py111
-rw-r--r--synapse/storage/databases/main/roommember.py2
-rw-r--r--tests/replication/_base.py4
9 files changed, 182 insertions, 11 deletions
diff --git a/changelog.d/9379.feature b/changelog.d/9379.feature
new file mode 100644
index 0000000000..d5e6fce5b4
--- /dev/null
+++ b/changelog.d/9379.feature
@@ -0,0 +1 @@
+Store cached events in the external redis cache, when redis is enabled.
\ No newline at end of file
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 4dbef41b7e..42cb17123e 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -727,6 +727,13 @@ acme:
 #
 #event_cache_size: 10K
 
+# The expiry time of an event stored in the external cache (Redis). This
+# time will be reset each time the event is accessed.
+# This is only used when Redis is configured.
+# Defaults to 30 minutes
+#
+#external_event_cache_expiry_ms: 1800000
+
 caches:
    # Controls the global cache factor, which is the default cache factor
    # for all caches if a specific factor for that cache is not otherwise
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 58edf6af6c..4a5b6433ae 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -36,6 +36,7 @@ from synapse.logging.context import (
     make_deferred_yieldable,
     run_in_background,
 )
+from synapse.replication.tcp.external_cache import ExternalCache
 from synapse.storage.database import DatabasePool, make_conn
 from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
 from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
@@ -208,13 +209,19 @@ class Store(
             "Attempt to set room_is_public during port_db: database not empty?"
         )
 
-
 class MockHomeserver:
     def __init__(self, config):
         self.clock = Clock(reactor)
         self.config = config
         self.hostname = config.server_name
         self.version_string = "Synapse/" + get_version_string(synapse)
+        self.external_cache = ExternalCache(self)
+
+    def get_outbound_redis_connection(self):
+        return None
+
+    def get_external_cache(self):
+        return self.external_cache
 
     def get_clock(self):
         return self.clock
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index 618548a305..68c11d3b8b 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -30,6 +30,8 @@ class RedisProtocol:
         only_if_exists: bool = False,
     ) -> None: ...
     async def get(self, key: str) -> Any: ...
+    async def delete(self, key: str) -> None: ...
+    async def expire(self, key: str, expire: int) -> None: ...
 
 class SubscriberProtocol(RedisProtocol):
     def __init__(self, *args, **kwargs): ...
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 8e03f14005..c65eb9b13e 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -31,6 +31,7 @@ _CACHES_LOCK = threading.Lock()
 
 _DEFAULT_FACTOR_SIZE = 0.5
 _DEFAULT_EVENT_CACHE_SIZE = "10K"
+_DEFAULT_EXTERNAL_CACHE_EXPIRY_MS = 30 * 60 * 1000  # 30 minutes
 
 
 class CacheProperties:
@@ -112,6 +113,13 @@ class CacheConfig(Config):
         #
         #event_cache_size: 10K
 
+        # The expiry time of an event stored in the external cache (Redis). This
+        # time will be reset each time the event is accessed.
+        # This is only used when Redis is configured.
+        # Defaults to 30 minutes
+        #
+        #external_event_cache_expiry_ms: 1800000
+
         caches:
            # Controls the global cache factor, which is the default cache factor
            # for all caches if a specific factor for that cache is not otherwise
@@ -148,6 +156,13 @@ class CacheConfig(Config):
         self.event_cache_size = self.parse_size(
             config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
         )
+
+        self.external_event_cache_expiry_ms = config.get(
+            "external_event_cache_expiry_ms", _DEFAULT_EXTERNAL_CACHE_EXPIRY_MS
+        )
+        if not isinstance(self.external_event_cache_expiry_ms, (int, float)):
+            raise ConfigError("external_event_cache_expiry_ms must be a number.")
+
         self.cache_factors = {}  # type: Dict[str, float]
 
         cache_config = config.get("caches") or {}
diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
index d89a36f25a..9db6c3ac9c 100644
--- a/synapse/replication/tcp/external_cache.py
+++ b/synapse/replication/tcp/external_cache.py
@@ -36,6 +36,12 @@ get_counter = Counter(
     labelnames=["cache_name", "hit"],
 )
 
+delete_counter = Counter(
+    "synapse_external_cache_delete",
+    "Number of times we deleted keys from a cache",
+    labelnames=["cache_name"],
+)
+
 
 logger = logging.getLogger(__name__)
 
@@ -59,7 +65,24 @@ class ExternalCache:
         """
         return self._redis_connection is not None
 
-    async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
+    async def delete(self, cache_name: str, key: str) -> None:
+        """Delete a key from the named cache."""
+
+        if self._redis_connection is None:
+            return
+        delete_counter.labels(cache_name).inc()
+
+        logger.debug("Deleting %s %s", cache_name, key)
+
+        return await make_deferred_yieldable(
+            self._redis_connection.delete(
+                self._get_redis_key(cache_name, key),
+            )
+        )
+
+    async def set(
+        self, cache_name: str, key: str, value: Any, expiry_ms: Optional[int] = None
+    ) -> None:
         """Add the key/value to the named cache, with the expiry time given."""
 
         if self._redis_connection is None:
@@ -81,15 +104,17 @@ class ExternalCache:
             )
         )
 
-    async def get(self, cache_name: str, key: str) -> Optional[Any]:
+    async def get(
+        self, cache_name: str, key: str, expiry_ms: Optional[int] = None
+    ) -> Optional[Any]:
         """Look up a key/value in the named cache."""
 
         if self._redis_connection is None:
             return None
 
-        result = await make_deferred_yieldable(
-            self._redis_connection.get(self._get_redis_key(cache_name, key))
-        )
+        cache_key = self._get_redis_key(cache_name, key)
+
+        result = await make_deferred_yieldable(self._redis_connection.get(cache_key))
 
         logger.debug("Got cache result %s %s: %r", cache_name, key, result)
 
@@ -98,6 +123,13 @@ class ExternalCache:
         if not result:
             return None
 
+        if expiry_ms:
+            # If we are using this key, bump the expiry time
+            # NOTE: txredisapi does not support pexire, so we must use (expire) seconds
+            await make_deferred_yieldable(
+                self._redis_connection.expire(cache_key, expiry_ms // 1000)
+            )
+
         # For some reason the integers get magically converted back to integers
         if isinstance(result, int):
             return result
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index edbe42f2bf..36df56ba2c 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -63,7 +63,7 @@ logger = logging.getLogger(__name__)
 EVENT_QUEUE_THREADS = 3  # Max number of threads that will fetch events
 EVENT_QUEUE_ITERATIONS = 3  # No. times we block waiting for requests for events
 EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
-
+GET_EVENT_CACHE_NAME = "getEvent"
 
 _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
 
@@ -147,11 +147,15 @@ class EventsWorkerStore(SQLBaseStore):
                 5 * 60 * 1000,
             )
 
+        self._external_cache = hs.get_external_cache()
         self._get_event_cache = LruCache(
             cache_name="*getEvent*",
             keylen=3,
             max_size=hs.config.caches.event_cache_size,
         )
+        self._external_cache_event_expiry_ms = (
+            hs.config.caches.external_event_cache_expiry_ms
+        )
 
         self._event_fetch_lock = threading.Condition()
         self._event_fetch_list = []
@@ -486,7 +490,7 @@ class EventsWorkerStore(SQLBaseStore):
             Dict[str, _EventCacheEntry]:
                 map from event id to result
         """
-        event_entry_map = self._get_events_from_cache(
+        event_entry_map = await self._get_events_from_cache(
             event_ids, allow_rejected=allow_rejected
         )
 
@@ -511,8 +515,77 @@ class EventsWorkerStore(SQLBaseStore):
 
     def _invalidate_get_event_cache(self, event_id):
         self._get_event_cache.invalidate((event_id,))
+        if self._external_cache.is_enabled():
+            # XXX: Is there danger in doing this?
+            # We could hold a set of recently evicted keys in memory if
+            # we need this to be synchronous?
+            run_as_background_process(
+                "getEvent_external_cache_delete",
+                self._external_cache.delete,
+                GET_EVENT_CACHE_NAME,
+                event_id,
+            )
+
+    def create_external_cache_event_from_event(self, event, redacted_event=None):
+        if redacted_event:
+            redacted_event = self.create_external_cache_event_from_event(
+                redacted_event
+            )[0]
+
+        event_dict = event.get_dict()
+
+        for key, value in event.unsigned.items():
+            if isinstance(value, EventBase):
+                event_dict["unsigned"][key] = {"_cache_event_id": value.event_id}
+
+        return _EventCacheEntry(
+            event={
+                "event_dict": event_dict,
+                "room_version": event.room_version.identifier,
+                "internal_metadata_dict": event.get_internal_metadata_dict(),
+                "rejected_reason": event.rejected_reason,
+                "stream_ordering": event.internal_metadata.stream_ordering,
+            },
+            redacted_event=redacted_event,
+        )
 
-    def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
+    async def _create_event_cache_entry_from_external_cache_entry(
+        self, external_entry: Tuple[JsonDict, Optional[JsonDict]]
+    ) -> Optional[_EventCacheEntry]:
+        """Create a _EventCacheEntry from a tuple of dicts
+        Args:
+            external_entry: A tuple of event, redacted_event
+        Returns:
+            A _EventCacheEntry containing the frozen event(s)
+        """
+        event_dict = external_entry[0].get("event_dict")
+        for key, value in event_dict.get("unsigned", {}).items():
+            # If unsigned contained any events, get them now
+            if isinstance(value, dict) and value.get("_cache_event_id"):
+                event_dict["unsigned"][key] = await self.get_event(
+                    value["_cache_event_id"]
+                )
+
+        original_ev = make_event_from_dict(
+            event_dict=event_dict,
+            room_version=KNOWN_ROOM_VERSIONS[external_entry[0].get("room_version")],
+            internal_metadata_dict=external_entry[0].get("internal_metadata_dict"),
+            rejected_reason=external_entry[0].get("rejected_reason"),
+        )
+        original_ev.internal_metadata.stream_ordering = external_entry[0].get(
+            "stream_ordering"
+        )
+        redacted_ev = None
+        if external_entry[1]:
+            redacted_ev = make_event_from_dict(
+                event_dict=external_entry[1].get("event_dict"),
+                room_version=KNOWN_ROOM_VERSIONS[external_entry[1].get("room_version")],
+                internal_metadata_dict=external_entry[1].get("internal_metadata_dict"),
+                rejected_reason=external_entry[1].get("rejected_reason"),
+            )
+        return _EventCacheEntry(event=original_ev, redacted_event=redacted_ev)
+
+    async def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
         """Fetch events from the caches
 
         Args:
@@ -528,9 +601,27 @@ class EventsWorkerStore(SQLBaseStore):
         event_map = {}
 
         for event_id in events:
+            # L1 cache - internal
             ret = self._get_event_cache.get(
                 (event_id,), None, update_metrics=update_metrics
             )
+
+            if not ret and self._external_cache.is_enabled():
+                # L2 cache - external
+                cache_result = await self._external_cache.get(
+                    GET_EVENT_CACHE_NAME,
+                    event_id,
+                    self._external_cache_event_expiry_ms,
+                )
+                if cache_result:
+                    ret = (
+                        await self._create_event_cache_entry_from_external_cache_entry(
+                            cache_result
+                        )
+                    )
+                    # We got a hit here, store it in the L1 cache
+                    self._get_event_cache.set((event_id,), ret)
+
             if not ret:
                 continue
 
@@ -814,10 +905,22 @@ class EventsWorkerStore(SQLBaseStore):
             cache_entry = _EventCacheEntry(
                 event=original_ev, redacted_event=redacted_event
             )
-
             self._get_event_cache.set((event_id,), cache_entry)
             result_map[event_id] = cache_entry
 
+            if self._external_cache.is_enabled():
+                # Store in the L2 cache
+                # Redis cannot store a FrozenEvent, so we transform these
+                # into two dicts
+                redis_cache_entry = self.create_external_cache_event_from_event(
+                    original_ev, redacted_event
+                )
+                await self._external_cache.set(
+                    GET_EVENT_CACHE_NAME,
+                    event_id,
+                    redis_cache_entry,
+                )
+
         return result_map
 
     async def _enqueue_events(self, events):
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index a9216ca9ae..5800755ac0 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -582,7 +582,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # We don't update the event cache hit ratio as it completely throws off
         # the hit ratio counts. After all, we don't populate the cache if we
         # miss it here
-        event_map = self._get_events_from_cache(
+        event_map = await self._get_events_from_cache(
             member_event_ids, allow_rejected=False, update_metrics=False
         )
 
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index f6a6aed35e..a81607ed04 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -629,6 +629,10 @@ class FakeRedisPubSubProtocol(Protocol):
             self.send("OK")
         elif command == b"GET":
             self.send(None)
+        elif command == b"DEL":
+            self.send("OK")
+        elif command == b"EXPIRE":
+            self.send("OK")
         else:
             raise Exception("Unknown command")