diff --git a/synapse/logging/tracing.py b/synapse/logging/tracing.py
index 109ae185e1..552f5f8504 100644
--- a/synapse/logging/tracing.py
+++ b/synapse/logging/tracing.py
@@ -166,6 +166,7 @@ from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
+ Awaitable,
Callable,
ContextManager,
Dict,
@@ -789,67 +790,108 @@ def extract_text_map(
# Tracing decorators
-def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
+def create_decorator(
+ func: Callable[P, R],
+ # TODO: What is the correct type for these `Any`? `P.args, P.kwargs` isn't allowed here
+ wrapping_logic: Callable[[Callable[P, R], Any, Any], ContextManager[None]],
+) -> Callable[P, R]:
"""
- Decorator to trace a function with a custom opname.
+ Creates a decorator that is able to handle sync functions, async functions
+ (coroutines), and inlineDeferred from Twisted.
+
+ Example usage:
+ ```py
+ # Decorator to time the functiona and log it out
+ def duration(func: Callable[P, R]) -> Callable[P, R]:
+ @contextlib.contextmanager
+ def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs):
+ start_ts = time.time()
+ yield
+ end_ts = time.time()
+ duration = end_ts - start_ts
+ logger.info("%s took %s seconds", func.__name__, duration)
- See the module's doc string for usage examples.
+ return create_decorator(func, _wrapping_logic)
+ ```
+ Args:
+ func: The function to be decorated
+ wrapping_logic: The business logic of your custom decorator.
+ This should be a ContextManager so you are able to run your logic
+ before/after the function as desired.
"""
- def decorator(func: Callable[P, R]) -> Callable[P, R]:
- if opentelemetry is None:
- return func # type: ignore[unreachable]
-
+ @wraps(func)
+ async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
if inspect.iscoroutinefunction(func):
-
- @wraps(func)
- async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
- with start_active_span(opname):
- return await func(*args, **kwargs) # type: ignore[misc]
-
+ with wrapping_logic(func, *args, **kwargs):
+ return await func(*args, **kwargs)
else:
# The other case here handles both sync functions and those
# decorated with inlineDeferred.
- @wraps(func)
- def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
- scope = start_active_span(opname)
- scope.__enter__()
-
- try:
- result = func(*args, **kwargs)
- if isinstance(result, defer.Deferred):
-
- def call_back(result: R) -> R:
- scope.__exit__(None, None, None)
- return result
-
- def err_back(result: R) -> R:
- scope.__exit__(None, None, None)
- return result
-
- result.addCallbacks(call_back, err_back)
-
- else:
- if inspect.isawaitable(result):
- logger.error(
- "@trace may not have wrapped %s correctly! "
- "The function is not async but returned a %s.",
- func.__qualname__,
- type(result).__name__,
- )
+ scope = wrapping_logic(func, *args, **kwargs)
+ scope.__enter__()
+
+ try:
+ result = func(*args, **kwargs)
+ if isinstance(result, defer.Deferred):
+
+ def call_back(result: R) -> R:
+ scope.__exit__(None, None, None)
+ return result
+ def err_back(result: R) -> R:
scope.__exit__(None, None, None)
+ return result
- return result
+ result.addCallbacks(call_back, err_back)
- except Exception as e:
- scope.__exit__(type(e), None, e.__traceback__)
- raise
+ else:
+ if inspect.isawaitable(result):
+ logger.error(
+ "@trace may not have wrapped %s correctly! "
+ "The function is not async but returned a %s.",
+ func.__qualname__,
+ type(result).__name__,
+ )
- return _trace_inner # type: ignore[return-value]
+ scope.__exit__(None, None, None)
- return decorator
+ return result
+
+ except Exception as e:
+ scope.__exit__(type(e), None, e.__traceback__)
+ raise
+
+ return _wrapper # type: ignore[return-value]
+
+
+def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
+ """
+ Decorator to trace a function with a custom opname.
+
+ See the module's doc string for usage examples.
+ """
+
+ @contextlib.contextmanager
+ def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs):
+ if opentelemetry is None:
+ return None
+
+ scope = start_active_span(opname)
+ scope.__enter__()
+ try:
+ yield
+ except Exception as e:
+ scope.__exit__(type(e), None, e.__traceback__)
+ raise
+ finally:
+ scope.__exit__(None, None, None)
+
+ def _decorator(func: Callable[P, R]):
+ return create_decorator(func, _wrapping_logic)
+
+ return _decorator
def trace(func: Callable[P, R]) -> Callable[P, R]:
@@ -866,22 +908,22 @@ def trace(func: Callable[P, R]) -> Callable[P, R]:
def tag_args(func: Callable[P, R]) -> Callable[P, R]:
"""
- Tags all of the args to the active span.
+ Decorator to tag all of the args to the active span.
"""
if not opentelemetry:
return func
- @wraps(func)
- def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R:
+ @contextlib.contextmanager
+ def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs):
argspec = inspect.getfullargspec(func)
for i, arg in enumerate(argspec.args[1:]):
set_attribute("ARG_" + arg, str(args[i + 1])) # type: ignore[index]
set_attribute("args", str(args[len(argspec.args) :])) # type: ignore[index]
set_attribute("kwargs", str(kwargs))
- return func(*args, **kwargs)
+ yield
- return _tag_args_inner
+ return create_decorator(func, _wrapping_logic)
@contextlib.contextmanager
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index eec55b6478..22e72a31de 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -33,6 +33,7 @@ from synapse.api.constants import MAX_DEPTH, EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.events import EventBase, make_event_from_dict
+from synapse.logging.tracing import tag_args, trace
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
@@ -709,6 +710,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
+ @trace
+ @tag_args
async def get_oldest_event_ids_with_depth_in_room(
self, room_id: str
) -> List[Tuple[str, int]]:
|