diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py
index 5ef0aff0c3..a35a1d3305 100644
--- a/tests/patch_inline_callbacks.py
+++ b/tests/patch_inline_callbacks.py
@@ -15,7 +15,6 @@
from __future__ import print_function
-import inspect
import functools
import sys
@@ -32,6 +31,8 @@ def do_patch():
from synapse.logging.context import LoggingContext
orig_inline_callbacks = defer.inlineCallbacks
+ if hasattr(orig_inline_callbacks, "patched_by_synapse"):
+ return
def new_inline_callbacks(f):
@functools.wraps(f)
@@ -100,13 +101,20 @@ def do_patch():
return wrapped
defer.inlineCallbacks = new_inline_callbacks
+ new_inline_callbacks.patched_by_synapse = True
def _check_yield_points(f, changes, start_context):
+ """Wraps a generator that is about to passed to defer.inlineCallbacks
+ checking that after every yield the log contexts are correct.
+ """
+
from synapse.logging.context import LoggingContext
@functools.wraps(f)
def check_yield_points_inner(*args, **kwargs):
+ expected_context = start_context
+
gen = f(*args, **kwargs)
last_yield_line_no = 1
@@ -119,12 +127,13 @@ def _check_yield_points(f, changes, start_context):
else:
d = gen.send(result)
except (StopIteration, defer._DefGen_Return) as e:
- if LoggingContext.current_context() != start_context:
+ if LoggingContext.current_context() != expected_context:
# This happens when the context is lost sometime *after* the
# final yield and returning. E.g. we forgot to yield on a
# function that returns a deferred.
err = (
- "%s returned and changed context from %s to %s, in %s between %d and end of func"
+ "Function %r returned and changed context from %s to %s,"
+ " in %s between %d and end of func"
% (
f.__qualname__,
start_context,
@@ -134,7 +143,6 @@ def _check_yield_points(f, changes, start_context):
)
)
changes.append(err)
- # print(err, file=sys.stderr)
# raise Exception(err)
return getattr(e, "value", None)
@@ -144,10 +152,8 @@ def _check_yield_points(f, changes, start_context):
result = Failure(e)
frame = gen.gi_frame
- if frame.f_code.co_name == "check_yield_points_inner":
- frame = inspect.getgeneratorlocals(gen)["gen"].gi_frame
- if LoggingContext.current_context() != start_context:
+ if LoggingContext.current_context() != expected_context:
# This happens because the context is lost sometime *after* the
# previous yield and *after* the current yield. E.g. the
# deferred we waited on didn't follow the rules, or we forgot to
@@ -164,9 +170,10 @@ def _check_yield_points(f, changes, start_context):
)
)
changes.append(err)
- # print(err, file=sys.stderr)
# raise Exception(err)
+ expected_context = LoggingContext.current_context()
+
last_yield_line_no = frame.f_lineno
return check_yield_points_inner
|