diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 4b4b579741..26ef016179 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,9 +15,10 @@
"""Contains functions for registering clients."""
import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from prometheus_client import Counter
+from typing_extensions import TypedDict
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
@@ -54,6 +55,16 @@ login_counter = Counter(
["guest", "auth_provider"],
)
+LoginDict = TypedDict(
+ "LoginDict",
+ {
+ "device_id": str,
+ "access_token": str,
+ "valid_until_ms": Optional[int],
+ "refresh_token": Optional[str],
+ },
+)
+
class RegistrationHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
@@ -85,6 +96,7 @@ class RegistrationHandler(BaseHandler):
self.pusher_pool = hs.get_pusherpool()
self.session_lifetime = hs.config.session_lifetime
+ self.access_token_lifetime = hs.config.access_token_lifetime
async def check_username(
self,
@@ -696,7 +708,8 @@ class RegistrationHandler(BaseHandler):
is_guest: bool = False,
is_appservice_ghost: bool = False,
auth_provider_id: Optional[str] = None,
- ) -> Tuple[str, str]:
+ should_issue_refresh_token: bool = False,
+ ) -> Tuple[str, str, Optional[int], Optional[str]]:
"""Register a device for a user and generate an access token.
The access token will be limited by the homeserver's session_lifetime config.
@@ -708,8 +721,9 @@ class RegistrationHandler(BaseHandler):
is_guest: Whether this is a guest account
auth_provider_id: The SSO IdP the user used, if any (just used for the
prometheus metrics).
+ should_issue_refresh_token: Whether it should also issue a refresh token
Returns:
- Tuple of device ID and access token
+ Tuple of device ID, access token, access token expiration time and refresh token
"""
res = await self._register_device_client(
user_id=user_id,
@@ -717,6 +731,7 @@ class RegistrationHandler(BaseHandler):
initial_display_name=initial_display_name,
is_guest=is_guest,
is_appservice_ghost=is_appservice_ghost,
+ should_issue_refresh_token=should_issue_refresh_token,
)
login_counter.labels(
@@ -724,7 +739,12 @@ class RegistrationHandler(BaseHandler):
auth_provider=(auth_provider_id or ""),
).inc()
- return res["device_id"], res["access_token"]
+ return (
+ res["device_id"],
+ res["access_token"],
+ res["valid_until_ms"],
+ res["refresh_token"],
+ )
async def register_device_inner(
self,
@@ -733,7 +753,8 @@ class RegistrationHandler(BaseHandler):
initial_display_name: Optional[str],
is_guest: bool = False,
is_appservice_ghost: bool = False,
- ) -> Dict[str, str]:
+ should_issue_refresh_token: bool = False,
+ ) -> LoginDict:
"""Helper for register_device
Does the bits that need doing on the main process. Not for use outside this
@@ -748,6 +769,9 @@ class RegistrationHandler(BaseHandler):
)
valid_until_ms = self.clock.time_msec() + self.session_lifetime
+ refresh_token = None
+ refresh_token_id = None
+
registered_device_id = await self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
@@ -755,14 +779,30 @@ class RegistrationHandler(BaseHandler):
assert valid_until_ms is None
access_token = self.macaroon_gen.generate_guest_access_token(user_id)
else:
+ if should_issue_refresh_token:
+ (
+ refresh_token,
+ refresh_token_id,
+ ) = await self._auth_handler.get_refresh_token_for_user_id(
+ user_id,
+ device_id=registered_device_id,
+ )
+ valid_until_ms = self.clock.time_msec() + self.access_token_lifetime
+
access_token = await self._auth_handler.get_access_token_for_user_id(
user_id,
device_id=registered_device_id,
valid_until_ms=valid_until_ms,
is_appservice_ghost=is_appservice_ghost,
+ refresh_token_id=refresh_token_id,
)
- return {"device_id": registered_device_id, "access_token": access_token}
+ return {
+ "device_id": registered_device_id,
+ "access_token": access_token,
+ "valid_until_ms": valid_until_ms,
+ "refresh_token": refresh_token,
+ }
async def post_registration_actions(
self, user_id: str, auth_result: dict, access_token: Optional[str]
|