summary refs log tree commit diff
path: root/tests/rest/client/v1/test_login.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/v1/test_login.py')
-rw-r--r--tests/rest/client/v1/test_login.py234
1 files changed, 188 insertions, 46 deletions
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 1d1dc9f8a2..2672ce24c6 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -15,8 +15,8 @@
 
 import time
 import urllib.parse
-from html.parser import HTMLParser
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Union
+from urllib.parse import urlencode
 
 from mock import Mock
 
@@ -30,12 +30,15 @@ from synapse.rest.client.v1 import login, logout
 from synapse.rest.client.v2_alpha import devices, register
 from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
 from synapse.rest.synapse.client.pick_idp import PickIdpResource
+from synapse.rest.synapse.client.pick_username import pick_username_resource
+from synapse.types import create_requester
 
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
 from tests.handlers.test_saml import has_saml2
 from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
-from tests.unittest import override_config, skip_unless
+from tests.test_utils.html_parsers import TestHtmlParser
+from tests.unittest import HomeserverTestCase, override_config, skip_unless
 
 try:
     import jwt
@@ -66,6 +69,12 @@ TEST_SAML_METADATA = """
 LOGIN_URL = b"/_matrix/client/r0/login"
 TEST_URL = b"/_matrix/client/r0/account/whoami"
 
+# a (valid) url with some annoying characters in.  %3D is =, %26 is &, %2B is +
+TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'
+
+# the query params in TEST_CLIENT_REDIRECT_URL
+EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
+
 
 class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
@@ -386,23 +395,44 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             },
         }
 
+        # default OIDC provider
         config["oidc_config"] = TEST_OIDC_CONFIG
 
+        # additional OIDC providers
+        config["oidc_providers"] = [
+            {
+                "idp_id": "idp1",
+                "idp_name": "IDP1",
+                "discover": False,
+                "issuer": "https://issuer1",
+                "client_id": "test-client-id",
+                "client_secret": "test-client-secret",
+                "scopes": ["profile"],
+                "authorization_endpoint": "https://issuer1/auth",
+                "token_endpoint": "https://issuer1/token",
+                "userinfo_endpoint": "https://issuer1/userinfo",
+                "user_mapping_provider": {
+                    "config": {"localpart_template": "{{ user.sub }}"}
+                },
+            }
+        ]
         return config
 
     def create_resource_dict(self) -> Dict[str, Resource]:
+        from synapse.rest.oidc import OIDCResource
+
         d = super().create_resource_dict()
         d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
+        d["/_synapse/oidc"] = OIDCResource(self.hs)
         return d
 
     def test_multi_sso_redirect(self):
         """/login/sso/redirect should redirect to an identity picker"""
-        client_redirect_url = "https://x?<abc>"
-
         # first hit the redirect url, which should redirect to our idp picker
         channel = self.make_request(
             "GET",
-            "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url,
+            "/_matrix/client/r0/login/sso/redirect?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
         )
         self.assertEqual(channel.code, 302, channel.result)
         uri = channel.headers.getRawHeaders("Location")[0]
@@ -412,46 +442,22 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
 
         # parse the form to check it has fields assumed elsewhere in this class
-        class FormPageParser(HTMLParser):
-            def __init__(self):
-                super().__init__()
-
-                # the values of the hidden inputs: map from name to value
-                self.hiddens = {}  # type: Dict[str, Optional[str]]
-
-                # the values of the radio buttons
-                self.radios = []  # type: List[Optional[str]]
-
-            def handle_starttag(
-                self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
-            ) -> None:
-                attr_dict = dict(attrs)
-                if tag == "input":
-                    if attr_dict["type"] == "radio" and attr_dict["name"] == "idp":
-                        self.radios.append(attr_dict["value"])
-                    elif attr_dict["type"] == "hidden":
-                        input_name = attr_dict["name"]
-                        assert input_name
-                        self.hiddens[input_name] = attr_dict["value"]
-
-            def error(_, message):
-                self.fail(message)
-
-        p = FormPageParser()
+        p = TestHtmlParser()
         p.feed(channel.result["body"].decode("utf-8"))
         p.close()
 
-        self.assertCountEqual(p.radios, ["cas", "oidc", "saml"])
+        self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"])
 
-        self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url)
+        self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
 
     def test_multi_sso_redirect_to_cas(self):
         """If CAS is chosen, should redirect to the CAS server"""
-        client_redirect_url = "https://x?<abc>"
 
         channel = self.make_request(
             "GET",
-            "/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas",
+            "/_synapse/client/pick_idp?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+            + "&idp=cas",
             shorthand=False,
         )
         self.assertEqual(channel.code, 302, channel.result)
@@ -467,16 +473,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         service_uri = cas_uri_params["service"][0]
         _, service_uri_query = service_uri.split("?", 1)
         service_uri_params = urllib.parse.parse_qs(service_uri_query)
-        self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url)
+        self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
 
     def test_multi_sso_redirect_to_saml(self):
         """If SAML is chosen, should redirect to the SAML server"""
-        client_redirect_url = "https://x?<abc>"
-
         channel = self.make_request(
             "GET",
             "/_synapse/client/pick_idp?redirectUrl="
-            + client_redirect_url
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
             + "&idp=saml",
         )
         self.assertEqual(channel.code, 302, channel.result)
@@ -489,16 +493,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         # the RelayState is used to carry the client redirect url
         saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
         relay_state_param = saml_uri_params["RelayState"][0]
-        self.assertEqual(relay_state_param, client_redirect_url)
+        self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
 
-    def test_multi_sso_redirect_to_oidc(self):
+    def test_login_via_oidc(self):
         """If OIDC is chosen, should redirect to the OIDC auth endpoint"""
-        client_redirect_url = "https://x?<abc>"
 
+        # pick the default OIDC provider
         channel = self.make_request(
             "GET",
             "/_synapse/client/pick_idp?redirectUrl="
-            + client_redirect_url
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
             + "&idp=oidc",
         )
         self.assertEqual(channel.code, 302, channel.result)
@@ -518,8 +522,40 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
         self.assertEqual(
             self._get_value_from_macaroon(macaroon, "client_redirect_url"),
-            client_redirect_url,
+            TEST_CLIENT_REDIRECT_URL,
+        )
+
+        channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
+
+        # that should serve a confirmation page
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertTrue(
+            channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html")
+        )
+        p = TestHtmlParser()
+        p.feed(channel.text_body)
+        p.close()
+
+        # ... which should contain our redirect link
+        self.assertEqual(len(p.links), 1)
+        path, query = p.links[0].split("?", 1)
+        self.assertEqual(path, "https://x")
+
+        # it will have url-encoded the params properly, so we'll have to parse them
+        params = urllib.parse.parse_qsl(
+            query, keep_blank_values=True, strict_parsing=True, errors="strict"
+        )
+        self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
+        self.assertEqual(params[2][0], "loginToken")
+
+        # finally, submit the matrix login token to the login API, which gives us our
+        # matrix access token, mxid, and device id.
+        login_token = params[2][1]
+        chan = self.make_request(
+            "POST", "/login", content={"type": "m.login.token", "token": login_token},
         )
+        self.assertEqual(chan.code, 200, chan.result)
+        self.assertEqual(chan.json_body["user_id"], "@user1:test")
 
     def test_multi_sso_redirect_to_unknown(self):
         """An unknown IdP should cause a 400"""
@@ -667,7 +703,9 @@ class CASTestCase(unittest.HomeserverTestCase):
 
         # Deactivate the account.
         self.get_success(
-            self.deactivate_account_handler.deactivate_account(self.user_id, False)
+            self.deactivate_account_handler.deactivate_account(
+                self.user_id, False, create_requester(self.user_id)
+            )
         )
 
         # Request the CAS ticket.
@@ -1057,3 +1095,107 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
         self.assertEquals(channel.result["code"], b"401", channel.result)
+
+
+@skip_unless(HAS_OIDC, "requires OIDC")
+class UsernamePickerTestCase(HomeserverTestCase):
+    """Tests for the username picker flow of SSO login"""
+
+    servlets = [login.register_servlets]
+
+    def default_config(self):
+        config = super().default_config()
+        config["public_baseurl"] = BASE_URL
+
+        config["oidc_config"] = {}
+        config["oidc_config"].update(TEST_OIDC_CONFIG)
+        config["oidc_config"]["user_mapping_provider"] = {
+            "config": {"display_name_template": "{{ user.displayname }}"}
+        }
+
+        # whitelist this client URI so we redirect straight to it rather than
+        # serving a confirmation page
+        config["sso"] = {"client_whitelist": ["https://x"]}
+        return config
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        from synapse.rest.oidc import OIDCResource
+
+        d = super().create_resource_dict()
+        d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
+        d["/_synapse/oidc"] = OIDCResource(self.hs)
+        return d
+
+    def test_username_picker(self):
+        """Test the happy path of a username picker flow."""
+
+        # do the start of the login flow
+        channel = self.helper.auth_via_oidc(
+            {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL
+        )
+
+        # that should redirect to the username picker
+        self.assertEqual(channel.code, 302, channel.result)
+        picker_url = channel.headers.getRawHeaders("Location")[0]
+        self.assertEqual(picker_url, "/_synapse/client/pick_username")
+
+        # ... with a username_mapping_session cookie
+        cookies = {}  # type: Dict[str,str]
+        channel.extract_cookies(cookies)
+        self.assertIn("username_mapping_session", cookies)
+        session_id = cookies["username_mapping_session"]
+
+        # introspect the sso handler a bit to check that the username mapping session
+        # looks ok.
+        username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
+        self.assertIn(
+            session_id, username_mapping_sessions, "session id not found in map",
+        )
+        session = username_mapping_sessions[session_id]
+        self.assertEqual(session.remote_user_id, "tester")
+        self.assertEqual(session.display_name, "Jonny")
+        self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)
+
+        # the expiry time should be about 15 minutes away
+        expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
+        self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
+
+        # Now, submit a username to the username picker, which should serve a redirect
+        # back to the client
+        submit_path = picker_url + "/submit"
+        content = urlencode({b"username": b"bobby"}).encode("utf8")
+        chan = self.make_request(
+            "POST",
+            path=submit_path,
+            content=content,
+            content_is_form=True,
+            custom_headers=[
+                ("Cookie", "username_mapping_session=" + session_id),
+                # old versions of twisted don't do form-parsing without a valid
+                # content-length header.
+                ("Content-Length", str(len(content))),
+            ],
+        )
+        self.assertEqual(chan.code, 302, chan.result)
+        location_headers = chan.headers.getRawHeaders("Location")
+        # ensure that the returned location matches the requested redirect URL
+        path, query = location_headers[0].split("?", 1)
+        self.assertEqual(path, "https://x")
+
+        # it will have url-encoded the params properly, so we'll have to parse them
+        params = urllib.parse.parse_qsl(
+            query, keep_blank_values=True, strict_parsing=True, errors="strict"
+        )
+        self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
+        self.assertEqual(params[2][0], "loginToken")
+
+        # fish the login token out of the returned redirect uri
+        login_token = params[2][1]
+
+        # finally, submit the matrix login token to the login API, which gives us our
+        # matrix access token, mxid, and device id.
+        chan = self.make_request(
+            "POST", "/login", content={"type": "m.login.token", "token": login_token},
+        )
+        self.assertEqual(chan.code, 200, chan.result)
+        self.assertEqual(chan.json_body["user_id"], "@bobby:test")