diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 96ccd991ed..514b1f69d8 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -16,10 +16,12 @@ import abc
import logging
from typing import (
TYPE_CHECKING,
+ Any,
Awaitable,
Callable,
Dict,
Iterable,
+ List,
Mapping,
Optional,
Set,
@@ -34,6 +36,7 @@ from twisted.web.iweb import IRequest
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
+from synapse.config.sso import SsoAttributeRequirement
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html, respond_with_redirect
@@ -324,7 +327,8 @@ class SsoHandler:
# Check if we already have a mapping for this user.
previously_registered_user_id = await self._store.get_user_by_external_id(
- auth_provider_id, remote_user_id,
+ auth_provider_id,
+ remote_user_id,
)
# A match was found, return the user ID.
@@ -413,7 +417,8 @@ class SsoHandler:
with await self._mapping_lock.queue(auth_provider_id):
# first of all, check if we already have a mapping for this user
user_id = await self.get_sso_user_by_remote_user_id(
- auth_provider_id, remote_user_id,
+ auth_provider_id,
+ remote_user_id,
)
# Check for grandfathering of users.
@@ -458,7 +463,8 @@ class SsoHandler:
)
async def _call_attribute_mapper(
- self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+ self,
+ sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
) -> UserAttributes:
"""Call the attribute mapper function in a loop, until we get a unique userid"""
for i in range(self._MAP_USERNAME_RETRIES):
@@ -629,7 +635,8 @@ class SsoHandler:
"""
user_id = await self.get_sso_user_by_remote_user_id(
- auth_provider_id, remote_user_id,
+ auth_provider_id,
+ remote_user_id,
)
user_id_to_verify = await self._auth_handler.get_session_data(
@@ -668,7 +675,8 @@ class SsoHandler:
# render an error page.
html = self._bad_user_template.render(
- server_name=self._server_name, user_id_to_verify=user_id_to_verify,
+ server_name=self._server_name,
+ user_id_to_verify=user_id_to_verify,
)
respond_with_html(request, 200, html)
@@ -692,7 +700,9 @@ class SsoHandler:
raise SynapseError(400, "unknown session")
async def check_username_availability(
- self, localpart: str, session_id: str,
+ self,
+ localpart: str,
+ session_id: str,
) -> bool:
"""Handle an "is username available" callback check
@@ -830,7 +840,8 @@ class SsoHandler:
)
attributes = UserAttributes(
- localpart=session.chosen_localpart, emails=session.emails_to_use,
+ localpart=session.chosen_localpart,
+ emails=session.emails_to_use,
)
if session.use_display_name:
@@ -893,6 +904,41 @@ class SsoHandler:
logger.info("Expiring mapping session %s", session_id)
del self._username_mapping_sessions[session_id]
+ def check_required_attributes(
+ self,
+ request: SynapseRequest,
+ attributes: Mapping[str, List[Any]],
+ attribute_requirements: Iterable[SsoAttributeRequirement],
+ ) -> bool:
+ """
+ Confirm that the required attributes were present in the SSO response.
+
+ If all requirements are met, this will return True.
+
+ If any requirement is not met, then the request will be finalized by
+ showing an error page to the user and False will be returned.
+
+ Args:
+ request: The request to (potentially) respond to.
+ attributes: The attributes from the SSO IdP.
+ attribute_requirements: The requirements that attributes must meet.
+
+ Returns:
+ True if all requirements are met, False if any attribute fails to
+ meet the requirement.
+
+ """
+ # Ensure that the attributes of the logged in user meet the required
+ # attributes.
+ for requirement in attribute_requirements:
+ if not _check_attribute_requirement(attributes, requirement):
+ self.render_error(
+ request, "unauthorised", "You are not authorised to log in here."
+ )
+ return False
+
+ return True
+
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
"""Extract the session ID from the cookie
@@ -903,3 +949,36 @@ def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
if not session_id:
raise SynapseError(code=400, msg="missing session_id")
return session_id.decode("ascii", errors="replace")
+
+
+def _check_attribute_requirement(
+ attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement
+) -> bool:
+ """Check if SSO attributes meet the proper requirements.
+
+ Args:
+ attributes: A mapping of attributes to an iterable of one or more values.
+ requirement: The configured requirement to check.
+
+ Returns:
+ True if the required attribute was found and had a proper value.
+ """
+ if req.attribute not in attributes:
+ logger.info("SSO attribute missing: %s", req.attribute)
+ return False
+
+ # If the requirement is None, the attribute existing is enough.
+ if req.value is None:
+ return True
+
+ values = attributes[req.attribute]
+ if req.value in values:
+ return True
+
+ logger.info(
+ "SSO attribute %s did not match required value '%s' (was '%s')",
+ req.attribute,
+ req.value,
+ values,
+ )
+ return False
|