summary refs log tree commit diff
path: root/synapse/metrics/__init__.py
diff options
context:
space:
mode:
authorSean Quah <8349537+squahtx@users.noreply.github.com>2021-11-17 19:07:02 +0000
committerGitHub <noreply@github.com>2021-11-17 19:07:02 +0000
commit84fac0f814f69645ff1ad564ef8294b31203dc95 (patch)
tree041505d9f8711a230ac6e11c707ff19017f038c0 /synapse/metrics/__init__.py
parentAdd support for `/_matrix/media/v3` APIs (#11371) (diff)
downloadsynapse-84fac0f814f69645ff1ad564ef8294b31203dc95.tar.xz
Add type annotations to `synapse.metrics` (#10847)
Diffstat (limited to 'synapse/metrics/__init__.py')
-rw-r--r--synapse/metrics/__init__.py101
1 files changed, 70 insertions, 31 deletions
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 91ee5c8193..ceef57ad88 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -20,10 +20,25 @@ import os
 import platform
 import threading
 import time
-from typing import Callable, Dict, Iterable, Mapping, Optional, Tuple, Union
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Generic,
+    Iterable,
+    Mapping,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+)
 
 import attr
-from prometheus_client import Counter, Gauge, Histogram
+from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, Metric
 from prometheus_client.core import (
     REGISTRY,
     CounterMetricFamily,
@@ -32,6 +47,7 @@ from prometheus_client.core import (
 )
 
 from twisted.internet import reactor
+from twisted.internet.base import ReactorBase
 from twisted.python.threadpool import ThreadPool
 
 import synapse
@@ -54,7 +70,7 @@ HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
 
 class RegistryProxy:
     @staticmethod
-    def collect():
+    def collect() -> Iterable[Metric]:
         for metric in REGISTRY.collect():
             if not metric.name.startswith("__"):
                 yield metric
@@ -74,7 +90,7 @@ class LaterGauge:
         ]
     )
 
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
 
         g = GaugeMetricFamily(self.name, self.desc, labels=self.labels)
 
@@ -93,10 +109,10 @@ class LaterGauge:
 
         yield g
 
-    def __attrs_post_init__(self):
+    def __attrs_post_init__(self) -> None:
         self._register()
 
-    def _register(self):
+    def _register(self) -> None:
         if self.name in all_gauges.keys():
             logger.warning("%s already registered, reregistering" % (self.name,))
             REGISTRY.unregister(all_gauges.pop(self.name))
@@ -105,7 +121,12 @@ class LaterGauge:
         all_gauges[self.name] = self
 
 
-class InFlightGauge:
+# `MetricsEntry` only makes sense when it is a `Protocol`,
+# but `Protocol` can't be used as a `TypeVar` bound.
+MetricsEntry = TypeVar("MetricsEntry")
+
+
+class InFlightGauge(Generic[MetricsEntry]):
     """Tracks number of things (e.g. requests, Measure blocks, etc) in flight
     at any given time.
 
@@ -115,14 +136,19 @@ class InFlightGauge:
     callbacks.
 
     Args:
-        name (str)
-        desc (str)
-        labels (list[str])
-        sub_metrics (list[str]): A list of sub metrics that the callbacks
-            will update.
+        name
+        desc
+        labels
+        sub_metrics: A list of sub metrics that the callbacks will update.
     """
 
-    def __init__(self, name, desc, labels, sub_metrics):
+    def __init__(
+        self,
+        name: str,
+        desc: str,
+        labels: Sequence[str],
+        sub_metrics: Sequence[str],
+    ):
         self.name = name
         self.desc = desc
         self.labels = labels
@@ -130,19 +156,25 @@ class InFlightGauge:
 
         # Create a class which have the sub_metrics values as attributes, which
         # default to 0 on initialization. Used to pass to registered callbacks.
-        self._metrics_class = attr.make_class(
+        self._metrics_class: Type[MetricsEntry] = attr.make_class(
             "_MetricsEntry", attrs={x: attr.ib(0) for x in sub_metrics}, slots=True
         )
 
         # Counts number of in flight blocks for a given set of label values
-        self._registrations: Dict = {}
+        self._registrations: Dict[
+            Tuple[str, ...], Set[Callable[[MetricsEntry], None]]
+        ] = {}
 
         # Protects access to _registrations
         self._lock = threading.Lock()
 
         self._register_with_collector()
 
-    def register(self, key, callback):
+    def register(
+        self,
+        key: Tuple[str, ...],
+        callback: Callable[[MetricsEntry], None],
+    ) -> None:
         """Registers that we've entered a new block with labels `key`.
 
         `callback` gets called each time the metrics are collected. The same
@@ -158,13 +190,17 @@ class InFlightGauge:
         with self._lock:
             self._registrations.setdefault(key, set()).add(callback)
 
-    def unregister(self, key, callback):
+    def unregister(
+        self,
+        key: Tuple[str, ...],
+        callback: Callable[[MetricsEntry], None],
+    ) -> None:
         """Registers that we've exited a block with labels `key`."""
 
         with self._lock:
             self._registrations.setdefault(key, set()).discard(callback)
 
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
         """Called by prometheus client when it reads metrics.
 
         Note: may be called by a separate thread.
@@ -200,7 +236,7 @@ class InFlightGauge:
                 gauge.add_metric(key, getattr(metrics, name))
             yield gauge
 
-    def _register_with_collector(self):
+    def _register_with_collector(self) -> None:
         if self.name in all_gauges.keys():
             logger.warning("%s already registered, reregistering" % (self.name,))
             REGISTRY.unregister(all_gauges.pop(self.name))
@@ -230,7 +266,7 @@ class GaugeBucketCollector:
         name: str,
         documentation: str,
         buckets: Iterable[float],
-        registry=REGISTRY,
+        registry: CollectorRegistry = REGISTRY,
     ):
         """
         Args:
@@ -257,12 +293,12 @@ class GaugeBucketCollector:
 
         registry.register(self)
 
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
         # Don't report metrics unless we've already collected some data
         if self._metric is not None:
             yield self._metric
 
-    def update_data(self, values: Iterable[float]):
+    def update_data(self, values: Iterable[float]) -> None:
         """Update the data to be reported by the metric
 
         The existing data is cleared, and each measurement in the input is assigned
@@ -304,7 +340,7 @@ class GaugeBucketCollector:
 
 
 class CPUMetrics:
-    def __init__(self):
+    def __init__(self) -> None:
         ticks_per_sec = 100
         try:
             # Try and get the system config
@@ -314,7 +350,7 @@ class CPUMetrics:
 
         self.ticks_per_sec = ticks_per_sec
 
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
         if not HAVE_PROC_SELF_STAT:
             return
 
@@ -364,7 +400,7 @@ gc_time = Histogram(
 
 
 class GCCounts:
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
         cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
         for n, m in enumerate(gc.get_count()):
             cm.add_metric([str(n)], m)
@@ -382,7 +418,7 @@ if not running_on_pypy:
 
 
 class PyPyGCStats:
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
 
         # @stats is a pretty-printer object with __str__() returning a nice table,
         # plus some fields that contain data from that table.
@@ -565,7 +601,7 @@ def register_threadpool(name: str, threadpool: ThreadPool) -> None:
 
 
 class ReactorLastSeenMetric:
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
         cm = GaugeMetricFamily(
             "python_twisted_reactor_last_seen",
             "Seconds since the Twisted reactor was last seen",
@@ -584,9 +620,12 @@ MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0)
 _last_gc = [0.0, 0.0, 0.0]
 
 
-def runUntilCurrentTimer(reactor, func):
+F = TypeVar("F", bound=Callable[..., Any])
+
+
+def runUntilCurrentTimer(reactor: ReactorBase, func: F) -> F:
     @functools.wraps(func)
-    def f(*args, **kwargs):
+    def f(*args: Any, **kwargs: Any) -> Any:
         now = reactor.seconds()
         num_pending = 0
 
@@ -649,7 +688,7 @@ def runUntilCurrentTimer(reactor, func):
 
         return ret
 
-    return f
+    return cast(F, f)
 
 
 try:
@@ -677,5 +716,5 @@ __all__ = [
     "start_http_server",
     "LaterGauge",
     "InFlightGauge",
-    "BucketCollector",
+    "GaugeBucketCollector",
 ]