diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index ec61e14423..13775b43f9 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -13,14 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
from functools import wraps
+from typing import Any, Callable, Optional, TypeVar, cast
from prometheus_client import Counter
-from twisted.internet import defer
-
from synapse.logging.context import LoggingContext, current_context
from synapse.metrics import InFlightGauge
@@ -60,29 +58,37 @@ in_flight = InFlightGauge(
sub_metrics=["real_time_max", "real_time_sum"],
)
+T = TypeVar("T", bound=Callable[..., Any])
-def measure_func(name=None):
- def wrapper(func):
- block_name = func.__name__ if name is None else name
- if inspect.iscoroutinefunction(func):
+def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
+ """
+ Used to decorate an async function with a `Measure` context manager.
+
+ Usage:
+
+ @measure_func()
+ async def foo(...):
+ ...
- @wraps(func)
- async def measured_func(self, *args, **kwargs):
- with Measure(self.clock, block_name):
- r = await func(self, *args, **kwargs)
- return r
+ Which is analogous to:
- else:
+ async def foo(...):
+ with Measure(...):
+ ...
+
+ """
+
+ def wrapper(func: T) -> T:
+ block_name = func.__name__ if name is None else name
- @wraps(func)
- @defer.inlineCallbacks
- def measured_func(self, *args, **kwargs):
- with Measure(self.clock, block_name):
- r = yield func(self, *args, **kwargs)
- return r
+ @wraps(func)
+ async def measured_func(self, *args, **kwargs):
+ with Measure(self.clock, block_name):
+ r = await func(self, *args, **kwargs)
+ return r
- return measured_func
+ return cast(T, measured_func)
return wrapper
|