diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index a805f51df1..13775b43f9 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -15,6 +15,7 @@
import logging
from functools import wraps
+from typing import Any, Callable, Optional, TypeVar, cast
from prometheus_client import Counter
@@ -57,8 +58,10 @@ in_flight = InFlightGauge(
sub_metrics=["real_time_max", "real_time_sum"],
)
+T = TypeVar("T", bound=Callable[..., Any])
-def measure_func(name=None):
+
+def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
"""
Used to decorate an async function with a `Measure` context manager.
@@ -76,7 +79,7 @@ def measure_func(name=None):
"""
- def wrapper(func):
+ def wrapper(func: T) -> T:
block_name = func.__name__ if name is None else name
@wraps(func)
@@ -85,7 +88,7 @@ def measure_func(name=None):
r = await func(self, *args, **kwargs)
return r
- return measured_func
+ return cast(T, measured_func)
return wrapper
|