summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
authorQuentin Gliech <quentingliech@gmail.com>2021-06-24 15:33:20 +0200
committerGitHub <noreply@github.com>2021-06-24 14:33:20 +0100
commitbd4919fb72b2a75f1c0a7f0c78bd619fd2ae30e8 (patch)
tree04a988e47720e9c58c99f05b74121e03ebe1f5f4 /synapse/rest
parentMerge tag 'v1.37.0rc1' into develop (diff)
downloadsynapse-bd4919fb72b2a75f1c0a7f0c78bd619fd2ae30e8.tar.xz
MSC2918 Refresh tokens implementation (#9450)
This implements refresh tokens, as defined by MSC2918

This MSC has been implemented client side in Hydrogen Web: vector-im/hydrogen-web#235

The basics of the MSC works: requesting refresh tokens on login, having the access tokens expire, and using the refresh token to get a new one.

Signed-off-by: Quentin Gliech <quentingliech@gmail.com>
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/client/v1/login.py171
-rw-r--r--synapse/rest/client/v2_alpha/register.py88
2 files changed, 218 insertions, 41 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index f6be5f1020..cbcb60fe31 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,7 +14,9 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
+
+from typing_extensions import TypedDict
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
@@ -25,6 +27,8 @@ from synapse.http import get_request_uri
 from synapse.http.server import HttpServer, finish_request
 from synapse.http.servlet import (
     RestServlet,
+    assert_params_in_dict,
+    parse_boolean,
     parse_bytes_from_args,
     parse_json_object_from_request,
     parse_string,
@@ -40,6 +44,21 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+LoginResponse = TypedDict(
+    "LoginResponse",
+    {
+        "user_id": str,
+        "access_token": str,
+        "home_server": str,
+        "expires_in_ms": Optional[int],
+        "refresh_token": Optional[str],
+        "device_id": str,
+        "well_known": Optional[Dict[str, Any]],
+    },
+    total=False,
+)
+
+
 class LoginRestServlet(RestServlet):
     PATTERNS = client_patterns("/login$", v1=True)
     CAS_TYPE = "m.login.cas"
@@ -48,6 +67,7 @@ class LoginRestServlet(RestServlet):
     JWT_TYPE = "org.matrix.login.jwt"
     JWT_TYPE_DEPRECATED = "m.login.jwt"
     APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
+    REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token"
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
@@ -65,9 +85,12 @@ class LoginRestServlet(RestServlet):
         self.cas_enabled = hs.config.cas_enabled
         self.oidc_enabled = hs.config.oidc_enabled
         self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+        self._msc2918_enabled = hs.config.access_token_lifetime is not None
 
         self.auth = hs.get_auth()
 
+        self.clock = hs.get_clock()
+
         self.auth_handler = self.hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
         self._sso_handler = hs.get_sso_handler()
@@ -138,6 +161,15 @@ class LoginRestServlet(RestServlet):
     async def on_POST(self, request: SynapseRequest):
         login_submission = parse_json_object_from_request(request)
 
+        if self._msc2918_enabled:
+            # Check if this login should also issue a refresh token, as per
+            # MSC2918
+            should_issue_refresh_token = parse_boolean(
+                request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False
+            )
+        else:
+            should_issue_refresh_token = False
+
         try:
             if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
                 appservice = self.auth.get_appservice_by_req(request)
@@ -147,19 +179,32 @@ class LoginRestServlet(RestServlet):
                         None, request.getClientIP()
                     )
 
-                result = await self._do_appservice_login(login_submission, appservice)
+                result = await self._do_appservice_login(
+                    login_submission,
+                    appservice,
+                    should_issue_refresh_token=should_issue_refresh_token,
+                )
             elif self.jwt_enabled and (
                 login_submission["type"] == LoginRestServlet.JWT_TYPE
                 or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
             ):
                 await self._address_ratelimiter.ratelimit(None, request.getClientIP())
-                result = await self._do_jwt_login(login_submission)
+                result = await self._do_jwt_login(
+                    login_submission,
+                    should_issue_refresh_token=should_issue_refresh_token,
+                )
             elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
                 await self._address_ratelimiter.ratelimit(None, request.getClientIP())
-                result = await self._do_token_login(login_submission)
+                result = await self._do_token_login(
+                    login_submission,
+                    should_issue_refresh_token=should_issue_refresh_token,
+                )
             else:
                 await self._address_ratelimiter.ratelimit(None, request.getClientIP())
-                result = await self._do_other_login(login_submission)
+                result = await self._do_other_login(
+                    login_submission,
+                    should_issue_refresh_token=should_issue_refresh_token,
+                )
         except KeyError:
             raise SynapseError(400, "Missing JSON keys.")
 
@@ -169,7 +214,10 @@ class LoginRestServlet(RestServlet):
         return 200, result
 
     async def _do_appservice_login(
-        self, login_submission: JsonDict, appservice: ApplicationService
+        self,
+        login_submission: JsonDict,
+        appservice: ApplicationService,
+        should_issue_refresh_token: bool = False,
     ):
         identifier = login_submission.get("identifier")
         logger.info("Got appservice login request with identifier: %r", identifier)
@@ -198,14 +246,21 @@ class LoginRestServlet(RestServlet):
             raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
 
         return await self._complete_login(
-            qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
+            qualified_user_id,
+            login_submission,
+            ratelimit=appservice.is_rate_limited(),
+            should_issue_refresh_token=should_issue_refresh_token,
         )
 
-    async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
+    async def _do_other_login(
+        self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+    ) -> LoginResponse:
         """Handle non-token/saml/jwt logins
 
         Args:
             login_submission:
+            should_issue_refresh_token: True if this login should issue
+                a refresh token alongside the access token.
 
         Returns:
             HTTP response
@@ -224,7 +279,10 @@ class LoginRestServlet(RestServlet):
             login_submission, ratelimit=True
         )
         result = await self._complete_login(
-            canonical_user_id, login_submission, callback
+            canonical_user_id,
+            login_submission,
+            callback,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
         return result
 
@@ -232,11 +290,12 @@ class LoginRestServlet(RestServlet):
         self,
         user_id: str,
         login_submission: JsonDict,
-        callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
+        callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None,
         create_non_existent_users: bool = False,
         ratelimit: bool = True,
         auth_provider_id: Optional[str] = None,
-    ) -> Dict[str, str]:
+        should_issue_refresh_token: bool = False,
+    ) -> LoginResponse:
         """Called when we've successfully authed the user and now need to
         actually login them in (e.g. create devices). This gets called on
         all successful logins.
@@ -253,6 +312,8 @@ class LoginRestServlet(RestServlet):
             ratelimit: Whether to ratelimit the login request.
             auth_provider_id: The SSO IdP the user used, if any (just used for the
                 prometheus metrics).
+            should_issue_refresh_token: True if this login should issue
+                a refresh token alongside the access token.
 
         Returns:
             result: Dictionary of account information after successful login.
@@ -274,28 +335,48 @@ class LoginRestServlet(RestServlet):
 
         device_id = login_submission.get("device_id")
         initial_display_name = login_submission.get("initial_device_display_name")
-        device_id, access_token = await self.registration_handler.register_device(
-            user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
+        (
+            device_id,
+            access_token,
+            valid_until_ms,
+            refresh_token,
+        ) = await self.registration_handler.register_device(
+            user_id,
+            device_id,
+            initial_display_name,
+            auth_provider_id=auth_provider_id,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
 
-        result = {
-            "user_id": user_id,
-            "access_token": access_token,
-            "home_server": self.hs.hostname,
-            "device_id": device_id,
-        }
+        result = LoginResponse(
+            user_id=user_id,
+            access_token=access_token,
+            home_server=self.hs.hostname,
+            device_id=device_id,
+        )
+
+        if valid_until_ms is not None:
+            expires_in_ms = valid_until_ms - self.clock.time_msec()
+            result["expires_in_ms"] = expires_in_ms
+
+        if refresh_token is not None:
+            result["refresh_token"] = refresh_token
 
         if callback is not None:
             await callback(result)
 
         return result
 
-    async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
+    async def _do_token_login(
+        self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+    ) -> LoginResponse:
         """
         Handle the final stage of SSO login.
 
         Args:
-             login_submission: The JSON request body.
+            login_submission: The JSON request body.
+            should_issue_refresh_token: True if this login should issue
+                a refresh token alongside the access token.
 
         Returns:
             The body of the JSON response.
@@ -309,9 +390,12 @@ class LoginRestServlet(RestServlet):
             login_submission,
             self.auth_handler._sso_login_callback,
             auth_provider_id=res.auth_provider_id,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
 
-    async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
+    async def _do_jwt_login(
+        self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+    ) -> LoginResponse:
         token = login_submission.get("token", None)
         if token is None:
             raise LoginError(
@@ -342,7 +426,10 @@ class LoginRestServlet(RestServlet):
 
         user_id = UserID(user, self.hs.hostname).to_string()
         result = await self._complete_login(
-            user_id, login_submission, create_non_existent_users=True
+            user_id,
+            login_submission,
+            create_non_existent_users=True,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
         return result
 
@@ -371,6 +458,42 @@ def _get_auth_flow_dict_for_idp(
     return e
 
 
+class RefreshTokenServlet(RestServlet):
+    PATTERNS = client_patterns(
+        "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True
+    )
+
+    def __init__(self, hs: "HomeServer"):
+        self._auth_handler = hs.get_auth_handler()
+        self._clock = hs.get_clock()
+        self.access_token_lifetime = hs.config.access_token_lifetime
+
+    async def on_POST(
+        self,
+        request: SynapseRequest,
+    ):
+        refresh_submission = parse_json_object_from_request(request)
+
+        assert_params_in_dict(refresh_submission, ["refresh_token"])
+        token = refresh_submission["refresh_token"]
+        if not isinstance(token, str):
+            raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
+
+        valid_until_ms = self._clock.time_msec() + self.access_token_lifetime
+        access_token, refresh_token = await self._auth_handler.refresh_token(
+            token, valid_until_ms
+        )
+        expires_in_ms = valid_until_ms - self._clock.time_msec()
+        return (
+            200,
+            {
+                "access_token": access_token,
+                "refresh_token": refresh_token,
+                "expires_in_ms": expires_in_ms,
+            },
+        )
+
+
 class SsoRedirectServlet(RestServlet):
     PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
         re.compile(
@@ -477,6 +600,8 @@ class CasTicketServlet(RestServlet):
 
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
+    if hs.config.access_token_lifetime is not None:
+        RefreshTokenServlet(hs).register(http_server)
     SsoRedirectServlet(hs).register(http_server)
     if hs.config.cas_enabled:
         CasTicketServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index a30a5df1b1..4d31584acd 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -41,11 +41,13 @@ from synapse.http.server import finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
+    parse_boolean,
     parse_json_object_from_request,
     parse_string,
 )
 from synapse.metrics import threepid_send_requests
 from synapse.push.mailer import Mailer
+from synapse.types import JsonDict
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -399,6 +401,7 @@ class RegisterRestServlet(RestServlet):
         self.password_policy_handler = hs.get_password_policy_handler()
         self.clock = hs.get_clock()
         self._registration_enabled = self.hs.config.enable_registration
+        self._msc2918_enabled = hs.config.access_token_lifetime is not None
 
         self._registration_flows = _calculate_registration_flows(
             hs.config, self.auth_handler
@@ -424,6 +427,15 @@ class RegisterRestServlet(RestServlet):
                 "Do not understand membership kind: %s" % (kind.decode("utf8"),)
             )
 
+        if self._msc2918_enabled:
+            # Check if this registration should also issue a refresh token, as
+            # per MSC2918
+            should_issue_refresh_token = parse_boolean(
+                request, name="org.matrix.msc2918.refresh_token", default=False
+            )
+        else:
+            should_issue_refresh_token = False
+
         # Pull out the provided username and do basic sanity checks early since
         # the auth layer will store these in sessions.
         desired_username = None
@@ -462,7 +474,10 @@ class RegisterRestServlet(RestServlet):
                 raise SynapseError(400, "Desired Username is missing or not a string")
 
             result = await self._do_appservice_registration(
-                desired_username, access_token, body
+                desired_username,
+                access_token,
+                body,
+                should_issue_refresh_token=should_issue_refresh_token,
             )
 
             return 200, result
@@ -665,7 +680,9 @@ class RegisterRestServlet(RestServlet):
             registered = True
 
         return_dict = await self._create_registration_details(
-            registered_user_id, params
+            registered_user_id,
+            params,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
 
         if registered:
@@ -677,7 +694,9 @@ class RegisterRestServlet(RestServlet):
 
         return 200, return_dict
 
-    async def _do_appservice_registration(self, username, as_token, body):
+    async def _do_appservice_registration(
+        self, username, as_token, body, should_issue_refresh_token: bool = False
+    ):
         user_id = await self.registration_handler.appservice_register(
             username, as_token
         )
@@ -685,19 +704,27 @@ class RegisterRestServlet(RestServlet):
             user_id,
             body,
             is_appservice_ghost=True,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
 
     async def _create_registration_details(
-        self, user_id, params, is_appservice_ghost=False
+        self,
+        user_id: str,
+        params: JsonDict,
+        is_appservice_ghost: bool = False,
+        should_issue_refresh_token: bool = False,
     ):
         """Complete registration of newly-registered user
 
         Allocates device_id if one was not given; also creates access_token.
 
         Args:
-            (str) user_id: full canonical @user:id
-            (object) params: registration parameters, from which we pull
-                device_id, initial_device_name and inhibit_login
+            user_id: full canonical @user:id
+            params: registration parameters, from which we pull device_id,
+                initial_device_name and inhibit_login
+            is_appservice_ghost
+            should_issue_refresh_token: True if this registration should issue
+                a refresh token alongside the access token.
         Returns:
              dictionary for response from /register
         """
@@ -705,15 +732,29 @@ class RegisterRestServlet(RestServlet):
         if not params.get("inhibit_login", False):
             device_id = params.get("device_id")
             initial_display_name = params.get("initial_device_display_name")
-            device_id, access_token = await self.registration_handler.register_device(
+            (
+                device_id,
+                access_token,
+                valid_until_ms,
+                refresh_token,
+            ) = await self.registration_handler.register_device(
                 user_id,
                 device_id,
                 initial_display_name,
                 is_guest=False,
                 is_appservice_ghost=is_appservice_ghost,
+                should_issue_refresh_token=should_issue_refresh_token,
             )
 
             result.update({"access_token": access_token, "device_id": device_id})
+
+            if valid_until_ms is not None:
+                expires_in_ms = valid_until_ms - self.clock.time_msec()
+                result["expires_in_ms"] = expires_in_ms
+
+            if refresh_token is not None:
+                result["refresh_token"] = refresh_token
+
         return result
 
     async def _do_guest_registration(self, params, address=None):
@@ -727,19 +768,30 @@ class RegisterRestServlet(RestServlet):
         # we have nowhere to store it.
         device_id = synapse.api.auth.GUEST_DEVICE_ID
         initial_display_name = params.get("initial_device_display_name")
-        device_id, access_token = await self.registration_handler.register_device(
+        (
+            device_id,
+            access_token,
+            valid_until_ms,
+            refresh_token,
+        ) = await self.registration_handler.register_device(
             user_id, device_id, initial_display_name, is_guest=True
         )
 
-        return (
-            200,
-            {
-                "user_id": user_id,
-                "device_id": device_id,
-                "access_token": access_token,
-                "home_server": self.hs.hostname,
-            },
-        )
+        result = {
+            "user_id": user_id,
+            "device_id": device_id,
+            "access_token": access_token,
+            "home_server": self.hs.hostname,
+        }
+
+        if valid_until_ms is not None:
+            expires_in_ms = valid_until_ms - self.clock.time_msec()
+            result["expires_in_ms"] = expires_in_ms
+
+        if refresh_token is not None:
+            result["refresh_token"] = refresh_token
+
+        return 200, result
 
 
 def _calculate_registration_flows(