summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/8900.feature1
-rw-r--r--synapse/handlers/saml_handler.py21
-rw-r--r--synapse/handlers/sso.py57
3 files changed, 51 insertions, 28 deletions
diff --git a/changelog.d/8900.feature b/changelog.d/8900.feature
new file mode 100644
index 0000000000..d450ef4998
--- /dev/null
+++ b/changelog.d/8900.feature
@@ -0,0 +1 @@
+Add support for allowing users to pick their own user ID during a single-sign-on login.
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 5846f08609..f2ca1ddb53 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -34,7 +34,6 @@ from synapse.types import (
     map_username_to_mxid_localpart,
     mxid_localpart_allowed_characters,
 )
-from synapse.util.async_helpers import Linearizer
 from synapse.util.iterutils import chunk_seq
 
 if TYPE_CHECKING:
@@ -81,9 +80,6 @@ class SamlHandler(BaseHandler):
         # a map from saml session id to Saml2SessionData object
         self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
 
-        # a lock on the mappings
-        self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
-
         self._sso_handler = hs.get_sso_handler()
 
     def handle_redirect_request(
@@ -299,15 +295,14 @@ class SamlHandler(BaseHandler):
 
             return None
 
-        with (await self._mapping_lock.queue(self._auth_provider_id)):
-            return await self._sso_handler.get_mxid_from_sso(
-                self._auth_provider_id,
-                remote_user_id,
-                user_agent,
-                ip_address,
-                saml_response_to_remapped_user_attributes,
-                grandfather_existing_users,
-            )
+        return await self._sso_handler.get_mxid_from_sso(
+            self._auth_provider_id,
+            remote_user_id,
+            user_agent,
+            ip_address,
+            saml_response_to_remapped_user_attributes,
+            grandfather_existing_users,
+        )
 
     def _remote_id_from_saml_response(
         self,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index e24767b921..112a7d5b2c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -22,6 +22,7 @@ from twisted.web.http import Request
 from synapse.api.errors import RedirectException
 from synapse.http.server import respond_with_html
 from synapse.types import UserID, contains_invalid_mxid_characters
+from synapse.util.async_helpers import Linearizer
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -54,6 +55,9 @@ class SsoHandler:
         self._error_template = hs.config.sso_error_template
         self._auth_handler = hs.get_auth_handler()
 
+        # a lock on the mappings
+        self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
+
     def render_error(
         self, request, error: str, error_description: Optional[str] = None
     ) -> None:
@@ -172,24 +176,38 @@ class SsoHandler:
                 to an additional page. (e.g. to prompt for more information)
 
         """
-        # first of all, check if we already have a mapping for this user
-        previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
-            auth_provider_id, remote_user_id,
-        )
-        if previously_registered_user_id:
-            return previously_registered_user_id
-
-        # Check for grandfathering of users.
-        if grandfather_existing_users:
-            previously_registered_user_id = await grandfather_existing_users()
+        # grab a lock while we try to find a mapping for this user. This seems...
+        # optimistic, especially for implementations that end up redirecting to
+        # interstitial pages.
+        with await self._mapping_lock.queue(auth_provider_id):
+            # first of all, check if we already have a mapping for this user
+            previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
+                auth_provider_id, remote_user_id,
+            )
             if previously_registered_user_id:
-                # Future logins should also match this user ID.
-                await self._store.record_user_external_id(
-                    auth_provider_id, remote_user_id, previously_registered_user_id
-                )
                 return previously_registered_user_id
 
-        # Otherwise, generate a new user.
+            # Check for grandfathering of users.
+            if grandfather_existing_users:
+                previously_registered_user_id = await grandfather_existing_users()
+                if previously_registered_user_id:
+                    # Future logins should also match this user ID.
+                    await self._store.record_user_external_id(
+                        auth_provider_id, remote_user_id, previously_registered_user_id
+                    )
+                    return previously_registered_user_id
+
+            # Otherwise, generate a new user.
+            attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
+            user_id = await self._register_mapped_user(
+                attributes, auth_provider_id, remote_user_id, user_agent, ip_address,
+            )
+            return user_id
+
+    async def _call_attribute_mapper(
+        self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+    ) -> UserAttributes:
+        """Call the attribute mapper function in a loop, until we get a unique userid"""
         for i in range(self._MAP_USERNAME_RETRIES):
             try:
                 attributes = await sso_to_matrix_id_mapper(i)
@@ -227,7 +245,16 @@ class SsoHandler:
             raise MappingException(
                 "Unable to generate a Matrix ID from the SSO response"
             )
+        return attributes
 
+    async def _register_mapped_user(
+        self,
+        attributes: UserAttributes,
+        auth_provider_id: str,
+        remote_user_id: str,
+        user_agent: str,
+        ip_address: str,
+    ) -> str:
         # Since the localpart is provided via a potentially untrusted module,
         # ensure the MXID is valid before registering.
         if contains_invalid_mxid_characters(attributes.localpart):