summary refs log tree commit diff
path: root/tests/rest/client
diff options
context:
space:
mode:
authorHugh Nimmo-Smith <hughns@users.noreply.github.com>2023-06-01 13:52:51 +0100
committerGitHub <noreply@github.com>2023-06-01 08:52:51 -0400
commitd1693f03626391097b59ea9568cd8a869ed89569 (patch)
treea88e675174b8ba030b231f7661e59d44e61e0654 /tests/rest/client
parentAdd Synapse version deploy annotations to Grafana dashboard (#15674) (diff)
downloadsynapse-d1693f03626391097b59ea9568cd8a869ed89569.tar.xz
Implement stable support for MSC3882 to allow an existing device/session to generate a login token for use on a new device/session (#15388)
Implements stable support for MSC3882; this involves updating Synapse's support to
match the MSC / the spec says.

Continue to support the unstable version to allow clients to transition.
Diffstat (limited to 'tests/rest/client')
-rw-r--r--tests/rest/client/test_capabilities.py28
-rw-r--r--tests/rest/client/test_login.py23
-rw-r--r--tests/rest/client/test_login_token_request.py71
3 files changed, 104 insertions, 18 deletions
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index c16e8d43f4..cf23430f6a 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -186,3 +186,31 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
             self.assertGreater(len(details["support"]), 0)
             for room_version in details["support"]:
                 self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version))
+
+    def test_get_get_token_login_fields_when_disabled(self) -> None:
+        """By default login via an existing session is disabled."""
+        access_token = self.get_success(
+            self.auth_handler.create_access_token_for_user_id(
+                self.user, device_id=None, valid_until_ms=None
+            )
+        )
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, HTTPStatus.OK)
+        self.assertFalse(capabilities["m.get_login_token"]["enabled"])
+
+    @override_config({"login_via_existing_session": {"enabled": True}})
+    def test_get_get_token_login_fields_when_enabled(self) -> None:
+        access_token = self.get_success(
+            self.auth_handler.create_access_token_for_user_id(
+                self.user, device_id=None, valid_until_ms=None
+            )
+        )
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, HTTPStatus.OK)
+        self.assertTrue(capabilities["m.get_login_token"]["enabled"])
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index dc32982e22..f3c3bc69a9 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -446,6 +446,29 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
         )
 
+    def test_get_login_flows_with_login_via_existing_disabled(self) -> None:
+        """GET /login should return m.login.token without get_login_token"""
+        channel = self.make_request("GET", "/_matrix/client/r0/login")
+        self.assertEqual(channel.code, 200, channel.result)
+
+        flows = {flow["type"]: flow for flow in channel.json_body["flows"]}
+        self.assertNotIn("m.login.token", flows)
+
+    @override_config({"login_via_existing_session": {"enabled": True}})
+    def test_get_login_flows_with_login_via_existing_enabled(self) -> None:
+        """GET /login should return m.login.token with get_login_token true"""
+        channel = self.make_request("GET", "/_matrix/client/r0/login")
+        self.assertEqual(channel.code, 200, channel.result)
+
+        self.assertCountEqual(
+            channel.json_body["flows"],
+            [
+                {"type": "m.login.token", "get_login_token": True},
+                {"type": "m.login.password"},
+                {"type": "m.login.application_service"},
+            ],
+        )
+
 
 @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
 class MultiSSOTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
index b8187db982..f05e619aa8 100644
--- a/tests/rest/client/test_login_token_request.py
+++ b/tests/rest/client/test_login_token_request.py
@@ -15,14 +15,14 @@
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.rest import admin
-from synapse.rest.client import login, login_token_request
+from synapse.rest.client import login, login_token_request, versions
 from synapse.server import HomeServer
 from synapse.util import Clock
 
 from tests import unittest
 from tests.unittest import override_config
 
-endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token"
+GET_TOKEN_ENDPOINT = "/_matrix/client/v1/login/get_token"
 
 
 class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
@@ -30,6 +30,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
         admin.register_servlets,
         login_token_request.register_servlets,
+        versions.register_servlets,  # TODO: remove once unstable revision 0 support is removed
     ]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@@ -46,26 +47,26 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
         self.password = "password"
 
     def test_disabled(self) -> None:
-        channel = self.make_request("POST", endpoint, {}, access_token=None)
+        channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=None)
         self.assertEqual(channel.code, 404)
 
         self.register_user(self.user, self.password)
         token = self.login(self.user, self.password)
 
-        channel = self.make_request("POST", endpoint, {}, access_token=token)
+        channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
         self.assertEqual(channel.code, 404)
 
-    @override_config({"experimental_features": {"msc3882_enabled": True}})
+    @override_config({"login_via_existing_session": {"enabled": True}})
     def test_require_auth(self) -> None:
-        channel = self.make_request("POST", endpoint, {}, access_token=None)
+        channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=None)
         self.assertEqual(channel.code, 401)
 
-    @override_config({"experimental_features": {"msc3882_enabled": True}})
+    @override_config({"login_via_existing_session": {"enabled": True}})
     def test_uia_on(self) -> None:
         user_id = self.register_user(self.user, self.password)
         token = self.login(self.user, self.password)
 
-        channel = self.make_request("POST", endpoint, {}, access_token=token)
+        channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
         self.assertEqual(channel.code, 401)
         self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
 
@@ -80,9 +81,9 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
             },
         }
 
-        channel = self.make_request("POST", endpoint, uia, access_token=token)
+        channel = self.make_request("POST", GET_TOKEN_ENDPOINT, uia, access_token=token)
         self.assertEqual(channel.code, 200)
-        self.assertEqual(channel.json_body["expires_in"], 300)
+        self.assertEqual(channel.json_body["expires_in_ms"], 300000)
 
         login_token = channel.json_body["login_token"]
 
@@ -95,15 +96,15 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["user_id"], user_id)
 
     @override_config(
-        {"experimental_features": {"msc3882_enabled": True, "msc3882_ui_auth": False}}
+        {"login_via_existing_session": {"enabled": True, "require_ui_auth": False}}
     )
     def test_uia_off(self) -> None:
         user_id = self.register_user(self.user, self.password)
         token = self.login(self.user, self.password)
 
-        channel = self.make_request("POST", endpoint, {}, access_token=token)
+        channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
         self.assertEqual(channel.code, 200)
-        self.assertEqual(channel.json_body["expires_in"], 300)
+        self.assertEqual(channel.json_body["expires_in_ms"], 300000)
 
         login_token = channel.json_body["login_token"]
 
@@ -117,10 +118,10 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
 
     @override_config(
         {
-            "experimental_features": {
-                "msc3882_enabled": True,
-                "msc3882_ui_auth": False,
-                "msc3882_token_timeout": "15s",
+            "login_via_existing_session": {
+                "enabled": True,
+                "require_ui_auth": False,
+                "token_timeout": "15s",
             }
         }
     )
@@ -128,6 +129,40 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
         self.register_user(self.user, self.password)
         token = self.login(self.user, self.password)
 
-        channel = self.make_request("POST", endpoint, {}, access_token=token)
+        channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["expires_in_ms"], 15000)
+
+    @override_config(
+        {
+            "login_via_existing_session": {
+                "enabled": True,
+                "require_ui_auth": False,
+                "token_timeout": "15s",
+            }
+        }
+    )
+    def test_unstable_support(self) -> None:
+        # TODO: remove support for unstable MSC3882 is no longer needed
+
+        # check feature is advertised in versions response:
+        channel = self.make_request(
+            "GET", "/_matrix/client/versions", {}, access_token=None
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            channel.json_body["unstable_features"]["org.matrix.msc3882"], True
+        )
+
+        self.register_user(self.user, self.password)
+        token = self.login(self.user, self.password)
+
+        # check feature is available via the unstable endpoint and returns an expires_in value in seconds
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/unstable/org.matrix.msc3882/login/token",
+            {},
+            access_token=token,
+        )
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body["expires_in"], 15)