diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 0f342c607b..04972f9cf0 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib.parse
-from typing import TYPE_CHECKING, Dict, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional
from xml.etree import ElementTree as ET
import attr
@@ -33,8 +33,7 @@ logger = logging.getLogger(__name__)
class CasError(Exception):
- """Used to catch errors when validating the CAS ticket.
- """
+ """Used to catch errors when validating the CAS ticket."""
def __init__(self, error, error_description=None):
self.error = error
@@ -49,7 +48,7 @@ class CasError(Exception):
@attr.s(slots=True, frozen=True)
class CasResponse:
username = attr.ib(type=str)
- attributes = attr.ib(type=Dict[str, Optional[str]])
+ attributes = attr.ib(type=Dict[str, List[Optional[str]]])
class CasHandler:
@@ -80,9 +79,10 @@ class CasHandler:
# user-facing name of this auth provider
self.idp_name = "CAS"
- # we do not currently support icons for CAS auth, but this is required by
+ # we do not currently support brands/icons for CAS auth, but this is required by
# the SsoIdentityProvider protocol type.
self.idp_icon = None
+ self.idp_brand = None
self._sso_handler = hs.get_sso_handler()
@@ -99,9 +99,8 @@ class CasHandler:
Returns:
The URL to use as a "service" parameter.
"""
- return "%s%s?%s" % (
+ return "%s?%s" % (
self._cas_service_url,
- "/_matrix/client/r0/login/cas/ticket",
urllib.parse.urlencode(args),
)
@@ -172,7 +171,7 @@ class CasHandler:
# Iterate through the nodes and pull out the user and any extra attributes.
user = None
- attributes = {}
+ attributes = {} # type: Dict[str, List[Optional[str]]]
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
@@ -185,7 +184,7 @@ class CasHandler:
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
- attributes[tag] = attribute.text
+ attributes.setdefault(tag, []).append(attribute.text)
# Ensure a user was found.
if user is None:
@@ -299,36 +298,20 @@ class CasHandler:
# first check if we're doing a UIA
if session:
return await self._sso_handler.complete_sso_ui_auth_request(
- self.idp_id, cas_response.username, session, request,
+ self.idp_id,
+ cas_response.username,
+ session,
+ request,
)
# otherwise, we're handling a login request.
# Ensure that the attributes of the logged in user meet the required
# attributes.
- for required_attribute, required_value in self._cas_required_attributes.items():
- # If required attribute was not in CAS Response - Forbidden
- if required_attribute not in cas_response.attributes:
- self._sso_handler.render_error(
- request,
- "unauthorised",
- "You are not authorised to log in here.",
- 401,
- )
- return
-
- # Also need to check value
- if required_value is not None:
- actual_value = cas_response.attributes[required_attribute]
- # If required attribute value does not match expected - Forbidden
- if required_value != actual_value:
- self._sso_handler.render_error(
- request,
- "unauthorised",
- "You are not authorised to log in here.",
- 401,
- )
- return
+ if not self._sso_handler.check_required_attributes(
+ request, cas_response.attributes, self._cas_required_attributes
+ ):
+ return
# Call the mapper to register/login the user
@@ -375,9 +358,10 @@ class CasHandler:
if failures:
raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
+ # Arbitrarily use the first attribute found.
display_name = cas_response.attributes.get(
- self._cas_displayname_attribute, None
- )
+ self._cas_displayname_attribute, [None]
+ )[0]
return UserAttributes(localpart=localpart, display_name=display_name)
@@ -387,7 +371,8 @@ class CasHandler:
user_id = UserID(localpart, self._hostname).to_string()
logger.debug(
- "Looking for existing account based on mapped %s", user_id,
+ "Looking for existing account based on mapped %s",
+ user_id,
)
users = await self._store.get_users_by_id_case_insensitive(user_id)
|