summary refs log tree commit diff
path: root/synapse/util/patch_inline_callbacks.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/patch_inline_callbacks.py')
-rw-r--r--synapse/util/patch_inline_callbacks.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index dace68666c..f97f98a057 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -16,6 +16,8 @@ import functools
 import sys
 from typing import Any, Callable, Generator, List, TypeVar, cast
 
+from typing_extensions import ParamSpec
+
 from twisted.internet import defer
 from twisted.internet.defer import Deferred
 from twisted.python.failure import Failure
@@ -25,6 +27,7 @@ _already_patched = False
 
 
 T = TypeVar("T")
+P = ParamSpec("P")
 
 
 def do_patch() -> None:
@@ -41,13 +44,13 @@ def do_patch() -> None:
         return
 
     def new_inline_callbacks(
-        f: Callable[..., Generator["Deferred[object]", object, T]]
-    ) -> Callable[..., "Deferred[T]"]:
+        f: Callable[P, Generator["Deferred[object]", object, T]]
+    ) -> Callable[P, "Deferred[T]"]:
         @functools.wraps(f)
-        def wrapped(*args: Any, **kwargs: Any) -> "Deferred[T]":
+        def wrapped(*args: P.args, **kwargs: P.kwargs) -> "Deferred[T]":
             start_context = current_context()
             changes: List[str] = []
-            orig: Callable[..., "Deferred[T]"] = orig_inline_callbacks(
+            orig: Callable[P, "Deferred[T]"] = orig_inline_callbacks(
                 _check_yield_points(f, changes)
             )
 
@@ -115,7 +118,7 @@ def do_patch() -> None:
 
 
 def _check_yield_points(
-    f: Callable[..., Generator["Deferred[object]", object, T]],
+    f: Callable[P, Generator["Deferred[object]", object, T]],
     changes: List[str],
 ) -> Callable:
     """Wraps a generator that is about to be passed to defer.inlineCallbacks
@@ -138,7 +141,7 @@ def _check_yield_points(
 
     @functools.wraps(f)
     def check_yield_points_inner(
-        *args: Any, **kwargs: Any
+        *args: P.args, **kwargs: P.kwargs
     ) -> Generator["Deferred[object]", object, T]:
         gen = f(*args, **kwargs)