diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 2672ce24c6..e2bb945453 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -75,6 +75,10 @@ TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'
# the query params in TEST_CLIENT_REDIRECT_URL
EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
+# (possibly experimental) login flows we expect to appear in the list after the normal
+# ones
+ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+
class LoginRestServletTestCase(unittest.HomeserverTestCase):
@@ -426,6 +430,57 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
d["/_synapse/oidc"] = OIDCResource(self.hs)
return d
+ def test_get_login_flows(self):
+ """GET /login should return password and SSO flows"""
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+
+ expected_flows = [
+ {"type": "m.login.cas"},
+ {"type": "m.login.sso"},
+ {"type": "m.login.token"},
+ {"type": "m.login.password"},
+ ] + ADDITIONAL_LOGIN_FLOWS
+
+ self.assertCountEqual(channel.json_body["flows"], expected_flows)
+
+ @override_config({"experimental_features": {"msc2858_enabled": True}})
+ def test_get_msc2858_login_flows(self):
+ """The SSO flow should include IdP info if MSC2858 is enabled"""
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # stick the flows results in a dict by type
+ flow_results = {} # type: Dict[str, Any]
+ for f in channel.json_body["flows"]:
+ flow_type = f["type"]
+ self.assertNotIn(
+ flow_type, flow_results, "duplicate flow type %s" % (flow_type,)
+ )
+ flow_results[flow_type] = f
+
+ self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned")
+ sso_flow = flow_results.pop("m.login.sso")
+ # we should have a set of IdPs
+ self.assertCountEqual(
+ sso_flow["org.matrix.msc2858.identity_providers"],
+ [
+ {"id": "cas", "name": "CAS"},
+ {"id": "saml", "name": "SAML"},
+ {"id": "oidc-idp1", "name": "IDP1"},
+ {"id": "oidc", "name": "OIDC"},
+ ],
+ )
+
+ # the rest of the flows are simple
+ expected_flows = [
+ {"type": "m.login.cas"},
+ {"type": "m.login.token"},
+ {"type": "m.login.password"},
+ ] + ADDITIONAL_LOGIN_FLOWS
+
+ self.assertCountEqual(flow_results.values(), expected_flows)
+
def test_multi_sso_redirect(self):
"""/login/sso/redirect should redirect to an identity picker"""
# first hit the redirect url, which should redirect to our idp picker
@@ -564,6 +619,43 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 400, channel.result)
+ def test_client_idp_redirect_msc2858_disabled(self):
+ """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+
+ @override_config({"experimental_features": {"msc2858_enabled": True}})
+ def test_client_idp_redirect_to_unknown(self):
+ """If the client tries to pick an unknown IdP, return a 404"""
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+ )
+ self.assertEqual(channel.code, 404, channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
+
+ @override_config({"experimental_features": {"msc2858_enabled": True}})
+ def test_client_idp_redirect_to_oidc(self):
+ """If the client pick a known IdP, redirect to it"""
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+ )
+
+ self.assertEqual(channel.code, 302, channel.result)
+ oidc_uri = channel.headers.getRawHeaders("Location")[0]
+ oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+ # it should redirect us to the auth page of the OIDC server
+ self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
@staticmethod
def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
prefix = key + " = "
|