summary refs log tree commit diff
path: root/synapse/handlers/oidc_handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/oidc_handler.py')
-rw-r--r--synapse/handlers/oidc_handler.py120
1 files changed, 50 insertions, 70 deletions
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 34de9109ea..78c4e94a9d 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import inspect
 import logging
 from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
 from urllib.parse import urlencode
@@ -35,15 +36,10 @@ from twisted.web.client import readBody
 
 from synapse.config import ConfigError
 from synapse.handlers._base import BaseHandler
-from synapse.handlers.sso import MappingException
+from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
-from synapse.types import (
-    JsonDict,
-    UserID,
-    contains_invalid_mxid_characters,
-    map_username_to_mxid_localpart,
-)
+from synapse.types import JsonDict, map_username_to_mxid_localpart
 from synapse.util import json_decoder
 
 if TYPE_CHECKING:
@@ -869,73 +865,51 @@ class OidcHandler(BaseHandler):
         # to be strings.
         remote_user_id = str(remote_user_id)
 
-        # first of all, check if we already have a mapping for this user
-        previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
-            self._auth_provider_id, remote_user_id,
+        # Older mapping providers don't accept the `failures` argument, so we
+        # try and detect support.
+        mapper_signature = inspect.signature(
+            self._user_mapping_provider.map_user_attributes
         )
-        if previously_registered_user_id:
-            return previously_registered_user_id
+        supports_failures = "failures" in mapper_signature.parameters
 
-        # Otherwise, generate a new user.
-        try:
-            attributes = await self._user_mapping_provider.map_user_attributes(
-                userinfo, token
-            )
-        except Exception as e:
-            raise MappingException(
-                "Could not extract user attributes from OIDC response: " + str(e)
-            )
+        async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
+            """
+            Call the mapping provider to map the OIDC userinfo and token to user attributes.
 
-        logger.debug(
-            "Retrieved user attributes from user mapping provider: %r", attributes
-        )
+            This is backwards compatibility for abstraction for the SSO handler.
+            """
+            if supports_failures:
+                attributes = await self._user_mapping_provider.map_user_attributes(
+                    userinfo, token, failures
+                )
+            else:
+                # If the mapping provider does not support processing failures,
+                # do not continually generate the same Matrix ID since it will
+                # continue to already be in use. Note that the error raised is
+                # arbitrary and will get turned into a MappingException.
+                if failures:
+                    raise RuntimeError(
+                        "Mapping provider does not support de-duplicating Matrix IDs"
+                    )
 
-        localpart = attributes["localpart"]
-        if not localpart:
-            raise MappingException(
-                "Error parsing OIDC response: OIDC mapping provider plugin "
-                "did not return a localpart value"
-            )
+                attributes = await self._user_mapping_provider.map_user_attributes(  # type: ignore
+                    userinfo, token
+                )
 
-        user_id = UserID(localpart, self.server_name).to_string()
-        users = await self.store.get_users_by_id_case_insensitive(user_id)
-        if users:
-            if self._allow_existing_users:
-                if len(users) == 1:
-                    registered_user_id = next(iter(users))
-                elif user_id in users:
-                    registered_user_id = user_id
-                else:
-                    raise MappingException(
-                        "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
-                            user_id, list(users.keys())
-                        )
-                    )
-            else:
-                # This mxid is taken
-                raise MappingException("mxid '{}' is already taken".format(user_id))
-        else:
-            # Since the localpart is provided via a potentially untrusted module,
-            # ensure the MXID is valid before registering.
-            if contains_invalid_mxid_characters(localpart):
-                raise MappingException("localpart is invalid: %s" % (localpart,))
-
-            # It's the first time this user is logging in and the mapped mxid was
-            # not taken, register the user
-            registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart,
-                default_display_name=attributes["display_name"],
-                user_agent_ips=[(user_agent, ip_address)],
-            )
+            return UserAttributes(**attributes)
 
-        await self.store.record_user_external_id(
-            self._auth_provider_id, remote_user_id, registered_user_id,
+        return await self._sso_handler.get_mxid_from_sso(
+            self._auth_provider_id,
+            remote_user_id,
+            user_agent,
+            ip_address,
+            oidc_response_to_user_attributes,
+            self._allow_existing_users,
         )
-        return registered_user_id
 
 
-UserAttribute = TypedDict(
-    "UserAttribute", {"localpart": str, "display_name": Optional[str]}
+UserAttributeDict = TypedDict(
+    "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
 )
 C = TypeVar("C")
 
@@ -978,13 +952,15 @@ class OidcMappingProvider(Generic[C]):
         raise NotImplementedError()
 
     async def map_user_attributes(
-        self, userinfo: UserInfo, token: Token
-    ) -> UserAttribute:
+        self, userinfo: UserInfo, token: Token, failures: int
+    ) -> UserAttributeDict:
         """Map a `UserInfo` object into user attributes.
 
         Args:
             userinfo: An object representing the user given by the OIDC provider
             token: A dict with the tokens returned by the provider
+            failures: How many times a call to this function with this
+                UserInfo has resulted in a failure.
 
         Returns:
             A dict containing the ``localpart`` and (optionally) the ``display_name``
@@ -1084,13 +1060,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
         return userinfo[self._config.subject_claim]
 
     async def map_user_attributes(
-        self, userinfo: UserInfo, token: Token
-    ) -> UserAttribute:
+        self, userinfo: UserInfo, token: Token, failures: int
+    ) -> UserAttributeDict:
         localpart = self._config.localpart_template.render(user=userinfo).strip()
 
         # Ensure only valid characters are included in the MXID.
         localpart = map_username_to_mxid_localpart(localpart)
 
+        # Append suffix integer if last call to this function failed to produce
+        # a usable mxid.
+        localpart += str(failures) if failures else ""
+
         display_name = None  # type: Optional[str]
         if self._config.display_name_template is not None:
             display_name = self._config.display_name_template.render(
@@ -1100,7 +1080,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
             if display_name == "":
                 display_name = None
 
-        return UserAttribute(localpart=localpart, display_name=display_name)
+        return UserAttributeDict(localpart=localpart, display_name=display_name)
 
     async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
         extras = {}  # type: Dict[str, str]