diff options
Diffstat (limited to 'synapse/handlers/register.py')
-rw-r--r-- | synapse/handlers/register.py | 52 |
1 files changed, 46 insertions, 6 deletions
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 4b4b579741..26ef016179 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -15,9 +15,10 @@ """Contains functions for registering clients.""" import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from prometheus_client import Counter +from typing_extensions import TypedDict from synapse import types from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType @@ -54,6 +55,16 @@ login_counter = Counter( ["guest", "auth_provider"], ) +LoginDict = TypedDict( + "LoginDict", + { + "device_id": str, + "access_token": str, + "valid_until_ms": Optional[int], + "refresh_token": Optional[str], + }, +) + class RegistrationHandler(BaseHandler): def __init__(self, hs: "HomeServer"): @@ -85,6 +96,7 @@ class RegistrationHandler(BaseHandler): self.pusher_pool = hs.get_pusherpool() self.session_lifetime = hs.config.session_lifetime + self.access_token_lifetime = hs.config.access_token_lifetime async def check_username( self, @@ -696,7 +708,8 @@ class RegistrationHandler(BaseHandler): is_guest: bool = False, is_appservice_ghost: bool = False, auth_provider_id: Optional[str] = None, - ) -> Tuple[str, str]: + should_issue_refresh_token: bool = False, + ) -> Tuple[str, str, Optional[int], Optional[str]]: """Register a device for a user and generate an access token. The access token will be limited by the homeserver's session_lifetime config. @@ -708,8 +721,9 @@ class RegistrationHandler(BaseHandler): is_guest: Whether this is a guest account auth_provider_id: The SSO IdP the user used, if any (just used for the prometheus metrics). + should_issue_refresh_token: Whether it should also issue a refresh token Returns: - Tuple of device ID and access token + Tuple of device ID, access token, access token expiration time and refresh token """ res = await self._register_device_client( user_id=user_id, @@ -717,6 +731,7 @@ class RegistrationHandler(BaseHandler): initial_display_name=initial_display_name, is_guest=is_guest, is_appservice_ghost=is_appservice_ghost, + should_issue_refresh_token=should_issue_refresh_token, ) login_counter.labels( @@ -724,7 +739,12 @@ class RegistrationHandler(BaseHandler): auth_provider=(auth_provider_id or ""), ).inc() - return res["device_id"], res["access_token"] + return ( + res["device_id"], + res["access_token"], + res["valid_until_ms"], + res["refresh_token"], + ) async def register_device_inner( self, @@ -733,7 +753,8 @@ class RegistrationHandler(BaseHandler): initial_display_name: Optional[str], is_guest: bool = False, is_appservice_ghost: bool = False, - ) -> Dict[str, str]: + should_issue_refresh_token: bool = False, + ) -> LoginDict: """Helper for register_device Does the bits that need doing on the main process. Not for use outside this @@ -748,6 +769,9 @@ class RegistrationHandler(BaseHandler): ) valid_until_ms = self.clock.time_msec() + self.session_lifetime + refresh_token = None + refresh_token_id = None + registered_device_id = await self.device_handler.check_device_registered( user_id, device_id, initial_display_name ) @@ -755,14 +779,30 @@ class RegistrationHandler(BaseHandler): assert valid_until_ms is None access_token = self.macaroon_gen.generate_guest_access_token(user_id) else: + if should_issue_refresh_token: + ( + refresh_token, + refresh_token_id, + ) = await self._auth_handler.get_refresh_token_for_user_id( + user_id, + device_id=registered_device_id, + ) + valid_until_ms = self.clock.time_msec() + self.access_token_lifetime + access_token = await self._auth_handler.get_access_token_for_user_id( user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms, is_appservice_ghost=is_appservice_ghost, + refresh_token_id=refresh_token_id, ) - return {"device_id": registered_device_id, "access_token": access_token} + return { + "device_id": registered_device_id, + "access_token": access_token, + "valid_until_ms": valid_until_ms, + "refresh_token": refresh_token, + } async def post_registration_actions( self, user_id: str, auth_result: dict, access_token: Optional[str] |