summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/client/v1/test_login.py43
1 files changed, 27 insertions, 16 deletions
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 20af3285bd..988821b16f 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -437,14 +437,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         channel = self.make_request("GET", "/_matrix/client/r0/login")
         self.assertEqual(channel.code, 200, channel.result)
 
-        expected_flows = [
-            {"type": "m.login.cas"},
-            {"type": "m.login.sso"},
-            {"type": "m.login.token"},
-            {"type": "m.login.password"},
-        ] + ADDITIONAL_LOGIN_FLOWS
+        expected_flow_types = [
+            "m.login.cas",
+            "m.login.sso",
+            "m.login.token",
+            "m.login.password",
+        ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS]
 
-        self.assertCountEqual(channel.json_body["flows"], expected_flows)
+        self.assertCountEqual(
+            [f["type"] for f in channel.json_body["flows"]], expected_flow_types
+        )
 
     @override_config({"experimental_features": {"msc2858_enabled": True}})
     def test_get_msc2858_login_flows(self):
@@ -636,22 +638,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.code, 400, channel.result)
 
-    def test_client_idp_redirect_msc2858_disabled(self):
-        """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
-        channel = self._make_sso_redirect_request(True, "oidc")
-        self.assertEqual(channel.code, 400, channel.result)
-        self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
-
-    @override_config({"experimental_features": {"msc2858_enabled": True}})
     def test_client_idp_redirect_to_unknown(self):
         """If the client tries to pick an unknown IdP, return a 404"""
-        channel = self._make_sso_redirect_request(True, "xxx")
+        channel = self._make_sso_redirect_request(False, "xxx")
         self.assertEqual(channel.code, 404, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
 
-    @override_config({"experimental_features": {"msc2858_enabled": True}})
     def test_client_idp_redirect_to_oidc(self):
         """If the client pick a known IdP, redirect to it"""
+        channel = self._make_sso_redirect_request(False, "oidc")
+        self.assertEqual(channel.code, 302, channel.result)
+        oidc_uri = channel.headers.getRawHeaders("Location")[0]
+        oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+        # it should redirect us to the auth page of the OIDC server
+        self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
+    @override_config({"experimental_features": {"msc2858_enabled": True}})
+    def test_client_msc2858_redirect_to_oidc(self):
+        """Test the unstable API"""
         channel = self._make_sso_redirect_request(True, "oidc")
         self.assertEqual(channel.code, 302, channel.result)
         oidc_uri = channel.headers.getRawHeaders("Location")[0]
@@ -660,6 +665,12 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         # it should redirect us to the auth page of the OIDC server
         self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
 
+    def test_client_idp_redirect_msc2858_disabled(self):
+        """If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400"""
+        channel = self._make_sso_redirect_request(True, "oidc")
+        self.assertEqual(channel.code, 400, channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+
     def _make_sso_redirect_request(
         self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
     ):