summary refs log tree commit diff
path: root/tests/rest/client/v1
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-01-15 13:45:13 +0000
committerGitHub <noreply@github.com>2021-01-15 13:45:13 +0000
commit0dd2649c127e4eb538dfbf0c879bd66c9ff1599c (patch)
treef12fe3ea07eb230569c61183d90cb1dd94036c83 /tests/rest/client/v1
parentStore an IdP ID in the OIDC session (#9109) (diff)
downloadsynapse-0dd2649c127e4eb538dfbf0c879bd66c9ff1599c.tar.xz
Improve UsernamePickerTestCase (#9112)
* make the OIDC bits of the test work at a higher level - via the REST api instead of poking the OIDCHandler directly.
* Move it to test_login.py, where I think it fits better.
Diffstat (limited to 'tests/rest/client/v1')
-rw-r--r--tests/rest/client/v1/test_login.py105
-rw-r--r--tests/rest/client/v1/utils.py11
2 files changed, 110 insertions, 6 deletions
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index f9b8011961..73a009efd1 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -17,6 +17,7 @@ import time
 import urllib.parse
 from html.parser import HTMLParser
 from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from urllib.parse import parse_qs, urlencode, urlparse
 
 from mock import Mock
 
@@ -30,13 +31,14 @@ from synapse.rest.client.v1 import login, logout
 from synapse.rest.client.v2_alpha import devices, register
 from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
 from synapse.rest.synapse.client.pick_idp import PickIdpResource
+from synapse.rest.synapse.client.pick_username import pick_username_resource
 from synapse.types import create_requester
 
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
 from tests.handlers.test_saml import has_saml2
 from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
-from tests.unittest import override_config, skip_unless
+from tests.unittest import HomeserverTestCase, override_config, skip_unless
 
 try:
     import jwt
@@ -1060,3 +1062,104 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
         self.assertEquals(channel.result["code"], b"401", channel.result)
+
+
+@skip_unless(HAS_OIDC, "requires OIDC")
+class UsernamePickerTestCase(HomeserverTestCase):
+    """Tests for the username picker flow of SSO login"""
+
+    servlets = [login.register_servlets]
+
+    def default_config(self):
+        config = super().default_config()
+        config["public_baseurl"] = BASE_URL
+
+        config["oidc_config"] = {}
+        config["oidc_config"].update(TEST_OIDC_CONFIG)
+        config["oidc_config"]["user_mapping_provider"] = {
+            "config": {"display_name_template": "{{ user.displayname }}"}
+        }
+
+        # whitelist this client URI so we redirect straight to it rather than
+        # serving a confirmation page
+        config["sso"] = {"client_whitelist": ["https://whitelisted.client"]}
+        return config
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        from synapse.rest.oidc import OIDCResource
+
+        d = super().create_resource_dict()
+        d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
+        d["/_synapse/oidc"] = OIDCResource(self.hs)
+        return d
+
+    def test_username_picker(self):
+        """Test the happy path of a username picker flow."""
+        client_redirect_url = "https://whitelisted.client"
+
+        # do the start of the login flow
+        channel = self.helper.auth_via_oidc(
+            {"sub": "tester", "displayname": "Jonny"}, client_redirect_url
+        )
+
+        # that should redirect to the username picker
+        self.assertEqual(channel.code, 302, channel.result)
+        picker_url = channel.headers.getRawHeaders("Location")[0]
+        self.assertEqual(picker_url, "/_synapse/client/pick_username")
+
+        # ... with a username_mapping_session cookie
+        cookies = {}  # type: Dict[str,str]
+        channel.extract_cookies(cookies)
+        self.assertIn("username_mapping_session", cookies)
+        session_id = cookies["username_mapping_session"]
+
+        # introspect the sso handler a bit to check that the username mapping session
+        # looks ok.
+        username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
+        self.assertIn(
+            session_id, username_mapping_sessions, "session id not found in map",
+        )
+        session = username_mapping_sessions[session_id]
+        self.assertEqual(session.remote_user_id, "tester")
+        self.assertEqual(session.display_name, "Jonny")
+        self.assertEqual(session.client_redirect_url, client_redirect_url)
+
+        # the expiry time should be about 15 minutes away
+        expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
+        self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
+
+        # Now, submit a username to the username picker, which should serve a redirect
+        # back to the client
+        submit_path = picker_url + "/submit"
+        content = urlencode({b"username": b"bobby"}).encode("utf8")
+        chan = self.make_request(
+            "POST",
+            path=submit_path,
+            content=content,
+            content_is_form=True,
+            custom_headers=[
+                ("Cookie", "username_mapping_session=" + session_id),
+                # old versions of twisted don't do form-parsing without a valid
+                # content-length header.
+                ("Content-Length", str(len(content))),
+            ],
+        )
+        self.assertEqual(chan.code, 302, chan.result)
+        location_headers = chan.headers.getRawHeaders("Location")
+        # ensure that the returned location starts with the requested redirect URL
+        self.assertEqual(
+            location_headers[0][: len(client_redirect_url)], client_redirect_url
+        )
+
+        # fish the login token out of the returned redirect uri
+        parts = urlparse(location_headers[0])
+        query = parse_qs(parts.query)
+        login_token = query["loginToken"][0]
+
+        # finally, submit the matrix login token to the login API, which gives us our
+        # matrix access token, mxid, and device id.
+        chan = self.make_request(
+            "POST", "/login", content={"type": "m.login.token", "token": login_token},
+        )
+        self.assertEqual(chan.code, 200, chan.result)
+        self.assertEqual(chan.json_body["user_id"], "@bobby:test")
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 85d1709ead..c6647dbe08 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -363,10 +363,10 @@ class RestHelper:
         the normal places.
         """
         client_redirect_url = "https://x"
-        channel = self.auth_via_oidc(remote_user_id, client_redirect_url)
+        channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
 
         # expect a confirmation page
-        assert channel.code == 200
+        assert channel.code == 200, channel.result
 
         # fish the matrix login token out of the body of the confirmation page
         m = re.search(
@@ -390,7 +390,7 @@ class RestHelper:
 
     def auth_via_oidc(
         self,
-        remote_user_id: str,
+        user_info_dict: JsonDict,
         client_redirect_url: Optional[str] = None,
         ui_auth_session_id: Optional[str] = None,
     ) -> FakeChannel:
@@ -411,7 +411,8 @@ class RestHelper:
         the normal places.
 
         Args:
-            remote_user_id: the remote id that the OIDC provider should present
+            user_info_dict: the remote userinfo that the OIDC provider should present.
+                Typically this should be '{"sub": "<remote user id>"}'.
             client_redirect_url: for a login flow, the client redirect URL to pass to
                 the login redirect endpoint
             ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
@@ -457,7 +458,7 @@ class RestHelper:
             # a dummy OIDC access token
             ("https://issuer.test/token", {"access_token": "TEST"}),
             # and then one to the user_info endpoint, which returns our remote user id.
-            ("https://issuer.test/userinfo", {"sub": remote_user_id}),
+            ("https://issuer.test/userinfo", user_info_dict),
         ]
 
         async def mock_req(method: str, uri: str, data=None, headers=None):