diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index c16e8d43f4..cf23430f6a 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -186,3 +186,31 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertGreater(len(details["support"]), 0)
for room_version in details["support"]:
self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version))
+
+ def test_get_get_token_login_fields_when_disabled(self) -> None:
+ """By default login via an existing session is disabled."""
+ access_token = self.get_success(
+ self.auth_handler.create_access_token_for_user_id(
+ self.user, device_id=None, valid_until_ms=None
+ )
+ )
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertFalse(capabilities["m.get_login_token"]["enabled"])
+
+ @override_config({"login_via_existing_session": {"enabled": True}})
+ def test_get_get_token_login_fields_when_enabled(self) -> None:
+ access_token = self.get_success(
+ self.auth_handler.create_access_token_for_user_id(
+ self.user, device_id=None, valid_until_ms=None
+ )
+ )
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertTrue(capabilities["m.get_login_token"]["enabled"])
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index dc32982e22..f3c3bc69a9 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -446,6 +446,29 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
)
+ def test_get_login_flows_with_login_via_existing_disabled(self) -> None:
+ """GET /login should return m.login.token without get_login_token"""
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+
+ flows = {flow["type"]: flow for flow in channel.json_body["flows"]}
+ self.assertNotIn("m.login.token", flows)
+
+ @override_config({"login_via_existing_session": {"enabled": True}})
+ def test_get_login_flows_with_login_via_existing_enabled(self) -> None:
+ """GET /login should return m.login.token with get_login_token true"""
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.assertCountEqual(
+ channel.json_body["flows"],
+ [
+ {"type": "m.login.token", "get_login_token": True},
+ {"type": "m.login.password"},
+ {"type": "m.login.application_service"},
+ ],
+ )
+
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
index b8187db982..f05e619aa8 100644
--- a/tests/rest/client/test_login_token_request.py
+++ b/tests/rest/client/test_login_token_request.py
@@ -15,14 +15,14 @@
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin
-from synapse.rest.client import login, login_token_request
+from synapse.rest.client import login, login_token_request, versions
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.unittest import override_config
-endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token"
+GET_TOKEN_ENDPOINT = "/_matrix/client/v1/login/get_token"
class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
@@ -30,6 +30,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
login.register_servlets,
admin.register_servlets,
login_token_request.register_servlets,
+ versions.register_servlets, # TODO: remove once unstable revision 0 support is removed
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@@ -46,26 +47,26 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
self.password = "password"
def test_disabled(self) -> None:
- channel = self.make_request("POST", endpoint, {}, access_token=None)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=None)
self.assertEqual(channel.code, 404)
self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
- channel = self.make_request("POST", endpoint, {}, access_token=token)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
self.assertEqual(channel.code, 404)
- @override_config({"experimental_features": {"msc3882_enabled": True}})
+ @override_config({"login_via_existing_session": {"enabled": True}})
def test_require_auth(self) -> None:
- channel = self.make_request("POST", endpoint, {}, access_token=None)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=None)
self.assertEqual(channel.code, 401)
- @override_config({"experimental_features": {"msc3882_enabled": True}})
+ @override_config({"login_via_existing_session": {"enabled": True}})
def test_uia_on(self) -> None:
user_id = self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
- channel = self.make_request("POST", endpoint, {}, access_token=token)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
self.assertEqual(channel.code, 401)
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
@@ -80,9 +81,9 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
},
}
- channel = self.make_request("POST", endpoint, uia, access_token=token)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, uia, access_token=token)
self.assertEqual(channel.code, 200)
- self.assertEqual(channel.json_body["expires_in"], 300)
+ self.assertEqual(channel.json_body["expires_in_ms"], 300000)
login_token = channel.json_body["login_token"]
@@ -95,15 +96,15 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["user_id"], user_id)
@override_config(
- {"experimental_features": {"msc3882_enabled": True, "msc3882_ui_auth": False}}
+ {"login_via_existing_session": {"enabled": True, "require_ui_auth": False}}
)
def test_uia_off(self) -> None:
user_id = self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
- channel = self.make_request("POST", endpoint, {}, access_token=token)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
self.assertEqual(channel.code, 200)
- self.assertEqual(channel.json_body["expires_in"], 300)
+ self.assertEqual(channel.json_body["expires_in_ms"], 300000)
login_token = channel.json_body["login_token"]
@@ -117,10 +118,10 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
@override_config(
{
- "experimental_features": {
- "msc3882_enabled": True,
- "msc3882_ui_auth": False,
- "msc3882_token_timeout": "15s",
+ "login_via_existing_session": {
+ "enabled": True,
+ "require_ui_auth": False,
+ "token_timeout": "15s",
}
}
)
@@ -128,6 +129,40 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
- channel = self.make_request("POST", endpoint, {}, access_token=token)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["expires_in_ms"], 15000)
+
+ @override_config(
+ {
+ "login_via_existing_session": {
+ "enabled": True,
+ "require_ui_auth": False,
+ "token_timeout": "15s",
+ }
+ }
+ )
+ def test_unstable_support(self) -> None:
+ # TODO: remove support for unstable MSC3882 is no longer needed
+
+ # check feature is advertised in versions response:
+ channel = self.make_request(
+ "GET", "/_matrix/client/versions", {}, access_token=None
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body["unstable_features"]["org.matrix.msc3882"], True
+ )
+
+ self.register_user(self.user, self.password)
+ token = self.login(self.user, self.password)
+
+ # check feature is available via the unstable endpoint and returns an expires_in value in seconds
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/unstable/org.matrix.msc3882/login/token",
+ {},
+ access_token=token,
+ )
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["expires_in"], 15)
|