summary refs log tree commit diff
path: root/synapse/util/metrics.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/metrics.py')
-rw-r--r--synapse/util/metrics.py48
1 files changed, 27 insertions, 21 deletions
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index ec61e14423..6e57c1ee72 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -13,14 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import inspect
 import logging
 from functools import wraps
+from typing import Any, Callable, Optional, TypeVar, cast
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
-
 from synapse.logging.context import LoggingContext, current_context
 from synapse.metrics import InFlightGauge
 
@@ -60,34 +58,42 @@ in_flight = InFlightGauge(
     sub_metrics=["real_time_max", "real_time_sum"],
 )
 
+T = TypeVar("T", bound=Callable[..., Any])
 
-def measure_func(name=None):
-    def wrapper(func):
-        block_name = func.__name__ if name is None else name
 
-        if inspect.iscoroutinefunction(func):
+def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
+    """
+    Used to decorate an async function with a `Measure` context manager.
+
+    Usage:
+
+    @measure_func()
+    async def foo(...):
+        ...
 
-            @wraps(func)
-            async def measured_func(self, *args, **kwargs):
-                with Measure(self.clock, block_name):
-                    r = await func(self, *args, **kwargs)
-                return r
+    Which is analogous to:
 
-        else:
+    async def foo(...):
+        with Measure(...):
+            ...
+
+    """
+
+    def wrapper(func: T) -> T:
+        block_name = func.__name__ if name is None else name
 
-            @wraps(func)
-            @defer.inlineCallbacks
-            def measured_func(self, *args, **kwargs):
-                with Measure(self.clock, block_name):
-                    r = yield func(self, *args, **kwargs)
-                return r
+        @wraps(func)
+        async def measured_func(self, *args, **kwargs):
+            with Measure(self.clock, block_name):
+                r = await func(self, *args, **kwargs)
+            return r
 
-        return measured_func
+        return cast(T, measured_func)
 
     return wrapper
 
 
-class Measure(object):
+class Measure:
     __slots__ = [
         "clock",
         "name",