summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorEric Eastwood <erice@element.io>2022-08-03 20:50:44 -0500
committerEric Eastwood <erice@element.io>2022-08-03 20:50:44 -0500
commitfdce1c2ec317b8ecafde8c1f272f7062bccac85a (patch)
treee2a7caa6fb70016b66b71d6fd19ad42216bc7019 /synapse
parentFix @tag_args being one-off (ahead) (diff)
downloadsynapse-fdce1c2ec317b8ecafde8c1f272f7062bccac85a.tar.xz
Allow @trace and @tag_args to be used together
Diffstat (limited to 'synapse')
-rw-r--r--synapse/logging/tracing.py144
-rw-r--r--synapse/storage/databases/main/event_federation.py3
2 files changed, 96 insertions, 51 deletions
diff --git a/synapse/logging/tracing.py b/synapse/logging/tracing.py
index 109ae185e1..552f5f8504 100644
--- a/synapse/logging/tracing.py
+++ b/synapse/logging/tracing.py
@@ -166,6 +166,7 @@ from functools import wraps
 from typing import (
     TYPE_CHECKING,
     Any,
+    Awaitable,
     Callable,
     ContextManager,
     Dict,
@@ -789,67 +790,108 @@ def extract_text_map(
 # Tracing decorators
 
 
-def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
+def create_decorator(
+    func: Callable[P, R],
+    # TODO: What is the correct type for these `Any`? `P.args, P.kwargs` isn't allowed here
+    wrapping_logic: Callable[[Callable[P, R], Any, Any], ContextManager[None]],
+) -> Callable[P, R]:
     """
-    Decorator to trace a function with a custom opname.
+    Creates a decorator that is able to handle sync functions, async functions
+    (coroutines), and inlineDeferred from Twisted.
+
+    Example usage:
+    ```py
+    # Decorator to time the functiona and log it out
+    def duration(func: Callable[P, R]) -> Callable[P, R]:
+        @contextlib.contextmanager
+        def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs):
+            start_ts = time.time()
+            yield
+            end_ts = time.time()
+            duration = end_ts - start_ts
+            logger.info("%s took %s seconds", func.__name__, duration)
 
-    See the module's doc string for usage examples.
+        return create_decorator(func, _wrapping_logic)
+    ```
 
+    Args:
+        func: The function to be decorated
+        wrapping_logic: The business logic of your custom decorator.
+            This should be a ContextManager so you are able to run your logic
+            before/after the function as desired.
     """
 
-    def decorator(func: Callable[P, R]) -> Callable[P, R]:
-        if opentelemetry is None:
-            return func  # type: ignore[unreachable]
-
+    @wraps(func)
+    async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
         if inspect.iscoroutinefunction(func):
-
-            @wraps(func)
-            async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
-                with start_active_span(opname):
-                    return await func(*args, **kwargs)  # type: ignore[misc]
-
+            with wrapping_logic(func, *args, **kwargs):
+                return await func(*args, **kwargs)
         else:
             # The other case here handles both sync functions and those
             # decorated with inlineDeferred.
-            @wraps(func)
-            def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
-                scope = start_active_span(opname)
-                scope.__enter__()
-
-                try:
-                    result = func(*args, **kwargs)
-                    if isinstance(result, defer.Deferred):
-
-                        def call_back(result: R) -> R:
-                            scope.__exit__(None, None, None)
-                            return result
-
-                        def err_back(result: R) -> R:
-                            scope.__exit__(None, None, None)
-                            return result
-
-                        result.addCallbacks(call_back, err_back)
-
-                    else:
-                        if inspect.isawaitable(result):
-                            logger.error(
-                                "@trace may not have wrapped %s correctly! "
-                                "The function is not async but returned a %s.",
-                                func.__qualname__,
-                                type(result).__name__,
-                            )
+            scope = wrapping_logic(func, *args, **kwargs)
+            scope.__enter__()
+
+            try:
+                result = func(*args, **kwargs)
+                if isinstance(result, defer.Deferred):
+
+                    def call_back(result: R) -> R:
+                        scope.__exit__(None, None, None)
+                        return result
 
+                    def err_back(result: R) -> R:
                         scope.__exit__(None, None, None)
+                        return result
 
-                    return result
+                    result.addCallbacks(call_back, err_back)
 
-                except Exception as e:
-                    scope.__exit__(type(e), None, e.__traceback__)
-                    raise
+                else:
+                    if inspect.isawaitable(result):
+                        logger.error(
+                            "@trace may not have wrapped %s correctly! "
+                            "The function is not async but returned a %s.",
+                            func.__qualname__,
+                            type(result).__name__,
+                        )
 
-        return _trace_inner  # type: ignore[return-value]
+                    scope.__exit__(None, None, None)
 
-    return decorator
+                return result
+
+            except Exception as e:
+                scope.__exit__(type(e), None, e.__traceback__)
+                raise
+
+    return _wrapper  # type: ignore[return-value]
+
+
+def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
+    """
+    Decorator to trace a function with a custom opname.
+
+    See the module's doc string for usage examples.
+    """
+
+    @contextlib.contextmanager
+    def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs):
+        if opentelemetry is None:
+            return None
+
+        scope = start_active_span(opname)
+        scope.__enter__()
+        try:
+            yield
+        except Exception as e:
+            scope.__exit__(type(e), None, e.__traceback__)
+            raise
+        finally:
+            scope.__exit__(None, None, None)
+
+    def _decorator(func: Callable[P, R]):
+        return create_decorator(func, _wrapping_logic)
+
+    return _decorator
 
 
 def trace(func: Callable[P, R]) -> Callable[P, R]:
@@ -866,22 +908,22 @@ def trace(func: Callable[P, R]) -> Callable[P, R]:
 
 def tag_args(func: Callable[P, R]) -> Callable[P, R]:
     """
-    Tags all of the args to the active span.
+    Decorator to tag all of the args to the active span.
     """
 
     if not opentelemetry:
         return func
 
-    @wraps(func)
-    def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R:
+    @contextlib.contextmanager
+    def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs):
         argspec = inspect.getfullargspec(func)
         for i, arg in enumerate(argspec.args[1:]):
             set_attribute("ARG_" + arg, str(args[i + 1]))  # type: ignore[index]
         set_attribute("args", str(args[len(argspec.args) :]))  # type: ignore[index]
         set_attribute("kwargs", str(kwargs))
-        return func(*args, **kwargs)
+        yield
 
-    return _tag_args_inner
+    return create_decorator(func, _wrapping_logic)
 
 
 @contextlib.contextmanager
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index eec55b6478..22e72a31de 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -33,6 +33,7 @@ from synapse.api.constants import MAX_DEPTH, EventTypes
 from synapse.api.errors import StoreError
 from synapse.api.room_versions import EventFormatVersions, RoomVersion
 from synapse.events import EventBase, make_event_from_dict
+from synapse.logging.tracing import tag_args, trace
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
@@ -709,6 +710,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # Return all events where not all sets can reach them.
         return {eid for eid, n in event_to_missing_sets.items() if n}
 
+    @trace
+    @tag_args
     async def get_oldest_event_ids_with_depth_in_room(
         self, room_id: str
     ) -> List[Tuple[str, int]]: