summary refs log tree commit diff
path: root/synapse/handlers/sso.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/sso.py')
-rw-r--r--synapse/handlers/sso.py90
1 files changed, 87 insertions, 3 deletions
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index b450668f1c..a63fd52485 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -16,10 +16,12 @@ import abc
 import logging
 from typing import (
     TYPE_CHECKING,
+    Any,
     Awaitable,
     Callable,
     Dict,
     Iterable,
+    List,
     Mapping,
     Optional,
     Set,
@@ -34,6 +36,7 @@ from twisted.web.iweb import IRequest
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
+from synapse.config.sso import SsoAttributeRequirement
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http import get_request_user_agent
 from synapse.http.server import respond_with_html, respond_with_redirect
@@ -742,7 +745,11 @@ class SsoHandler:
             use_display_name: whether the user wants to use the suggested display name
             emails_to_use: emails that the user would like to use
         """
-        session = self.get_mapping_session(session_id)
+        try:
+            session = self.get_mapping_session(session_id)
+        except SynapseError as e:
+            self.render_error(request, "bad_session", e.msg, code=e.code)
+            return
 
         # update the session with the user's choices
         session.chosen_localpart = localpart
@@ -793,7 +800,12 @@ class SsoHandler:
             session_id,
             terms_version,
         )
-        session = self.get_mapping_session(session_id)
+        try:
+            session = self.get_mapping_session(session_id)
+        except SynapseError as e:
+            self.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
         session.terms_accepted_version = terms_version
 
         # we're done; now we can register the user
@@ -808,7 +820,11 @@ class SsoHandler:
             request: HTTP request
             session_id: ID of the username mapping session, extracted from a cookie
         """
-        session = self.get_mapping_session(session_id)
+        try:
+            session = self.get_mapping_session(session_id)
+        except SynapseError as e:
+            self.render_error(request, "bad_session", e.msg, code=e.code)
+            return
 
         logger.info(
             "[session %s] Registering localpart %s",
@@ -880,6 +896,41 @@ class SsoHandler:
             logger.info("Expiring mapping session %s", session_id)
             del self._username_mapping_sessions[session_id]
 
+    def check_required_attributes(
+        self,
+        request: SynapseRequest,
+        attributes: Mapping[str, List[Any]],
+        attribute_requirements: Iterable[SsoAttributeRequirement],
+    ) -> bool:
+        """
+        Confirm that the required attributes were present in the SSO response.
+
+        If all requirements are met, this will return True.
+
+        If any requirement is not met, then the request will be finalized by
+        showing an error page to the user and False will be returned.
+
+        Args:
+            request: The request to (potentially) respond to.
+            attributes: The attributes from the SSO IdP.
+            attribute_requirements: The requirements that attributes must meet.
+
+        Returns:
+            True if all requirements are met, False if any attribute fails to
+            meet the requirement.
+
+        """
+        # Ensure that the attributes of the logged in user meet the required
+        # attributes.
+        for requirement in attribute_requirements:
+            if not _check_attribute_requirement(attributes, requirement):
+                self.render_error(
+                    request, "unauthorised", "You are not authorised to log in here."
+                )
+                return False
+
+        return True
+
 
 def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
     """Extract the session ID from the cookie
@@ -890,3 +941,36 @@ def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
     if not session_id:
         raise SynapseError(code=400, msg="missing session_id")
     return session_id.decode("ascii", errors="replace")
+
+
+def _check_attribute_requirement(
+    attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement
+) -> bool:
+    """Check if SSO attributes meet the proper requirements.
+
+    Args:
+        attributes: A mapping of attributes to an iterable of one or more values.
+        requirement: The configured requirement to check.
+
+    Returns:
+        True if the required attribute was found and had a proper value.
+    """
+    if req.attribute not in attributes:
+        logger.info("SSO attribute missing: %s", req.attribute)
+        return False
+
+    # If the requirement is None, the attribute existing is enough.
+    if req.value is None:
+        return True
+
+    values = attributes[req.attribute]
+    if req.value in values:
+        return True
+
+    logger.info(
+        "SSO attribute %s did not match required value '%s' (was '%s')",
+        req.attribute,
+        req.value,
+        values,
+    )
+    return False