diff options
Diffstat (limited to 'synapse/handlers/oidc_handler.py')
-rw-r--r-- | synapse/handlers/oidc_handler.py | 120 |
1 files changed, 50 insertions, 70 deletions
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 34de9109ea..78c4e94a9d 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar from urllib.parse import urlencode @@ -35,15 +36,10 @@ from twisted.web.client import readBody from synapse.config import ConfigError from synapse.handlers._base import BaseHandler -from synapse.handlers.sso import MappingException +from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable -from synapse.types import ( - JsonDict, - UserID, - contains_invalid_mxid_characters, - map_username_to_mxid_localpart, -) +from synapse.types import JsonDict, map_username_to_mxid_localpart from synapse.util import json_decoder if TYPE_CHECKING: @@ -869,73 +865,51 @@ class OidcHandler(BaseHandler): # to be strings. remote_user_id = str(remote_user_id) - # first of all, check if we already have a mapping for this user - previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id( - self._auth_provider_id, remote_user_id, + # Older mapping providers don't accept the `failures` argument, so we + # try and detect support. + mapper_signature = inspect.signature( + self._user_mapping_provider.map_user_attributes ) - if previously_registered_user_id: - return previously_registered_user_id + supports_failures = "failures" in mapper_signature.parameters - # Otherwise, generate a new user. - try: - attributes = await self._user_mapping_provider.map_user_attributes( - userinfo, token - ) - except Exception as e: - raise MappingException( - "Could not extract user attributes from OIDC response: " + str(e) - ) + async def oidc_response_to_user_attributes(failures: int) -> UserAttributes: + """ + Call the mapping provider to map the OIDC userinfo and token to user attributes. - logger.debug( - "Retrieved user attributes from user mapping provider: %r", attributes - ) + This is backwards compatibility for abstraction for the SSO handler. + """ + if supports_failures: + attributes = await self._user_mapping_provider.map_user_attributes( + userinfo, token, failures + ) + else: + # If the mapping provider does not support processing failures, + # do not continually generate the same Matrix ID since it will + # continue to already be in use. Note that the error raised is + # arbitrary and will get turned into a MappingException. + if failures: + raise RuntimeError( + "Mapping provider does not support de-duplicating Matrix IDs" + ) - localpart = attributes["localpart"] - if not localpart: - raise MappingException( - "Error parsing OIDC response: OIDC mapping provider plugin " - "did not return a localpart value" - ) + attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore + userinfo, token + ) - user_id = UserID(localpart, self.server_name).to_string() - users = await self.store.get_users_by_id_case_insensitive(user_id) - if users: - if self._allow_existing_users: - if len(users) == 1: - registered_user_id = next(iter(users)) - elif user_id in users: - registered_user_id = user_id - else: - raise MappingException( - "Attempted to login as '{}' but it matches more than one user inexactly: {}".format( - user_id, list(users.keys()) - ) - ) - else: - # This mxid is taken - raise MappingException("mxid '{}' is already taken".format(user_id)) - else: - # Since the localpart is provided via a potentially untrusted module, - # ensure the MXID is valid before registering. - if contains_invalid_mxid_characters(localpart): - raise MappingException("localpart is invalid: %s" % (localpart,)) - - # It's the first time this user is logging in and the mapped mxid was - # not taken, register the user - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=attributes["display_name"], - user_agent_ips=[(user_agent, ip_address)], - ) + return UserAttributes(**attributes) - await self.store.record_user_external_id( - self._auth_provider_id, remote_user_id, registered_user_id, + return await self._sso_handler.get_mxid_from_sso( + self._auth_provider_id, + remote_user_id, + user_agent, + ip_address, + oidc_response_to_user_attributes, + self._allow_existing_users, ) - return registered_user_id -UserAttribute = TypedDict( - "UserAttribute", {"localpart": str, "display_name": Optional[str]} +UserAttributeDict = TypedDict( + "UserAttributeDict", {"localpart": str, "display_name": Optional[str]} ) C = TypeVar("C") @@ -978,13 +952,15 @@ class OidcMappingProvider(Generic[C]): raise NotImplementedError() async def map_user_attributes( - self, userinfo: UserInfo, token: Token - ) -> UserAttribute: + self, userinfo: UserInfo, token: Token, failures: int + ) -> UserAttributeDict: """Map a `UserInfo` object into user attributes. Args: userinfo: An object representing the user given by the OIDC provider token: A dict with the tokens returned by the provider + failures: How many times a call to this function with this + UserInfo has resulted in a failure. Returns: A dict containing the ``localpart`` and (optionally) the ``display_name`` @@ -1084,13 +1060,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): return userinfo[self._config.subject_claim] async def map_user_attributes( - self, userinfo: UserInfo, token: Token - ) -> UserAttribute: + self, userinfo: UserInfo, token: Token, failures: int + ) -> UserAttributeDict: localpart = self._config.localpart_template.render(user=userinfo).strip() # Ensure only valid characters are included in the MXID. localpart = map_username_to_mxid_localpart(localpart) + # Append suffix integer if last call to this function failed to produce + # a usable mxid. + localpart += str(failures) if failures else "" + display_name = None # type: Optional[str] if self._config.display_name_template is not None: display_name = self._config.display_name_template.render( @@ -1100,7 +1080,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): if display_name == "": display_name = None - return UserAttribute(localpart=localpart, display_name=display_name) + return UserAttributeDict(localpart=localpart, display_name=display_name) async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: extras = {} # type: Dict[str, str] |