diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index f6be5f1020..cbcb60fe31 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,7 +14,9 @@
import logging
import re
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
+
+from typing_extensions import TypedDict
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@@ -25,6 +27,8 @@ from synapse.http import get_request_uri
from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import (
RestServlet,
+ assert_params_in_dict,
+ parse_boolean,
parse_bytes_from_args,
parse_json_object_from_request,
parse_string,
@@ -40,6 +44,21 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+LoginResponse = TypedDict(
+ "LoginResponse",
+ {
+ "user_id": str,
+ "access_token": str,
+ "home_server": str,
+ "expires_in_ms": Optional[int],
+ "refresh_token": Optional[str],
+ "device_id": str,
+ "well_known": Optional[Dict[str, Any]],
+ },
+ total=False,
+)
+
+
class LoginRestServlet(RestServlet):
PATTERNS = client_patterns("/login$", v1=True)
CAS_TYPE = "m.login.cas"
@@ -48,6 +67,7 @@ class LoginRestServlet(RestServlet):
JWT_TYPE = "org.matrix.login.jwt"
JWT_TYPE_DEPRECATED = "m.login.jwt"
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
+ REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token"
def __init__(self, hs: "HomeServer"):
super().__init__()
@@ -65,9 +85,12 @@ class LoginRestServlet(RestServlet):
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+ self._msc2918_enabled = hs.config.access_token_lifetime is not None
self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self._sso_handler = hs.get_sso_handler()
@@ -138,6 +161,15 @@ class LoginRestServlet(RestServlet):
async def on_POST(self, request: SynapseRequest):
login_submission = parse_json_object_from_request(request)
+ if self._msc2918_enabled:
+ # Check if this login should also issue a refresh token, as per
+ # MSC2918
+ should_issue_refresh_token = parse_boolean(
+ request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False
+ )
+ else:
+ should_issue_refresh_token = False
+
try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
appservice = self.auth.get_appservice_by_req(request)
@@ -147,19 +179,32 @@ class LoginRestServlet(RestServlet):
None, request.getClientIP()
)
- result = await self._do_appservice_login(login_submission, appservice)
+ result = await self._do_appservice_login(
+ login_submission,
+ appservice,
+ should_issue_refresh_token=should_issue_refresh_token,
+ )
elif self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
- result = await self._do_jwt_login(login_submission)
+ result = await self._do_jwt_login(
+ login_submission,
+ should_issue_refresh_token=should_issue_refresh_token,
+ )
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
- result = await self._do_token_login(login_submission)
+ result = await self._do_token_login(
+ login_submission,
+ should_issue_refresh_token=should_issue_refresh_token,
+ )
else:
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
- result = await self._do_other_login(login_submission)
+ result = await self._do_other_login(
+ login_submission,
+ should_issue_refresh_token=should_issue_refresh_token,
+ )
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@@ -169,7 +214,10 @@ class LoginRestServlet(RestServlet):
return 200, result
async def _do_appservice_login(
- self, login_submission: JsonDict, appservice: ApplicationService
+ self,
+ login_submission: JsonDict,
+ appservice: ApplicationService,
+ should_issue_refresh_token: bool = False,
):
identifier = login_submission.get("identifier")
logger.info("Got appservice login request with identifier: %r", identifier)
@@ -198,14 +246,21 @@ class LoginRestServlet(RestServlet):
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
return await self._complete_login(
- qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
+ qualified_user_id,
+ login_submission,
+ ratelimit=appservice.is_rate_limited(),
+ should_issue_refresh_token=should_issue_refresh_token,
)
- async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
+ async def _do_other_login(
+ self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+ ) -> LoginResponse:
"""Handle non-token/saml/jwt logins
Args:
login_submission:
+ should_issue_refresh_token: True if this login should issue
+ a refresh token alongside the access token.
Returns:
HTTP response
@@ -224,7 +279,10 @@ class LoginRestServlet(RestServlet):
login_submission, ratelimit=True
)
result = await self._complete_login(
- canonical_user_id, login_submission, callback
+ canonical_user_id,
+ login_submission,
+ callback,
+ should_issue_refresh_token=should_issue_refresh_token,
)
return result
@@ -232,11 +290,12 @@ class LoginRestServlet(RestServlet):
self,
user_id: str,
login_submission: JsonDict,
- callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
+ callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
ratelimit: bool = True,
auth_provider_id: Optional[str] = None,
- ) -> Dict[str, str]:
+ should_issue_refresh_token: bool = False,
+ ) -> LoginResponse:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
all successful logins.
@@ -253,6 +312,8 @@ class LoginRestServlet(RestServlet):
ratelimit: Whether to ratelimit the login request.
auth_provider_id: The SSO IdP the user used, if any (just used for the
prometheus metrics).
+ should_issue_refresh_token: True if this login should issue
+ a refresh token alongside the access token.
Returns:
result: Dictionary of account information after successful login.
@@ -274,28 +335,48 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
- device_id, access_token = await self.registration_handler.register_device(
- user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
+ (
+ device_id,
+ access_token,
+ valid_until_ms,
+ refresh_token,
+ ) = await self.registration_handler.register_device(
+ user_id,
+ device_id,
+ initial_display_name,
+ auth_provider_id=auth_provider_id,
+ should_issue_refresh_token=should_issue_refresh_token,
)
- result = {
- "user_id": user_id,
- "access_token": access_token,
- "home_server": self.hs.hostname,
- "device_id": device_id,
- }
+ result = LoginResponse(
+ user_id=user_id,
+ access_token=access_token,
+ home_server=self.hs.hostname,
+ device_id=device_id,
+ )
+
+ if valid_until_ms is not None:
+ expires_in_ms = valid_until_ms - self.clock.time_msec()
+ result["expires_in_ms"] = expires_in_ms
+
+ if refresh_token is not None:
+ result["refresh_token"] = refresh_token
if callback is not None:
await callback(result)
return result
- async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
+ async def _do_token_login(
+ self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+ ) -> LoginResponse:
"""
Handle the final stage of SSO login.
Args:
- login_submission: The JSON request body.
+ login_submission: The JSON request body.
+ should_issue_refresh_token: True if this login should issue
+ a refresh token alongside the access token.
Returns:
The body of the JSON response.
@@ -309,9 +390,12 @@ class LoginRestServlet(RestServlet):
login_submission,
self.auth_handler._sso_login_callback,
auth_provider_id=res.auth_provider_id,
+ should_issue_refresh_token=should_issue_refresh_token,
)
- async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
+ async def _do_jwt_login(
+ self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+ ) -> LoginResponse:
token = login_submission.get("token", None)
if token is None:
raise LoginError(
@@ -342,7 +426,10 @@ class LoginRestServlet(RestServlet):
user_id = UserID(user, self.hs.hostname).to_string()
result = await self._complete_login(
- user_id, login_submission, create_non_existent_users=True
+ user_id,
+ login_submission,
+ create_non_existent_users=True,
+ should_issue_refresh_token=should_issue_refresh_token,
)
return result
@@ -371,6 +458,42 @@ def _get_auth_flow_dict_for_idp(
return e
+class RefreshTokenServlet(RestServlet):
+ PATTERNS = client_patterns(
+ "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ self._auth_handler = hs.get_auth_handler()
+ self._clock = hs.get_clock()
+ self.access_token_lifetime = hs.config.access_token_lifetime
+
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ ):
+ refresh_submission = parse_json_object_from_request(request)
+
+ assert_params_in_dict(refresh_submission, ["refresh_token"])
+ token = refresh_submission["refresh_token"]
+ if not isinstance(token, str):
+ raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
+
+ valid_until_ms = self._clock.time_msec() + self.access_token_lifetime
+ access_token, refresh_token = await self._auth_handler.refresh_token(
+ token, valid_until_ms
+ )
+ expires_in_ms = valid_until_ms - self._clock.time_msec()
+ return (
+ 200,
+ {
+ "access_token": access_token,
+ "refresh_token": refresh_token,
+ "expires_in_ms": expires_in_ms,
+ },
+ )
+
+
class SsoRedirectServlet(RestServlet):
PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
re.compile(
@@ -477,6 +600,8 @@ class CasTicketServlet(RestServlet):
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
+ if hs.config.access_token_lifetime is not None:
+ RefreshTokenServlet(hs).register(http_server)
SsoRedirectServlet(hs).register(http_server)
if hs.config.cas_enabled:
CasTicketServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index a30a5df1b1..4d31584acd 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -41,11 +41,13 @@ from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
+ parse_boolean,
parse_json_object_from_request,
parse_string,
)
from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
+from synapse.types import JsonDict
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -399,6 +401,7 @@ class RegisterRestServlet(RestServlet):
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
self._registration_enabled = self.hs.config.enable_registration
+ self._msc2918_enabled = hs.config.access_token_lifetime is not None
self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler
@@ -424,6 +427,15 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind.decode("utf8"),)
)
+ if self._msc2918_enabled:
+ # Check if this registration should also issue a refresh token, as
+ # per MSC2918
+ should_issue_refresh_token = parse_boolean(
+ request, name="org.matrix.msc2918.refresh_token", default=False
+ )
+ else:
+ should_issue_refresh_token = False
+
# Pull out the provided username and do basic sanity checks early since
# the auth layer will store these in sessions.
desired_username = None
@@ -462,7 +474,10 @@ class RegisterRestServlet(RestServlet):
raise SynapseError(400, "Desired Username is missing or not a string")
result = await self._do_appservice_registration(
- desired_username, access_token, body
+ desired_username,
+ access_token,
+ body,
+ should_issue_refresh_token=should_issue_refresh_token,
)
return 200, result
@@ -665,7 +680,9 @@ class RegisterRestServlet(RestServlet):
registered = True
return_dict = await self._create_registration_details(
- registered_user_id, params
+ registered_user_id,
+ params,
+ should_issue_refresh_token=should_issue_refresh_token,
)
if registered:
@@ -677,7 +694,9 @@ class RegisterRestServlet(RestServlet):
return 200, return_dict
- async def _do_appservice_registration(self, username, as_token, body):
+ async def _do_appservice_registration(
+ self, username, as_token, body, should_issue_refresh_token: bool = False
+ ):
user_id = await self.registration_handler.appservice_register(
username, as_token
)
@@ -685,19 +704,27 @@ class RegisterRestServlet(RestServlet):
user_id,
body,
is_appservice_ghost=True,
+ should_issue_refresh_token=should_issue_refresh_token,
)
async def _create_registration_details(
- self, user_id, params, is_appservice_ghost=False
+ self,
+ user_id: str,
+ params: JsonDict,
+ is_appservice_ghost: bool = False,
+ should_issue_refresh_token: bool = False,
):
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token.
Args:
- (str) user_id: full canonical @user:id
- (object) params: registration parameters, from which we pull
- device_id, initial_device_name and inhibit_login
+ user_id: full canonical @user:id
+ params: registration parameters, from which we pull device_id,
+ initial_device_name and inhibit_login
+ is_appservice_ghost
+ should_issue_refresh_token: True if this registration should issue
+ a refresh token alongside the access token.
Returns:
dictionary for response from /register
"""
@@ -705,15 +732,29 @@ class RegisterRestServlet(RestServlet):
if not params.get("inhibit_login", False):
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
- device_id, access_token = await self.registration_handler.register_device(
+ (
+ device_id,
+ access_token,
+ valid_until_ms,
+ refresh_token,
+ ) = await self.registration_handler.register_device(
user_id,
device_id,
initial_display_name,
is_guest=False,
is_appservice_ghost=is_appservice_ghost,
+ should_issue_refresh_token=should_issue_refresh_token,
)
result.update({"access_token": access_token, "device_id": device_id})
+
+ if valid_until_ms is not None:
+ expires_in_ms = valid_until_ms - self.clock.time_msec()
+ result["expires_in_ms"] = expires_in_ms
+
+ if refresh_token is not None:
+ result["refresh_token"] = refresh_token
+
return result
async def _do_guest_registration(self, params, address=None):
@@ -727,19 +768,30 @@ class RegisterRestServlet(RestServlet):
# we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name")
- device_id, access_token = await self.registration_handler.register_device(
+ (
+ device_id,
+ access_token,
+ valid_until_ms,
+ refresh_token,
+ ) = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest=True
)
- return (
- 200,
- {
- "user_id": user_id,
- "device_id": device_id,
- "access_token": access_token,
- "home_server": self.hs.hostname,
- },
- )
+ result = {
+ "user_id": user_id,
+ "device_id": device_id,
+ "access_token": access_token,
+ "home_server": self.hs.hostname,
+ }
+
+ if valid_until_ms is not None:
+ expires_in_ms = valid_until_ms - self.clock.time_msec()
+ result["expires_in_ms"] = expires_in_ms
+
+ if refresh_token is not None:
+ result["refresh_token"] = refresh_token
+
+ return 200, result
def _calculate_registration_flows(
|