summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8911.feature1
-rw-r--r--tests/handlers/test_oidc.py286
2 files changed, 146 insertions, 141 deletions
diff --git a/changelog.d/8911.feature b/changelog.d/8911.feature
new file mode 100644
index 0000000000..d450ef4998
--- /dev/null
+++ b/changelog.d/8911.feature
@@ -0,0 +1 @@
+Add support for allowing users to pick their own user ID during a single-sign-on login.
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1d99a45436..9878527bab 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -15,7 +15,7 @@
 import json
 from urllib.parse import parse_qs, urlparse
 
-from mock import Mock, patch
+from mock import ANY, Mock, patch
 
 import pymacaroons
 
@@ -82,7 +82,7 @@ class TestMappingProviderFailures(TestMappingProvider):
         }
 
 
-def simple_async_mock(return_value=None, raises=None):
+def simple_async_mock(return_value=None, raises=None) -> Mock:
     # AsyncMock is not available in python3.5, this mimics part of its behaviour
     async def cb(*args, **kwargs):
         if raises:
@@ -160,6 +160,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.assertEqual(args[2], error_description)
         # Reset the render_error mock
         self.render_error.reset_mock()
+        return args
 
     def test_config(self):
         """Basic config correctly sets up the callback URL and client auth correctly."""
@@ -374,26 +375,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "id_token": "id_token",
             "access_token": "access_token",
         }
+        username = "bar"
         userinfo = {
             "sub": "foo",
-            "preferred_username": "bar",
+            "username": username,
         }
-        user_id = "@foo:domain.org"
+        expected_user_id = "@%s:%s" % (username, self.hs.hostname)
         self.handler._exchange_code = simple_async_mock(return_value=token)
         self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
         self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
-        self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
-        self.handler._auth_handler.complete_sso_login = simple_async_mock()
-        request = Mock(
-            spec=[
-                "args",
-                "getCookie",
-                "addCookie",
-                "requestHeaders",
-                "getClientIP",
-                "get_user_agent",
-            ]
-        )
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
 
         code = "code"
         state = "state"
@@ -401,64 +393,54 @@ class OidcHandlerTestCase(HomeserverTestCase):
         client_redirect_url = "http://client/redirect"
         user_agent = "Browser"
         ip_address = "10.0.0.1"
-        request.getCookie.return_value = self.handler._generate_oidc_session_token(
+        session = self.handler._generate_oidc_session_token(
             state=state,
             nonce=nonce,
             client_redirect_url=client_redirect_url,
             ui_auth_session_id=None,
         )
-
-        request.args = {}
-        request.args[b"code"] = [code.encode("utf-8")]
-        request.args[b"state"] = [state.encode("utf-8")]
-
-        request.getClientIP.return_value = ip_address
-        request.get_user_agent.return_value = user_agent
+        request = self._build_callback_request(
+            code, state, session, user_agent=user_agent, ip_address=ip_address
+        )
 
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        self.handler._auth_handler.complete_sso_login.assert_called_once_with(
-            user_id, request, client_redirect_url, {},
+        auth_handler.complete_sso_login.assert_called_once_with(
+            expected_user_id, request, client_redirect_url, {},
         )
         self.handler._exchange_code.assert_called_once_with(code)
         self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.handler._map_userinfo_to_user.assert_called_once_with(
-            userinfo, token, user_agent, ip_address
-        )
         self.handler._fetch_userinfo.assert_not_called()
         self.render_error.assert_not_called()
 
         # Handle mapping errors
-        self.handler._map_userinfo_to_user = simple_async_mock(
-            raises=MappingException()
-        )
-        self.get_success(self.handler.handle_oidc_callback(request))
-        self.assertRenderedError("mapping_error")
-        self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+        with patch.object(
+            self.handler,
+            "_remote_id_from_userinfo",
+            new=Mock(side_effect=MappingException()),
+        ):
+            self.get_success(self.handler.handle_oidc_callback(request))
+            self.assertRenderedError("mapping_error")
 
         # Handle ID token errors
         self.handler._parse_id_token = simple_async_mock(raises=Exception())
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_token")
 
-        self.handler._auth_handler.complete_sso_login.reset_mock()
+        auth_handler.complete_sso_login.reset_mock()
         self.handler._exchange_code.reset_mock()
         self.handler._parse_id_token.reset_mock()
-        self.handler._map_userinfo_to_user.reset_mock()
         self.handler._fetch_userinfo.reset_mock()
 
         # With userinfo fetching
         self.handler._scopes = []  # do not ask the "openid" scope
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        self.handler._auth_handler.complete_sso_login.assert_called_once_with(
-            user_id, request, client_redirect_url, {},
+        auth_handler.complete_sso_login.assert_called_once_with(
+            expected_user_id, request, client_redirect_url, {},
         )
         self.handler._exchange_code.assert_called_once_with(code)
         self.handler._parse_id_token.assert_not_called()
-        self.handler._map_userinfo_to_user.assert_called_once_with(
-            userinfo, token, user_agent, ip_address
-        )
         self.handler._fetch_userinfo.assert_called_once_with(token)
         self.render_error.assert_not_called()
 
@@ -609,72 +591,55 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         userinfo = {
             "sub": "foo",
+            "username": "foo",
             "phone": "1234567",
         }
-        user_id = "@foo:domain.org"
         self.handler._exchange_code = simple_async_mock(return_value=token)
         self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
-        self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
-        self.handler._auth_handler.complete_sso_login = simple_async_mock()
-        request = Mock(
-            spec=[
-                "args",
-                "getCookie",
-                "addCookie",
-                "requestHeaders",
-                "getClientIP",
-                "get_user_agent",
-            ]
-        )
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
 
         state = "state"
         client_redirect_url = "http://client/redirect"
-        request.getCookie.return_value = self.handler._generate_oidc_session_token(
+        session = self.handler._generate_oidc_session_token(
             state=state,
             nonce="nonce",
             client_redirect_url=client_redirect_url,
             ui_auth_session_id=None,
         )
-
-        request.args = {}
-        request.args[b"code"] = [b"code"]
-        request.args[b"state"] = [state.encode("utf-8")]
-
-        request.getClientIP.return_value = "10.0.0.1"
-        request.get_user_agent.return_value = "Browser"
+        request = self._build_callback_request("code", state, session)
 
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        self.handler._auth_handler.complete_sso_login.assert_called_once_with(
-            user_id, request, client_redirect_url, {"phone": "1234567"},
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@foo:test", request, client_redirect_url, {"phone": "1234567"},
         )
 
     def test_map_userinfo_to_user(self):
         """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         userinfo = {
             "sub": "test_user",
             "username": "test_user",
         }
-        # The token doesn't matter with the default user mapping provider.
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self._make_callback_with_userinfo(userinfo)
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", ANY, ANY, {}
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Some providers return an integer ID.
         userinfo = {
             "sub": 1234,
             "username": "test_user_2",
         }
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self._make_callback_with_userinfo(userinfo)
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user_2:test", ANY, ANY, {}
         )
-        self.assertEqual(mxid, "@test_user_2:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Test if the mxid is already taken
         store = self.hs.get_datastore()
@@ -683,14 +648,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user3.to_string(), password_hash=None)
         )
         userinfo = {"sub": "test3", "username": "test_user_3"}
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
-        )
-        self.assertEqual(
-            str(e.value), "Mapping provider does not support de-duplicating Matrix IDs",
+        self._make_callback_with_userinfo(userinfo)
+        auth_handler.complete_sso_login.assert_not_called()
+        self.assertRenderedError(
+            "mapping_error",
+            "Mapping provider does not support de-duplicating Matrix IDs",
         )
 
     @override_config({"oidc_config": {"allow_existing_users": True}})
@@ -702,26 +664,26 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user.to_string(), password_hash=None)
         )
 
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         # Map a user via SSO.
         userinfo = {
             "sub": "test",
             "username": "test_user",
         }
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self._make_callback_with_userinfo(userinfo)
+        auth_handler.complete_sso_login.assert_called_once_with(
+            user.to_string(), ANY, ANY, {},
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Subsequent calls should map to the same mxid.
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self._make_callback_with_userinfo(userinfo)
+        auth_handler.complete_sso_login.assert_called_once_with(
+            user.to_string(), ANY, ANY, {},
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Note that a second SSO user can be mapped to the same Matrix ID. (This
         # requires a unique sub, but something that maps to the same matrix ID,
@@ -732,13 +694,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test1",
             "username": "test_user",
         }
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self._make_callback_with_userinfo(userinfo)
+        auth_handler.complete_sso_login.assert_called_once_with(
+            user.to_string(), ANY, ANY, {},
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Register some non-exact matching cases.
         user2 = UserID.from_string("@TEST_user_2:test")
@@ -755,14 +715,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test2",
             "username": "TEST_USER_2",
         }
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
-        )
+        self._make_callback_with_userinfo(userinfo)
+        auth_handler.complete_sso_login.assert_not_called()
+        args = self.assertRenderedError("mapping_error")
         self.assertTrue(
-            str(e.value).startswith(
+            args[2].startswith(
                 "Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
             )
         )
@@ -773,28 +730,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user2.to_string(), password_hash=None)
         )
 
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self._make_callback_with_userinfo(userinfo)
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@TEST_USER_2:test", ANY, ANY, {},
         )
-        self.assertEqual(mxid, "@TEST_USER_2:test")
 
     def test_map_userinfo_to_invalid_localpart(self):
         """If the mapping provider generates an invalid localpart it should be rejected."""
-        userinfo = {
-            "sub": "test2",
-            "username": "föö",
-        }
-        token = {}
-
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
-        )
-        self.assertEqual(str(e.value), "localpart is invalid: föö")
+        self._make_callback_with_userinfo({"sub": "test2", "username": "föö"})
+        self.assertRenderedError("mapping_error", "localpart is invalid: föö")
 
     @override_config(
         {
@@ -807,6 +751,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_map_userinfo_to_user_retries(self):
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         store = self.hs.get_datastore()
         self.get_success(
             store.register_user(user_id="@test_user:test", password_hash=None)
@@ -815,14 +762,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test",
             "username": "test_user",
         }
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
-        )
+        self._make_callback_with_userinfo(userinfo)
+
         # test_user is already taken, so test_user1 gets registered instead.
-        self.assertEqual(mxid, "@test_user1:test")
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user1:test", ANY, ANY, {},
+        )
+        auth_handler.complete_sso_login.reset_mock()
 
         # Register all of the potential mxids for a particular OIDC username.
         self.get_success(
@@ -838,12 +784,70 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": "tester",
         }
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
+        self._make_callback_with_userinfo(userinfo)
+        auth_handler.complete_sso_login.assert_not_called()
+        self.assertRenderedError(
+            "mapping_error", "Unable to generate a Matrix ID from the SSO response"
         )
-        self.assertEqual(
-            str(e.value), "Unable to generate a Matrix ID from the SSO response"
+
+    def _make_callback_with_userinfo(
+        self, userinfo: dict, client_redirect_url: str = "http://client/redirect"
+    ) -> None:
+        self.handler._exchange_code = simple_async_mock(return_value={})
+        self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
+        self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        state = "state"
+        session = self.handler._generate_oidc_session_token(
+            state=state,
+            nonce="nonce",
+            client_redirect_url=client_redirect_url,
+            ui_auth_session_id=None,
+        )
+        request = self._build_callback_request("code", state, session)
+
+        self.get_success(self.handler.handle_oidc_callback(request))
+
+    def _build_callback_request(
+        self,
+        code: str,
+        state: str,
+        session: str,
+        user_agent: str = "Browser",
+        ip_address: str = "10.0.0.1",
+    ):
+        """Builds a fake SynapseRequest to mock the browser callback
+
+        Returns a Mock object which looks like the SynapseRequest we get from a browser
+        after SSO (before we return to the client)
+
+        Args:
+            code: the authorization code which would have been returned by the OIDC
+               provider
+            state: the "state" param which would have been passed around in the
+               query param. Should be the same as was embedded in the session in
+               _build_oidc_session.
+            session: the "session" which would have been passed around in the cookie.
+            user_agent: the user-agent to present
+            ip_address: the IP address to pretend the request came from
+        """
+        request = Mock(
+            spec=[
+                "args",
+                "getCookie",
+                "addCookie",
+                "requestHeaders",
+                "getClientIP",
+                "get_user_agent",
+            ]
         )
+
+        request.getCookie.return_value = session
+        request.args = {}
+        request.args[b"code"] = [code.encode("utf-8")]
+        request.args[b"state"] = [state.encode("utf-8")]
+        request.getClientIP.return_value = ip_address
+        request.get_user_agent.return_value = user_agent
+        return request