diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py
index e84dde31b1..11fc0b0678 100644
--- a/synapse/rest/client/capabilities.py
+++ b/synapse/rest/client/capabilities.py
@@ -82,6 +82,11 @@ class CapabilitiesRestServlet(RestServlet):
"enabled": self.config.experimental.msc3664_enabled,
}
+ if self.config.experimental.msc3882_enabled:
+ response["capabilities"]["org.matrix.msc3882.get_login_token"] = {
+ "enabled": True,
+ }
+
return HTTPStatus.OK, response
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index b7e9c8f6b5..896cf2cdbe 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -107,6 +107,9 @@ class LoginRestServlet(RestServlet):
and hs.config.experimental.msc3866.require_approval_for_new_accounts
)
+ # Whether MSC3882 get login token is enabled.
+ self._get_login_token_enabled = hs.config.experimental.msc3882_enabled
+
self.auth = hs.get_auth()
self.clock = hs.get_clock()
@@ -145,7 +148,12 @@ class LoginRestServlet(RestServlet):
# to SSO.
flows.append({"type": LoginRestServlet.CAS_TYPE})
- if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
+ if (
+ self.cas_enabled
+ or self.saml2_enabled
+ or self.oidc_enabled
+ or self._get_login_token_enabled
+ ):
flows.append(
{
"type": LoginRestServlet.SSO_TYPE,
@@ -163,7 +171,11 @@ class LoginRestServlet(RestServlet):
# don't know how to implement, since they (currently) will always
# fall back to the fallback API if they don't understand one of the
# login flow types returned.
- flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+ tokenTypeFlow: Dict[str, Any] = {"type": LoginRestServlet.TOKEN_TYPE}
+ # If MSC3882 is enabled we advertise the get_login_token flag.
+ if self._get_login_token_enabled:
+ tokenTypeFlow["org.matrix.msc3882.get_login_token"] = True
+ flows.append(tokenTypeFlow)
flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types())
diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py
index 43ea21d5e6..2d8726ac4c 100644
--- a/synapse/rest/client/login_token_request.py
+++ b/synapse/rest/client/login_token_request.py
@@ -33,7 +33,7 @@ class LoginTokenRequestServlet(RestServlet):
Request:
- POST /login/token HTTP/1.1
+ POST /login/get_token HTTP/1.1
Content-Type: application/json
{}
@@ -43,12 +43,12 @@ class LoginTokenRequestServlet(RestServlet):
HTTP/1.1 200 OK
{
"login_token": "ABDEFGH",
- "expires_in": 3600,
+ "expires_in_ms": 3600000,
}
"""
PATTERNS = client_patterns(
- "/org.matrix.msc3882/login/token$", releases=[], v1=False, unstable=True
+ "/org.matrix.msc3882/login/get_token$", releases=[], v1=False, unstable=True
)
def __init__(self, hs: "HomeServer"):
@@ -77,7 +77,7 @@ class LoginTokenRequestServlet(RestServlet):
login_token = await self.auth_handler.create_login_token_for_user_id(
user_id=requester.user.to_string(),
- auth_provider_id="org.matrix.msc3882.login_token_request",
+ auth_provider_id="org.matrix.msc3882.get_login_token",
duration_ms=self.token_timeout,
)
@@ -85,7 +85,7 @@ class LoginTokenRequestServlet(RestServlet):
200,
{
"login_token": login_token,
- "expires_in": self.token_timeout // 1000,
+ "expires_in_ms": self.token_timeout,
},
)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 59aed66464..ecd84f435f 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -112,8 +112,6 @@ class VersionsRestServlet(RestServlet):
"fi.mau.msc2815": self.config.experimental.msc2815_enabled,
# Adds a ping endpoint for appservices to check HS->AS connection
"fi.mau.msc2659": self.config.experimental.msc2659_enabled,
- # Adds support for login token requests as per MSC3882
- "org.matrix.msc3882": self.config.experimental.msc3882_enabled,
# Adds support for remotely enabling/disabling pushers, as per MSC3881
"org.matrix.msc3881": self.config.experimental.msc3881_enabled,
# Adds support for filtering /messages by event relation.
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index c16e8d43f4..bc33854b22 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -186,3 +186,33 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertGreater(len(details["support"]), 0)
for room_version in details["support"]:
self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version))
+
+ def test_get_does_not_include_msc3882_fields_when_disabled(self) -> None:
+ access_token = self.get_success(
+ self.auth_handler.create_access_token_for_user_id(
+ self.user, device_id=None, valid_until_ms=None
+ )
+ )
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertTrue(
+ "org.matrix.msc3882.get_login_token" not in capabilities
+ or not capabilities["org.matrix.msc3882.get_login_token"]["enabled"]
+ )
+
+ @override_config({"experimental_features": {"msc3882_enabled": True}})
+ def test_get_does_include_msc3882_fields_when_enabled(self) -> None:
+ access_token = self.get_success(
+ self.auth_handler.create_access_token_for_user_id(
+ self.user, device_id=None, valid_until_ms=None
+ )
+ )
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertTrue(capabilities["org.matrix.msc3882.get_login_token"]["enabled"])
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 62acf4f44e..69b4638900 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -446,6 +446,29 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
)
+ def test_get_login_flows_with_msc3882_disabled(self) -> None:
+ """GET /login should return m.login.token without get_login_token true"""
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+
+ flows = {flow["type"]: flow for flow in channel.json_body["flows"]}
+ self.assertTrue(
+ "m.login.token" not in flows
+ or "org.matrix.msc3882.get_login_token" not in flows["m.login.token"]
+ or not flows["m.login.token"]["org.matrix.msc3882.get_login_token"]
+ )
+
+ @override_config({"experimental_features": {"msc3882_enabled": True}})
+ def test_get_login_flows_with_msc3882_enabled(self) -> None:
+ """GET /login should return m.login.token without get_login_token true"""
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+
+ print(channel.json_body)
+
+ flows = {flow["type"]: flow for flow in channel.json_body["flows"]}
+ self.assertTrue(flows["m.login.token"]["org.matrix.msc3882.get_login_token"])
+
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
index b8187db982..cdf4134cbe 100644
--- a/tests/rest/client/test_login_token_request.py
+++ b/tests/rest/client/test_login_token_request.py
@@ -22,7 +22,7 @@ from synapse.util import Clock
from tests import unittest
from tests.unittest import override_config
-endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token"
+endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/get_token"
class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
@@ -82,7 +82,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", endpoint, uia, access_token=token)
self.assertEqual(channel.code, 200)
- self.assertEqual(channel.json_body["expires_in"], 300)
+ self.assertEqual(channel.json_body["expires_in_ms"], 300000)
login_token = channel.json_body["login_token"]
@@ -103,7 +103,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", endpoint, {}, access_token=token)
self.assertEqual(channel.code, 200)
- self.assertEqual(channel.json_body["expires_in"], 300)
+ self.assertEqual(channel.json_body["expires_in_ms"], 300000)
login_token = channel.json_body["login_token"]
@@ -130,4 +130,4 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", endpoint, {}, access_token=token)
self.assertEqual(channel.code, 200)
- self.assertEqual(channel.json_body["expires_in"], 15)
+ self.assertEqual(channel.json_body["expires_in_ms"], 15000)
|