diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py
index 0443f4571c..a0971ce994 100644
--- a/synapse/rest/client/_base.py
+++ b/synapse/rest/client/_base.py
@@ -16,7 +16,7 @@
"""
import logging
import re
-from typing import Iterable, Pattern
+from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, cast
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX
@@ -76,7 +76,10 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int)
)
-def interactive_auth_handler(orig):
+C = TypeVar("C", bound=Callable[..., Awaitable[Tuple[int, JsonDict]]])
+
+
+def interactive_auth_handler(orig: C) -> C:
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
Takes a on_POST method which returns an Awaitable (errcode, body) response
@@ -91,10 +94,10 @@ def interactive_auth_handler(orig):
await self.auth_handler.check_auth
"""
- async def wrapped(*args, **kwargs):
+ async def wrapped(*args: Any, **kwargs: Any) -> Tuple[int, JsonDict]:
try:
return await orig(*args, **kwargs)
except InteractiveAuthIncompleteError as e:
return 401, e.result
- return wrapped
+ return cast(C, wrapped)
|