diff --git a/changelog.d/10847.misc b/changelog.d/10847.misc
new file mode 100644
index 0000000000..7933a38dca
--- /dev/null
+++ b/changelog.d/10847.misc
@@ -0,0 +1 @@
+Add type annotations to `synapse.metrics`.
diff --git a/mypy.ini b/mypy.ini
index f32c6c41a3..308cfd95d8 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -160,6 +160,9 @@ disallow_untyped_defs = True
[mypy-synapse.handlers.*]
disallow_untyped_defs = True
+[mypy-synapse.metrics.*]
+disallow_untyped_defs = True
+
[mypy-synapse.push.*]
disallow_untyped_defs = True
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 573bb487b2..807ee3d46e 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -402,7 +402,7 @@ async def start(hs: "HomeServer") -> None:
if hasattr(signal, "SIGHUP"):
@wrap_as_background_process("sighup")
- def handle_sighup(*args: Any, **kwargs: Any) -> None:
+ async def handle_sighup(*args: Any, **kwargs: Any) -> None:
# Tell systemd our state, if we're using it. This will silently fail if
# we're not using systemd.
sdnotify(b"RELOADING=1")
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index 53f99031b1..a87896e538 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -40,6 +40,8 @@ from typing import TYPE_CHECKING, Optional, Tuple
from signedjson.sign import sign_json
+from twisted.internet.defer import Deferred
+
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import JsonDict, get_domain_from_id
@@ -166,7 +168,7 @@ class GroupAttestionRenewer:
return {}
- def _start_renew_attestations(self) -> None:
+ def _start_renew_attestations(self) -> "Deferred[None]":
return run_as_background_process("renew_attestations", self._renew_attestations)
async def _renew_attestations(self) -> None:
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 22c6174821..1676ebd057 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -90,7 +90,7 @@ class FollowerTypingHandler:
self.wheel_timer = WheelTimer(bucket_size=5000)
@wrap_as_background_process("typing._handle_timeouts")
- def _handle_timeouts(self) -> None:
+ async def _handle_timeouts(self) -> None:
logger.debug("Checking for typing timeouts")
now = self.clock.time_msec()
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 91ee5c8193..ceef57ad88 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -20,10 +20,25 @@ import os
import platform
import threading
import time
-from typing import Callable, Dict, Iterable, Mapping, Optional, Tuple, Union
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Generic,
+ Iterable,
+ Mapping,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
import attr
-from prometheus_client import Counter, Gauge, Histogram
+from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, Metric
from prometheus_client.core import (
REGISTRY,
CounterMetricFamily,
@@ -32,6 +47,7 @@ from prometheus_client.core import (
)
from twisted.internet import reactor
+from twisted.internet.base import ReactorBase
from twisted.python.threadpool import ThreadPool
import synapse
@@ -54,7 +70,7 @@ HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
class RegistryProxy:
@staticmethod
- def collect():
+ def collect() -> Iterable[Metric]:
for metric in REGISTRY.collect():
if not metric.name.startswith("__"):
yield metric
@@ -74,7 +90,7 @@ class LaterGauge:
]
)
- def collect(self):
+ def collect(self) -> Iterable[Metric]:
g = GaugeMetricFamily(self.name, self.desc, labels=self.labels)
@@ -93,10 +109,10 @@ class LaterGauge:
yield g
- def __attrs_post_init__(self):
+ def __attrs_post_init__(self) -> None:
self._register()
- def _register(self):
+ def _register(self) -> None:
if self.name in all_gauges.keys():
logger.warning("%s already registered, reregistering" % (self.name,))
REGISTRY.unregister(all_gauges.pop(self.name))
@@ -105,7 +121,12 @@ class LaterGauge:
all_gauges[self.name] = self
-class InFlightGauge:
+# `MetricsEntry` only makes sense when it is a `Protocol`,
+# but `Protocol` can't be used as a `TypeVar` bound.
+MetricsEntry = TypeVar("MetricsEntry")
+
+
+class InFlightGauge(Generic[MetricsEntry]):
"""Tracks number of things (e.g. requests, Measure blocks, etc) in flight
at any given time.
@@ -115,14 +136,19 @@ class InFlightGauge:
callbacks.
Args:
- name (str)
- desc (str)
- labels (list[str])
- sub_metrics (list[str]): A list of sub metrics that the callbacks
- will update.
+ name
+ desc
+ labels
+ sub_metrics: A list of sub metrics that the callbacks will update.
"""
- def __init__(self, name, desc, labels, sub_metrics):
+ def __init__(
+ self,
+ name: str,
+ desc: str,
+ labels: Sequence[str],
+ sub_metrics: Sequence[str],
+ ):
self.name = name
self.desc = desc
self.labels = labels
@@ -130,19 +156,25 @@ class InFlightGauge:
# Create a class which have the sub_metrics values as attributes, which
# default to 0 on initialization. Used to pass to registered callbacks.
- self._metrics_class = attr.make_class(
+ self._metrics_class: Type[MetricsEntry] = attr.make_class(
"_MetricsEntry", attrs={x: attr.ib(0) for x in sub_metrics}, slots=True
)
# Counts number of in flight blocks for a given set of label values
- self._registrations: Dict = {}
+ self._registrations: Dict[
+ Tuple[str, ...], Set[Callable[[MetricsEntry], None]]
+ ] = {}
# Protects access to _registrations
self._lock = threading.Lock()
self._register_with_collector()
- def register(self, key, callback):
+ def register(
+ self,
+ key: Tuple[str, ...],
+ callback: Callable[[MetricsEntry], None],
+ ) -> None:
"""Registers that we've entered a new block with labels `key`.
`callback` gets called each time the metrics are collected. The same
@@ -158,13 +190,17 @@ class InFlightGauge:
with self._lock:
self._registrations.setdefault(key, set()).add(callback)
- def unregister(self, key, callback):
+ def unregister(
+ self,
+ key: Tuple[str, ...],
+ callback: Callable[[MetricsEntry], None],
+ ) -> None:
"""Registers that we've exited a block with labels `key`."""
with self._lock:
self._registrations.setdefault(key, set()).discard(callback)
- def collect(self):
+ def collect(self) -> Iterable[Metric]:
"""Called by prometheus client when it reads metrics.
Note: may be called by a separate thread.
@@ -200,7 +236,7 @@ class InFlightGauge:
gauge.add_metric(key, getattr(metrics, name))
yield gauge
- def _register_with_collector(self):
+ def _register_with_collector(self) -> None:
if self.name in all_gauges.keys():
logger.warning("%s already registered, reregistering" % (self.name,))
REGISTRY.unregister(all_gauges.pop(self.name))
@@ -230,7 +266,7 @@ class GaugeBucketCollector:
name: str,
documentation: str,
buckets: Iterable[float],
- registry=REGISTRY,
+ registry: CollectorRegistry = REGISTRY,
):
"""
Args:
@@ -257,12 +293,12 @@ class GaugeBucketCollector:
registry.register(self)
- def collect(self):
+ def collect(self) -> Iterable[Metric]:
# Don't report metrics unless we've already collected some data
if self._metric is not None:
yield self._metric
- def update_data(self, values: Iterable[float]):
+ def update_data(self, values: Iterable[float]) -> None:
"""Update the data to be reported by the metric
The existing data is cleared, and each measurement in the input is assigned
@@ -304,7 +340,7 @@ class GaugeBucketCollector:
class CPUMetrics:
- def __init__(self):
+ def __init__(self) -> None:
ticks_per_sec = 100
try:
# Try and get the system config
@@ -314,7 +350,7 @@ class CPUMetrics:
self.ticks_per_sec = ticks_per_sec
- def collect(self):
+ def collect(self) -> Iterable[Metric]:
if not HAVE_PROC_SELF_STAT:
return
@@ -364,7 +400,7 @@ gc_time = Histogram(
class GCCounts:
- def collect(self):
+ def collect(self) -> Iterable[Metric]:
cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
for n, m in enumerate(gc.get_count()):
cm.add_metric([str(n)], m)
@@ -382,7 +418,7 @@ if not running_on_pypy:
class PyPyGCStats:
- def collect(self):
+ def collect(self) -> Iterable[Metric]:
# @stats is a pretty-printer object with __str__() returning a nice table,
# plus some fields that contain data from that table.
@@ -565,7 +601,7 @@ def register_threadpool(name: str, threadpool: ThreadPool) -> None:
class ReactorLastSeenMetric:
- def collect(self):
+ def collect(self) -> Iterable[Metric]:
cm = GaugeMetricFamily(
"python_twisted_reactor_last_seen",
"Seconds since the Twisted reactor was last seen",
@@ -584,9 +620,12 @@ MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0)
_last_gc = [0.0, 0.0, 0.0]
-def runUntilCurrentTimer(reactor, func):
+F = TypeVar("F", bound=Callable[..., Any])
+
+
+def runUntilCurrentTimer(reactor: ReactorBase, func: F) -> F:
@functools.wraps(func)
- def f(*args, **kwargs):
+ def f(*args: Any, **kwargs: Any) -> Any:
now = reactor.seconds()
num_pending = 0
@@ -649,7 +688,7 @@ def runUntilCurrentTimer(reactor, func):
return ret
- return f
+ return cast(F, f)
try:
@@ -677,5 +716,5 @@ __all__ = [
"start_http_server",
"LaterGauge",
"InFlightGauge",
- "BucketCollector",
+ "GaugeBucketCollector",
]
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py
index bb9bcb5592..353d0a63b6 100644
--- a/synapse/metrics/_exposition.py
+++ b/synapse/metrics/_exposition.py
@@ -25,27 +25,25 @@ import math
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from socketserver import ThreadingMixIn
-from typing import Dict, List
+from typing import Any, Dict, List, Type, Union
from urllib.parse import parse_qs, urlparse
-from prometheus_client import REGISTRY
+from prometheus_client import REGISTRY, CollectorRegistry
+from prometheus_client.core import Sample
from twisted.web.resource import Resource
+from twisted.web.server import Request
from synapse.util import caches
CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8"
-INF = float("inf")
-MINUS_INF = float("-inf")
-
-
-def floatToGoString(d):
+def floatToGoString(d: Union[int, float]) -> str:
d = float(d)
- if d == INF:
+ if d == math.inf:
return "+Inf"
- elif d == MINUS_INF:
+ elif d == -math.inf:
return "-Inf"
elif math.isnan(d):
return "NaN"
@@ -60,7 +58,7 @@ def floatToGoString(d):
return s
-def sample_line(line, name):
+def sample_line(line: Sample, name: str) -> str:
if line.labels:
labelstr = "{{{0}}}".format(
",".join(
@@ -82,7 +80,7 @@ def sample_line(line, name):
return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp)
-def generate_latest(registry, emit_help=False):
+def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes:
# Trigger the cache metrics to be rescraped, which updates the common
# metrics but do not produce metrics themselves
@@ -187,7 +185,7 @@ class MetricsHandler(BaseHTTPRequestHandler):
registry = REGISTRY
- def do_GET(self):
+ def do_GET(self) -> None:
registry = self.registry
params = parse_qs(urlparse(self.path).query)
@@ -207,11 +205,11 @@ class MetricsHandler(BaseHTTPRequestHandler):
self.end_headers()
self.wfile.write(output)
- def log_message(self, format, *args):
+ def log_message(self, format: str, *args: Any) -> None:
"""Log nothing."""
@classmethod
- def factory(cls, registry):
+ def factory(cls, registry: CollectorRegistry) -> Type:
"""Returns a dynamic MetricsHandler class tied
to the passed registry.
"""
@@ -236,7 +234,9 @@ class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer):
daemon_threads = True
-def start_http_server(port, addr="", registry=REGISTRY):
+def start_http_server(
+ port: int, addr: str = "", registry: CollectorRegistry = REGISTRY
+) -> None:
"""Starts an HTTP server for prometheus metrics as a daemon thread"""
CustomMetricsHandler = MetricsHandler.factory(registry)
httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler)
@@ -252,10 +252,10 @@ class MetricsResource(Resource):
isLeaf = True
- def __init__(self, registry=REGISTRY):
+ def __init__(self, registry: CollectorRegistry = REGISTRY):
self.registry = registry
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> bytes:
request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii"))
response = generate_latest(self.registry)
request.setHeader(b"Content-Length", str(len(response)))
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 2ab599a334..53c508af91 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -15,19 +15,37 @@
import logging
import threading
from functools import wraps
-from typing import TYPE_CHECKING, Dict, Optional, Set, Union
+from types import TracebackType
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ Optional,
+ Set,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
+from prometheus_client import Metric
from prometheus_client.core import REGISTRY, Counter, Gauge
from twisted.internet import defer
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import (
+ ContextResourceUsage,
+ LoggingContext,
+ PreserveLoggingContext,
+)
from synapse.logging.opentracing import (
SynapseTags,
noop_context_manager,
start_active_span,
)
-from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
import resource
@@ -116,7 +134,7 @@ class _Collector:
before they are returned.
"""
- def collect(self):
+ def collect(self) -> Iterable[Metric]:
global _background_processes_active_since_last_scrape
# We swap out the _background_processes set with an empty one so that
@@ -144,12 +162,12 @@ REGISTRY.register(_Collector())
class _BackgroundProcess:
- def __init__(self, desc, ctx):
+ def __init__(self, desc: str, ctx: LoggingContext):
self.desc = desc
self._context = ctx
- self._reported_stats = None
+ self._reported_stats: Optional[ContextResourceUsage] = None
- def update_metrics(self):
+ def update_metrics(self) -> None:
"""Updates the metrics with values from this process."""
new_stats = self._context.get_resource_usage()
if self._reported_stats is None:
@@ -169,7 +187,16 @@ class _BackgroundProcess:
)
-def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwargs):
+R = TypeVar("R")
+
+
+def run_as_background_process(
+ desc: str,
+ func: Callable[..., Awaitable[Optional[R]]],
+ *args: Any,
+ bg_start_span: bool = True,
+ **kwargs: Any,
+) -> "defer.Deferred[Optional[R]]":
"""Run the given function in its own logcontext, with resource metrics
This should be used to wrap processes which are fired off to run in the
@@ -189,11 +216,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
args: positional args for func
kwargs: keyword args for func
- Returns: Deferred which returns the result of func, but note that it does not
- follow the synapse logcontext rules.
+ Returns:
+ Deferred which returns the result of func, or `None` if func raises.
+ Note that the returned Deferred does not follow the synapse logcontext
+ rules.
"""
- async def run():
+ async def run() -> Optional[R]:
with _bg_metrics_lock:
count = _background_process_counts.get(desc, 0)
_background_process_counts[desc] = count + 1
@@ -210,12 +239,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
else:
ctx = noop_context_manager()
with ctx:
- return await maybe_awaitable(func(*args, **kwargs))
+ return await func(*args, **kwargs)
except Exception:
logger.exception(
"Background process '%s' threw an exception",
desc,
)
+ return None
finally:
_background_process_in_flight_count.labels(desc).dec()
@@ -225,19 +255,24 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
return defer.ensureDeferred(run())
-def wrap_as_background_process(desc):
+F = TypeVar("F", bound=Callable[..., Awaitable[Optional[Any]]])
+
+
+def wrap_as_background_process(desc: str) -> Callable[[F], F]:
"""Decorator that wraps a function that gets called as a background
process.
- Equivalent of calling the function with `run_as_background_process`
+ Equivalent to calling the function with `run_as_background_process`
"""
- def wrap_as_background_process_inner(func):
+ def wrap_as_background_process_inner(func: F) -> F:
@wraps(func)
- def wrap_as_background_process_inner_2(*args, **kwargs):
+ def wrap_as_background_process_inner_2(
+ *args: Any, **kwargs: Any
+ ) -> "defer.Deferred[Optional[R]]":
return run_as_background_process(desc, func, *args, **kwargs)
- return wrap_as_background_process_inner_2
+ return cast(F, wrap_as_background_process_inner_2)
return wrap_as_background_process_inner
@@ -265,7 +300,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
super().__init__("%s-%s" % (name, instance_id))
self._proc = _BackgroundProcess(name, self)
- def start(self, rusage: "Optional[resource.struct_rusage]"):
+ def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
"""Log context has started running (again)."""
super().start(rusage)
@@ -276,7 +311,12 @@ class BackgroundProcessLoggingContext(LoggingContext):
with _bg_metrics_lock:
_background_processes_active_since_last_scrape.add(self._proc)
- def __exit__(self, type, value, traceback) -> None:
+ def __exit__(
+ self,
+ type: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> None:
"""Log context has finished."""
super().__exit__(type, value, traceback)
diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py
index 29ab6c0229..98ed9c0829 100644
--- a/synapse/metrics/jemalloc.py
+++ b/synapse/metrics/jemalloc.py
@@ -16,14 +16,16 @@ import ctypes
import logging
import os
import re
-from typing import Optional
+from typing import Iterable, Optional
+
+from prometheus_client import Metric
from synapse.metrics import REGISTRY, GaugeMetricFamily
logger = logging.getLogger(__name__)
-def _setup_jemalloc_stats():
+def _setup_jemalloc_stats() -> None:
"""Checks to see if jemalloc is loaded, and hooks up a collector to record
statistics exposed by jemalloc.
"""
@@ -135,7 +137,7 @@ def _setup_jemalloc_stats():
class JemallocCollector:
"""Metrics for internal jemalloc stats."""
- def collect(self):
+ def collect(self) -> Iterable[Metric]:
_jemalloc_refresh_stats()
g = GaugeMetricFamily(
@@ -185,7 +187,7 @@ def _setup_jemalloc_stats():
logger.debug("Added jemalloc stats")
-def setup_jemalloc_stats():
+def setup_jemalloc_stats() -> None:
"""Try to setup jemalloc stats, if jemalloc is loaded."""
try:
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d4cab69ebf..0693d39006 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -188,7 +188,7 @@ class LoggingDatabaseConnection:
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
-_CallbackListEntry = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]]
+_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]]
R = TypeVar("R")
@@ -235,7 +235,7 @@ class LoggingTransaction:
self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks
- def call_after(self, callback: Callable[..., None], *args: Any, **kwargs: Any):
+ def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
@@ -247,7 +247,7 @@ class LoggingTransaction:
self.after_callbacks.append((callback, args, kwargs))
def call_on_exception(
- self, callback: Callable[..., None], *args: Any, **kwargs: Any
+ self, callback: Callable[..., object], *args: Any, **kwargs: Any
):
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 6a7e534576..67ee4c693b 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -159,7 +159,7 @@ class ExpiringCache(Generic[KT, VT]):
self[key] = value
return value
- def _prune_cache(self) -> None:
+ async def _prune_cache(self) -> None:
if not self._expiry_ms:
# zero expiry time means don't expire. This should never get called
# since we have this check in start too.
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index ad775dfc7d..98ee49af6e 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -56,8 +56,15 @@ block_db_sched_duration = Counter(
"synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]
)
+
+# This is dynamically created in InFlightGauge.__init__.
+class _InFlightMetric(Protocol):
+ real_time_max: float
+ real_time_sum: float
+
+
# Tracks the number of blocks currently active
-in_flight = InFlightGauge(
+in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge(
"synapse_util_metrics_block_in_flight",
"",
labels=["block_name"],
@@ -65,12 +72,6 @@ in_flight = InFlightGauge(
)
-# This is dynamically created in InFlightGauge.__init__.
-class _InFlightMetric(Protocol):
- real_time_max: float
- real_time_sum: float
-
-
T = TypeVar("T", bound=Callable[..., Any])
|