summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/logging/opentracing.py63
1 files changed, 40 insertions, 23 deletions
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index c6c0e623c1..2101517575 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -733,37 +733,54 @@ def trace(func=None, opname=None):
 
         _opname = opname if opname else func.__name__
 
-        @wraps(func)
-        def _trace_inner(*args, **kwargs):
-            if opentracing is None:
-                return func(*args, **kwargs)
+        if inspect.iscoroutinefunction(func):
 
-            scope = start_active_span(_opname)
-            scope.__enter__()
+            @wraps(func)
+            async def _trace_inner(*args, **kwargs):
+                if opentracing is None:
+                    return await func(*args, **kwargs)
 
-            try:
-                result = func(*args, **kwargs)
-                if isinstance(result, defer.Deferred):
+                with start_active_span(_opname) as scope:
+                    try:
+                        return await func(*args, **kwargs)
+                    except Exception:
+                        scope.span.set_tag(tags.ERROR, True)
+                        raise
 
-                    def call_back(result):
-                        scope.__exit__(None, None, None)
-                        return result
+        else:
+            # The other case here handles both sync functions and those
+            # decorated with inlineDeferred.
+            @wraps(func)
+            def _trace_inner(*args, **kwargs):
+                if opentracing is None:
+                    return func(*args, **kwargs)
 
-                    def err_back(result):
-                        scope.span.set_tag(tags.ERROR, True)
-                        scope.__exit__(None, None, None)
-                        return result
+                scope = start_active_span(_opname)
+                scope.__enter__()
+
+                try:
+                    result = func(*args, **kwargs)
+                    if isinstance(result, defer.Deferred):
+
+                        def call_back(result):
+                            scope.__exit__(None, None, None)
+                            return result
 
-                    result.addCallbacks(call_back, err_back)
+                        def err_back(result):
+                            scope.span.set_tag(tags.ERROR, True)
+                            scope.__exit__(None, None, None)
+                            return result
 
-                else:
-                    scope.__exit__(None, None, None)
+                        result.addCallbacks(call_back, err_back)
+
+                    else:
+                        scope.__exit__(None, None, None)
 
-                return result
+                    return result
 
-            except Exception as e:
-                scope.__exit__(type(e), None, e.__traceback__)
-                raise
+                except Exception as e:
+                    scope.__exit__(type(e), None, e.__traceback__)
+                    raise
 
         return _trace_inner