diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index c54f1c5797..368d600b33 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,14 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
-from urllib.parse import parse_qs, urlparse
+import re
+from typing import Dict
+from urllib.parse import parse_qs, urlencode, urlparse
from mock import ANY, Mock, patch
import pymacaroons
+from twisted.web.resource import Resource
+
+from synapse.api.errors import RedirectException
from synapse.handlers.oidc_handler import OidcError
from synapse.handlers.sso import MappingException
+from synapse.rest.client.v1 import login
+from synapse.rest.synapse.client.pick_username import pick_username_resource
from synapse.server import HomeServer
from synapse.types import UserID
@@ -793,6 +800,140 @@ class OidcHandlerTestCase(HomeserverTestCase):
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
)
+ def test_empty_localpart(self):
+ """Attempts to map onto an empty localpart should be rejected."""
+ userinfo = {
+ "sub": "tester",
+ "username": "",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ self.assertRenderedError("mapping_error", "localpart is invalid: ")
+
+ @override_config(
+ {
+ "oidc_config": {
+ "user_mapping_provider": {
+ "config": {"localpart_template": "{{ user.username }}"}
+ }
+ }
+ }
+ )
+ def test_null_localpart(self):
+ """Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
+ userinfo = {
+ "sub": "tester",
+ "username": None,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ self.assertRenderedError("mapping_error", "localpart is invalid: ")
+
+
+class UsernamePickerTestCase(HomeserverTestCase):
+ servlets = [login.register_servlets]
+
+ def default_config(self):
+ config = super().default_config()
+ config["public_baseurl"] = BASE_URL
+ oidc_config = {
+ "enabled": True,
+ "client_id": CLIENT_ID,
+ "client_secret": CLIENT_SECRET,
+ "issuer": ISSUER,
+ "scopes": SCOPES,
+ "user_mapping_provider": {
+ "config": {"display_name_template": "{{ user.displayname }}"}
+ },
+ }
+
+ # Update this config with what's in the default config so that
+ # override_config works as expected.
+ oidc_config.update(config.get("oidc_config", {}))
+ config["oidc_config"] = oidc_config
+
+ # whitelist this client URI so we redirect straight to it rather than
+ # serving a confirmation page
+ config["sso"] = {"client_whitelist": ["https://whitelisted.client"]}
+ return config
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
+ return d
+
+ def test_username_picker(self):
+ """Test the happy path of a username picker flow."""
+ client_redirect_url = "https://whitelisted.client"
+
+ # first of all, mock up an OIDC callback to the OidcHandler, which should
+ # raise a RedirectException
+ userinfo = {"sub": "tester", "displayname": "Jonny"}
+ f = self.get_failure(
+ _make_callback_with_userinfo(
+ self.hs, userinfo, client_redirect_url=client_redirect_url
+ ),
+ RedirectException,
+ )
+
+ # check the Location and cookies returned by the RedirectException
+ self.assertEqual(f.value.location, b"/_synapse/client/pick_username")
+ cookieheader = f.value.cookies[0]
+ regex = re.compile(b"^username_mapping_session=([a-zA-Z]+);")
+ m = regex.search(cookieheader)
+ if not m:
+ self.fail("cookie header %s does not match %s" % (cookieheader, regex))
+
+ # introspect the sso handler a bit to check that the username mapping session
+ # looks ok.
+ session_id = m.group(1).decode("ascii")
+ username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
+ self.assertIn(
+ session_id, username_mapping_sessions, "session id not found in map"
+ )
+ session = username_mapping_sessions[session_id]
+ self.assertEqual(session.remote_user_id, "tester")
+ self.assertEqual(session.display_name, "Jonny")
+ self.assertEqual(session.client_redirect_url, client_redirect_url)
+
+ # the expiry time should be about 15 minutes away
+ expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
+ self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
+
+ # Now, submit a username to the username picker, which should serve a redirect
+ # back to the client
+ submit_path = f.value.location + b"/submit"
+ content = urlencode({b"username": b"bobby"}).encode("utf8")
+ chan = self.make_request(
+ "POST",
+ path=submit_path,
+ content=content,
+ content_is_form=True,
+ custom_headers=[
+ ("Cookie", cookieheader),
+ # old versions of twisted don't do form-parsing without a valid
+ # content-length header.
+ ("Content-Length", str(len(content))),
+ ],
+ )
+ self.assertEqual(chan.code, 302, chan.result)
+ location_headers = chan.headers.getRawHeaders("Location")
+ # ensure that the returned location starts with the requested redirect URL
+ self.assertEqual(
+ location_headers[0][: len(client_redirect_url)], client_redirect_url
+ )
+
+ # fish the login token out of the returned redirect uri
+ parts = urlparse(location_headers[0])
+ query = parse_qs(parts.query)
+ login_token = query["loginToken"][0]
+
+ # finally, submit the matrix login token to the login API, which gives us our
+ # matrix access token, mxid, and device id.
+ chan = self.make_request(
+ "POST", "/login", content={"type": "m.login.token", "token": login_token},
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+ self.assertEqual(chan.json_body["user_id"], "@bobby:test")
+
async def _make_callback_with_userinfo(
hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
|