diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index b518dae256..64a2c891c3 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -18,11 +18,17 @@ from __future__ import print_function
import functools
import sys
+from typing import List, Callable, Any
+
from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
+# Tracks if we've already patched inlineCallbacks
+_already_patched = False
+
+
def do_patch():
"""
Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
@@ -30,16 +36,18 @@ def do_patch():
from synapse.logging.context import LoggingContext
+ global _already_patched
+
orig_inline_callbacks = defer.inlineCallbacks
- if hasattr(orig_inline_callbacks, "patched_by_synapse"):
+ if _already_patched:
return
def new_inline_callbacks(f):
@functools.wraps(f)
def wrapped(*args, **kwargs):
start_context = LoggingContext.current_context()
- changes = []
- orig = orig_inline_callbacks(_check_yield_points(f, changes, start_context))
+ changes: List[str] = []
+ orig = orig_inline_callbacks(_check_yield_points(f, changes))
try:
res = orig(*args, **kwargs)
@@ -101,10 +109,10 @@ def do_patch():
return wrapped
defer.inlineCallbacks = new_inline_callbacks
- new_inline_callbacks.patched_by_synapse = True
+ _already_patched = True
-def _check_yield_points(f, changes, start_context):
+def _check_yield_points(f: Callable, changes: List[str]):
"""Wraps a generator that is about to be passed to defer.inlineCallbacks
checking that after every yield the log contexts are correct.
@@ -114,9 +122,8 @@ def _check_yield_points(f, changes, start_context):
Args:
f: generator function to wrap
- changes (list[str]): A list of strings detailing how the contexts
+ changes: A list of strings detailing how the contexts
changed within a function.
- start_context (LoggingContext): The initial context we're expecting
Returns:
function
@@ -126,13 +133,13 @@ def _check_yield_points(f, changes, start_context):
@functools.wraps(f)
def check_yield_points_inner(*args, **kwargs):
- expected_context = start_context
-
gen = f(*args, **kwargs)
last_yield_line_no = gen.gi_frame.f_lineno
- result = None
+ result: Any = None
while True:
+ expected_context = LoggingContext.current_context()
+
try:
isFailure = isinstance(result, Failure)
if isFailure:
@@ -200,7 +207,7 @@ def _check_yield_points(f, changes, start_context):
"%s changed context from %s to %s, happened between lines %d and %d in %s"
% (
frame.f_code.co_name,
- start_context,
+ expected_context,
LoggingContext.current_context(),
last_yield_line_no,
frame.f_lineno,
@@ -209,8 +216,6 @@ def _check_yield_points(f, changes, start_context):
)
changes.append(err)
- expected_context = LoggingContext.current_context()
-
last_yield_line_no = frame.f_lineno
return check_yield_points_inner
|