diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 0322b60cfc..00eae92052 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -137,6 +137,15 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
}
+@attr.s(slots=True)
+class SsoLoginExtraAttributes:
+ """Data we track about SAML2 sessions"""
+
+ # time the session was created, in milliseconds
+ creation_time = attr.ib(type=int)
+ extra_attributes = attr.ib(type=JsonDict)
+
+
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
@@ -239,6 +248,10 @@ class AuthHandler(BaseHandler):
# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
+ # A mapping of user ID to extra attributes to include in the login
+ # response.
+ self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes]
+
async def validate_user_via_ui_auth(
self,
requester: Requester,
@@ -1165,6 +1178,7 @@ class AuthHandler(BaseHandler):
registered_user_id: str,
request: SynapseRequest,
client_redirect_url: str,
+ extra_attributes: Optional[JsonDict] = None,
):
"""Having figured out a mxid for this user, complete the HTTP request
@@ -1173,6 +1187,8 @@ class AuthHandler(BaseHandler):
request: The request to complete.
client_redirect_url: The URL to which to redirect the user at the end of the
process.
+ extra_attributes: Extra attributes which will be passed to the client
+ during successful login. Must be JSON serializable.
"""
# If the account has been deactivated, do not proceed with the login
# flow.
@@ -1181,19 +1197,30 @@ class AuthHandler(BaseHandler):
respond_with_html(request, 403, self._sso_account_deactivated_template)
return
- self._complete_sso_login(registered_user_id, request, client_redirect_url)
+ self._complete_sso_login(
+ registered_user_id, request, client_redirect_url, extra_attributes
+ )
def _complete_sso_login(
self,
registered_user_id: str,
request: SynapseRequest,
client_redirect_url: str,
+ extra_attributes: Optional[JsonDict] = None,
):
"""
The synchronous portion of complete_sso_login.
This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
"""
+ # Store any extra attributes which will be passed in the login response.
+ # Note that this is per-user so it may overwrite a previous value, this
+ # is considered OK since the newest SSO attributes should be most valid.
+ if extra_attributes:
+ self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes(
+ self._clock.time_msec(), extra_attributes,
+ )
+
# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id
@@ -1226,6 +1253,37 @@ class AuthHandler(BaseHandler):
)
respond_with_html(request, 200, html)
+ async def _sso_login_callback(self, login_result: JsonDict) -> None:
+ """
+ A login callback which might add additional attributes to the login response.
+
+ Args:
+ login_result: The data to be sent to the client. Includes the user
+ ID and access token.
+ """
+ # Expire attributes before processing. Note that there shouldn't be any
+ # valid logins that still have extra attributes.
+ self._expire_sso_extra_attributes()
+
+ extra_attributes = self._extra_attributes.get(login_result["user_id"])
+ if extra_attributes:
+ login_result.update(extra_attributes.extra_attributes)
+
+ def _expire_sso_extra_attributes(self) -> None:
+ """
+ Iterate through the mapping of user IDs to extra attributes and remove any that are no longer valid.
+ """
+ # TODO This should match the amount of time the macaroon is valid for.
+ LOGIN_TOKEN_EXPIRATION_TIME = 2 * 60 * 1000
+ expire_before = self._clock.time_msec() - LOGIN_TOKEN_EXPIRATION_TIME
+ to_expire = set()
+ for user_id, data in self._extra_attributes.items():
+ if data.creation_time < expire_before:
+ to_expire.add(user_id)
+ for user_id in to_expire:
+ logger.debug("Expiring extra attributes for user %s", user_id)
+ del self._extra_attributes[user_id]
+
@staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any):
url_parts = list(urllib.parse.urlparse(url))
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 0e06e4408d..19cd652675 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -37,7 +37,7 @@ from synapse.config import ConfigError
from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -707,6 +707,15 @@ class OidcHandler:
self._render_error(request, "mapping_error", str(e))
return
+ # Mapping providers might not have get_extra_attributes: only call this
+ # method if it exists.
+ extra_attributes = None
+ get_extra_attributes = getattr(
+ self._user_mapping_provider, "get_extra_attributes", None
+ )
+ if get_extra_attributes:
+ extra_attributes = await get_extra_attributes(userinfo, token)
+
# and finally complete the login
if ui_auth_session_id:
await self._auth_handler.complete_sso_ui_auth(
@@ -714,7 +723,7 @@ class OidcHandler:
)
else:
await self._auth_handler.complete_sso_login(
- user_id, request, client_redirect_url
+ user_id, request, client_redirect_url, extra_attributes
)
def _generate_oidc_session_token(
@@ -984,7 +993,7 @@ class OidcMappingProvider(Generic[C]):
async def map_user_attributes(
self, userinfo: UserInfo, token: Token
) -> UserAttribute:
- """Map a ``UserInfo`` objects into user attributes.
+ """Map a `UserInfo` object into user attributes.
Args:
userinfo: An object representing the user given by the OIDC provider
@@ -995,6 +1004,18 @@ class OidcMappingProvider(Generic[C]):
"""
raise NotImplementedError()
+ async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
+ """Map a `UserInfo` object into additional attributes passed to the client during login.
+
+ Args:
+ userinfo: An object representing the user given by the OIDC provider
+ token: A dict with the tokens returned by the provider
+
+ Returns:
+ A dict containing additional attributes. Must be JSON serializable.
+ """
+ return {}
+
# Used to clear out "None" values in templates
def jinja_finalize(thing):
@@ -1009,6 +1030,7 @@ class JinjaOidcMappingConfig:
subject_claim = attr.ib() # type: str
localpart_template = attr.ib() # type: Template
display_name_template = attr.ib() # type: Optional[Template]
+ extra_attributes = attr.ib() # type: Dict[str, Template]
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@@ -1047,10 +1069,28 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
% (e,)
)
+ extra_attributes = {} # type Dict[str, Template]
+ if "extra_attributes" in config:
+ extra_attributes_config = config.get("extra_attributes") or {}
+ if not isinstance(extra_attributes_config, dict):
+ raise ConfigError(
+ "oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
+ )
+
+ for key, value in extra_attributes_config.items():
+ try:
+ extra_attributes[key] = env.from_string(value)
+ except Exception as e:
+ raise ConfigError(
+ "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
+ % (key, e)
+ )
+
return JinjaOidcMappingConfig(
subject_claim=subject_claim,
localpart_template=localpart_template,
display_name_template=display_name_template,
+ extra_attributes=extra_attributes,
)
def get_remote_user_id(self, userinfo: UserInfo) -> str:
@@ -1071,3 +1111,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
display_name = None
return UserAttribute(localpart=localpart, display_name=display_name)
+
+ async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
+ extras = {} # type: Dict[str, str]
+ for key, template in self._config.extra_attributes.items():
+ try:
+ extras[key] = template.render(user=userinfo).strip()
+ except Exception as e:
+ # Log an error and skip this value (don't break login for this).
+ logger.error("Failed to render OIDC extra attribute %s: %s" % (key, e))
+ return extras
|