summary refs log tree commit diff
path: root/synapse/rest/client/v1/login.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v1/login.py')
-rw-r--r--synapse/rest/client/v1/login.py171
1 files changed, 148 insertions, 23 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index f6be5f1020..cbcb60fe31 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,7 +14,9 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
+
+from typing_extensions import TypedDict
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
@@ -25,6 +27,8 @@ from synapse.http import get_request_uri
 from synapse.http.server import HttpServer, finish_request
 from synapse.http.servlet import (
     RestServlet,
+    assert_params_in_dict,
+    parse_boolean,
     parse_bytes_from_args,
     parse_json_object_from_request,
     parse_string,
@@ -40,6 +44,21 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+LoginResponse = TypedDict(
+    "LoginResponse",
+    {
+        "user_id": str,
+        "access_token": str,
+        "home_server": str,
+        "expires_in_ms": Optional[int],
+        "refresh_token": Optional[str],
+        "device_id": str,
+        "well_known": Optional[Dict[str, Any]],
+    },
+    total=False,
+)
+
+
 class LoginRestServlet(RestServlet):
     PATTERNS = client_patterns("/login$", v1=True)
     CAS_TYPE = "m.login.cas"
@@ -48,6 +67,7 @@ class LoginRestServlet(RestServlet):
     JWT_TYPE = "org.matrix.login.jwt"
     JWT_TYPE_DEPRECATED = "m.login.jwt"
     APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
+    REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token"
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
@@ -65,9 +85,12 @@ class LoginRestServlet(RestServlet):
         self.cas_enabled = hs.config.cas_enabled
         self.oidc_enabled = hs.config.oidc_enabled
         self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+        self._msc2918_enabled = hs.config.access_token_lifetime is not None
 
         self.auth = hs.get_auth()
 
+        self.clock = hs.get_clock()
+
         self.auth_handler = self.hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
         self._sso_handler = hs.get_sso_handler()
@@ -138,6 +161,15 @@ class LoginRestServlet(RestServlet):
     async def on_POST(self, request: SynapseRequest):
         login_submission = parse_json_object_from_request(request)
 
+        if self._msc2918_enabled:
+            # Check if this login should also issue a refresh token, as per
+            # MSC2918
+            should_issue_refresh_token = parse_boolean(
+                request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False
+            )
+        else:
+            should_issue_refresh_token = False
+
         try:
             if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
                 appservice = self.auth.get_appservice_by_req(request)
@@ -147,19 +179,32 @@ class LoginRestServlet(RestServlet):
                         None, request.getClientIP()
                     )
 
-                result = await self._do_appservice_login(login_submission, appservice)
+                result = await self._do_appservice_login(
+                    login_submission,
+                    appservice,
+                    should_issue_refresh_token=should_issue_refresh_token,
+                )
             elif self.jwt_enabled and (
                 login_submission["type"] == LoginRestServlet.JWT_TYPE
                 or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
             ):
                 await self._address_ratelimiter.ratelimit(None, request.getClientIP())
-                result = await self._do_jwt_login(login_submission)
+                result = await self._do_jwt_login(
+                    login_submission,
+                    should_issue_refresh_token=should_issue_refresh_token,
+                )
             elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
                 await self._address_ratelimiter.ratelimit(None, request.getClientIP())
-                result = await self._do_token_login(login_submission)
+                result = await self._do_token_login(
+                    login_submission,
+                    should_issue_refresh_token=should_issue_refresh_token,
+                )
             else:
                 await self._address_ratelimiter.ratelimit(None, request.getClientIP())
-                result = await self._do_other_login(login_submission)
+                result = await self._do_other_login(
+                    login_submission,
+                    should_issue_refresh_token=should_issue_refresh_token,
+                )
         except KeyError:
             raise SynapseError(400, "Missing JSON keys.")
 
@@ -169,7 +214,10 @@ class LoginRestServlet(RestServlet):
         return 200, result
 
     async def _do_appservice_login(
-        self, login_submission: JsonDict, appservice: ApplicationService
+        self,
+        login_submission: JsonDict,
+        appservice: ApplicationService,
+        should_issue_refresh_token: bool = False,
     ):
         identifier = login_submission.get("identifier")
         logger.info("Got appservice login request with identifier: %r", identifier)
@@ -198,14 +246,21 @@ class LoginRestServlet(RestServlet):
             raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
 
         return await self._complete_login(
-            qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
+            qualified_user_id,
+            login_submission,
+            ratelimit=appservice.is_rate_limited(),
+            should_issue_refresh_token=should_issue_refresh_token,
         )
 
-    async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
+    async def _do_other_login(
+        self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+    ) -> LoginResponse:
         """Handle non-token/saml/jwt logins
 
         Args:
             login_submission:
+            should_issue_refresh_token: True if this login should issue
+                a refresh token alongside the access token.
 
         Returns:
             HTTP response
@@ -224,7 +279,10 @@ class LoginRestServlet(RestServlet):
             login_submission, ratelimit=True
         )
         result = await self._complete_login(
-            canonical_user_id, login_submission, callback
+            canonical_user_id,
+            login_submission,
+            callback,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
         return result
 
@@ -232,11 +290,12 @@ class LoginRestServlet(RestServlet):
         self,
         user_id: str,
         login_submission: JsonDict,
-        callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
+        callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None,
         create_non_existent_users: bool = False,
         ratelimit: bool = True,
         auth_provider_id: Optional[str] = None,
-    ) -> Dict[str, str]:
+        should_issue_refresh_token: bool = False,
+    ) -> LoginResponse:
         """Called when we've successfully authed the user and now need to
         actually login them in (e.g. create devices). This gets called on
         all successful logins.
@@ -253,6 +312,8 @@ class LoginRestServlet(RestServlet):
             ratelimit: Whether to ratelimit the login request.
             auth_provider_id: The SSO IdP the user used, if any (just used for the
                 prometheus metrics).
+            should_issue_refresh_token: True if this login should issue
+                a refresh token alongside the access token.
 
         Returns:
             result: Dictionary of account information after successful login.
@@ -274,28 +335,48 @@ class LoginRestServlet(RestServlet):
 
         device_id = login_submission.get("device_id")
         initial_display_name = login_submission.get("initial_device_display_name")
-        device_id, access_token = await self.registration_handler.register_device(
-            user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
+        (
+            device_id,
+            access_token,
+            valid_until_ms,
+            refresh_token,
+        ) = await self.registration_handler.register_device(
+            user_id,
+            device_id,
+            initial_display_name,
+            auth_provider_id=auth_provider_id,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
 
-        result = {
-            "user_id": user_id,
-            "access_token": access_token,
-            "home_server": self.hs.hostname,
-            "device_id": device_id,
-        }
+        result = LoginResponse(
+            user_id=user_id,
+            access_token=access_token,
+            home_server=self.hs.hostname,
+            device_id=device_id,
+        )
+
+        if valid_until_ms is not None:
+            expires_in_ms = valid_until_ms - self.clock.time_msec()
+            result["expires_in_ms"] = expires_in_ms
+
+        if refresh_token is not None:
+            result["refresh_token"] = refresh_token
 
         if callback is not None:
             await callback(result)
 
         return result
 
-    async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
+    async def _do_token_login(
+        self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+    ) -> LoginResponse:
         """
         Handle the final stage of SSO login.
 
         Args:
-             login_submission: The JSON request body.
+            login_submission: The JSON request body.
+            should_issue_refresh_token: True if this login should issue
+                a refresh token alongside the access token.
 
         Returns:
             The body of the JSON response.
@@ -309,9 +390,12 @@ class LoginRestServlet(RestServlet):
             login_submission,
             self.auth_handler._sso_login_callback,
             auth_provider_id=res.auth_provider_id,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
 
-    async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
+    async def _do_jwt_login(
+        self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+    ) -> LoginResponse:
         token = login_submission.get("token", None)
         if token is None:
             raise LoginError(
@@ -342,7 +426,10 @@ class LoginRestServlet(RestServlet):
 
         user_id = UserID(user, self.hs.hostname).to_string()
         result = await self._complete_login(
-            user_id, login_submission, create_non_existent_users=True
+            user_id,
+            login_submission,
+            create_non_existent_users=True,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
         return result
 
@@ -371,6 +458,42 @@ def _get_auth_flow_dict_for_idp(
     return e
 
 
+class RefreshTokenServlet(RestServlet):
+    PATTERNS = client_patterns(
+        "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True
+    )
+
+    def __init__(self, hs: "HomeServer"):
+        self._auth_handler = hs.get_auth_handler()
+        self._clock = hs.get_clock()
+        self.access_token_lifetime = hs.config.access_token_lifetime
+
+    async def on_POST(
+        self,
+        request: SynapseRequest,
+    ):
+        refresh_submission = parse_json_object_from_request(request)
+
+        assert_params_in_dict(refresh_submission, ["refresh_token"])
+        token = refresh_submission["refresh_token"]
+        if not isinstance(token, str):
+            raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
+
+        valid_until_ms = self._clock.time_msec() + self.access_token_lifetime
+        access_token, refresh_token = await self._auth_handler.refresh_token(
+            token, valid_until_ms
+        )
+        expires_in_ms = valid_until_ms - self._clock.time_msec()
+        return (
+            200,
+            {
+                "access_token": access_token,
+                "refresh_token": refresh_token,
+                "expires_in_ms": expires_in_ms,
+            },
+        )
+
+
 class SsoRedirectServlet(RestServlet):
     PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
         re.compile(
@@ -477,6 +600,8 @@ class CasTicketServlet(RestServlet):
 
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
+    if hs.config.access_token_lifetime is not None:
+        RefreshTokenServlet(hs).register(http_server)
     SsoRedirectServlet(hs).register(http_server)
     if hs.config.cas_enabled:
         CasTicketServlet(hs).register(http_server)