summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/9326.misc1
-rw-r--r--synapse/config/cas.py32
-rw-r--r--synapse/config/saml2_config.py25
-rw-r--r--synapse/config/sso.py19
-rw-r--r--synapse/handlers/cas_handler.py40
-rw-r--r--synapse/handlers/saml_handler.py26
-rw-r--r--synapse/handlers/sso.py71
-rw-r--r--tests/handlers/test_cas.py52
-rw-r--r--tests/handlers/test_saml.py56
9 files changed, 245 insertions, 77 deletions
diff --git a/changelog.d/9326.misc b/changelog.d/9326.misc
new file mode 100644
index 0000000000..768c18d27e
--- /dev/null
+++ b/changelog.d/9326.misc
@@ -0,0 +1 @@
+Share the code for handling required attributes between the CAS and SAML handlers.
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index b226890c2a..daea848d24 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -13,7 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Any, List
+
+from synapse.config.sso import SsoAttributeRequirement
+
 from ._base import Config
+from ._util import validate_config
 
 
 class CasConfig(Config):
@@ -38,12 +43,16 @@ class CasConfig(Config):
                 public_base_url + "_matrix/client/r0/login/cas/ticket"
             )
             self.cas_displayname_attribute = cas_config.get("displayname_attribute")
-            self.cas_required_attributes = cas_config.get("required_attributes") or {}
+            required_attributes = cas_config.get("required_attributes") or {}
+            self.cas_required_attributes = _parsed_required_attributes_def(
+                required_attributes
+            )
+
         else:
             self.cas_server_url = None
             self.cas_service_url = None
             self.cas_displayname_attribute = None
-            self.cas_required_attributes = {}
+            self.cas_required_attributes = []
 
     def generate_config_section(self, config_dir_path, server_name, **kwargs):
         return """\
@@ -75,3 +84,22 @@ class CasConfig(Config):
           #  userGroup: "staff"
           #  department: None
         """
+
+
+# CAS uses a legacy required attributes mapping, not the one provided by
+# SsoAttributeRequirement.
+REQUIRED_ATTRIBUTES_SCHEMA = {
+    "type": "object",
+    "additionalProperties": {"anyOf": [{"type": "string"}, {"type": "null"}]},
+}
+
+
+def _parsed_required_attributes_def(
+    required_attributes: Any,
+) -> List[SsoAttributeRequirement]:
+    validate_config(
+        REQUIRED_ATTRIBUTES_SCHEMA,
+        required_attributes,
+        config_path=("cas_config", "required_attributes"),
+    )
+    return [SsoAttributeRequirement(k, v) for k, v in required_attributes.items()]
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index ad865a667f..1820614bc0 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -17,8 +17,7 @@
 import logging
 from typing import Any, List
 
-import attr
-
+from synapse.config.sso import SsoAttributeRequirement
 from synapse.python_dependencies import DependencyException, check_requirements
 from synapse.util.module_loader import load_module, load_python_module
 
@@ -396,32 +395,18 @@ class SAML2Config(Config):
         }
 
 
-@attr.s(frozen=True)
-class SamlAttributeRequirement:
-    """Object describing a single requirement for SAML attributes."""
-
-    attribute = attr.ib(type=str)
-    value = attr.ib(type=str)
-
-    JSON_SCHEMA = {
-        "type": "object",
-        "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
-        "required": ["attribute", "value"],
-    }
-
-
 ATTRIBUTE_REQUIREMENTS_SCHEMA = {
     "type": "array",
-    "items": SamlAttributeRequirement.JSON_SCHEMA,
+    "items": SsoAttributeRequirement.JSON_SCHEMA,
 }
 
 
 def _parse_attribute_requirements_def(
     attribute_requirements: Any,
-) -> List[SamlAttributeRequirement]:
+) -> List[SsoAttributeRequirement]:
     validate_config(
         ATTRIBUTE_REQUIREMENTS_SCHEMA,
         attribute_requirements,
-        config_path=["saml2_config", "attribute_requirements"],
+        config_path=("saml2_config", "attribute_requirements"),
     )
-    return [SamlAttributeRequirement(**x) for x in attribute_requirements]
+    return [SsoAttributeRequirement(**x) for x in attribute_requirements]
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 6c60c6fea4..b94d3cd5e1 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -12,11 +12,28 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Dict
+from typing import Any, Dict, Optional
+
+import attr
 
 from ._base import Config
 
 
+@attr.s(frozen=True)
+class SsoAttributeRequirement:
+    """Object describing a single requirement for SSO attributes."""
+
+    attribute = attr.ib(type=str)
+    # If a value is not given, than the attribute must simply exist.
+    value = attr.ib(type=Optional[str])
+
+    JSON_SCHEMA = {
+        "type": "object",
+        "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
+        "required": ["attribute", "value"],
+    }
+
+
 class SSOConfig(Config):
     """SSO Configuration
     """
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index bd35d1fb87..81ed44ac87 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import urllib.parse
-from typing import TYPE_CHECKING, Dict, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional
 from xml.etree import ElementTree as ET
 
 import attr
@@ -49,7 +49,7 @@ class CasError(Exception):
 @attr.s(slots=True, frozen=True)
 class CasResponse:
     username = attr.ib(type=str)
-    attributes = attr.ib(type=Dict[str, Optional[str]])
+    attributes = attr.ib(type=Dict[str, List[Optional[str]]])
 
 
 class CasHandler:
@@ -169,7 +169,7 @@ class CasHandler:
 
         # Iterate through the nodes and pull out the user and any extra attributes.
         user = None
-        attributes = {}
+        attributes = {}  # type: Dict[str, List[Optional[str]]]
         for child in root[0]:
             if child.tag.endswith("user"):
                 user = child.text
@@ -182,7 +182,7 @@ class CasHandler:
                     tag = attribute.tag
                     if "}" in tag:
                         tag = tag.split("}")[1]
-                    attributes[tag] = attribute.text
+                    attributes.setdefault(tag, []).append(attribute.text)
 
         # Ensure a user was found.
         if user is None:
@@ -303,29 +303,10 @@ class CasHandler:
 
         # Ensure that the attributes of the logged in user meet the required
         # attributes.
-        for required_attribute, required_value in self._cas_required_attributes.items():
-            # If required attribute was not in CAS Response - Forbidden
-            if required_attribute not in cas_response.attributes:
-                self._sso_handler.render_error(
-                    request,
-                    "unauthorised",
-                    "You are not authorised to log in here.",
-                    401,
-                )
-                return
-
-            # Also need to check value
-            if required_value is not None:
-                actual_value = cas_response.attributes[required_attribute]
-                # If required attribute value does not match expected - Forbidden
-                if required_value != actual_value:
-                    self._sso_handler.render_error(
-                        request,
-                        "unauthorised",
-                        "You are not authorised to log in here.",
-                        401,
-                    )
-                    return
+        if not self._sso_handler.check_required_attributes(
+            request, cas_response.attributes, self._cas_required_attributes
+        ):
+            return
 
         # Call the mapper to register/login the user
 
@@ -372,9 +353,10 @@ class CasHandler:
             if failures:
                 raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
 
+            # Arbitrarily use the first attribute found.
             display_name = cas_response.attributes.get(
-                self._cas_displayname_attribute, None
-            )
+                self._cas_displayname_attribute, [None]
+            )[0]
 
             return UserAttributes(localpart=localpart, display_name=display_name)
 
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index e88fd59749..78f130e152 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -23,7 +23,6 @@ from saml2.client import Saml2Client
 
 from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
-from synapse.config.saml2_config import SamlAttributeRequirement
 from synapse.handlers._base import BaseHandler
 from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.servlet import parse_string
@@ -239,12 +238,10 @@ class SamlHandler(BaseHandler):
 
         # Ensure that the attributes of the logged in user meet the required
         # attributes.
-        for requirement in self._saml2_attribute_requirements:
-            if not _check_attribute_requirement(saml2_auth.ava, requirement):
-                self._sso_handler.render_error(
-                    request, "unauthorised", "You are not authorised to log in here."
-                )
-                return
+        if not self._sso_handler.check_required_attributes(
+            request, saml2_auth.ava, self._saml2_attribute_requirements
+        ):
+            return
 
         # Call the mapper to register/login the user
         try:
@@ -373,21 +370,6 @@ class SamlHandler(BaseHandler):
             del self._outstanding_requests_dict[reqid]
 
 
-def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
-    values = ava.get(req.attribute, [])
-    for v in values:
-        if v == req.value:
-            return True
-
-    logger.info(
-        "SAML2 attribute %s did not match required value '%s' (was '%s')",
-        req.attribute,
-        req.value,
-        values,
-    )
-    return False
-
-
 DOT_REPLACE_PATTERN = re.compile(
     ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
 )
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 96ccd991ed..a63fd52485 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -16,10 +16,12 @@ import abc
 import logging
 from typing import (
     TYPE_CHECKING,
+    Any,
     Awaitable,
     Callable,
     Dict,
     Iterable,
+    List,
     Mapping,
     Optional,
     Set,
@@ -34,6 +36,7 @@ from twisted.web.iweb import IRequest
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
+from synapse.config.sso import SsoAttributeRequirement
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http import get_request_user_agent
 from synapse.http.server import respond_with_html, respond_with_redirect
@@ -893,6 +896,41 @@ class SsoHandler:
             logger.info("Expiring mapping session %s", session_id)
             del self._username_mapping_sessions[session_id]
 
+    def check_required_attributes(
+        self,
+        request: SynapseRequest,
+        attributes: Mapping[str, List[Any]],
+        attribute_requirements: Iterable[SsoAttributeRequirement],
+    ) -> bool:
+        """
+        Confirm that the required attributes were present in the SSO response.
+
+        If all requirements are met, this will return True.
+
+        If any requirement is not met, then the request will be finalized by
+        showing an error page to the user and False will be returned.
+
+        Args:
+            request: The request to (potentially) respond to.
+            attributes: The attributes from the SSO IdP.
+            attribute_requirements: The requirements that attributes must meet.
+
+        Returns:
+            True if all requirements are met, False if any attribute fails to
+            meet the requirement.
+
+        """
+        # Ensure that the attributes of the logged in user meet the required
+        # attributes.
+        for requirement in attribute_requirements:
+            if not _check_attribute_requirement(attributes, requirement):
+                self.render_error(
+                    request, "unauthorised", "You are not authorised to log in here."
+                )
+                return False
+
+        return True
+
 
 def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
     """Extract the session ID from the cookie
@@ -903,3 +941,36 @@ def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
     if not session_id:
         raise SynapseError(code=400, msg="missing session_id")
     return session_id.decode("ascii", errors="replace")
+
+
+def _check_attribute_requirement(
+    attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement
+) -> bool:
+    """Check if SSO attributes meet the proper requirements.
+
+    Args:
+        attributes: A mapping of attributes to an iterable of one or more values.
+        requirement: The configured requirement to check.
+
+    Returns:
+        True if the required attribute was found and had a proper value.
+    """
+    if req.attribute not in attributes:
+        logger.info("SSO attribute missing: %s", req.attribute)
+        return False
+
+    # If the requirement is None, the attribute existing is enough.
+    if req.value is None:
+        return True
+
+    values = attributes[req.attribute]
+    if req.value in values:
+        return True
+
+    logger.info(
+        "SSO attribute %s did not match required value '%s' (was '%s')",
+        req.attribute,
+        req.value,
+        values,
+    )
+    return False
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 7baf224f7e..6f992291b8 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -16,7 +16,7 @@ from mock import Mock
 from synapse.handlers.cas_handler import CasResponse
 
 from tests.test_utils import simple_async_mock
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
 
 # These are a few constants that are used as config parameters in the tests.
 BASE_URL = "https://synapse/"
@@ -32,6 +32,10 @@ class CasHandlerTestCase(HomeserverTestCase):
             "server_url": SERVER_URL,
             "service_url": BASE_URL,
         }
+
+        # Update this config with what's in the default config so that
+        # override_config works as expected.
+        cas_config.update(config.get("cas_config", {}))
         config["cas_config"] = cas_config
 
         return config
@@ -115,7 +119,51 @@ class CasHandlerTestCase(HomeserverTestCase):
             "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
         )
 
+    @override_config(
+        {
+            "cas_config": {
+                "required_attributes": {"userGroup": "staff", "department": None}
+            }
+        }
+    )
+    def test_required_attributes(self):
+        """The required attributes must be met from the CAS response."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # The response doesn't have the proper userGroup or department.
+        cas_response = CasResponse("test_user", {})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # The response doesn't have any department.
+        cas_response = CasResponse("test_user", {"userGroup": "staff"})
+        request.reset_mock()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # Add the proper attributes and it should succeed.
+        cas_response = CasResponse(
+            "test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]}
+        )
+        request.reset_mock()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=True
+        )
+
 
 def _mock_request():
     """Returns a mock which will stand in as a SynapseRequest"""
-    return Mock(spec=["getClientIP", "getHeader"])
+    return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index a8d6c0f617..029af2853e 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -259,7 +259,61 @@ class SamlHandlerTestCase(HomeserverTestCase):
         )
         self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
 
+    @override_config(
+        {
+            "saml2_config": {
+                "attribute_requirements": [
+                    {"attribute": "userGroup", "value": "staff"},
+                    {"attribute": "department", "value": "sales"},
+                ],
+            },
+        }
+    )
+    def test_attribute_requirements(self):
+        """The required attributes must be met from the SAML response."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # The response doesn't have the proper userGroup or department.
+        saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+        )
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # The response doesn't have the proper department.
+        saml_response = FakeAuthnResponse(
+            {"uid": "test_user", "username": "test_user", "userGroup": ["staff"]}
+        )
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+        )
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # Add the proper attributes and it should succeed.
+        saml_response = FakeAuthnResponse(
+            {
+                "uid": "test_user",
+                "username": "test_user",
+                "userGroup": ["staff", "admin"],
+                "department": ["sales"],
+            }
+        )
+        request.reset_mock()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=True
+        )
+
 
 def _mock_request():
     """Returns a mock which will stand in as a SynapseRequest"""
-    return Mock(spec=["getClientIP", "getHeader"])
+    return Mock(spec=["getClientIP", "getHeader", "_disconnected"])