summary refs log tree commit diff
path: root/synapse/util/patch_inline_callbacks.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/patch_inline_callbacks.py')
-rw-r--r--synapse/util/patch_inline_callbacks.py28
1 files changed, 20 insertions, 8 deletions
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 9dd010af3b..1f18654d47 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -14,7 +14,7 @@
 
 import functools
 import sys
-from typing import Any, Callable, List
+from typing import Any, Callable, Generator, List, TypeVar
 
 from twisted.internet import defer
 from twisted.internet.defer import Deferred
@@ -24,6 +24,9 @@ from twisted.python.failure import Failure
 _already_patched = False
 
 
+T = TypeVar("T")
+
+
 def do_patch() -> None:
     """
     Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
@@ -37,15 +40,19 @@ def do_patch() -> None:
     if _already_patched:
         return
 
-    def new_inline_callbacks(f):
+    def new_inline_callbacks(
+        f: Callable[..., Generator["Deferred[object]", object, T]]
+    ) -> Callable[..., "Deferred[T]"]:
         @functools.wraps(f)
-        def wrapped(*args, **kwargs):
+        def wrapped(*args: Any, **kwargs: Any) -> "Deferred[T]":
             start_context = current_context()
             changes: List[str] = []
-            orig = orig_inline_callbacks(_check_yield_points(f, changes))
+            orig: Callable[..., "Deferred[T]"] = orig_inline_callbacks(
+                _check_yield_points(f, changes)
+            )
 
             try:
-                res = orig(*args, **kwargs)
+                res: "Deferred[T]" = orig(*args, **kwargs)
             except Exception:
                 if current_context() != start_context:
                     for err in changes:
@@ -84,7 +91,7 @@ def do_patch() -> None:
                 print(err, file=sys.stderr)
                 raise Exception(err)
 
-            def check_ctx(r):
+            def check_ctx(r: T) -> T:
                 if current_context() != start_context:
                     for err in changes:
                         print(err, file=sys.stderr)
@@ -107,7 +114,10 @@ def do_patch() -> None:
     _already_patched = True
 
 
-def _check_yield_points(f: Callable, changes: List[str]) -> Callable:
+def _check_yield_points(
+    f: Callable[..., Generator["Deferred[object]", object, T]],
+    changes: List[str],
+) -> Callable:
     """Wraps a generator that is about to be passed to defer.inlineCallbacks
     checking that after every yield the log contexts are correct.
 
@@ -127,7 +137,9 @@ def _check_yield_points(f: Callable, changes: List[str]) -> Callable:
     from synapse.logging.context import current_context
 
     @functools.wraps(f)
-    def check_yield_points_inner(*args, **kwargs):
+    def check_yield_points_inner(
+        *args: Any, **kwargs: Any
+    ) -> Generator["Deferred[object]", object, T]:
         gen = f(*args, **kwargs)
 
         last_yield_line_no = gen.gi_frame.f_lineno