diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 4ff050617e..d58dc3cc29 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
import time
import unicodedata
@@ -22,6 +21,7 @@ import urllib.parse
from typing import (
TYPE_CHECKING,
Any,
+ Awaitable,
Callable,
Dict,
Iterable,
@@ -58,6 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
+from synapse.util.async_helpers import maybe_awaitable
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
@@ -851,7 +852,7 @@ class AuthHandler(BaseHandler):
async def validate_login(
self, login_submission: Dict[str, Any], ratelimit: bool = False,
- ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate auth types which don't
@@ -994,7 +995,7 @@ class AuthHandler(BaseHandler):
async def _validate_userid_login(
self, username: str, login_submission: Dict[str, Any],
- ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Helper for validate_login
Handles login, once we've mapped 3pids onto userids
@@ -1072,7 +1073,7 @@ class AuthHandler(BaseHandler):
async def check_password_provider_3pid(
self, medium: str, address: str, password: str
- ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]:
+ ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Check if a password provider is able to validate a thirdparty login
Args:
@@ -1628,6 +1629,6 @@ class PasswordProvider:
# This might return an awaitable, if it does block the log out
# until it completes.
- result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
- if inspect.isawaitable(result):
- await result
+ await maybe_awaitable(
+ g(user_id=user_id, device_id=device_id, access_token=access_token,)
+ )
|