diff options
Diffstat (limited to 'synapse/rest/client/v1/login.py')
-rw-r--r-- | synapse/rest/client/v1/login.py | 171 |
1 files changed, 148 insertions, 23 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index f6be5f1020..cbcb60fe31 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -14,7 +14,9 @@ import logging import re -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional + +from typing_extensions import TypedDict from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter @@ -25,6 +27,8 @@ from synapse.http import get_request_uri from synapse.http.server import HttpServer, finish_request from synapse.http.servlet import ( RestServlet, + assert_params_in_dict, + parse_boolean, parse_bytes_from_args, parse_json_object_from_request, parse_string, @@ -40,6 +44,21 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +LoginResponse = TypedDict( + "LoginResponse", + { + "user_id": str, + "access_token": str, + "home_server": str, + "expires_in_ms": Optional[int], + "refresh_token": Optional[str], + "device_id": str, + "well_known": Optional[Dict[str, Any]], + }, + total=False, +) + + class LoginRestServlet(RestServlet): PATTERNS = client_patterns("/login$", v1=True) CAS_TYPE = "m.login.cas" @@ -48,6 +67,7 @@ class LoginRestServlet(RestServlet): JWT_TYPE = "org.matrix.login.jwt" JWT_TYPE_DEPRECATED = "m.login.jwt" APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" + REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token" def __init__(self, hs: "HomeServer"): super().__init__() @@ -65,9 +85,12 @@ class LoginRestServlet(RestServlet): self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled self._msc2858_enabled = hs.config.experimental.msc2858_enabled + self._msc2918_enabled = hs.config.access_token_lifetime is not None self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self._sso_handler = hs.get_sso_handler() @@ -138,6 +161,15 @@ class LoginRestServlet(RestServlet): async def on_POST(self, request: SynapseRequest): login_submission = parse_json_object_from_request(request) + if self._msc2918_enabled: + # Check if this login should also issue a refresh token, as per + # MSC2918 + should_issue_refresh_token = parse_boolean( + request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False + ) + else: + should_issue_refresh_token = False + try: if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: appservice = self.auth.get_appservice_by_req(request) @@ -147,19 +179,32 @@ class LoginRestServlet(RestServlet): None, request.getClientIP() ) - result = await self._do_appservice_login(login_submission, appservice) + result = await self._do_appservice_login( + login_submission, + appservice, + should_issue_refresh_token=should_issue_refresh_token, + ) elif self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): await self._address_ratelimiter.ratelimit(None, request.getClientIP()) - result = await self._do_jwt_login(login_submission) + result = await self._do_jwt_login( + login_submission, + should_issue_refresh_token=should_issue_refresh_token, + ) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: await self._address_ratelimiter.ratelimit(None, request.getClientIP()) - result = await self._do_token_login(login_submission) + result = await self._do_token_login( + login_submission, + should_issue_refresh_token=should_issue_refresh_token, + ) else: await self._address_ratelimiter.ratelimit(None, request.getClientIP()) - result = await self._do_other_login(login_submission) + result = await self._do_other_login( + login_submission, + should_issue_refresh_token=should_issue_refresh_token, + ) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -169,7 +214,10 @@ class LoginRestServlet(RestServlet): return 200, result async def _do_appservice_login( - self, login_submission: JsonDict, appservice: ApplicationService + self, + login_submission: JsonDict, + appservice: ApplicationService, + should_issue_refresh_token: bool = False, ): identifier = login_submission.get("identifier") logger.info("Got appservice login request with identifier: %r", identifier) @@ -198,14 +246,21 @@ class LoginRestServlet(RestServlet): raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN) return await self._complete_login( - qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited() + qualified_user_id, + login_submission, + ratelimit=appservice.is_rate_limited(), + should_issue_refresh_token=should_issue_refresh_token, ) - async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]: + async def _do_other_login( + self, login_submission: JsonDict, should_issue_refresh_token: bool = False + ) -> LoginResponse: """Handle non-token/saml/jwt logins Args: login_submission: + should_issue_refresh_token: True if this login should issue + a refresh token alongside the access token. Returns: HTTP response @@ -224,7 +279,10 @@ class LoginRestServlet(RestServlet): login_submission, ratelimit=True ) result = await self._complete_login( - canonical_user_id, login_submission, callback + canonical_user_id, + login_submission, + callback, + should_issue_refresh_token=should_issue_refresh_token, ) return result @@ -232,11 +290,12 @@ class LoginRestServlet(RestServlet): self, user_id: str, login_submission: JsonDict, - callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None, + callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None, create_non_existent_users: bool = False, ratelimit: bool = True, auth_provider_id: Optional[str] = None, - ) -> Dict[str, str]: + should_issue_refresh_token: bool = False, + ) -> LoginResponse: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on all successful logins. @@ -253,6 +312,8 @@ class LoginRestServlet(RestServlet): ratelimit: Whether to ratelimit the login request. auth_provider_id: The SSO IdP the user used, if any (just used for the prometheus metrics). + should_issue_refresh_token: True if this login should issue + a refresh token alongside the access token. Returns: result: Dictionary of account information after successful login. @@ -274,28 +335,48 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") - device_id, access_token = await self.registration_handler.register_device( - user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id + ( + device_id, + access_token, + valid_until_ms, + refresh_token, + ) = await self.registration_handler.register_device( + user_id, + device_id, + initial_display_name, + auth_provider_id=auth_provider_id, + should_issue_refresh_token=should_issue_refresh_token, ) - result = { - "user_id": user_id, - "access_token": access_token, - "home_server": self.hs.hostname, - "device_id": device_id, - } + result = LoginResponse( + user_id=user_id, + access_token=access_token, + home_server=self.hs.hostname, + device_id=device_id, + ) + + if valid_until_ms is not None: + expires_in_ms = valid_until_ms - self.clock.time_msec() + result["expires_in_ms"] = expires_in_ms + + if refresh_token is not None: + result["refresh_token"] = refresh_token if callback is not None: await callback(result) return result - async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]: + async def _do_token_login( + self, login_submission: JsonDict, should_issue_refresh_token: bool = False + ) -> LoginResponse: """ Handle the final stage of SSO login. Args: - login_submission: The JSON request body. + login_submission: The JSON request body. + should_issue_refresh_token: True if this login should issue + a refresh token alongside the access token. Returns: The body of the JSON response. @@ -309,9 +390,12 @@ class LoginRestServlet(RestServlet): login_submission, self.auth_handler._sso_login_callback, auth_provider_id=res.auth_provider_id, + should_issue_refresh_token=should_issue_refresh_token, ) - async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]: + async def _do_jwt_login( + self, login_submission: JsonDict, should_issue_refresh_token: bool = False + ) -> LoginResponse: token = login_submission.get("token", None) if token is None: raise LoginError( @@ -342,7 +426,10 @@ class LoginRestServlet(RestServlet): user_id = UserID(user, self.hs.hostname).to_string() result = await self._complete_login( - user_id, login_submission, create_non_existent_users=True + user_id, + login_submission, + create_non_existent_users=True, + should_issue_refresh_token=should_issue_refresh_token, ) return result @@ -371,6 +458,42 @@ def _get_auth_flow_dict_for_idp( return e +class RefreshTokenServlet(RestServlet): + PATTERNS = client_patterns( + "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True + ) + + def __init__(self, hs: "HomeServer"): + self._auth_handler = hs.get_auth_handler() + self._clock = hs.get_clock() + self.access_token_lifetime = hs.config.access_token_lifetime + + async def on_POST( + self, + request: SynapseRequest, + ): + refresh_submission = parse_json_object_from_request(request) + + assert_params_in_dict(refresh_submission, ["refresh_token"]) + token = refresh_submission["refresh_token"] + if not isinstance(token, str): + raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM) + + valid_until_ms = self._clock.time_msec() + self.access_token_lifetime + access_token, refresh_token = await self._auth_handler.refresh_token( + token, valid_until_ms + ) + expires_in_ms = valid_until_ms - self._clock.time_msec() + return ( + 200, + { + "access_token": access_token, + "refresh_token": refresh_token, + "expires_in_ms": expires_in_ms, + }, + ) + + class SsoRedirectServlet(RestServlet): PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [ re.compile( @@ -477,6 +600,8 @@ class CasTicketServlet(RestServlet): def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) + if hs.config.access_token_lifetime is not None: + RefreshTokenServlet(hs).register(http_server) SsoRedirectServlet(hs).register(http_server) if hs.config.cas_enabled: CasTicketServlet(hs).register(http_server) |