diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 308a27213b..0638cec429 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -169,7 +169,9 @@ import contextlib
import inspect
import logging
import re
+import types
from functools import wraps
+from typing import Dict
from canonicaljson import json
@@ -547,7 +549,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
return
span = opentracing.tracer.active_span
- carrier = {}
+ carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
@@ -584,7 +586,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
span = opentracing.tracer.active_span
- carrier = {}
+ carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items():
@@ -639,7 +641,7 @@ def get_active_span_text_map(destination=None):
if destination and not whitelisted_homeserver(destination):
return {}
- carrier = {}
+ carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
)
@@ -653,7 +655,7 @@ def active_span_context_as_string():
Returns:
The active span context encoded as a string.
"""
- carrier = {}
+ carrier = {} # type: Dict[str, str]
if opentracing:
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
@@ -777,8 +779,7 @@ def trace_servlet(servlet_name, extract_context=False):
return func
@wraps(func)
- @defer.inlineCallbacks
- def _trace_servlet_inner(request, *args, **kwargs):
+ async def _trace_servlet_inner(request, *args, **kwargs):
request_tags = {
"request_id": request.get_request_id(),
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
@@ -795,8 +796,14 @@ def trace_servlet(servlet_name, extract_context=False):
scope = start_active_span(servlet_name, tags=request_tags)
with scope:
- result = yield defer.maybeDeferred(func, request, *args, **kwargs)
- return result
+ result = func(request, *args, **kwargs)
+
+ if not isinstance(result, (types.CoroutineType, defer.Deferred)):
+ # Some servlets aren't async and just return results
+ # directly, so we handle that here.
+ return result
+
+ return await result
return _trace_servlet_inner
diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
index 7df0fa6087..6073fc2725 100644
--- a/synapse/logging/utils.py
+++ b/synapse/logging/utils.py
@@ -119,7 +119,11 @@ def trace_function(f):
logger = logging.getLogger(name)
level = logging.DEBUG
- s = inspect.currentframe().f_back
+ frame = inspect.currentframe()
+ if frame is None:
+ raise Exception("Can't get current frame!")
+
+ s = frame.f_back
to_print = [
"\t%s:%s %s. Args: args=%s, kwargs=%s"
@@ -144,7 +148,7 @@ def trace_function(f):
pathname=pathname,
lineno=lineno,
msg=msg,
- args=None,
+ args=tuple(),
exc_info=None,
)
@@ -157,7 +161,12 @@ def trace_function(f):
def get_previous_frames():
- s = inspect.currentframe().f_back.f_back
+
+ frame = inspect.currentframe()
+ if frame is None:
+ raise Exception("Can't get current frame!")
+
+ s = frame.f_back.f_back
to_return = []
while s:
if s.f_globals["__name__"].startswith("synapse"):
@@ -174,7 +183,10 @@ def get_previous_frames():
def get_previous_frame(ignore=[]):
- s = inspect.currentframe().f_back.f_back
+ frame = inspect.currentframe()
+ if frame is None:
+ raise Exception("Can't get current frame!")
+ s = frame.f_back.f_back
while s:
if s.f_globals["__name__"].startswith("synapse"):
|