summary refs log tree commit diff
path: root/synapse/handlers/saml_handler.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-02-11 10:05:15 -0500
committerGitHub <noreply@github.com>2021-02-11 10:05:15 -0500
commit6dade80048380166ac7543d96c4d4686401b1e37 (patch)
tree31e9f226a6f77a701a5849878c2b0cffd71b89c6 /synapse/handlers/saml_handler.py
parentRemove conflicting sqlite tables that are "reserved" (shadow fts4 tables) (#9... (diff)
downloadsynapse-6dade80048380166ac7543d96c4d4686401b1e37.tar.xz
Combine the CAS & SAML implementations for required attributes. (#9326)
Diffstat (limited to 'synapse/handlers/saml_handler.py')
-rw-r--r--synapse/handlers/saml_handler.py26
1 files changed, 4 insertions, 22 deletions
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index e88fd59749..78f130e152 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -23,7 +23,6 @@ from saml2.client import Saml2Client
 
 from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
-from synapse.config.saml2_config import SamlAttributeRequirement
 from synapse.handlers._base import BaseHandler
 from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.servlet import parse_string
@@ -239,12 +238,10 @@ class SamlHandler(BaseHandler):
 
         # Ensure that the attributes of the logged in user meet the required
         # attributes.
-        for requirement in self._saml2_attribute_requirements:
-            if not _check_attribute_requirement(saml2_auth.ava, requirement):
-                self._sso_handler.render_error(
-                    request, "unauthorised", "You are not authorised to log in here."
-                )
-                return
+        if not self._sso_handler.check_required_attributes(
+            request, saml2_auth.ava, self._saml2_attribute_requirements
+        ):
+            return
 
         # Call the mapper to register/login the user
         try:
@@ -373,21 +370,6 @@ class SamlHandler(BaseHandler):
             del self._outstanding_requests_dict[reqid]
 
 
-def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
-    values = ava.get(req.attribute, [])
-    for v in values:
-        if v == req.value:
-            return True
-
-    logger.info(
-        "SAML2 attribute %s did not match required value '%s' (was '%s')",
-        req.attribute,
-        req.value,
-        values,
-    )
-    return False
-
-
 DOT_REPLACE_PATTERN = re.compile(
     ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
 )