diff options
Diffstat (limited to 'synapse/util/patch_inline_callbacks.py')
-rw-r--r-- | synapse/util/patch_inline_callbacks.py | 28 |
1 files changed, 20 insertions, 8 deletions
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 9dd010af3b..1f18654d47 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -14,7 +14,7 @@ import functools import sys -from typing import Any, Callable, List +from typing import Any, Callable, Generator, List, TypeVar from twisted.internet import defer from twisted.internet.defer import Deferred @@ -24,6 +24,9 @@ from twisted.python.failure import Failure _already_patched = False +T = TypeVar("T") + + def do_patch() -> None: """ Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit @@ -37,15 +40,19 @@ def do_patch() -> None: if _already_patched: return - def new_inline_callbacks(f): + def new_inline_callbacks( + f: Callable[..., Generator["Deferred[object]", object, T]] + ) -> Callable[..., "Deferred[T]"]: @functools.wraps(f) - def wrapped(*args, **kwargs): + def wrapped(*args: Any, **kwargs: Any) -> "Deferred[T]": start_context = current_context() changes: List[str] = [] - orig = orig_inline_callbacks(_check_yield_points(f, changes)) + orig: Callable[..., "Deferred[T]"] = orig_inline_callbacks( + _check_yield_points(f, changes) + ) try: - res = orig(*args, **kwargs) + res: "Deferred[T]" = orig(*args, **kwargs) except Exception: if current_context() != start_context: for err in changes: @@ -84,7 +91,7 @@ def do_patch() -> None: print(err, file=sys.stderr) raise Exception(err) - def check_ctx(r): + def check_ctx(r: T) -> T: if current_context() != start_context: for err in changes: print(err, file=sys.stderr) @@ -107,7 +114,10 @@ def do_patch() -> None: _already_patched = True -def _check_yield_points(f: Callable, changes: List[str]) -> Callable: +def _check_yield_points( + f: Callable[..., Generator["Deferred[object]", object, T]], + changes: List[str], +) -> Callable: """Wraps a generator that is about to be passed to defer.inlineCallbacks checking that after every yield the log contexts are correct. @@ -127,7 +137,9 @@ def _check_yield_points(f: Callable, changes: List[str]) -> Callable: from synapse.logging.context import current_context @functools.wraps(f) - def check_yield_points_inner(*args, **kwargs): + def check_yield_points_inner( + *args: Any, **kwargs: Any + ) -> Generator["Deferred[object]", object, T]: gen = f(*args, **kwargs) last_yield_line_no = gen.gi_frame.f_lineno |