summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_oidc.py143
-rw-r--r--tests/unittest.py8
2 files changed, 149 insertions, 2 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index c54f1c5797..368d600b33 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,14 +13,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
-from urllib.parse import parse_qs, urlparse
+import re
+from typing import Dict
+from urllib.parse import parse_qs, urlencode, urlparse
 
 from mock import ANY, Mock, patch
 
 import pymacaroons
 
+from twisted.web.resource import Resource
+
+from synapse.api.errors import RedirectException
 from synapse.handlers.oidc_handler import OidcError
 from synapse.handlers.sso import MappingException
+from synapse.rest.client.v1 import login
+from synapse.rest.synapse.client.pick_username import pick_username_resource
 from synapse.server import HomeServer
 from synapse.types import UserID
 
@@ -793,6 +800,140 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "mapping_error", "Unable to generate a Matrix ID from the SSO response"
         )
 
+    def test_empty_localpart(self):
+        """Attempts to map onto an empty localpart should be rejected."""
+        userinfo = {
+            "sub": "tester",
+            "username": "",
+        }
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        self.assertRenderedError("mapping_error", "localpart is invalid: ")
+
+    @override_config(
+        {
+            "oidc_config": {
+                "user_mapping_provider": {
+                    "config": {"localpart_template": "{{ user.username }}"}
+                }
+            }
+        }
+    )
+    def test_null_localpart(self):
+        """Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
+        userinfo = {
+            "sub": "tester",
+            "username": None,
+        }
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        self.assertRenderedError("mapping_error", "localpart is invalid: ")
+
+
+class UsernamePickerTestCase(HomeserverTestCase):
+    servlets = [login.register_servlets]
+
+    def default_config(self):
+        config = super().default_config()
+        config["public_baseurl"] = BASE_URL
+        oidc_config = {
+            "enabled": True,
+            "client_id": CLIENT_ID,
+            "client_secret": CLIENT_SECRET,
+            "issuer": ISSUER,
+            "scopes": SCOPES,
+            "user_mapping_provider": {
+                "config": {"display_name_template": "{{ user.displayname }}"}
+            },
+        }
+
+        # Update this config with what's in the default config so that
+        # override_config works as expected.
+        oidc_config.update(config.get("oidc_config", {}))
+        config["oidc_config"] = oidc_config
+
+        # whitelist this client URI so we redirect straight to it rather than
+        # serving a confirmation page
+        config["sso"] = {"client_whitelist": ["https://whitelisted.client"]}
+        return config
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
+        return d
+
+    def test_username_picker(self):
+        """Test the happy path of a username picker flow."""
+        client_redirect_url = "https://whitelisted.client"
+
+        # first of all, mock up an OIDC callback to the OidcHandler, which should
+        # raise a RedirectException
+        userinfo = {"sub": "tester", "displayname": "Jonny"}
+        f = self.get_failure(
+            _make_callback_with_userinfo(
+                self.hs, userinfo, client_redirect_url=client_redirect_url
+            ),
+            RedirectException,
+        )
+
+        # check the Location and cookies returned by the RedirectException
+        self.assertEqual(f.value.location, b"/_synapse/client/pick_username")
+        cookieheader = f.value.cookies[0]
+        regex = re.compile(b"^username_mapping_session=([a-zA-Z]+);")
+        m = regex.search(cookieheader)
+        if not m:
+            self.fail("cookie header %s does not match %s" % (cookieheader, regex))
+
+        # introspect the sso handler a bit to check that the username mapping session
+        # looks ok.
+        session_id = m.group(1).decode("ascii")
+        username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
+        self.assertIn(
+            session_id, username_mapping_sessions, "session id not found in map"
+        )
+        session = username_mapping_sessions[session_id]
+        self.assertEqual(session.remote_user_id, "tester")
+        self.assertEqual(session.display_name, "Jonny")
+        self.assertEqual(session.client_redirect_url, client_redirect_url)
+
+        # the expiry time should be about 15 minutes away
+        expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
+        self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
+
+        # Now, submit a username to the username picker, which should serve a redirect
+        # back to the client
+        submit_path = f.value.location + b"/submit"
+        content = urlencode({b"username": b"bobby"}).encode("utf8")
+        chan = self.make_request(
+            "POST",
+            path=submit_path,
+            content=content,
+            content_is_form=True,
+            custom_headers=[
+                ("Cookie", cookieheader),
+                # old versions of twisted don't do form-parsing without a valid
+                # content-length header.
+                ("Content-Length", str(len(content))),
+            ],
+        )
+        self.assertEqual(chan.code, 302, chan.result)
+        location_headers = chan.headers.getRawHeaders("Location")
+        # ensure that the returned location starts with the requested redirect URL
+        self.assertEqual(
+            location_headers[0][: len(client_redirect_url)], client_redirect_url
+        )
+
+        # fish the login token out of the returned redirect uri
+        parts = urlparse(location_headers[0])
+        query = parse_qs(parts.query)
+        login_token = query["loginToken"][0]
+
+        # finally, submit the matrix login token to the login API, which gives us our
+        # matrix access token, mxid, and device id.
+        chan = self.make_request(
+            "POST", "/login", content={"type": "m.login.token", "token": login_token},
+        )
+        self.assertEqual(chan.code, 200, chan.result)
+        self.assertEqual(chan.json_body["user_id"], "@bobby:test")
+
 
 async def _make_callback_with_userinfo(
     hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
diff --git a/tests/unittest.py b/tests/unittest.py
index 39e5e7b85c..af7f752c5a 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import hmac
 import inspect
 import logging
 import time
-from typing import Dict, Optional, Type, TypeVar, Union
+from typing import Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
 
 from mock import Mock, patch
 
@@ -383,6 +383,9 @@ class HomeserverTestCase(TestCase):
         federation_auth_origin: str = None,
         content_is_form: bool = False,
         await_result: bool = True,
+        custom_headers: Optional[
+            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+        ] = None,
     ) -> FakeChannel:
         """
         Create a SynapseRequest at the path using the method and containing the
@@ -405,6 +408,8 @@ class HomeserverTestCase(TestCase):
                  true (the default), will pump the test reactor until the the renderer
                  tells the channel the request is finished.
 
+            custom_headers: (name, value) pairs to add as request headers
+
         Returns:
             The FakeChannel object which stores the result of the request.
         """
@@ -420,6 +425,7 @@ class HomeserverTestCase(TestCase):
             federation_auth_origin,
             content_is_form,
             await_result,
+            custom_headers,
         )
 
     def setup_test_homeserver(self, *args, **kwargs):