summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/8855.feature1
-rw-r--r--synapse/handlers/oidc_handler.py30
-rw-r--r--synapse/handlers/saml_handler.py9
-rw-r--r--synapse/handlers/sso.py60
-rw-r--r--tests/handlers/test_oidc.py8
-rw-r--r--tests/handlers/test_saml.py34
6 files changed, 94 insertions, 48 deletions
diff --git a/changelog.d/8855.feature b/changelog.d/8855.feature
new file mode 100644
index 0000000000..77f7fe4e5d
--- /dev/null
+++ b/changelog.d/8855.feature
@@ -0,0 +1 @@
+Add support for re-trying generation of a localpart for OpenID Connect mapping providers.
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 78c4e94a9d..55c4377890 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -39,7 +39,7 @@ from synapse.handlers._base import BaseHandler
 from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
-from synapse.types import JsonDict, map_username_to_mxid_localpart
+from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
 from synapse.util import json_decoder
 
 if TYPE_CHECKING:
@@ -898,13 +898,39 @@ class OidcHandler(BaseHandler):
 
             return UserAttributes(**attributes)
 
+        async def grandfather_existing_users() -> Optional[str]:
+            if self._allow_existing_users:
+                # If allowing existing users we want to generate a single localpart
+                # and attempt to match it.
+                attributes = await oidc_response_to_user_attributes(failures=0)
+
+                user_id = UserID(attributes.localpart, self.server_name).to_string()
+                users = await self.store.get_users_by_id_case_insensitive(user_id)
+                if users:
+                    # If an existing matrix ID is returned, then use it.
+                    if len(users) == 1:
+                        previously_registered_user_id = next(iter(users))
+                    elif user_id in users:
+                        previously_registered_user_id = user_id
+                    else:
+                        # Do not attempt to continue generating Matrix IDs.
+                        raise MappingException(
+                            "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+                                user_id, users
+                            )
+                        )
+
+                    return previously_registered_user_id
+
+            return None
+
         return await self._sso_handler.get_mxid_from_sso(
             self._auth_provider_id,
             remote_user_id,
             user_agent,
             ip_address,
             oidc_response_to_user_attributes,
-            self._allow_existing_users,
+            grandfather_existing_users,
         )
 
 
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 7ffad7d8af..76d4169fe2 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -268,7 +268,7 @@ class SamlHandler(BaseHandler):
                 emails=result.get("emails", []),
             )
 
-        with (await self._mapping_lock.queue(self._auth_provider_id)):
+        async def grandfather_existing_users() -> Optional[str]:
             # backwards-compatibility hack: see if there is an existing user with a
             # suitable mapping from the uid
             if (
@@ -290,17 +290,18 @@ class SamlHandler(BaseHandler):
                 if users:
                     registered_user_id = list(users.keys())[0]
                     logger.info("Grandfathering mapping to %s", registered_user_id)
-                    await self.store.record_user_external_id(
-                        self._auth_provider_id, remote_user_id, registered_user_id
-                    )
                     return registered_user_id
 
+            return None
+
+        with (await self._mapping_lock.queue(self._auth_provider_id)):
             return await self._sso_handler.get_mxid_from_sso(
                 self._auth_provider_id,
                 remote_user_id,
                 user_agent,
                 ip_address,
                 saml_response_to_remapped_user_attributes,
+                grandfather_existing_users,
             )
 
     def expire_sessions(self):
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index d963082210..f42b90e1bc 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -116,7 +116,7 @@ class SsoHandler(BaseHandler):
         user_agent: str,
         ip_address: str,
         sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
-        allow_existing_users: bool = False,
+        grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
     ) -> str:
         """
         Given an SSO ID, retrieve the user ID for it and possibly register the user.
@@ -125,6 +125,10 @@ class SsoHandler(BaseHandler):
         if it has that matrix ID is returned regardless of the current mapping
         logic.
 
+        If a callable is provided for grandfathering users, it is called and can
+        potentially return a matrix ID to use. If it does, the SSO ID is linked to
+        this matrix ID for subsequent calls.
+
         The mapping function is called (potentially multiple times) to generate
         a localpart for the user.
 
@@ -132,17 +136,6 @@ class SsoHandler(BaseHandler):
         given user-agent and IP address and the SSO ID is linked to this matrix
         ID for subsequent calls.
 
-        If allow_existing_users is true the mapping function is only called once
-        and results in:
-
-            1. The use of a previously registered matrix ID. In this case, the
-               SSO ID is linked to the matrix ID. (Note it is possible that
-               other SSO IDs are linked to the same matrix ID.)
-            2. An unused localpart, in which case the user is registered (as
-               discussed above).
-            3. An error if the generated localpart matches multiple pre-existing
-               matrix IDs. Generally this should not happen.
-
         Args:
             auth_provider_id: A unique identifier for this SSO provider, e.g.
                 "oidc" or "saml".
@@ -152,8 +145,9 @@ class SsoHandler(BaseHandler):
             sso_to_matrix_id_mapper: A callable to generate the user attributes.
                 The only parameter is an integer which represents the amount of
                 times the returned mxid localpart mapping has failed.
-            allow_existing_users: True if the localpart returned from the
-                mapping provider can be linked to an existing matrix ID.
+            grandfather_existing_users: A callable which can return an previously
+                existing matrix ID. The SSO ID is then linked to the returned
+                matrix ID.
 
         Returns:
              The user ID associated with the SSO response.
@@ -171,6 +165,16 @@ class SsoHandler(BaseHandler):
         if previously_registered_user_id:
             return previously_registered_user_id
 
+        # Check for grandfathering of users.
+        if grandfather_existing_users:
+            previously_registered_user_id = await grandfather_existing_users()
+            if previously_registered_user_id:
+                # Future logins should also match this user ID.
+                await self.store.record_user_external_id(
+                    auth_provider_id, remote_user_id, previously_registered_user_id
+                )
+                return previously_registered_user_id
+
         # Otherwise, generate a new user.
         for i in range(self._MAP_USERNAME_RETRIES):
             try:
@@ -194,33 +198,7 @@ class SsoHandler(BaseHandler):
 
             # Check if this mxid already exists
             user_id = UserID(attributes.localpart, self.server_name).to_string()
-            users = await self.store.get_users_by_id_case_insensitive(user_id)
-            # Note, if allow_existing_users is true then the loop is guaranteed
-            # to end on the first iteration: either by matching an existing user,
-            # raising an error, or registering a new user. See the docstring for
-            # more in-depth an explanation.
-            if users and allow_existing_users:
-                # If an existing matrix ID is returned, then use it.
-                if len(users) == 1:
-                    previously_registered_user_id = next(iter(users))
-                elif user_id in users:
-                    previously_registered_user_id = user_id
-                else:
-                    # Do not attempt to continue generating Matrix IDs.
-                    raise MappingException(
-                        "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
-                            user_id, users
-                        )
-                    )
-
-                # Future logins should also match this user ID.
-                await self.store.record_user_external_id(
-                    auth_provider_id, remote_user_id, previously_registered_user_id
-                )
-
-                return previously_registered_user_id
-
-            elif not users:
+            if not await self.store.get_users_by_id_case_insensitive(user_id):
                 # This mxid is free
                 break
         else:
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index c9807a7b73..d485af52fd 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -731,6 +731,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
         self.assertEqual(mxid, "@test_user:test")
 
+        # Subsequent calls should map to the same mxid.
+        mxid = self.get_success(
+            self.handler._map_userinfo_to_user(
+                userinfo, token, "user-agent", "10.10.10.10"
+            )
+        )
+        self.assertEqual(mxid, "@test_user:test")
+
         # Note that a second SSO user can be mapped to the same Matrix ID. (This
         # requires a unique sub, but something that maps to the same matrix ID,
         # in this case we'll just use the same username. A more realistic example
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 79fd47036f..e1e13a5faf 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -16,7 +16,7 @@ import attr
 
 from synapse.handlers.sso import MappingException
 
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
 
 # These are a few constants that are used as config parameters in the tests.
 BASE_URL = "https://synapse/"
@@ -59,6 +59,10 @@ class SamlHandlerTestCase(HomeserverTestCase):
             "grandfathered_mxid_source_attribute": None,
             "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
         }
+
+        # Update this config with what's in the default config so that
+        # override_config works as expected.
+        saml_config.update(config.get("saml2_config", {}))
         config["saml2_config"] = saml_config
 
         return config
@@ -86,6 +90,34 @@ class SamlHandlerTestCase(HomeserverTestCase):
         )
         self.assertEqual(mxid, "@test_user:test")
 
+    @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
+    def test_map_saml_response_to_existing_user(self):
+        """Existing users can log in with SAML account."""
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.register_user(user_id="@test_user:test", password_hash=None)
+        )
+
+        # Map a user via SSO.
+        saml_response = FakeAuthnResponse(
+            {"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
+        )
+        redirect_url = ""
+        mxid = self.get_success(
+            self.handler._map_saml_response_to_user(
+                saml_response, redirect_url, "user-agent", "10.10.10.10"
+            )
+        )
+        self.assertEqual(mxid, "@test_user:test")
+
+        # Subsequent calls should map to the same mxid.
+        mxid = self.get_success(
+            self.handler._map_saml_response_to_user(
+                saml_response, redirect_url, "user-agent", "10.10.10.10"
+            )
+        )
+        self.assertEqual(mxid, "@test_user:test")
+
     def test_map_saml_response_to_invalid_localpart(self):
         """If the mapping provider generates an invalid localpart it should be rejected."""
         saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})