summary refs log tree commit diff
path: root/synapse/handlers/sso.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/sso.py')
-rw-r--r--synapse/handlers/sso.py57
1 files changed, 42 insertions, 15 deletions
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index e24767b921..112a7d5b2c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -22,6 +22,7 @@ from twisted.web.http import Request
 from synapse.api.errors import RedirectException
 from synapse.http.server import respond_with_html
 from synapse.types import UserID, contains_invalid_mxid_characters
+from synapse.util.async_helpers import Linearizer
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -54,6 +55,9 @@ class SsoHandler:
         self._error_template = hs.config.sso_error_template
         self._auth_handler = hs.get_auth_handler()
 
+        # a lock on the mappings
+        self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
+
     def render_error(
         self, request, error: str, error_description: Optional[str] = None
     ) -> None:
@@ -172,24 +176,38 @@ class SsoHandler:
                 to an additional page. (e.g. to prompt for more information)
 
         """
-        # first of all, check if we already have a mapping for this user
-        previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
-            auth_provider_id, remote_user_id,
-        )
-        if previously_registered_user_id:
-            return previously_registered_user_id
-
-        # Check for grandfathering of users.
-        if grandfather_existing_users:
-            previously_registered_user_id = await grandfather_existing_users()
+        # grab a lock while we try to find a mapping for this user. This seems...
+        # optimistic, especially for implementations that end up redirecting to
+        # interstitial pages.
+        with await self._mapping_lock.queue(auth_provider_id):
+            # first of all, check if we already have a mapping for this user
+            previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
+                auth_provider_id, remote_user_id,
+            )
             if previously_registered_user_id:
-                # Future logins should also match this user ID.
-                await self._store.record_user_external_id(
-                    auth_provider_id, remote_user_id, previously_registered_user_id
-                )
                 return previously_registered_user_id
 
-        # Otherwise, generate a new user.
+            # Check for grandfathering of users.
+            if grandfather_existing_users:
+                previously_registered_user_id = await grandfather_existing_users()
+                if previously_registered_user_id:
+                    # Future logins should also match this user ID.
+                    await self._store.record_user_external_id(
+                        auth_provider_id, remote_user_id, previously_registered_user_id
+                    )
+                    return previously_registered_user_id
+
+            # Otherwise, generate a new user.
+            attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
+            user_id = await self._register_mapped_user(
+                attributes, auth_provider_id, remote_user_id, user_agent, ip_address,
+            )
+            return user_id
+
+    async def _call_attribute_mapper(
+        self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+    ) -> UserAttributes:
+        """Call the attribute mapper function in a loop, until we get a unique userid"""
         for i in range(self._MAP_USERNAME_RETRIES):
             try:
                 attributes = await sso_to_matrix_id_mapper(i)
@@ -227,7 +245,16 @@ class SsoHandler:
             raise MappingException(
                 "Unable to generate a Matrix ID from the SSO response"
             )
+        return attributes
 
+    async def _register_mapped_user(
+        self,
+        attributes: UserAttributes,
+        auth_provider_id: str,
+        remote_user_id: str,
+        user_agent: str,
+        ip_address: str,
+    ) -> str:
         # Since the localpart is provided via a potentially untrusted module,
         # ensure the MXID is valid before registering.
         if contains_invalid_mxid_characters(attributes.localpart):