summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorSean Quah <8349537+squahtx@users.noreply.github.com>2021-11-17 19:07:02 +0000
committerGitHub <noreply@github.com>2021-11-17 19:07:02 +0000
commit84fac0f814f69645ff1ad564ef8294b31203dc95 (patch)
tree041505d9f8711a230ac6e11c707ff19017f038c0 /synapse
parentAdd support for `/_matrix/media/v3` APIs (#11371) (diff)
downloadsynapse-84fac0f814f69645ff1ad564ef8294b31203dc95.tar.xz
Add type annotations to `synapse.metrics` (#10847)
Diffstat (limited to '')
-rw-r--r--synapse/app/_base.py2
-rw-r--r--synapse/groups/attestations.py4
-rw-r--r--synapse/handlers/typing.py2
-rw-r--r--synapse/metrics/__init__.py101
-rw-r--r--synapse/metrics/_exposition.py34
-rw-r--r--synapse/metrics/background_process_metrics.py78
-rw-r--r--synapse/metrics/jemalloc.py10
-rw-r--r--synapse/storage/database.py6
-rw-r--r--synapse/util/caches/expiringcache.py2
-rw-r--r--synapse/util/metrics.py15
10 files changed, 169 insertions, 85 deletions
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 573bb487b2..807ee3d46e 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -402,7 +402,7 @@ async def start(hs: "HomeServer") -> None:
     if hasattr(signal, "SIGHUP"):
 
         @wrap_as_background_process("sighup")
-        def handle_sighup(*args: Any, **kwargs: Any) -> None:
+        async def handle_sighup(*args: Any, **kwargs: Any) -> None:
             # Tell systemd our state, if we're using it. This will silently fail if
             # we're not using systemd.
             sdnotify(b"RELOADING=1")
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index 53f99031b1..a87896e538 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -40,6 +40,8 @@ from typing import TYPE_CHECKING, Optional, Tuple
 
 from signedjson.sign import sign_json
 
+from twisted.internet.defer import Deferred
+
 from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import JsonDict, get_domain_from_id
@@ -166,7 +168,7 @@ class GroupAttestionRenewer:
 
         return {}
 
-    def _start_renew_attestations(self) -> None:
+    def _start_renew_attestations(self) -> "Deferred[None]":
         return run_as_background_process("renew_attestations", self._renew_attestations)
 
     async def _renew_attestations(self) -> None:
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 22c6174821..1676ebd057 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -90,7 +90,7 @@ class FollowerTypingHandler:
         self.wheel_timer = WheelTimer(bucket_size=5000)
 
     @wrap_as_background_process("typing._handle_timeouts")
-    def _handle_timeouts(self) -> None:
+    async def _handle_timeouts(self) -> None:
         logger.debug("Checking for typing timeouts")
 
         now = self.clock.time_msec()
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 91ee5c8193..ceef57ad88 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -20,10 +20,25 @@ import os
 import platform
 import threading
 import time
-from typing import Callable, Dict, Iterable, Mapping, Optional, Tuple, Union
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Generic,
+    Iterable,
+    Mapping,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+)
 
 import attr
-from prometheus_client import Counter, Gauge, Histogram
+from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, Metric
 from prometheus_client.core import (
     REGISTRY,
     CounterMetricFamily,
@@ -32,6 +47,7 @@ from prometheus_client.core import (
 )
 
 from twisted.internet import reactor
+from twisted.internet.base import ReactorBase
 from twisted.python.threadpool import ThreadPool
 
 import synapse
@@ -54,7 +70,7 @@ HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
 
 class RegistryProxy:
     @staticmethod
-    def collect():
+    def collect() -> Iterable[Metric]:
         for metric in REGISTRY.collect():
             if not metric.name.startswith("__"):
                 yield metric
@@ -74,7 +90,7 @@ class LaterGauge:
         ]
     )
 
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
 
         g = GaugeMetricFamily(self.name, self.desc, labels=self.labels)
 
@@ -93,10 +109,10 @@ class LaterGauge:
 
         yield g
 
-    def __attrs_post_init__(self):
+    def __attrs_post_init__(self) -> None:
         self._register()
 
-    def _register(self):
+    def _register(self) -> None:
         if self.name in all_gauges.keys():
             logger.warning("%s already registered, reregistering" % (self.name,))
             REGISTRY.unregister(all_gauges.pop(self.name))
@@ -105,7 +121,12 @@ class LaterGauge:
         all_gauges[self.name] = self
 
 
-class InFlightGauge:
+# `MetricsEntry` only makes sense when it is a `Protocol`,
+# but `Protocol` can't be used as a `TypeVar` bound.
+MetricsEntry = TypeVar("MetricsEntry")
+
+
+class InFlightGauge(Generic[MetricsEntry]):
     """Tracks number of things (e.g. requests, Measure blocks, etc) in flight
     at any given time.
 
@@ -115,14 +136,19 @@ class InFlightGauge:
     callbacks.
 
     Args:
-        name (str)
-        desc (str)
-        labels (list[str])
-        sub_metrics (list[str]): A list of sub metrics that the callbacks
-            will update.
+        name
+        desc
+        labels
+        sub_metrics: A list of sub metrics that the callbacks will update.
     """
 
-    def __init__(self, name, desc, labels, sub_metrics):
+    def __init__(
+        self,
+        name: str,
+        desc: str,
+        labels: Sequence[str],
+        sub_metrics: Sequence[str],
+    ):
         self.name = name
         self.desc = desc
         self.labels = labels
@@ -130,19 +156,25 @@ class InFlightGauge:
 
         # Create a class which have the sub_metrics values as attributes, which
         # default to 0 on initialization. Used to pass to registered callbacks.
-        self._metrics_class = attr.make_class(
+        self._metrics_class: Type[MetricsEntry] = attr.make_class(
             "_MetricsEntry", attrs={x: attr.ib(0) for x in sub_metrics}, slots=True
         )
 
         # Counts number of in flight blocks for a given set of label values
-        self._registrations: Dict = {}
+        self._registrations: Dict[
+            Tuple[str, ...], Set[Callable[[MetricsEntry], None]]
+        ] = {}
 
         # Protects access to _registrations
         self._lock = threading.Lock()
 
         self._register_with_collector()
 
-    def register(self, key, callback):
+    def register(
+        self,
+        key: Tuple[str, ...],
+        callback: Callable[[MetricsEntry], None],
+    ) -> None:
         """Registers that we've entered a new block with labels `key`.
 
         `callback` gets called each time the metrics are collected. The same
@@ -158,13 +190,17 @@ class InFlightGauge:
         with self._lock:
             self._registrations.setdefault(key, set()).add(callback)
 
-    def unregister(self, key, callback):
+    def unregister(
+        self,
+        key: Tuple[str, ...],
+        callback: Callable[[MetricsEntry], None],
+    ) -> None:
         """Registers that we've exited a block with labels `key`."""
 
         with self._lock:
             self._registrations.setdefault(key, set()).discard(callback)
 
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
         """Called by prometheus client when it reads metrics.
 
         Note: may be called by a separate thread.
@@ -200,7 +236,7 @@ class InFlightGauge:
                 gauge.add_metric(key, getattr(metrics, name))
             yield gauge
 
-    def _register_with_collector(self):
+    def _register_with_collector(self) -> None:
         if self.name in all_gauges.keys():
             logger.warning("%s already registered, reregistering" % (self.name,))
             REGISTRY.unregister(all_gauges.pop(self.name))
@@ -230,7 +266,7 @@ class GaugeBucketCollector:
         name: str,
         documentation: str,
         buckets: Iterable[float],
-        registry=REGISTRY,
+        registry: CollectorRegistry = REGISTRY,
     ):
         """
         Args:
@@ -257,12 +293,12 @@ class GaugeBucketCollector:
 
         registry.register(self)
 
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
         # Don't report metrics unless we've already collected some data
         if self._metric is not None:
             yield self._metric
 
-    def update_data(self, values: Iterable[float]):
+    def update_data(self, values: Iterable[float]) -> None:
         """Update the data to be reported by the metric
 
         The existing data is cleared, and each measurement in the input is assigned
@@ -304,7 +340,7 @@ class GaugeBucketCollector:
 
 
 class CPUMetrics:
-    def __init__(self):
+    def __init__(self) -> None:
         ticks_per_sec = 100
         try:
             # Try and get the system config
@@ -314,7 +350,7 @@ class CPUMetrics:
 
         self.ticks_per_sec = ticks_per_sec
 
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
         if not HAVE_PROC_SELF_STAT:
             return
 
@@ -364,7 +400,7 @@ gc_time = Histogram(
 
 
 class GCCounts:
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
         cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
         for n, m in enumerate(gc.get_count()):
             cm.add_metric([str(n)], m)
@@ -382,7 +418,7 @@ if not running_on_pypy:
 
 
 class PyPyGCStats:
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
 
         # @stats is a pretty-printer object with __str__() returning a nice table,
         # plus some fields that contain data from that table.
@@ -565,7 +601,7 @@ def register_threadpool(name: str, threadpool: ThreadPool) -> None:
 
 
 class ReactorLastSeenMetric:
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
         cm = GaugeMetricFamily(
             "python_twisted_reactor_last_seen",
             "Seconds since the Twisted reactor was last seen",
@@ -584,9 +620,12 @@ MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0)
 _last_gc = [0.0, 0.0, 0.0]
 
 
-def runUntilCurrentTimer(reactor, func):
+F = TypeVar("F", bound=Callable[..., Any])
+
+
+def runUntilCurrentTimer(reactor: ReactorBase, func: F) -> F:
     @functools.wraps(func)
-    def f(*args, **kwargs):
+    def f(*args: Any, **kwargs: Any) -> Any:
         now = reactor.seconds()
         num_pending = 0
 
@@ -649,7 +688,7 @@ def runUntilCurrentTimer(reactor, func):
 
         return ret
 
-    return f
+    return cast(F, f)
 
 
 try:
@@ -677,5 +716,5 @@ __all__ = [
     "start_http_server",
     "LaterGauge",
     "InFlightGauge",
-    "BucketCollector",
+    "GaugeBucketCollector",
 ]
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py
index bb9bcb5592..353d0a63b6 100644
--- a/synapse/metrics/_exposition.py
+++ b/synapse/metrics/_exposition.py
@@ -25,27 +25,25 @@ import math
 import threading
 from http.server import BaseHTTPRequestHandler, HTTPServer
 from socketserver import ThreadingMixIn
-from typing import Dict, List
+from typing import Any, Dict, List, Type, Union
 from urllib.parse import parse_qs, urlparse
 
-from prometheus_client import REGISTRY
+from prometheus_client import REGISTRY, CollectorRegistry
+from prometheus_client.core import Sample
 
 from twisted.web.resource import Resource
+from twisted.web.server import Request
 
 from synapse.util import caches
 
 CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8"
 
 
-INF = float("inf")
-MINUS_INF = float("-inf")
-
-
-def floatToGoString(d):
+def floatToGoString(d: Union[int, float]) -> str:
     d = float(d)
-    if d == INF:
+    if d == math.inf:
         return "+Inf"
-    elif d == MINUS_INF:
+    elif d == -math.inf:
         return "-Inf"
     elif math.isnan(d):
         return "NaN"
@@ -60,7 +58,7 @@ def floatToGoString(d):
         return s
 
 
-def sample_line(line, name):
+def sample_line(line: Sample, name: str) -> str:
     if line.labels:
         labelstr = "{{{0}}}".format(
             ",".join(
@@ -82,7 +80,7 @@ def sample_line(line, name):
     return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp)
 
 
-def generate_latest(registry, emit_help=False):
+def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes:
 
     # Trigger the cache metrics to be rescraped, which updates the common
     # metrics but do not produce metrics themselves
@@ -187,7 +185,7 @@ class MetricsHandler(BaseHTTPRequestHandler):
 
     registry = REGISTRY
 
-    def do_GET(self):
+    def do_GET(self) -> None:
         registry = self.registry
         params = parse_qs(urlparse(self.path).query)
 
@@ -207,11 +205,11 @@ class MetricsHandler(BaseHTTPRequestHandler):
         self.end_headers()
         self.wfile.write(output)
 
-    def log_message(self, format, *args):
+    def log_message(self, format: str, *args: Any) -> None:
         """Log nothing."""
 
     @classmethod
-    def factory(cls, registry):
+    def factory(cls, registry: CollectorRegistry) -> Type:
         """Returns a dynamic MetricsHandler class tied
         to the passed registry.
         """
@@ -236,7 +234,9 @@ class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer):
     daemon_threads = True
 
 
-def start_http_server(port, addr="", registry=REGISTRY):
+def start_http_server(
+    port: int, addr: str = "", registry: CollectorRegistry = REGISTRY
+) -> None:
     """Starts an HTTP server for prometheus metrics as a daemon thread"""
     CustomMetricsHandler = MetricsHandler.factory(registry)
     httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler)
@@ -252,10 +252,10 @@ class MetricsResource(Resource):
 
     isLeaf = True
 
-    def __init__(self, registry=REGISTRY):
+    def __init__(self, registry: CollectorRegistry = REGISTRY):
         self.registry = registry
 
-    def render_GET(self, request):
+    def render_GET(self, request: Request) -> bytes:
         request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii"))
         response = generate_latest(self.registry)
         request.setHeader(b"Content-Length", str(len(response)))
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 2ab599a334..53c508af91 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -15,19 +15,37 @@
 import logging
 import threading
 from functools import wraps
-from typing import TYPE_CHECKING, Dict, Optional, Set, Union
+from types import TracebackType
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    Optional,
+    Set,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+)
 
+from prometheus_client import Metric
 from prometheus_client.core import REGISTRY, Counter, Gauge
 
 from twisted.internet import defer
 
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import (
+    ContextResourceUsage,
+    LoggingContext,
+    PreserveLoggingContext,
+)
 from synapse.logging.opentracing import (
     SynapseTags,
     noop_context_manager,
     start_active_span,
 )
-from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
     import resource
@@ -116,7 +134,7 @@ class _Collector:
     before they are returned.
     """
 
-    def collect(self):
+    def collect(self) -> Iterable[Metric]:
         global _background_processes_active_since_last_scrape
 
         # We swap out the _background_processes set with an empty one so that
@@ -144,12 +162,12 @@ REGISTRY.register(_Collector())
 
 
 class _BackgroundProcess:
-    def __init__(self, desc, ctx):
+    def __init__(self, desc: str, ctx: LoggingContext):
         self.desc = desc
         self._context = ctx
-        self._reported_stats = None
+        self._reported_stats: Optional[ContextResourceUsage] = None
 
-    def update_metrics(self):
+    def update_metrics(self) -> None:
         """Updates the metrics with values from this process."""
         new_stats = self._context.get_resource_usage()
         if self._reported_stats is None:
@@ -169,7 +187,16 @@ class _BackgroundProcess:
         )
 
 
-def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwargs):
+R = TypeVar("R")
+
+
+def run_as_background_process(
+    desc: str,
+    func: Callable[..., Awaitable[Optional[R]]],
+    *args: Any,
+    bg_start_span: bool = True,
+    **kwargs: Any,
+) -> "defer.Deferred[Optional[R]]":
     """Run the given function in its own logcontext, with resource metrics
 
     This should be used to wrap processes which are fired off to run in the
@@ -189,11 +216,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
         args: positional args for func
         kwargs: keyword args for func
 
-    Returns: Deferred which returns the result of func, but note that it does not
-        follow the synapse logcontext rules.
+    Returns:
+        Deferred which returns the result of func, or `None` if func raises.
+        Note that the returned Deferred does not follow the synapse logcontext
+        rules.
     """
 
-    async def run():
+    async def run() -> Optional[R]:
         with _bg_metrics_lock:
             count = _background_process_counts.get(desc, 0)
             _background_process_counts[desc] = count + 1
@@ -210,12 +239,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
                 else:
                     ctx = noop_context_manager()
                 with ctx:
-                    return await maybe_awaitable(func(*args, **kwargs))
+                    return await func(*args, **kwargs)
             except Exception:
                 logger.exception(
                     "Background process '%s' threw an exception",
                     desc,
                 )
+                return None
             finally:
                 _background_process_in_flight_count.labels(desc).dec()
 
@@ -225,19 +255,24 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
         return defer.ensureDeferred(run())
 
 
-def wrap_as_background_process(desc):
+F = TypeVar("F", bound=Callable[..., Awaitable[Optional[Any]]])
+
+
+def wrap_as_background_process(desc: str) -> Callable[[F], F]:
     """Decorator that wraps a function that gets called as a background
     process.
 
-    Equivalent of calling the function with `run_as_background_process`
+    Equivalent to calling the function with `run_as_background_process`
     """
 
-    def wrap_as_background_process_inner(func):
+    def wrap_as_background_process_inner(func: F) -> F:
         @wraps(func)
-        def wrap_as_background_process_inner_2(*args, **kwargs):
+        def wrap_as_background_process_inner_2(
+            *args: Any, **kwargs: Any
+        ) -> "defer.Deferred[Optional[R]]":
             return run_as_background_process(desc, func, *args, **kwargs)
 
-        return wrap_as_background_process_inner_2
+        return cast(F, wrap_as_background_process_inner_2)
 
     return wrap_as_background_process_inner
 
@@ -265,7 +300,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
         super().__init__("%s-%s" % (name, instance_id))
         self._proc = _BackgroundProcess(name, self)
 
-    def start(self, rusage: "Optional[resource.struct_rusage]"):
+    def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
         """Log context has started running (again)."""
 
         super().start(rusage)
@@ -276,7 +311,12 @@ class BackgroundProcessLoggingContext(LoggingContext):
         with _bg_metrics_lock:
             _background_processes_active_since_last_scrape.add(self._proc)
 
-    def __exit__(self, type, value, traceback) -> None:
+    def __exit__(
+        self,
+        type: Optional[Type[BaseException]],
+        value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
         """Log context has finished."""
 
         super().__exit__(type, value, traceback)
diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py
index 29ab6c0229..98ed9c0829 100644
--- a/synapse/metrics/jemalloc.py
+++ b/synapse/metrics/jemalloc.py
@@ -16,14 +16,16 @@ import ctypes
 import logging
 import os
 import re
-from typing import Optional
+from typing import Iterable, Optional
+
+from prometheus_client import Metric
 
 from synapse.metrics import REGISTRY, GaugeMetricFamily
 
 logger = logging.getLogger(__name__)
 
 
-def _setup_jemalloc_stats():
+def _setup_jemalloc_stats() -> None:
     """Checks to see if jemalloc is loaded, and hooks up a collector to record
     statistics exposed by jemalloc.
     """
@@ -135,7 +137,7 @@ def _setup_jemalloc_stats():
     class JemallocCollector:
         """Metrics for internal jemalloc stats."""
 
-        def collect(self):
+        def collect(self) -> Iterable[Metric]:
             _jemalloc_refresh_stats()
 
             g = GaugeMetricFamily(
@@ -185,7 +187,7 @@ def _setup_jemalloc_stats():
     logger.debug("Added jemalloc stats")
 
 
-def setup_jemalloc_stats():
+def setup_jemalloc_stats() -> None:
     """Try to setup jemalloc stats, if jemalloc is loaded."""
 
     try:
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d4cab69ebf..0693d39006 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -188,7 +188,7 @@ class LoggingDatabaseConnection:
 
 
 # The type of entry which goes on our after_callbacks and exception_callbacks lists.
-_CallbackListEntry = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]]
+_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]]
 
 
 R = TypeVar("R")
@@ -235,7 +235,7 @@ class LoggingTransaction:
         self.after_callbacks = after_callbacks
         self.exception_callbacks = exception_callbacks
 
-    def call_after(self, callback: Callable[..., None], *args: Any, **kwargs: Any):
+    def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any):
         """Call the given callback on the main twisted thread after the
         transaction has finished. Used to invalidate the caches on the
         correct thread.
@@ -247,7 +247,7 @@ class LoggingTransaction:
         self.after_callbacks.append((callback, args, kwargs))
 
     def call_on_exception(
-        self, callback: Callable[..., None], *args: Any, **kwargs: Any
+        self, callback: Callable[..., object], *args: Any, **kwargs: Any
     ):
         # if self.exception_callbacks is None, that means that whatever constructed the
         # LoggingTransaction isn't expecting there to be any callbacks; assert that
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 6a7e534576..67ee4c693b 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -159,7 +159,7 @@ class ExpiringCache(Generic[KT, VT]):
             self[key] = value
             return value
 
-    def _prune_cache(self) -> None:
+    async def _prune_cache(self) -> None:
         if not self._expiry_ms:
             # zero expiry time means don't expire. This should never get called
             # since we have this check in start too.
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index ad775dfc7d..98ee49af6e 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -56,8 +56,15 @@ block_db_sched_duration = Counter(
     "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]
 )
 
+
+# This is dynamically created in InFlightGauge.__init__.
+class _InFlightMetric(Protocol):
+    real_time_max: float
+    real_time_sum: float
+
+
 # Tracks the number of blocks currently active
-in_flight = InFlightGauge(
+in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge(
     "synapse_util_metrics_block_in_flight",
     "",
     labels=["block_name"],
@@ -65,12 +72,6 @@ in_flight = InFlightGauge(
 )
 
 
-# This is dynamically created in InFlightGauge.__init__.
-class _InFlightMetric(Protocol):
-    real_time_max: float
-    real_time_sum: float
-
-
 T = TypeVar("T", bound=Callable[..., Any])