diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 4230dbaf99..05ac86e697 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -37,7 +37,7 @@ from synapse.config import ConfigError
from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -96,6 +96,7 @@ class OidcHandler:
self.hs = hs
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
+ self._user_profile_method = hs.config.oidc_user_profile_method # type: str
self._client_auth = ClientAuth(
hs.config.oidc_client_id,
hs.config.oidc_client_secret,
@@ -114,6 +115,7 @@ class OidcHandler:
hs.config.oidc_user_mapping_provider_config
) # type: OidcMappingProvider
self._skip_verification = hs.config.oidc_skip_verification # type: bool
+ self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
@@ -195,11 +197,11 @@ class OidcHandler:
% (m["response_types_supported"],)
)
- # If the openid scope was not requested, we need a userinfo endpoint to fetch user infos
+ # Ensure there's a userinfo endpoint to fetch from if it is required.
if self._uses_userinfo:
if m.get("userinfo_endpoint") is None:
raise ValueError(
- 'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested'
+ 'provider has no "userinfo_endpoint", even though it is required'
)
else:
# If we're not using userinfo, we need a valid jwks to validate the ID token
@@ -219,8 +221,10 @@ class OidcHandler:
``access_token`` with the ``userinfo_endpoint``.
"""
- # Maybe that should be user-configurable and not inferred?
- return "openid" not in self._scopes
+ return (
+ "openid" not in self._scopes
+ or self._user_profile_method == "userinfo_endpoint"
+ )
async def load_metadata(self) -> OpenIDProviderMetadata:
"""Load and validate the provider metadata.
@@ -706,6 +710,15 @@ class OidcHandler:
self._render_error(request, "mapping_error", str(e))
return
+ # Mapping providers might not have get_extra_attributes: only call this
+ # method if it exists.
+ extra_attributes = None
+ get_extra_attributes = getattr(
+ self._user_mapping_provider, "get_extra_attributes", None
+ )
+ if get_extra_attributes:
+ extra_attributes = await get_extra_attributes(userinfo, token)
+
# and finally complete the login
if ui_auth_session_id:
await self._auth_handler.complete_sso_ui_auth(
@@ -713,7 +726,7 @@ class OidcHandler:
)
else:
await self._auth_handler.complete_sso_login(
- user_id, request, client_redirect_url
+ user_id, request, client_redirect_url, extra_attributes
)
def _generate_oidc_session_token(
@@ -849,7 +862,8 @@ class OidcHandler:
If we don't find the user that way, we should register the user,
mapping the localpart and the display name from the UserInfo.
- If a user already exists with the mxid we've mapped, raise an exception.
+ If a user already exists with the mxid we've mapped and allow_existing_users
+ is disabled, raise an exception.
Args:
userinfo: an object representing the user
@@ -905,21 +919,31 @@ class OidcHandler:
localpart = map_username_to_mxid_localpart(attributes["localpart"])
- user_id = UserID(localpart, self._hostname)
- if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()):
- # This mxid is taken
- raise MappingException(
- "mxid '{}' is already taken".format(user_id.to_string())
+ user_id = UserID(localpart, self._hostname).to_string()
+ users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+ if users:
+ if self._allow_existing_users:
+ if len(users) == 1:
+ registered_user_id = next(iter(users))
+ elif user_id in users:
+ registered_user_id = user_id
+ else:
+ raise MappingException(
+ "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+ user_id, list(users.keys())
+ )
+ )
+ else:
+ # This mxid is taken
+ raise MappingException("mxid '{}' is already taken".format(user_id))
+ else:
+ # It's the first time this user is logging in and the mapped mxid was
+ # not taken, register the user
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=localpart,
+ default_display_name=attributes["display_name"],
+ user_agent_ips=(user_agent, ip_address),
)
-
- # It's the first time this user is logging in and the mapped mxid was
- # not taken, register the user
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=attributes["display_name"],
- user_agent_ips=(user_agent, ip_address),
- )
-
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id,
)
@@ -972,7 +996,7 @@ class OidcMappingProvider(Generic[C]):
async def map_user_attributes(
self, userinfo: UserInfo, token: Token
) -> UserAttribute:
- """Map a ``UserInfo`` objects into user attributes.
+ """Map a `UserInfo` object into user attributes.
Args:
userinfo: An object representing the user given by the OIDC provider
@@ -983,6 +1007,18 @@ class OidcMappingProvider(Generic[C]):
"""
raise NotImplementedError()
+ async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
+ """Map a `UserInfo` object into additional attributes passed to the client during login.
+
+ Args:
+ userinfo: An object representing the user given by the OIDC provider
+ token: A dict with the tokens returned by the provider
+
+ Returns:
+ A dict containing additional attributes. Must be JSON serializable.
+ """
+ return {}
+
# Used to clear out "None" values in templates
def jinja_finalize(thing):
@@ -997,6 +1033,7 @@ class JinjaOidcMappingConfig:
subject_claim = attr.ib() # type: str
localpart_template = attr.ib() # type: Template
display_name_template = attr.ib() # type: Optional[Template]
+ extra_attributes = attr.ib() # type: Dict[str, Template]
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@@ -1035,10 +1072,28 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
% (e,)
)
+ extra_attributes = {} # type Dict[str, Template]
+ if "extra_attributes" in config:
+ extra_attributes_config = config.get("extra_attributes") or {}
+ if not isinstance(extra_attributes_config, dict):
+ raise ConfigError(
+ "oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
+ )
+
+ for key, value in extra_attributes_config.items():
+ try:
+ extra_attributes[key] = env.from_string(value)
+ except Exception as e:
+ raise ConfigError(
+ "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
+ % (key, e)
+ )
+
return JinjaOidcMappingConfig(
subject_claim=subject_claim,
localpart_template=localpart_template,
display_name_template=display_name_template,
+ extra_attributes=extra_attributes,
)
def get_remote_user_id(self, userinfo: UserInfo) -> str:
@@ -1059,3 +1114,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
display_name = None
return UserAttribute(localpart=localpart, display_name=display_name)
+
+ async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
+ extras = {} # type: Dict[str, str]
+ for key, template in self._config.extra_attributes.items():
+ try:
+ extras[key] = template.render(user=userinfo).strip()
+ except Exception as e:
+ # Log an error and skip this value (don't break login for this).
+ logger.error("Failed to render OIDC extra attribute %s: %s" % (key, e))
+ return extras
|