summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/cache.py33
-rw-r--r--synapse/metrics/jemalloc.py114
-rw-r--r--synapse/util/caches/lrucache.py79
3 files changed, 175 insertions, 51 deletions
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 58b2fe5519..d2f55534d7 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -176,6 +176,24 @@ class CacheConfig(Config):
           #
           #cache_entry_ttl: 30m
 
+          # This flag enables cache autotuning, and is further specified by the sub-options `max_cache_memory_usage`,
+          # `target_cache_memory_usage`, `min_cache_ttl`. These flags work in conjunction with each other to maintain
+          # a balance between cache memory usage and cache entry availability. You must be using jemalloc to utilize
+          # this option, and all three of the options must be specified for this feature to work.
+          #cache_autotuning:
+            # This flag sets a ceiling on much memory the cache can use before caches begin to be continuously evicted.
+            # They will continue to be evicted until the memory usage drops below the `target_memory_usage`, set in
+            # the flag below, or until the `min_cache_ttl` is hit.
+            #max_cache_memory_usage: 1024M
+
+            # This flag sets a rough target for the desired memory usage of the caches.
+            #target_cache_memory_usage: 758M
+
+            # 'min_cache_ttl` sets a limit under which newer cache entries are not evicted and is only applied when
+            # caches are actively being evicted/`max_cache_memory_usage` has been exceeded. This is to protect hot caches
+            # from being emptied while Synapse is evicting due to memory.
+            #min_cache_ttl: 5m
+
           # Controls how long the results of a /sync request are cached for after
           # a successful response is returned. A higher duration can help clients with
           # intermittent connections, at the cost of higher memory usage.
@@ -263,6 +281,21 @@ class CacheConfig(Config):
             )
             self.expiry_time_msec = self.parse_duration(expiry_time)
 
+        self.cache_autotuning = cache_config.get("cache_autotuning")
+        if self.cache_autotuning:
+            max_memory_usage = self.cache_autotuning.get("max_cache_memory_usage")
+            self.cache_autotuning["max_cache_memory_usage"] = self.parse_size(
+                max_memory_usage
+            )
+
+            target_mem_size = self.cache_autotuning.get("target_cache_memory_usage")
+            self.cache_autotuning["target_cache_memory_usage"] = self.parse_size(
+                target_mem_size
+            )
+
+            min_cache_ttl = self.cache_autotuning.get("min_cache_ttl")
+            self.cache_autotuning["min_cache_ttl"] = self.parse_duration(min_cache_ttl)
+
         self.sync_response_cache_duration = self.parse_duration(
             cache_config.get("sync_response_cache_duration", 0)
         )
diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py
index 6bc329f04a..1fc8a0e888 100644
--- a/synapse/metrics/jemalloc.py
+++ b/synapse/metrics/jemalloc.py
@@ -18,6 +18,7 @@ import os
 import re
 from typing import Iterable, Optional, overload
 
+import attr
 from prometheus_client import REGISTRY, Metric
 from typing_extensions import Literal
 
@@ -27,52 +28,24 @@ from synapse.metrics._types import Collector
 logger = logging.getLogger(__name__)
 
 
-def _setup_jemalloc_stats() -> None:
-    """Checks to see if jemalloc is loaded, and hooks up a collector to record
-    statistics exposed by jemalloc.
-    """
-
-    # Try to find the loaded jemalloc shared library, if any. We need to
-    # introspect into what is loaded, rather than loading whatever is on the
-    # path, as if we load a *different* jemalloc version things will seg fault.
-
-    # We look in `/proc/self/maps`, which only exists on linux.
-    if not os.path.exists("/proc/self/maps"):
-        logger.debug("Not looking for jemalloc as no /proc/self/maps exist")
-        return
-
-    # We're looking for a path at the end of the line that includes
-    # "libjemalloc".
-    regex = re.compile(r"/\S+/libjemalloc.*$")
-
-    jemalloc_path = None
-    with open("/proc/self/maps") as f:
-        for line in f:
-            match = regex.search(line.strip())
-            if match:
-                jemalloc_path = match.group()
-
-    if not jemalloc_path:
-        # No loaded jemalloc was found.
-        logger.debug("jemalloc not found")
-        return
-
-    logger.debug("Found jemalloc at %s", jemalloc_path)
-
-    jemalloc = ctypes.CDLL(jemalloc_path)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class JemallocStats:
+    jemalloc: ctypes.CDLL
 
     @overload
     def _mallctl(
-        name: str, read: Literal[True] = True, write: Optional[int] = None
+        self, name: str, read: Literal[True] = True, write: Optional[int] = None
     ) -> int:
         ...
 
     @overload
-    def _mallctl(name: str, read: Literal[False], write: Optional[int] = None) -> None:
+    def _mallctl(
+        self, name: str, read: Literal[False], write: Optional[int] = None
+    ) -> None:
         ...
 
     def _mallctl(
-        name: str, read: bool = True, write: Optional[int] = None
+        self, name: str, read: bool = True, write: Optional[int] = None
     ) -> Optional[int]:
         """Wrapper around `mallctl` for reading and writing integers to
         jemalloc.
@@ -120,7 +93,7 @@ def _setup_jemalloc_stats() -> None:
         # Where oldp/oldlenp is a buffer where the old value will be written to
         # (if not null), and newp/newlen is the buffer with the new value to set
         # (if not null). Note that they're all references *except* newlen.
-        result = jemalloc.mallctl(
+        result = self.jemalloc.mallctl(
             name.encode("ascii"),
             input_var_ref,
             input_len_ref,
@@ -136,21 +109,80 @@ def _setup_jemalloc_stats() -> None:
 
         return input_var.value
 
-    def _jemalloc_refresh_stats() -> None:
+    def refresh_stats(self) -> None:
         """Request that jemalloc updates its internal statistics. This needs to
         be called before querying for stats, otherwise it will return stale
         values.
         """
         try:
-            _mallctl("epoch", read=False, write=1)
+            self._mallctl("epoch", read=False, write=1)
         except Exception as e:
             logger.warning("Failed to reload jemalloc stats: %s", e)
 
+    def get_stat(self, name: str) -> int:
+        """Request the stat of the given name at the time of the last
+        `refresh_stats` call. This may throw if we fail to read
+        the stat.
+        """
+        return self._mallctl(f"stats.{name}")
+
+
+_JEMALLOC_STATS: Optional[JemallocStats] = None
+
+
+def get_jemalloc_stats() -> Optional[JemallocStats]:
+    """Returns an interface to jemalloc, if it is being used.
+
+    Note that this will always return None until `setup_jemalloc_stats` has been
+    called.
+    """
+    return _JEMALLOC_STATS
+
+
+def _setup_jemalloc_stats() -> None:
+    """Checks to see if jemalloc is loaded, and hooks up a collector to record
+    statistics exposed by jemalloc.
+    """
+
+    global _JEMALLOC_STATS
+
+    # Try to find the loaded jemalloc shared library, if any. We need to
+    # introspect into what is loaded, rather than loading whatever is on the
+    # path, as if we load a *different* jemalloc version things will seg fault.
+
+    # We look in `/proc/self/maps`, which only exists on linux.
+    if not os.path.exists("/proc/self/maps"):
+        logger.debug("Not looking for jemalloc as no /proc/self/maps exist")
+        return
+
+    # We're looking for a path at the end of the line that includes
+    # "libjemalloc".
+    regex = re.compile(r"/\S+/libjemalloc.*$")
+
+    jemalloc_path = None
+    with open("/proc/self/maps") as f:
+        for line in f:
+            match = regex.search(line.strip())
+            if match:
+                jemalloc_path = match.group()
+
+    if not jemalloc_path:
+        # No loaded jemalloc was found.
+        logger.debug("jemalloc not found")
+        return
+
+    logger.debug("Found jemalloc at %s", jemalloc_path)
+
+    jemalloc_dll = ctypes.CDLL(jemalloc_path)
+
+    stats = JemallocStats(jemalloc_dll)
+    _JEMALLOC_STATS = stats
+
     class JemallocCollector(Collector):
         """Metrics for internal jemalloc stats."""
 
         def collect(self) -> Iterable[Metric]:
-            _jemalloc_refresh_stats()
+            stats.refresh_stats()
 
             g = GaugeMetricFamily(
                 "jemalloc_stats_app_memory_bytes",
@@ -184,7 +216,7 @@ def _setup_jemalloc_stats() -> None:
                 "metadata",
             ):
                 try:
-                    value = _mallctl(f"stats.{t}")
+                    value = stats.get_stat(t)
                 except Exception as e:
                     # There was an error fetching the value, skip.
                     logger.warning("Failed to read jemalloc stats.%s: %s", t, e)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 45ff0de638..a3b60578e3 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import logging
+import math
 import threading
 import weakref
 from enum import Enum
@@ -40,6 +41,7 @@ from twisted.internet.interfaces import IReactorTime
 
 from synapse.config import cache as cache_config
 from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.metrics.jemalloc import get_jemalloc_stats
 from synapse.util import Clock, caches
 from synapse.util.caches import CacheMetric, EvictionReason, register_cache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
@@ -106,10 +108,16 @@ GLOBAL_ROOT = ListNode["_Node"].create_root_node()
 
 
 @wrap_as_background_process("LruCache._expire_old_entries")
-async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
+async def _expire_old_entries(
+    clock: Clock, expiry_seconds: int, autotune_config: Optional[dict]
+) -> None:
     """Walks the global cache list to find cache entries that haven't been
-    accessed in the given number of seconds.
+    accessed in the given number of seconds, or if a given memory threshold has been breached.
     """
+    if autotune_config:
+        max_cache_memory_usage = autotune_config["max_cache_memory_usage"]
+        target_cache_memory_usage = autotune_config["target_cache_memory_usage"]
+        min_cache_ttl = autotune_config["min_cache_ttl"] / 1000
 
     now = int(clock.time())
     node = GLOBAL_ROOT.prev_node
@@ -119,11 +127,36 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
 
     logger.debug("Searching for stale caches")
 
+    evicting_due_to_memory = False
+
+    # determine if we're evicting due to memory
+    jemalloc_interface = get_jemalloc_stats()
+    if jemalloc_interface and autotune_config:
+        try:
+            jemalloc_interface.refresh_stats()
+            mem_usage = jemalloc_interface.get_stat("allocated")
+            if mem_usage > max_cache_memory_usage:
+                logger.info("Begin memory-based cache eviction.")
+                evicting_due_to_memory = True
+        except Exception:
+            logger.warning(
+                "Unable to read allocated memory, skipping memory-based cache eviction."
+            )
+
     while node is not GLOBAL_ROOT:
         # Only the root node isn't a `_TimedListNode`.
         assert isinstance(node, _TimedListNode)
 
-        if node.last_access_ts_secs > now - expiry_seconds:
+        # if node has not aged past expiry_seconds and we are not evicting due to memory usage, there's
+        # nothing to do here
+        if (
+            node.last_access_ts_secs > now - expiry_seconds
+            and not evicting_due_to_memory
+        ):
+            break
+
+        # if entry is newer than min_cache_entry_ttl then do not evict and don't evict anything newer
+        if evicting_due_to_memory and now - node.last_access_ts_secs < min_cache_ttl:
             break
 
         cache_entry = node.get_cache_entry()
@@ -136,10 +169,29 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
         assert cache_entry is not None
         cache_entry.drop_from_cache()
 
+        # Check mem allocation periodically if we are evicting a bunch of caches
+        if jemalloc_interface and evicting_due_to_memory and (i + 1) % 100 == 0:
+            try:
+                jemalloc_interface.refresh_stats()
+                mem_usage = jemalloc_interface.get_stat("allocated")
+                if mem_usage < target_cache_memory_usage:
+                    evicting_due_to_memory = False
+                    logger.info("Stop memory-based cache eviction.")
+            except Exception:
+                logger.warning(
+                    "Unable to read allocated memory, this may affect memory-based cache eviction."
+                )
+                # If we've failed to read the current memory usage then we
+                # should stop trying to evict based on memory usage
+                evicting_due_to_memory = False
+
         # If we do lots of work at once we yield to allow other stuff to happen.
         if (i + 1) % 10000 == 0:
             logger.debug("Waiting during drop")
-            await clock.sleep(0)
+            if node.last_access_ts_secs > now - expiry_seconds:
+                await clock.sleep(0.5)
+            else:
+                await clock.sleep(0)
             logger.debug("Waking during drop")
 
         node = next_node
@@ -156,21 +208,28 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None:
 
 def setup_expire_lru_cache_entries(hs: "HomeServer") -> None:
     """Start a background job that expires all cache entries if they have not
-    been accessed for the given number of seconds.
+    been accessed for the given number of seconds, or if a given memory usage threshold has been
+    breached.
     """
-    if not hs.config.caches.expiry_time_msec:
+    if not hs.config.caches.expiry_time_msec and not hs.config.caches.cache_autotuning:
         return
 
-    logger.info(
-        "Expiring LRU caches after %d seconds", hs.config.caches.expiry_time_msec / 1000
-    )
+    if hs.config.caches.expiry_time_msec:
+        expiry_time = hs.config.caches.expiry_time_msec / 1000
+        logger.info("Expiring LRU caches after %d seconds", expiry_time)
+    else:
+        expiry_time = math.inf
 
     global USE_GLOBAL_LIST
     USE_GLOBAL_LIST = True
 
     clock = hs.get_clock()
     clock.looping_call(
-        _expire_old_entries, 30 * 1000, clock, hs.config.caches.expiry_time_msec / 1000
+        _expire_old_entries,
+        30 * 1000,
+        clock,
+        expiry_time,
+        hs.config.caches.cache_autotuning,
     )