summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
authorQuentin Gliech <quentingliech@gmail.com>2021-06-24 15:33:20 +0200
committerGitHub <noreply@github.com>2021-06-24 14:33:20 +0100
commitbd4919fb72b2a75f1c0a7f0c78bd619fd2ae30e8 (patch)
tree04a988e47720e9c58c99f05b74121e03ebe1f5f4 /synapse/handlers
parentMerge tag 'v1.37.0rc1' into develop (diff)
downloadsynapse-bd4919fb72b2a75f1c0a7f0c78bd619fd2ae30e8.tar.xz
MSC2918 Refresh tokens implementation (#9450)
This implements refresh tokens, as defined by MSC2918

This MSC has been implemented client side in Hydrogen Web: vector-im/hydrogen-web#235

The basics of the MSC works: requesting refresh tokens on login, having the access tokens expire, and using the refresh token to get a new one.

Signed-off-by: Quentin Gliech <quentingliech@gmail.com>
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/auth.py132
-rw-r--r--synapse/handlers/register.py52
2 files changed, 173 insertions, 11 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 1971e373ed..e2ac595a62 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -30,6 +30,7 @@ from typing import (
     Optional,
     Tuple,
     Union,
+    cast,
 )
 
 import attr
@@ -72,6 +73,7 @@ from synapse.util.stringutils import base62_encode
 from synapse.util.threepids import canonicalise_email
 
 if TYPE_CHECKING:
+    from synapse.rest.client.v1.login import LoginResponse
     from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
@@ -777,6 +779,108 @@ class AuthHandler(BaseHandler):
             "params": params,
         }
 
+    async def refresh_token(
+        self,
+        refresh_token: str,
+        valid_until_ms: Optional[int],
+    ) -> Tuple[str, str]:
+        """
+        Consumes a refresh token and generate both a new access token and a new refresh token from it.
+
+        The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
+
+        Args:
+            refresh_token: The token to consume.
+            valid_until_ms: The expiration timestamp of the new access token.
+
+        Returns:
+            A tuple containing the new access token and refresh token
+        """
+
+        # Verify the token signature first before looking up the token
+        if not self._verify_refresh_token(refresh_token):
+            raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
+
+        existing_token = await self.store.lookup_refresh_token(refresh_token)
+        if existing_token is None:
+            raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
+
+        if (
+            existing_token.has_next_access_token_been_used
+            or existing_token.has_next_refresh_token_been_refreshed
+        ):
+            raise SynapseError(
+                403, "refresh token isn't valid anymore", Codes.FORBIDDEN
+            )
+
+        (
+            new_refresh_token,
+            new_refresh_token_id,
+        ) = await self.get_refresh_token_for_user_id(
+            user_id=existing_token.user_id, device_id=existing_token.device_id
+        )
+        access_token = await self.get_access_token_for_user_id(
+            user_id=existing_token.user_id,
+            device_id=existing_token.device_id,
+            valid_until_ms=valid_until_ms,
+            refresh_token_id=new_refresh_token_id,
+        )
+        await self.store.replace_refresh_token(
+            existing_token.token_id, new_refresh_token_id
+        )
+        return access_token, new_refresh_token
+
+    def _verify_refresh_token(self, token: str) -> bool:
+        """
+        Verifies the shape of a refresh token.
+
+        Args:
+            token: The refresh token to verify
+
+        Returns:
+            Whether the token has the right shape
+        """
+        parts = token.split("_", maxsplit=4)
+        if len(parts) != 4:
+            return False
+
+        type, localpart, rand, crc = parts
+
+        # Refresh tokens are prefixed by "syr_", let's check that
+        if type != "syr":
+            return False
+
+        # Check the CRC
+        base = f"{type}_{localpart}_{rand}"
+        expected_crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
+        if crc != expected_crc:
+            return False
+
+        return True
+
+    async def get_refresh_token_for_user_id(
+        self,
+        user_id: str,
+        device_id: str,
+    ) -> Tuple[str, int]:
+        """
+        Creates a new refresh token for the user with the given user ID.
+
+        Args:
+            user_id: canonical user ID
+            device_id: the device ID to associate with the token.
+
+        Returns:
+            The newly created refresh token and its ID in the database
+        """
+        refresh_token = self.generate_refresh_token(UserID.from_string(user_id))
+        refresh_token_id = await self.store.add_refresh_token_to_user(
+            user_id=user_id,
+            token=refresh_token,
+            device_id=device_id,
+        )
+        return refresh_token, refresh_token_id
+
     async def get_access_token_for_user_id(
         self,
         user_id: str,
@@ -784,6 +888,7 @@ class AuthHandler(BaseHandler):
         valid_until_ms: Optional[int],
         puppets_user_id: Optional[str] = None,
         is_appservice_ghost: bool = False,
+        refresh_token_id: Optional[int] = None,
     ) -> str:
         """
         Creates a new access token for the user with the given user ID.
@@ -801,6 +906,8 @@ class AuthHandler(BaseHandler):
             valid_until_ms: when the token is valid until. None for
                 no expiry.
             is_appservice_ghost: Whether the user is an application ghost user
+            refresh_token_id: the refresh token ID that will be associated with
+                this access token.
         Returns:
               The access token for the user's session.
         Raises:
@@ -836,6 +943,7 @@ class AuthHandler(BaseHandler):
             device_id=device_id,
             valid_until_ms=valid_until_ms,
             puppets_user_id=puppets_user_id,
+            refresh_token_id=refresh_token_id,
         )
 
         # the device *should* have been registered before we got here; however,
@@ -928,7 +1036,7 @@ class AuthHandler(BaseHandler):
         self,
         login_submission: Dict[str, Any],
         ratelimit: bool = False,
-    ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
+    ) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
         """Authenticates the user for the /login API
 
         Also used by the user-interactive auth flow to validate auth types which don't
@@ -1073,7 +1181,7 @@ class AuthHandler(BaseHandler):
         self,
         username: str,
         login_submission: Dict[str, Any],
-    ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
+    ) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
         """Helper for validate_login
 
         Handles login, once we've mapped 3pids onto userids
@@ -1151,7 +1259,7 @@ class AuthHandler(BaseHandler):
 
     async def check_password_provider_3pid(
         self, medium: str, address: str, password: str
-    ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
+    ) -> Tuple[Optional[str], Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
         """Check if a password provider is able to validate a thirdparty login
 
         Args:
@@ -1215,6 +1323,19 @@ class AuthHandler(BaseHandler):
         crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
         return f"{base}_{crc}"
 
+    def generate_refresh_token(self, for_user: UserID) -> str:
+        """Generates an opaque string, for use as a refresh token"""
+
+        # we use the following format for refresh tokens:
+        #    syr_<base64 local part>_<random string>_<base62 crc check>
+
+        b64local = unpaddedbase64.encode_base64(for_user.localpart.encode("utf-8"))
+        random_string = stringutils.random_string(20)
+        base = f"syr_{b64local}_{random_string}"
+
+        crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
+        return f"{base}_{crc}"
+
     async def validate_short_term_login_token(
         self, login_token: str
     ) -> LoginTokenAttributes:
@@ -1563,7 +1684,7 @@ class AuthHandler(BaseHandler):
         )
         respond_with_html(request, 200, html)
 
-    async def _sso_login_callback(self, login_result: JsonDict) -> None:
+    async def _sso_login_callback(self, login_result: "LoginResponse") -> None:
         """
         A login callback which might add additional attributes to the login response.
 
@@ -1577,7 +1698,8 @@ class AuthHandler(BaseHandler):
 
         extra_attributes = self._extra_attributes.get(login_result["user_id"])
         if extra_attributes:
-            login_result.update(extra_attributes.extra_attributes)
+            login_result_dict = cast(Dict[str, Any], login_result)
+            login_result_dict.update(extra_attributes.extra_attributes)
 
     def _expire_sso_extra_attributes(self) -> None:
         """
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 4b4b579741..26ef016179 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,9 +15,10 @@
 """Contains functions for registering clients."""
 
 import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
 from prometheus_client import Counter
+from typing_extensions import TypedDict
 
 from synapse import types
 from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
@@ -54,6 +55,16 @@ login_counter = Counter(
     ["guest", "auth_provider"],
 )
 
+LoginDict = TypedDict(
+    "LoginDict",
+    {
+        "device_id": str,
+        "access_token": str,
+        "valid_until_ms": Optional[int],
+        "refresh_token": Optional[str],
+    },
+)
+
 
 class RegistrationHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
@@ -85,6 +96,7 @@ class RegistrationHandler(BaseHandler):
             self.pusher_pool = hs.get_pusherpool()
 
         self.session_lifetime = hs.config.session_lifetime
+        self.access_token_lifetime = hs.config.access_token_lifetime
 
     async def check_username(
         self,
@@ -696,7 +708,8 @@ class RegistrationHandler(BaseHandler):
         is_guest: bool = False,
         is_appservice_ghost: bool = False,
         auth_provider_id: Optional[str] = None,
-    ) -> Tuple[str, str]:
+        should_issue_refresh_token: bool = False,
+    ) -> Tuple[str, str, Optional[int], Optional[str]]:
         """Register a device for a user and generate an access token.
 
         The access token will be limited by the homeserver's session_lifetime config.
@@ -708,8 +721,9 @@ class RegistrationHandler(BaseHandler):
             is_guest: Whether this is a guest account
             auth_provider_id: The SSO IdP the user used, if any (just used for the
                 prometheus metrics).
+            should_issue_refresh_token: Whether it should also issue a refresh token
         Returns:
-            Tuple of device ID and access token
+            Tuple of device ID, access token, access token expiration time and refresh token
         """
         res = await self._register_device_client(
             user_id=user_id,
@@ -717,6 +731,7 @@ class RegistrationHandler(BaseHandler):
             initial_display_name=initial_display_name,
             is_guest=is_guest,
             is_appservice_ghost=is_appservice_ghost,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
 
         login_counter.labels(
@@ -724,7 +739,12 @@ class RegistrationHandler(BaseHandler):
             auth_provider=(auth_provider_id or ""),
         ).inc()
 
-        return res["device_id"], res["access_token"]
+        return (
+            res["device_id"],
+            res["access_token"],
+            res["valid_until_ms"],
+            res["refresh_token"],
+        )
 
     async def register_device_inner(
         self,
@@ -733,7 +753,8 @@ class RegistrationHandler(BaseHandler):
         initial_display_name: Optional[str],
         is_guest: bool = False,
         is_appservice_ghost: bool = False,
-    ) -> Dict[str, str]:
+        should_issue_refresh_token: bool = False,
+    ) -> LoginDict:
         """Helper for register_device
 
         Does the bits that need doing on the main process. Not for use outside this
@@ -748,6 +769,9 @@ class RegistrationHandler(BaseHandler):
                 )
             valid_until_ms = self.clock.time_msec() + self.session_lifetime
 
+        refresh_token = None
+        refresh_token_id = None
+
         registered_device_id = await self.device_handler.check_device_registered(
             user_id, device_id, initial_display_name
         )
@@ -755,14 +779,30 @@ class RegistrationHandler(BaseHandler):
             assert valid_until_ms is None
             access_token = self.macaroon_gen.generate_guest_access_token(user_id)
         else:
+            if should_issue_refresh_token:
+                (
+                    refresh_token,
+                    refresh_token_id,
+                ) = await self._auth_handler.get_refresh_token_for_user_id(
+                    user_id,
+                    device_id=registered_device_id,
+                )
+                valid_until_ms = self.clock.time_msec() + self.access_token_lifetime
+
             access_token = await self._auth_handler.get_access_token_for_user_id(
                 user_id,
                 device_id=registered_device_id,
                 valid_until_ms=valid_until_ms,
                 is_appservice_ghost=is_appservice_ghost,
+                refresh_token_id=refresh_token_id,
             )
 
-        return {"device_id": registered_device_id, "access_token": access_token}
+        return {
+            "device_id": registered_device_id,
+            "access_token": access_token,
+            "valid_until_ms": valid_until_ms,
+            "refresh_token": refresh_token,
+        }
 
     async def post_registration_actions(
         self, user_id: str, auth_result: dict, access_token: Optional[str]