diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 11d07776b2..4be502a77b 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -1,4 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
import logging
import re
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple
from typing_extensions import TypedDict
@@ -110,7 +110,7 @@ class LoginRestServlet(RestServlet):
# counters are initialised for the auth_provider_ids.
_load_sso_handlers(hs)
- def on_GET(self, request: SynapseRequest):
+ def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
@@ -157,7 +157,7 @@ class LoginRestServlet(RestServlet):
return 200, {"flows": flows}
- async def on_POST(self, request: SynapseRequest):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
login_submission = parse_json_object_from_request(request)
if self._msc2918_enabled:
@@ -217,7 +217,7 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict,
appservice: ApplicationService,
should_issue_refresh_token: bool = False,
- ):
+ ) -> LoginResponse:
identifier = login_submission.get("identifier")
logger.info("Got appservice login request with identifier: %r", identifier)
@@ -467,10 +467,7 @@ class RefreshTokenServlet(RestServlet):
self._clock = hs.get_clock()
self.access_token_lifetime = hs.config.access_token_lifetime
- async def on_POST(
- self,
- request: SynapseRequest,
- ):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
refresh_submission = parse_json_object_from_request(request)
assert_params_in_dict(refresh_submission, ["refresh_token"])
@@ -570,7 +567,7 @@ class SsoRedirectServlet(RestServlet):
class CasTicketServlet(RestServlet):
PATTERNS = client_patterns("/login/cas/ticket", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self._cas_handler = hs.get_cas_handler()
@@ -592,7 +589,7 @@ class CasTicketServlet(RestServlet):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
LoginRestServlet(hs).register(http_server)
if hs.config.access_token_lifetime is not None:
RefreshTokenServlet(hs).register(http_server)
@@ -601,7 +598,7 @@ def register_servlets(hs, http_server):
CasTicketServlet(hs).register(http_server)
-def _load_sso_handlers(hs: "HomeServer"):
+def _load_sso_handlers(hs: "HomeServer") -> None:
"""Ensure that the SSO handlers are loaded, if they are enabled by configuration.
This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves
|