diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 9dd010af3b..1f18654d47 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -14,7 +14,7 @@
import functools
import sys
-from typing import Any, Callable, List
+from typing import Any, Callable, Generator, List, TypeVar
from twisted.internet import defer
from twisted.internet.defer import Deferred
@@ -24,6 +24,9 @@ from twisted.python.failure import Failure
_already_patched = False
+T = TypeVar("T")
+
+
def do_patch() -> None:
"""
Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
@@ -37,15 +40,19 @@ def do_patch() -> None:
if _already_patched:
return
- def new_inline_callbacks(f):
+ def new_inline_callbacks(
+ f: Callable[..., Generator["Deferred[object]", object, T]]
+ ) -> Callable[..., "Deferred[T]"]:
@functools.wraps(f)
- def wrapped(*args, **kwargs):
+ def wrapped(*args: Any, **kwargs: Any) -> "Deferred[T]":
start_context = current_context()
changes: List[str] = []
- orig = orig_inline_callbacks(_check_yield_points(f, changes))
+ orig: Callable[..., "Deferred[T]"] = orig_inline_callbacks(
+ _check_yield_points(f, changes)
+ )
try:
- res = orig(*args, **kwargs)
+ res: "Deferred[T]" = orig(*args, **kwargs)
except Exception:
if current_context() != start_context:
for err in changes:
@@ -84,7 +91,7 @@ def do_patch() -> None:
print(err, file=sys.stderr)
raise Exception(err)
- def check_ctx(r):
+ def check_ctx(r: T) -> T:
if current_context() != start_context:
for err in changes:
print(err, file=sys.stderr)
@@ -107,7 +114,10 @@ def do_patch() -> None:
_already_patched = True
-def _check_yield_points(f: Callable, changes: List[str]) -> Callable:
+def _check_yield_points(
+ f: Callable[..., Generator["Deferred[object]", object, T]],
+ changes: List[str],
+) -> Callable:
"""Wraps a generator that is about to be passed to defer.inlineCallbacks
checking that after every yield the log contexts are correct.
@@ -127,7 +137,9 @@ def _check_yield_points(f: Callable, changes: List[str]) -> Callable:
from synapse.logging.context import current_context
@functools.wraps(f)
- def check_yield_points_inner(*args, **kwargs):
+ def check_yield_points_inner(
+ *args: Any, **kwargs: Any
+ ) -> Generator["Deferred[object]", object, T]:
gen = f(*args, **kwargs)
last_yield_line_no = gen.gi_frame.f_lineno
|