diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index b226890c2a..daea848d24 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -13,7 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, List
+
+from synapse.config.sso import SsoAttributeRequirement
+
from ._base import Config
+from ._util import validate_config
class CasConfig(Config):
@@ -38,12 +43,16 @@ class CasConfig(Config):
public_base_url + "_matrix/client/r0/login/cas/ticket"
)
self.cas_displayname_attribute = cas_config.get("displayname_attribute")
- self.cas_required_attributes = cas_config.get("required_attributes") or {}
+ required_attributes = cas_config.get("required_attributes") or {}
+ self.cas_required_attributes = _parsed_required_attributes_def(
+ required_attributes
+ )
+
else:
self.cas_server_url = None
self.cas_service_url = None
self.cas_displayname_attribute = None
- self.cas_required_attributes = {}
+ self.cas_required_attributes = []
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
@@ -75,3 +84,22 @@ class CasConfig(Config):
# userGroup: "staff"
# department: None
"""
+
+
+# CAS uses a legacy required attributes mapping, not the one provided by
+# SsoAttributeRequirement.
+REQUIRED_ATTRIBUTES_SCHEMA = {
+ "type": "object",
+ "additionalProperties": {"anyOf": [{"type": "string"}, {"type": "null"}]},
+}
+
+
+def _parsed_required_attributes_def(
+ required_attributes: Any,
+) -> List[SsoAttributeRequirement]:
+ validate_config(
+ REQUIRED_ATTRIBUTES_SCHEMA,
+ required_attributes,
+ config_path=("cas_config", "required_attributes"),
+ )
+ return [SsoAttributeRequirement(k, v) for k, v in required_attributes.items()]
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index ad865a667f..1820614bc0 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -17,8 +17,7 @@
import logging
from typing import Any, List
-import attr
-
+from synapse.config.sso import SsoAttributeRequirement
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module, load_python_module
@@ -396,32 +395,18 @@ class SAML2Config(Config):
}
-@attr.s(frozen=True)
-class SamlAttributeRequirement:
- """Object describing a single requirement for SAML attributes."""
-
- attribute = attr.ib(type=str)
- value = attr.ib(type=str)
-
- JSON_SCHEMA = {
- "type": "object",
- "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
- "required": ["attribute", "value"],
- }
-
-
ATTRIBUTE_REQUIREMENTS_SCHEMA = {
"type": "array",
- "items": SamlAttributeRequirement.JSON_SCHEMA,
+ "items": SsoAttributeRequirement.JSON_SCHEMA,
}
def _parse_attribute_requirements_def(
attribute_requirements: Any,
-) -> List[SamlAttributeRequirement]:
+) -> List[SsoAttributeRequirement]:
validate_config(
ATTRIBUTE_REQUIREMENTS_SCHEMA,
attribute_requirements,
- config_path=["saml2_config", "attribute_requirements"],
+ config_path=("saml2_config", "attribute_requirements"),
)
- return [SamlAttributeRequirement(**x) for x in attribute_requirements]
+ return [SsoAttributeRequirement(**x) for x in attribute_requirements]
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 6c60c6fea4..b94d3cd5e1 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -12,11 +12,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict
+from typing import Any, Dict, Optional
+
+import attr
from ._base import Config
+@attr.s(frozen=True)
+class SsoAttributeRequirement:
+ """Object describing a single requirement for SSO attributes."""
+
+ attribute = attr.ib(type=str)
+ # If a value is not given, than the attribute must simply exist.
+ value = attr.ib(type=Optional[str])
+
+ JSON_SCHEMA = {
+ "type": "object",
+ "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
+ "required": ["attribute", "value"],
+ }
+
+
class SSOConfig(Config):
"""SSO Configuration
"""
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index bd35d1fb87..81ed44ac87 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib.parse
-from typing import TYPE_CHECKING, Dict, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional
from xml.etree import ElementTree as ET
import attr
@@ -49,7 +49,7 @@ class CasError(Exception):
@attr.s(slots=True, frozen=True)
class CasResponse:
username = attr.ib(type=str)
- attributes = attr.ib(type=Dict[str, Optional[str]])
+ attributes = attr.ib(type=Dict[str, List[Optional[str]]])
class CasHandler:
@@ -169,7 +169,7 @@ class CasHandler:
# Iterate through the nodes and pull out the user and any extra attributes.
user = None
- attributes = {}
+ attributes = {} # type: Dict[str, List[Optional[str]]]
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
@@ -182,7 +182,7 @@ class CasHandler:
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
- attributes[tag] = attribute.text
+ attributes.setdefault(tag, []).append(attribute.text)
# Ensure a user was found.
if user is None:
@@ -303,29 +303,10 @@ class CasHandler:
# Ensure that the attributes of the logged in user meet the required
# attributes.
- for required_attribute, required_value in self._cas_required_attributes.items():
- # If required attribute was not in CAS Response - Forbidden
- if required_attribute not in cas_response.attributes:
- self._sso_handler.render_error(
- request,
- "unauthorised",
- "You are not authorised to log in here.",
- 401,
- )
- return
-
- # Also need to check value
- if required_value is not None:
- actual_value = cas_response.attributes[required_attribute]
- # If required attribute value does not match expected - Forbidden
- if required_value != actual_value:
- self._sso_handler.render_error(
- request,
- "unauthorised",
- "You are not authorised to log in here.",
- 401,
- )
- return
+ if not self._sso_handler.check_required_attributes(
+ request, cas_response.attributes, self._cas_required_attributes
+ ):
+ return
# Call the mapper to register/login the user
@@ -372,9 +353,10 @@ class CasHandler:
if failures:
raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
+ # Arbitrarily use the first attribute found.
display_name = cas_response.attributes.get(
- self._cas_displayname_attribute, None
- )
+ self._cas_displayname_attribute, [None]
+ )[0]
return UserAttributes(localpart=localpart, display_name=display_name)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index e88fd59749..78f130e152 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -23,7 +23,6 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
-from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
@@ -239,12 +238,10 @@ class SamlHandler(BaseHandler):
# Ensure that the attributes of the logged in user meet the required
# attributes.
- for requirement in self._saml2_attribute_requirements:
- if not _check_attribute_requirement(saml2_auth.ava, requirement):
- self._sso_handler.render_error(
- request, "unauthorised", "You are not authorised to log in here."
- )
- return
+ if not self._sso_handler.check_required_attributes(
+ request, saml2_auth.ava, self._saml2_attribute_requirements
+ ):
+ return
# Call the mapper to register/login the user
try:
@@ -373,21 +370,6 @@ class SamlHandler(BaseHandler):
del self._outstanding_requests_dict[reqid]
-def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
- values = ava.get(req.attribute, [])
- for v in values:
- if v == req.value:
- return True
-
- logger.info(
- "SAML2 attribute %s did not match required value '%s' (was '%s')",
- req.attribute,
- req.value,
- values,
- )
- return False
-
-
DOT_REPLACE_PATTERN = re.compile(
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
)
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 96ccd991ed..a63fd52485 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -16,10 +16,12 @@ import abc
import logging
from typing import (
TYPE_CHECKING,
+ Any,
Awaitable,
Callable,
Dict,
Iterable,
+ List,
Mapping,
Optional,
Set,
@@ -34,6 +36,7 @@ from twisted.web.iweb import IRequest
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
+from synapse.config.sso import SsoAttributeRequirement
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html, respond_with_redirect
@@ -893,6 +896,41 @@ class SsoHandler:
logger.info("Expiring mapping session %s", session_id)
del self._username_mapping_sessions[session_id]
+ def check_required_attributes(
+ self,
+ request: SynapseRequest,
+ attributes: Mapping[str, List[Any]],
+ attribute_requirements: Iterable[SsoAttributeRequirement],
+ ) -> bool:
+ """
+ Confirm that the required attributes were present in the SSO response.
+
+ If all requirements are met, this will return True.
+
+ If any requirement is not met, then the request will be finalized by
+ showing an error page to the user and False will be returned.
+
+ Args:
+ request: The request to (potentially) respond to.
+ attributes: The attributes from the SSO IdP.
+ attribute_requirements: The requirements that attributes must meet.
+
+ Returns:
+ True if all requirements are met, False if any attribute fails to
+ meet the requirement.
+
+ """
+ # Ensure that the attributes of the logged in user meet the required
+ # attributes.
+ for requirement in attribute_requirements:
+ if not _check_attribute_requirement(attributes, requirement):
+ self.render_error(
+ request, "unauthorised", "You are not authorised to log in here."
+ )
+ return False
+
+ return True
+
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
"""Extract the session ID from the cookie
@@ -903,3 +941,36 @@ def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
if not session_id:
raise SynapseError(code=400, msg="missing session_id")
return session_id.decode("ascii", errors="replace")
+
+
+def _check_attribute_requirement(
+ attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement
+) -> bool:
+ """Check if SSO attributes meet the proper requirements.
+
+ Args:
+ attributes: A mapping of attributes to an iterable of one or more values.
+ requirement: The configured requirement to check.
+
+ Returns:
+ True if the required attribute was found and had a proper value.
+ """
+ if req.attribute not in attributes:
+ logger.info("SSO attribute missing: %s", req.attribute)
+ return False
+
+ # If the requirement is None, the attribute existing is enough.
+ if req.value is None:
+ return True
+
+ values = attributes[req.attribute]
+ if req.value in values:
+ return True
+
+ logger.info(
+ "SSO attribute %s did not match required value '%s' (was '%s')",
+ req.attribute,
+ req.value,
+ values,
+ )
+ return False
|