summary refs log tree commit diff
path: root/synapse/metrics/jemalloc.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/metrics/jemalloc.py')
-rw-r--r--synapse/metrics/jemalloc.py114
1 files changed, 73 insertions, 41 deletions
diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py
index 6bc329f04a..836edae586 100644
--- a/synapse/metrics/jemalloc.py
+++ b/synapse/metrics/jemalloc.py
@@ -18,6 +18,7 @@ import os
 import re
 from typing import Iterable, Optional, overload
 
+import attr
 from prometheus_client import REGISTRY, Metric
 from typing_extensions import Literal
 
@@ -27,52 +28,24 @@ from synapse.metrics._types import Collector
 logger = logging.getLogger(__name__)
 
 
-def _setup_jemalloc_stats() -> None:
-    """Checks to see if jemalloc is loaded, and hooks up a collector to record
-    statistics exposed by jemalloc.
-    """
-
-    # Try to find the loaded jemalloc shared library, if any. We need to
-    # introspect into what is loaded, rather than loading whatever is on the
-    # path, as if we load a *different* jemalloc version things will seg fault.
-
-    # We look in `/proc/self/maps`, which only exists on linux.
-    if not os.path.exists("/proc/self/maps"):
-        logger.debug("Not looking for jemalloc as no /proc/self/maps exist")
-        return
-
-    # We're looking for a path at the end of the line that includes
-    # "libjemalloc".
-    regex = re.compile(r"/\S+/libjemalloc.*$")
-
-    jemalloc_path = None
-    with open("/proc/self/maps") as f:
-        for line in f:
-            match = regex.search(line.strip())
-            if match:
-                jemalloc_path = match.group()
-
-    if not jemalloc_path:
-        # No loaded jemalloc was found.
-        logger.debug("jemalloc not found")
-        return
-
-    logger.debug("Found jemalloc at %s", jemalloc_path)
-
-    jemalloc = ctypes.CDLL(jemalloc_path)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class JemallocStats:
+    jemalloc: ctypes.CDLL
 
     @overload
     def _mallctl(
-        name: str, read: Literal[True] = True, write: Optional[int] = None
+        self, name: str, read: Literal[True] = True, write: Optional[int] = None
     ) -> int:
         ...
 
     @overload
-    def _mallctl(name: str, read: Literal[False], write: Optional[int] = None) -> None:
+    def _mallctl(
+        self, name: str, read: Literal[False], write: Optional[int] = None
+    ) -> None:
         ...
 
     def _mallctl(
-        name: str, read: bool = True, write: Optional[int] = None
+        self, name: str, read: bool = True, write: Optional[int] = None
     ) -> Optional[int]:
         """Wrapper around `mallctl` for reading and writing integers to
         jemalloc.
@@ -120,7 +93,7 @@ def _setup_jemalloc_stats() -> None:
         # Where oldp/oldlenp is a buffer where the old value will be written to
         # (if not null), and newp/newlen is the buffer with the new value to set
         # (if not null). Note that they're all references *except* newlen.
-        result = jemalloc.mallctl(
+        result = self.jemalloc.mallctl(
             name.encode("ascii"),
             input_var_ref,
             input_len_ref,
@@ -136,21 +109,80 @@ def _setup_jemalloc_stats() -> None:
 
         return input_var.value
 
-    def _jemalloc_refresh_stats() -> None:
+    def refresh_stats(self) -> None:
         """Request that jemalloc updates its internal statistics. This needs to
         be called before querying for stats, otherwise it will return stale
         values.
         """
         try:
-            _mallctl("epoch", read=False, write=1)
+            self._mallctl("epoch", read=False, write=1)
         except Exception as e:
             logger.warning("Failed to reload jemalloc stats: %s", e)
 
+    def get_stat(self, name: str) -> int:
+        """Request the stat of the given name at the time of the last
+        `refresh_stats` call. This may throw if we fail to read
+        the stat.
+        """
+        return self._mallctl(f"stats.{name}")
+
+
+_JEMALLOC_STATS: Optional[JemallocStats] = None
+
+
+def get_jemalloc_stats() -> Optional[JemallocStats]:
+    """Returns an interface to jemalloc, if it is being used.
+
+    Note that this will always return None until `setup_jemalloc_stats` has been
+    called.
+    """
+    return _JEMALLOC_STATS
+
+
+def _setup_jemalloc_stats() -> None:
+    """Checks to see if jemalloc is loaded, and hooks up a collector to record
+    statistics exposed by jemalloc.
+    """
+
+    global _JEMALLOC_STATS
+
+    # Try to find the loaded jemalloc shared library, if any. We need to
+    # introspect into what is loaded, rather than loading whatever is on the
+    # path, as if we load a *different* jemalloc version things will seg fault.
+
+    # We look in `/proc/self/maps`, which only exists on linux.
+    if not os.path.exists("/proc/self/maps"):
+        logger.debug("Not looking for jemalloc as no /proc/self/maps exist")
+        return
+
+    # We're looking for a path at the end of the line that includes
+    # "libjemalloc".
+    regex = re.compile(r"/\S+/libjemalloc.*$")
+
+    jemalloc_path = None
+    with open("/proc/self/maps") as f:
+        for line in f:
+            match = regex.search(line.strip())
+            if match:
+                jemalloc_path = match.group()
+
+    if not jemalloc_path:
+        # No loaded jemalloc was found.
+        logger.debug("jemalloc not found")
+        return
+
+    logger.debug("Found jemalloc at %s", jemalloc_path)
+
+    jemalloc_dll = ctypes.CDLL(jemalloc_path)
+
+    stats = JemallocStats(jemalloc_dll)
+    _JEMALLOC_STATS = stats
+
     class JemallocCollector(Collector):
         """Metrics for internal jemalloc stats."""
 
         def collect(self) -> Iterable[Metric]:
-            _jemalloc_refresh_stats()
+            stats.refresh_stats
 
             g = GaugeMetricFamily(
                 "jemalloc_stats_app_memory_bytes",
@@ -184,7 +216,7 @@ def _setup_jemalloc_stats() -> None:
                 "metadata",
             ):
                 try:
-                    value = _mallctl(f"stats.{t}")
+                    value = stats.get_stat(t)
                 except Exception as e:
                     # There was an error fetching the value, skip.
                     logger.warning("Failed to read jemalloc stats.%s: %s", t, e)