diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index ec61e14423..ffdea0de8d 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -13,15 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
from functools import wraps
+from typing import Any, Callable, Optional, TypeVar, cast
from prometheus_client import Counter
-from twisted.internet import defer
-
-from synapse.logging.context import LoggingContext, current_context
+from synapse.logging.context import (
+ ContextResourceUsage,
+ LoggingContext,
+ current_context,
+)
from synapse.metrics import InFlightGauge
logger = logging.getLogger(__name__)
@@ -60,34 +62,42 @@ in_flight = InFlightGauge(
sub_metrics=["real_time_max", "real_time_sum"],
)
+T = TypeVar("T", bound=Callable[..., Any])
-def measure_func(name=None):
- def wrapper(func):
- block_name = func.__name__ if name is None else name
- if inspect.iscoroutinefunction(func):
+def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
+ """
+ Used to decorate an async function with a `Measure` context manager.
+
+ Usage:
+
+ @measure_func()
+ async def foo(...):
+ ...
- @wraps(func)
- async def measured_func(self, *args, **kwargs):
- with Measure(self.clock, block_name):
- r = await func(self, *args, **kwargs)
- return r
+ Which is analogous to:
- else:
+ async def foo(...):
+ with Measure(...):
+ ...
- @wraps(func)
- @defer.inlineCallbacks
- def measured_func(self, *args, **kwargs):
- with Measure(self.clock, block_name):
- r = yield func(self, *args, **kwargs)
- return r
+ """
- return measured_func
+ def wrapper(func: T) -> T:
+ block_name = func.__name__ if name is None else name
+
+ @wraps(func)
+ async def measured_func(self, *args, **kwargs):
+ with Measure(self.clock, block_name):
+ r = await func(self, *args, **kwargs)
+ return r
+
+ return cast(T, measured_func)
return wrapper
-class Measure(object):
+class Measure:
__slots__ = [
"clock",
"name",
@@ -98,27 +108,27 @@ class Measure(object):
def __init__(self, clock, name):
self.clock = clock
self.name = name
- self._logging_context = None
+ parent_context = current_context()
+ self._logging_context = LoggingContext(
+ "Measure[%s]" % (self.name,), parent_context
+ )
self.start = None
- def __enter__(self):
- if self._logging_context:
+ def __enter__(self) -> "Measure":
+ if self.start is not None:
raise RuntimeError("Measure() objects cannot be re-used")
self.start = self.clock.time()
- parent_context = current_context()
- self._logging_context = LoggingContext(
- "Measure[%s]" % (self.name,), parent_context
- )
self._logging_context.__enter__()
in_flight.register((self.name,), self._update_in_flight)
+ return self
def __exit__(self, exc_type, exc_val, exc_tb):
- if not self._logging_context:
+ if self.start is None:
raise RuntimeError("Measure() block exited without being entered")
duration = self.clock.time() - self.start
- usage = self._logging_context.get_resource_usage()
+ usage = self.get_resource_usage()
in_flight.unregister((self.name,), self._update_in_flight)
self._logging_context.__exit__(exc_type, exc_val, exc_tb)
@@ -134,6 +144,13 @@ class Measure(object):
except ValueError:
logger.warning("Failed to save metrics! Usage: %s", usage)
+ def get_resource_usage(self) -> ContextResourceUsage:
+ """Get the resources used within this Measure block
+
+ If the Measure block is still active, returns the resource usage so far.
+ """
+ return self._logging_context.get_resource_usage()
+
def _update_in_flight(self, metrics):
"""Gets called when processing in flight metrics
"""
|