diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 1b82dca81b..1e784b3f1f 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -14,9 +14,11 @@
import logging
from functools import wraps
-from typing import Any, Callable, Optional, TypeVar, cast
+from types import TracebackType
+from typing import Any, Callable, Optional, Type, TypeVar, cast
from prometheus_client import Counter
+from typing_extensions import Protocol
from synapse.logging.context import (
ContextResourceUsage,
@@ -24,6 +26,7 @@ from synapse.logging.context import (
current_context,
)
from synapse.metrics import InFlightGauge
+from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -64,6 +67,10 @@ in_flight = InFlightGauge(
T = TypeVar("T", bound=Callable[..., Any])
+class HasClock(Protocol):
+ clock: Clock
+
+
def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
"""
Used to decorate an async function with a `Measure` context manager.
@@ -86,7 +93,7 @@ def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
block_name = func.__name__ if name is None else name
@wraps(func)
- async def measured_func(self, *args, **kwargs):
+ async def measured_func(self: HasClock, *args: Any, **kwargs: Any) -> Any:
with Measure(self.clock, block_name):
r = await func(self, *args, **kwargs)
return r
@@ -104,10 +111,10 @@ class Measure:
"start",
]
- def __init__(self, clock, name: str):
+ def __init__(self, clock: Clock, name: str) -> None:
"""
Args:
- clock: A n object with a "time()" method, which returns the current
+ clock: An object with a "time()" method, which returns the current
time in seconds.
name: The name of the metric to report.
"""
@@ -124,7 +131,7 @@ class Measure:
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context
self._logging_context = LoggingContext(str(curr_context), parent_context)
- self.start: Optional[int] = None
+ self.start: Optional[float] = None
def __enter__(self) -> "Measure":
if self.start is not None:
@@ -138,7 +145,12 @@ class Measure:
return self
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
if self.start is None:
raise RuntimeError("Measure() block exited without being entered")
@@ -168,8 +180,9 @@ class Measure:
"""
return self._logging_context.get_resource_usage()
- def _update_in_flight(self, metrics):
+ def _update_in_flight(self, metrics) -> None:
"""Gets called when processing in flight metrics"""
+ assert self.start is not None
duration = self.clock.time() - self.start
metrics.real_time_max = max(metrics.real_time_max, duration)
|