summary refs log tree commit diff
path: root/synapse/util/metrics.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/metrics.py')
-rw-r--r--synapse/util/metrics.py31
1 files changed, 21 insertions, 10 deletions
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 98ee49af6e..bc3b4938ea 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -15,10 +15,10 @@
 import logging
 from functools import wraps
 from types import TracebackType
-from typing import Any, Callable, Optional, Type, TypeVar, cast
+from typing import Awaitable, Callable, Optional, Type, TypeVar
 
 from prometheus_client import Counter
-from typing_extensions import Protocol
+from typing_extensions import Concatenate, ParamSpec, Protocol
 
 from synapse.logging.context import (
     ContextResourceUsage,
@@ -72,16 +72,21 @@ in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge(
 )
 
 
-T = TypeVar("T", bound=Callable[..., Any])
+P = ParamSpec("P")
+R = TypeVar("R")
 
 
 class HasClock(Protocol):
     clock: Clock
 
 
-def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
-    """
-    Used to decorate an async function with a `Measure` context manager.
+def measure_func(
+    name: Optional[str] = None,
+) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
+    """Decorate an async method with a `Measure` context manager.
+
+    The Measure is created using `self.clock`; it should only be used to decorate
+    methods in classes defining an instance-level `clock` attribute.
 
     Usage:
 
@@ -97,18 +102,24 @@ def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
 
     """
 
-    def wrapper(func: T) -> T:
+    def wrapper(
+        func: Callable[Concatenate[HasClock, P], Awaitable[R]]
+    ) -> Callable[P, Awaitable[R]]:
         block_name = func.__name__ if name is None else name
 
         @wraps(func)
-        async def measured_func(self: HasClock, *args: Any, **kwargs: Any) -> Any:
+        async def measured_func(self: HasClock, *args: P.args, **kwargs: P.kwargs) -> R:
             with Measure(self.clock, block_name):
                 r = await func(self, *args, **kwargs)
             return r
 
-        return cast(T, measured_func)
+        # There are some shenanigans here, because we're decorating a method but
+        # explicitly making use of the `self` parameter. The key thing here is that the
+        # return type within the return type for `measure_func` itself describes how the
+        # decorated function will be called.
+        return measured_func  # type: ignore[return-value]
 
-    return wrapper
+    return wrapper  # type: ignore[return-value]
 
 
 class Measure: