summary refs log tree commit diff
path: root/synapse/handlers/cas_handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/cas_handler.py')
-rw-r--r--synapse/handlers/cas_handler.py40
1 files changed, 11 insertions, 29 deletions
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index bd35d1fb87..81ed44ac87 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import urllib.parse
-from typing import TYPE_CHECKING, Dict, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional
 from xml.etree import ElementTree as ET
 
 import attr
@@ -49,7 +49,7 @@ class CasError(Exception):
 @attr.s(slots=True, frozen=True)
 class CasResponse:
     username = attr.ib(type=str)
-    attributes = attr.ib(type=Dict[str, Optional[str]])
+    attributes = attr.ib(type=Dict[str, List[Optional[str]]])
 
 
 class CasHandler:
@@ -169,7 +169,7 @@ class CasHandler:
 
         # Iterate through the nodes and pull out the user and any extra attributes.
         user = None
-        attributes = {}
+        attributes = {}  # type: Dict[str, List[Optional[str]]]
         for child in root[0]:
             if child.tag.endswith("user"):
                 user = child.text
@@ -182,7 +182,7 @@ class CasHandler:
                     tag = attribute.tag
                     if "}" in tag:
                         tag = tag.split("}")[1]
-                    attributes[tag] = attribute.text
+                    attributes.setdefault(tag, []).append(attribute.text)
 
         # Ensure a user was found.
         if user is None:
@@ -303,29 +303,10 @@ class CasHandler:
 
         # Ensure that the attributes of the logged in user meet the required
         # attributes.
-        for required_attribute, required_value in self._cas_required_attributes.items():
-            # If required attribute was not in CAS Response - Forbidden
-            if required_attribute not in cas_response.attributes:
-                self._sso_handler.render_error(
-                    request,
-                    "unauthorised",
-                    "You are not authorised to log in here.",
-                    401,
-                )
-                return
-
-            # Also need to check value
-            if required_value is not None:
-                actual_value = cas_response.attributes[required_attribute]
-                # If required attribute value does not match expected - Forbidden
-                if required_value != actual_value:
-                    self._sso_handler.render_error(
-                        request,
-                        "unauthorised",
-                        "You are not authorised to log in here.",
-                        401,
-                    )
-                    return
+        if not self._sso_handler.check_required_attributes(
+            request, cas_response.attributes, self._cas_required_attributes
+        ):
+            return
 
         # Call the mapper to register/login the user
 
@@ -372,9 +353,10 @@ class CasHandler:
             if failures:
                 raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
 
+            # Arbitrarily use the first attribute found.
             display_name = cas_response.attributes.get(
-                self._cas_displayname_attribute, None
-            )
+                self._cas_displayname_attribute, [None]
+            )[0]
 
             return UserAttributes(localpart=localpart, display_name=display_name)