diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 34de9109ea..78c4e94a9d 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
@@ -35,15 +36,10 @@ from twisted.web.client import readBody
from synapse.config import ConfigError
from synapse.handlers._base import BaseHandler
-from synapse.handlers.sso import MappingException
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
-from synapse.types import (
- JsonDict,
- UserID,
- contains_invalid_mxid_characters,
- map_username_to_mxid_localpart,
-)
+from synapse.types import JsonDict, map_username_to_mxid_localpart
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -869,73 +865,51 @@ class OidcHandler(BaseHandler):
# to be strings.
remote_user_id = str(remote_user_id)
- # first of all, check if we already have a mapping for this user
- previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
- self._auth_provider_id, remote_user_id,
+ # Older mapping providers don't accept the `failures` argument, so we
+ # try and detect support.
+ mapper_signature = inspect.signature(
+ self._user_mapping_provider.map_user_attributes
)
- if previously_registered_user_id:
- return previously_registered_user_id
+ supports_failures = "failures" in mapper_signature.parameters
- # Otherwise, generate a new user.
- try:
- attributes = await self._user_mapping_provider.map_user_attributes(
- userinfo, token
- )
- except Exception as e:
- raise MappingException(
- "Could not extract user attributes from OIDC response: " + str(e)
- )
+ async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
+ """
+ Call the mapping provider to map the OIDC userinfo and token to user attributes.
- logger.debug(
- "Retrieved user attributes from user mapping provider: %r", attributes
- )
+ This is backwards compatibility for abstraction for the SSO handler.
+ """
+ if supports_failures:
+ attributes = await self._user_mapping_provider.map_user_attributes(
+ userinfo, token, failures
+ )
+ else:
+ # If the mapping provider does not support processing failures,
+ # do not continually generate the same Matrix ID since it will
+ # continue to already be in use. Note that the error raised is
+ # arbitrary and will get turned into a MappingException.
+ if failures:
+ raise RuntimeError(
+ "Mapping provider does not support de-duplicating Matrix IDs"
+ )
- localpart = attributes["localpart"]
- if not localpart:
- raise MappingException(
- "Error parsing OIDC response: OIDC mapping provider plugin "
- "did not return a localpart value"
- )
+ attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
+ userinfo, token
+ )
- user_id = UserID(localpart, self.server_name).to_string()
- users = await self.store.get_users_by_id_case_insensitive(user_id)
- if users:
- if self._allow_existing_users:
- if len(users) == 1:
- registered_user_id = next(iter(users))
- elif user_id in users:
- registered_user_id = user_id
- else:
- raise MappingException(
- "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
- user_id, list(users.keys())
- )
- )
- else:
- # This mxid is taken
- raise MappingException("mxid '{}' is already taken".format(user_id))
- else:
- # Since the localpart is provided via a potentially untrusted module,
- # ensure the MXID is valid before registering.
- if contains_invalid_mxid_characters(localpart):
- raise MappingException("localpart is invalid: %s" % (localpart,))
-
- # It's the first time this user is logging in and the mapped mxid was
- # not taken, register the user
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=attributes["display_name"],
- user_agent_ips=[(user_agent, ip_address)],
- )
+ return UserAttributes(**attributes)
- await self.store.record_user_external_id(
- self._auth_provider_id, remote_user_id, registered_user_id,
+ return await self._sso_handler.get_mxid_from_sso(
+ self._auth_provider_id,
+ remote_user_id,
+ user_agent,
+ ip_address,
+ oidc_response_to_user_attributes,
+ self._allow_existing_users,
)
- return registered_user_id
-UserAttribute = TypedDict(
- "UserAttribute", {"localpart": str, "display_name": Optional[str]}
+UserAttributeDict = TypedDict(
+ "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
)
C = TypeVar("C")
@@ -978,13 +952,15 @@ class OidcMappingProvider(Generic[C]):
raise NotImplementedError()
async def map_user_attributes(
- self, userinfo: UserInfo, token: Token
- ) -> UserAttribute:
+ self, userinfo: UserInfo, token: Token, failures: int
+ ) -> UserAttributeDict:
"""Map a `UserInfo` object into user attributes.
Args:
userinfo: An object representing the user given by the OIDC provider
token: A dict with the tokens returned by the provider
+ failures: How many times a call to this function with this
+ UserInfo has resulted in a failure.
Returns:
A dict containing the ``localpart`` and (optionally) the ``display_name``
@@ -1084,13 +1060,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
return userinfo[self._config.subject_claim]
async def map_user_attributes(
- self, userinfo: UserInfo, token: Token
- ) -> UserAttribute:
+ self, userinfo: UserInfo, token: Token, failures: int
+ ) -> UserAttributeDict:
localpart = self._config.localpart_template.render(user=userinfo).strip()
# Ensure only valid characters are included in the MXID.
localpart = map_username_to_mxid_localpart(localpart)
+ # Append suffix integer if last call to this function failed to produce
+ # a usable mxid.
+ localpart += str(failures) if failures else ""
+
display_name = None # type: Optional[str]
if self._config.display_name_template is not None:
display_name = self._config.display_name_template.render(
@@ -1100,7 +1080,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if display_name == "":
display_name = None
- return UserAttribute(localpart=localpart, display_name=display_name)
+ return UserAttributeDict(localpart=localpart, display_name=display_name)
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str]
|