summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8941.feature1
-rw-r--r--synapse/handlers/oidc_handler.py62
-rw-r--r--synapse/handlers/saml_handler.py37
-rw-r--r--synapse/handlers/sso.py58
-rw-r--r--tests/handlers/test_saml.py8
5 files changed, 80 insertions, 86 deletions
diff --git a/changelog.d/8941.feature b/changelog.d/8941.feature
new file mode 100644
index 0000000000..d450ef4998
--- /dev/null
+++ b/changelog.d/8941.feature
@@ -0,0 +1 @@
+Add support for allowing users to pick their own user ID during a single-sign-on login.
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index f626117f76..cbd11a1382 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -115,8 +115,6 @@ class OidcHandler(BaseHandler):
         self._allow_existing_users = hs.config.oidc_allow_existing_users  # type: bool
 
         self._http_client = hs.get_proxied_http_client()
-        self._auth_handler = hs.get_auth_handler()
-        self._registration_handler = hs.get_registration_handler()
         self._server_name = hs.config.server_name  # type: str
         self._macaroon_secret_key = hs.config.macaroon_secret_key
 
@@ -689,33 +687,14 @@ class OidcHandler(BaseHandler):
 
         # otherwise, it's a login
 
-        # Pull out the user-agent and IP from the request.
-        user_agent = request.get_user_agent("")
-        ip_address = self.hs.get_ip_from_request(request)
-
         # Call the mapper to register/login the user
         try:
-            user_id = await self._map_userinfo_to_user(
-                userinfo, token, user_agent, ip_address
+            await self._complete_oidc_login(
+                userinfo, token, request, client_redirect_url
             )
         except MappingException as e:
             logger.exception("Could not map user")
             self._sso_handler.render_error(request, "mapping_error", str(e))
-            return
-
-        # Mapping providers might not have get_extra_attributes: only call this
-        # method if it exists.
-        extra_attributes = None
-        get_extra_attributes = getattr(
-            self._user_mapping_provider, "get_extra_attributes", None
-        )
-        if get_extra_attributes:
-            extra_attributes = await get_extra_attributes(userinfo, token)
-
-        # and finally complete the login
-        await self._auth_handler.complete_sso_login(
-            user_id, request, client_redirect_url, extra_attributes
-        )
 
     def _generate_oidc_session_token(
         self,
@@ -838,10 +817,14 @@ class OidcHandler(BaseHandler):
         now = self.clock.time_msec()
         return now < expiry
 
-    async def _map_userinfo_to_user(
-        self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
-    ) -> str:
-        """Maps a UserInfo object to a mxid.
+    async def _complete_oidc_login(
+        self,
+        userinfo: UserInfo,
+        token: Token,
+        request: SynapseRequest,
+        client_redirect_url: str,
+    ) -> None:
+        """Given a UserInfo response, complete the login flow
 
         UserInfo should have a claim that uniquely identifies users. This claim
         is usually `sub`, but can be configured with `oidc_config.subject_claim`.
@@ -853,17 +836,16 @@ class OidcHandler(BaseHandler):
         If a user already exists with the mxid we've mapped and allow_existing_users
         is disabled, raise an exception.
 
+        Otherwise, render a redirect back to the client_redirect_url with a loginToken.
+
         Args:
             userinfo: an object representing the user
             token: a dict with the tokens obtained from the provider
-            user_agent: The user agent of the client making the request.
-            ip_address: The IP address of the client making the request.
+            request: The request to respond to
+            client_redirect_url: The redirect URL passed in by the client.
 
         Raises:
             MappingException: if there was an error while mapping some properties
-
-        Returns:
-            The mxid of the user
         """
         try:
             remote_user_id = self._remote_id_from_userinfo(userinfo)
@@ -931,13 +913,23 @@ class OidcHandler(BaseHandler):
 
             return None
 
-        return await self._sso_handler.get_mxid_from_sso(
+        # Mapping providers might not have get_extra_attributes: only call this
+        # method if it exists.
+        extra_attributes = None
+        get_extra_attributes = getattr(
+            self._user_mapping_provider, "get_extra_attributes", None
+        )
+        if get_extra_attributes:
+            extra_attributes = await get_extra_attributes(userinfo, token)
+
+        await self._sso_handler.complete_sso_login_request(
             self._auth_provider_id,
             remote_user_id,
-            user_agent,
-            ip_address,
+            request,
+            client_redirect_url,
             oidc_response_to_user_attributes,
             grandfather_existing_users,
+            extra_attributes,
         )
 
     def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 6001fe3e27..5fa7ab3f8b 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -58,8 +58,6 @@ class SamlHandler(BaseHandler):
         super().__init__(hs)
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
         self._saml_idp_entityid = hs.config.saml2_idp_entityid
-        self._auth_handler = hs.get_auth_handler()
-        self._registration_handler = hs.get_registration_handler()
 
         self._saml2_session_lifetime = hs.config.saml2_session_lifetime
         self._grandfathered_mxid_source_attribute = (
@@ -229,40 +227,29 @@ class SamlHandler(BaseHandler):
                 )
                 return
 
-        # Pull out the user-agent and IP from the request.
-        user_agent = request.get_user_agent("")
-        ip_address = self.hs.get_ip_from_request(request)
-
         # Call the mapper to register/login the user
         try:
-            user_id = await self._map_saml_response_to_user(
-                saml2_auth, relay_state, user_agent, ip_address
-            )
+            await self._complete_saml_login(saml2_auth, request, relay_state)
         except MappingException as e:
             logger.exception("Could not map user")
             self._sso_handler.render_error(request, "mapping_error", str(e))
-            return
 
-        await self._auth_handler.complete_sso_login(user_id, request, relay_state)
-
-    async def _map_saml_response_to_user(
+    async def _complete_saml_login(
         self,
         saml2_auth: saml2.response.AuthnResponse,
+        request: SynapseRequest,
         client_redirect_url: str,
-        user_agent: str,
-        ip_address: str,
-    ) -> str:
+    ) -> None:
         """
-        Given a SAML response, retrieve the user ID for it and possibly register the user.
+        Given a SAML response, complete the login flow
+
+        Retrieves the remote user ID, registers the user if necessary, and serves
+        a redirect back to the client with a login-token.
 
         Args:
             saml2_auth: The parsed SAML2 response.
+            request: The request to respond to
             client_redirect_url: The redirect URL passed in by the client.
-            user_agent: The user agent of the client making the request.
-            ip_address: The IP address of the client making the request.
-
-        Returns:
-             The user ID associated with this response.
 
         Raises:
             MappingException if there was a problem mapping the response to a user.
@@ -318,11 +305,11 @@ class SamlHandler(BaseHandler):
 
             return None
 
-        return await self._sso_handler.get_mxid_from_sso(
+        await self._sso_handler.complete_sso_login_request(
             self._auth_provider_id,
             remote_user_id,
-            user_agent,
-            ip_address,
+            request,
+            client_redirect_url,
             saml_response_to_remapped_user_attributes,
             grandfather_existing_users,
         )
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 112a7d5b2c..f054b66a53 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -21,7 +21,8 @@ from twisted.web.http import Request
 
 from synapse.api.errors import RedirectException
 from synapse.http.server import respond_with_html
-from synapse.types import UserID, contains_invalid_mxid_characters
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
 from synapse.util.async_helpers import Linearizer
 
 if TYPE_CHECKING:
@@ -119,15 +120,16 @@ class SsoHandler:
         # No match.
         return None
 
-    async def get_mxid_from_sso(
+    async def complete_sso_login_request(
         self,
         auth_provider_id: str,
         remote_user_id: str,
-        user_agent: str,
-        ip_address: str,
+        request: SynapseRequest,
+        client_redirect_url: str,
         sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
         grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
-    ) -> str:
+        extra_login_attributes: Optional[JsonDict] = None,
+    ) -> None:
         """
         Given an SSO ID, retrieve the user ID for it and possibly register the user.
 
@@ -146,12 +148,18 @@ class SsoHandler:
         given user-agent and IP address and the SSO ID is linked to this matrix
         ID for subsequent calls.
 
+        Finally, we generate a redirect to the supplied redirect uri, with a login token
+
         Args:
             auth_provider_id: A unique identifier for this SSO provider, e.g.
                 "oidc" or "saml".
+
             remote_user_id: The unique identifier from the SSO provider.
-            user_agent: The user agent of the client making the request.
-            ip_address: The IP address of the client making the request.
+
+            request: The request to respond to
+
+            client_redirect_url: The redirect URL passed in by the client.
+
             sso_to_matrix_id_mapper: A callable to generate the user attributes.
                 The only parameter is an integer which represents the amount of
                 times the returned mxid localpart mapping has failed.
@@ -163,12 +171,13 @@ class SsoHandler:
                         to the user.
                     RedirectException to redirect to an additional page (e.g.
                         to prompt the user for more information).
+
             grandfather_existing_users: A callable which can return an previously
                 existing matrix ID. The SSO ID is then linked to the returned
                 matrix ID.
 
-        Returns:
-             The user ID associated with the SSO response.
+            extra_login_attributes: An optional dictionary of extra
+                attributes to be provided to the client in the login response.
 
         Raises:
             MappingException if there was a problem mapping the response to a user.
@@ -181,28 +190,33 @@ class SsoHandler:
         # interstitial pages.
         with await self._mapping_lock.queue(auth_provider_id):
             # first of all, check if we already have a mapping for this user
-            previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
+            user_id = await self.get_sso_user_by_remote_user_id(
                 auth_provider_id, remote_user_id,
             )
-            if previously_registered_user_id:
-                return previously_registered_user_id
 
             # Check for grandfathering of users.
-            if grandfather_existing_users:
-                previously_registered_user_id = await grandfather_existing_users()
-                if previously_registered_user_id:
+            if not user_id and grandfather_existing_users:
+                user_id = await grandfather_existing_users()
+                if user_id:
                     # Future logins should also match this user ID.
                     await self._store.record_user_external_id(
-                        auth_provider_id, remote_user_id, previously_registered_user_id
+                        auth_provider_id, remote_user_id, user_id
                     )
-                    return previously_registered_user_id
 
             # Otherwise, generate a new user.
-            attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
-            user_id = await self._register_mapped_user(
-                attributes, auth_provider_id, remote_user_id, user_agent, ip_address,
-            )
-            return user_id
+            if not user_id:
+                attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
+                user_id = await self._register_mapped_user(
+                    attributes,
+                    auth_provider_id,
+                    remote_user_id,
+                    request.get_user_agent(""),
+                    request.getClientIP(),
+                )
+
+        await self._auth_handler.complete_sso_login(
+            user_id, request, client_redirect_url, extra_login_attributes
+        )
 
     async def _call_attribute_mapper(
         self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 69927cf6be..548038214b 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri"
+            "@test_user:test", request, "redirect_uri", None
         )
 
     @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, ""
+            "@test_user:test", request, "", None
         )
 
         # Subsequent calls should map to the same mxid.
@@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
             self.handler._handle_authn_response(request, saml_response, "")
         )
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, ""
+            "@test_user:test", request, "", None
         )
 
     def test_map_saml_response_to_invalid_localpart(self):
@@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", request, ""
+            "@test_user1:test", request, "", None
         )
         auth_handler.complete_sso_login.reset_mock()