summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/util/patch_inline_callbacks.py31
1 files changed, 18 insertions, 13 deletions
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index b518dae256..64a2c891c3 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -18,11 +18,17 @@ from __future__ import print_function
 import functools
 import sys
 
+from typing import List, Callable, Any
+
 from twisted.internet import defer
 from twisted.internet.defer import Deferred
 from twisted.python.failure import Failure
 
 
+# Tracks if we've already patched inlineCallbacks
+_already_patched = False
+
+
 def do_patch():
     """
     Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
@@ -30,16 +36,18 @@ def do_patch():
 
     from synapse.logging.context import LoggingContext
 
+    global _already_patched
+
     orig_inline_callbacks = defer.inlineCallbacks
-    if hasattr(orig_inline_callbacks, "patched_by_synapse"):
+    if _already_patched:
         return
 
     def new_inline_callbacks(f):
         @functools.wraps(f)
         def wrapped(*args, **kwargs):
             start_context = LoggingContext.current_context()
-            changes = []
-            orig = orig_inline_callbacks(_check_yield_points(f, changes, start_context))
+            changes: List[str] = []
+            orig = orig_inline_callbacks(_check_yield_points(f, changes))
 
             try:
                 res = orig(*args, **kwargs)
@@ -101,10 +109,10 @@ def do_patch():
         return wrapped
 
     defer.inlineCallbacks = new_inline_callbacks
-    new_inline_callbacks.patched_by_synapse = True
+    _already_patched = True
 
 
-def _check_yield_points(f, changes, start_context):
+def _check_yield_points(f: Callable, changes: List[str]):
     """Wraps a generator that is about to be passed to defer.inlineCallbacks
     checking that after every yield the log contexts are correct.
 
@@ -114,9 +122,8 @@ def _check_yield_points(f, changes, start_context):
 
     Args:
         f: generator function to wrap
-        changes (list[str]): A list of strings detailing how the contexts
+        changes: A list of strings detailing how the contexts
             changed within a function.
-        start_context (LoggingContext): The initial context we're expecting
 
     Returns:
         function
@@ -126,13 +133,13 @@ def _check_yield_points(f, changes, start_context):
 
     @functools.wraps(f)
     def check_yield_points_inner(*args, **kwargs):
-        expected_context = start_context
-
         gen = f(*args, **kwargs)
 
         last_yield_line_no = gen.gi_frame.f_lineno
-        result = None
+        result: Any = None
         while True:
+            expected_context = LoggingContext.current_context()
+
             try:
                 isFailure = isinstance(result, Failure)
                 if isFailure:
@@ -200,7 +207,7 @@ def _check_yield_points(f, changes, start_context):
                     "%s changed context from %s to %s, happened between lines %d and %d in %s"
                     % (
                         frame.f_code.co_name,
-                        start_context,
+                        expected_context,
                         LoggingContext.current_context(),
                         last_yield_line_no,
                         frame.f_lineno,
@@ -209,8 +216,6 @@ def _check_yield_points(f, changes, start_context):
                 )
                 changes.append(err)
 
-                expected_context = LoggingContext.current_context()
-
             last_yield_line_no = frame.f_lineno
 
     return check_yield_points_inner