summary refs log tree commit diff
path: root/synapse/handlers/cas_handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/cas_handler.py')
-rw-r--r--synapse/handlers/cas_handler.py356
1 files changed, 251 insertions, 105 deletions
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index f4ea0a9767..bd35d1fb87 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -13,13 +13,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import urllib
-from typing import TYPE_CHECKING, Dict, Optional, Tuple
+import urllib.parse
+from typing import TYPE_CHECKING, Dict, Optional
 from xml.etree import ElementTree as ET
 
+import attr
+
 from twisted.web.client import PartialDownloadError
 
-from synapse.api.errors import Codes, LoginError
+from synapse.api.errors import HttpResponseException
+from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.site import SynapseRequest
 from synapse.types import UserID, map_username_to_mxid_localpart
 
@@ -29,6 +32,26 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class CasError(Exception):
+    """Used to catch errors when validating the CAS ticket.
+    """
+
+    def __init__(self, error, error_description=None):
+        self.error = error
+        self.error_description = error_description
+
+    def __str__(self):
+        if self.error_description:
+            return "{}: {}".format(self.error, self.error_description)
+        return self.error
+
+
+@attr.s(slots=True, frozen=True)
+class CasResponse:
+    username = attr.ib(type=str)
+    attributes = attr.ib(type=Dict[str, Optional[str]])
+
+
 class CasHandler:
     """
     Utility class for to handle the response from a CAS SSO service.
@@ -40,6 +63,7 @@ class CasHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self._hostname = hs.hostname
+        self._store = hs.get_datastore()
         self._auth_handler = hs.get_auth_handler()
         self._registration_handler = hs.get_registration_handler()
 
@@ -50,6 +74,21 @@ class CasHandler:
 
         self._http_client = hs.get_proxied_http_client()
 
+        # identifier for the external_ids table
+        self.idp_id = "cas"
+
+        # user-facing name of this auth provider
+        self.idp_name = "CAS"
+
+        # we do not currently support brands/icons for CAS auth, but this is required by
+        # the SsoIdentityProvider protocol type.
+        self.idp_icon = None
+        self.idp_brand = None
+
+        self._sso_handler = hs.get_sso_handler()
+
+        self._sso_handler.register_identity_provider(self)
+
     def _build_service_param(self, args: Dict[str, str]) -> str:
         """
         Generates a value to use as the "service" parameter when redirecting or
@@ -61,22 +100,24 @@ class CasHandler:
         Returns:
             The URL to use as a "service" parameter.
         """
-        return "%s%s?%s" % (
-            self._cas_service_url,
-            "/_matrix/client/r0/login/cas/ticket",
-            urllib.parse.urlencode(args),
-        )
+        return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),)
 
     async def _validate_ticket(
         self, ticket: str, service_args: Dict[str, str]
-    ) -> Tuple[str, Optional[str]]:
+    ) -> CasResponse:
         """
-        Validate a CAS ticket with the server, parse the response, and return the user and display name.
+        Validate a CAS ticket with the server, and return the parsed the response.
 
         Args:
             ticket: The CAS ticket from the client.
             service_args: Additional arguments to include in the service URL.
-                Should be the same as those passed to `get_redirect_url`.
+                Should be the same as those passed to `handle_redirect_request`.
+
+        Raises:
+            CasError: If there's an error parsing the CAS response.
+
+        Returns:
+            The parsed CAS response.
         """
         uri = self._cas_server_url + "/proxyValidate"
         args = {
@@ -89,77 +130,91 @@ class CasHandler:
             # Twisted raises this error if the connection is closed,
             # even if that's being used old-http style to signal end-of-data
             body = pde.response
+        except HttpResponseException as e:
+            description = (
+                (
+                    'Authorization server responded with a "{status}" error '
+                    "while exchanging the authorization code."
+                ).format(status=e.code),
+            )
+            raise CasError("server_error", description) from e
 
-        user, attributes = self._parse_cas_response(body)
-        displayname = attributes.pop(self._cas_displayname_attribute, None)
-
-        for required_attribute, required_value in self._cas_required_attributes.items():
-            # If required attribute was not in CAS Response - Forbidden
-            if required_attribute not in attributes:
-                raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
-            # Also need to check value
-            if required_value is not None:
-                actual_value = attributes[required_attribute]
-                # If required attribute value does not match expected - Forbidden
-                if required_value != actual_value:
-                    raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
-        return user, displayname
+        return self._parse_cas_response(body)
 
-    def _parse_cas_response(
-        self, cas_response_body: bytes
-    ) -> Tuple[str, Dict[str, Optional[str]]]:
+    def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
         """
         Retrieve the user and other parameters from the CAS response.
 
         Args:
             cas_response_body: The response from the CAS query.
 
+        Raises:
+            CasError: If there's an error parsing the CAS response.
+
         Returns:
-            A tuple of the user and a mapping of other attributes.
+            The parsed CAS response.
         """
-        user = None
-        attributes = {}
-        try:
-            root = ET.fromstring(cas_response_body)
-            if not root.tag.endswith("serviceResponse"):
-                raise Exception("root of CAS response is not serviceResponse")
-            success = root[0].tag.endswith("authenticationSuccess")
-            for child in root[0]:
-                if child.tag.endswith("user"):
-                    user = child.text
-                if child.tag.endswith("attributes"):
-                    for attribute in child:
-                        # ElementTree library expands the namespace in
-                        # attribute tags to the full URL of the namespace.
-                        # We don't care about namespace here and it will always
-                        # be encased in curly braces, so we remove them.
-                        tag = attribute.tag
-                        if "}" in tag:
-                            tag = tag.split("}")[1]
-                        attributes[tag] = attribute.text
-            if user is None:
-                raise Exception("CAS response does not contain user")
-        except Exception:
-            logger.exception("Error parsing CAS response")
-            raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
-        if not success:
-            raise LoginError(
-                401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
+
+        # Ensure the response is valid.
+        root = ET.fromstring(cas_response_body)
+        if not root.tag.endswith("serviceResponse"):
+            raise CasError(
+                "missing_service_response",
+                "root of CAS response is not serviceResponse",
             )
-        return user, attributes
 
-    def get_redirect_url(self, service_args: Dict[str, str]) -> str:
-        """
-        Generates a URL for the CAS server where the client should be redirected.
+        success = root[0].tag.endswith("authenticationSuccess")
+        if not success:
+            raise CasError("unsucessful_response", "Unsuccessful CAS response")
+
+        # Iterate through the nodes and pull out the user and any extra attributes.
+        user = None
+        attributes = {}
+        for child in root[0]:
+            if child.tag.endswith("user"):
+                user = child.text
+            if child.tag.endswith("attributes"):
+                for attribute in child:
+                    # ElementTree library expands the namespace in
+                    # attribute tags to the full URL of the namespace.
+                    # We don't care about namespace here and it will always
+                    # be encased in curly braces, so we remove them.
+                    tag = attribute.tag
+                    if "}" in tag:
+                        tag = tag.split("}")[1]
+                    attributes[tag] = attribute.text
+
+        # Ensure a user was found.
+        if user is None:
+            raise CasError("no_user", "CAS response does not contain user")
+
+        return CasResponse(user, attributes)
+
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: Optional[bytes],
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
+        """Generates a URL for the CAS server where the client should be redirected.
 
         Args:
-            service_args: Additional arguments to include in the final redirect URL.
+            request: the incoming HTTP request
+            client_redirect_url: the URL that we should redirect the
+                client to after login (or None for UI Auth).
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
+                None if this is a login).
 
         Returns:
-            The URL to redirect the client to.
+            URL to redirect to
         """
+
+        if ui_auth_session_id:
+            service_args = {"session": ui_auth_session_id}
+        else:
+            assert client_redirect_url
+            service_args = {"redirectUrl": client_redirect_url.decode("utf8")}
+
         args = urllib.parse.urlencode(
             {"service": self._build_service_param(service_args)}
         )
@@ -201,59 +256,150 @@ class CasHandler:
             args["redirectUrl"] = client_redirect_url
         if session:
             args["session"] = session
-        username, user_display_name = await self._validate_ticket(ticket, args)
-
-        # Pull out the user-agent and IP from the request.
-        user_agent = request.get_user_agent("")
-        ip_address = self.hs.get_ip_from_request(request)
 
-        # Get the matrix ID from the CAS username.
-        user_id = await self._map_cas_user_to_matrix_user(
-            username, user_display_name, user_agent, ip_address
+        try:
+            cas_response = await self._validate_ticket(ticket, args)
+        except CasError as e:
+            logger.exception("Could not validate ticket")
+            self._sso_handler.render_error(request, e.error, e.error_description, 401)
+            return
+
+        await self._handle_cas_response(
+            request, cas_response, client_redirect_url, session
         )
 
+    async def _handle_cas_response(
+        self,
+        request: SynapseRequest,
+        cas_response: CasResponse,
+        client_redirect_url: Optional[str],
+        session: Optional[str],
+    ) -> None:
+        """Handle a CAS response to a ticket request.
+
+        Assumes that the response has been validated. Maps the user onto an MXID,
+        registering them if necessary, and returns a response to the browser.
+
+        Args:
+            request: the incoming request from the browser. We'll respond to it with an
+                HTML page or a redirect
+
+            cas_response: The parsed CAS response.
+
+            client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
+                This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
+
+            session: The session parameter from the `/cas/ticket` HTTP request, if given.
+                This should be the UI Auth session id.
+        """
+
+        # first check if we're doing a UIA
         if session:
-            await self._auth_handler.complete_sso_ui_auth(
-                user_id, session, request,
+            return await self._sso_handler.complete_sso_ui_auth_request(
+                self.idp_id, cas_response.username, session, request,
             )
-        else:
-            # If this not a UI auth request than there must be a redirect URL.
-            assert client_redirect_url
 
-            await self._auth_handler.complete_sso_login(
-                user_id, request, client_redirect_url
-            )
+        # otherwise, we're handling a login request.
+
+        # Ensure that the attributes of the logged in user meet the required
+        # attributes.
+        for required_attribute, required_value in self._cas_required_attributes.items():
+            # If required attribute was not in CAS Response - Forbidden
+            if required_attribute not in cas_response.attributes:
+                self._sso_handler.render_error(
+                    request,
+                    "unauthorised",
+                    "You are not authorised to log in here.",
+                    401,
+                )
+                return
+
+            # Also need to check value
+            if required_value is not None:
+                actual_value = cas_response.attributes[required_attribute]
+                # If required attribute value does not match expected - Forbidden
+                if required_value != actual_value:
+                    self._sso_handler.render_error(
+                        request,
+                        "unauthorised",
+                        "You are not authorised to log in here.",
+                        401,
+                    )
+                    return
+
+        # Call the mapper to register/login the user
+
+        # If this not a UI auth request than there must be a redirect URL.
+        assert client_redirect_url is not None
 
-    async def _map_cas_user_to_matrix_user(
+        try:
+            await self._complete_cas_login(cas_response, request, client_redirect_url)
+        except MappingException as e:
+            logger.exception("Could not map user")
+            self._sso_handler.render_error(request, "mapping_error", str(e))
+
+    async def _complete_cas_login(
         self,
-        remote_user_id: str,
-        display_name: Optional[str],
-        user_agent: str,
-        ip_address: str,
-    ) -> str:
+        cas_response: CasResponse,
+        request: SynapseRequest,
+        client_redirect_url: str,
+    ) -> None:
         """
-        Given a CAS username, retrieve the user ID for it and possibly register the user.
+        Given a CAS response, complete the login flow
 
-        Args:
-            remote_user_id: The username from the CAS response.
-            display_name: The display name from the CAS response.
-            user_agent: The user agent of the client making the request.
-            ip_address: The IP address of the client making the request.
+        Retrieves the remote user ID, registers the user if necessary, and serves
+        a redirect back to the client with a login-token.
 
-        Returns:
-             The user ID associated with this response.
+        Args:
+            cas_response: The parsed CAS response.
+            request: The request to respond to
+            client_redirect_url: The redirect URL passed in by the client.
+
+        Raises:
+            MappingException if there was a problem mapping the response to a user.
+            RedirectException: some mapping providers may raise this if they need
+                to redirect to an interstitial page.
         """
+        # Note that CAS does not support a mapping provider, so the logic is hard-coded.
+        localpart = map_username_to_mxid_localpart(cas_response.username)
+
+        async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
+            """
+            Map from CAS attributes to user attributes.
+            """
+            # Due to the grandfathering logic matching any previously registered
+            # mxids it isn't expected for there to be any failures.
+            if failures:
+                raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
+
+            display_name = cas_response.attributes.get(
+                self._cas_displayname_attribute, None
+            )
 
-        localpart = map_username_to_mxid_localpart(remote_user_id)
-        user_id = UserID(localpart, self._hostname).to_string()
-        registered_user_id = await self._auth_handler.check_user_exists(user_id)
+            return UserAttributes(localpart=localpart, display_name=display_name)
 
-        # If the user does not exist, register it.
-        if not registered_user_id:
-            registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart,
-                default_display_name=display_name,
-                user_agent_ips=[(user_agent, ip_address)],
+        async def grandfather_existing_users() -> Optional[str]:
+            # Since CAS did not always use the user_external_ids table, always
+            # to attempt to map to existing users.
+            user_id = UserID(localpart, self._hostname).to_string()
+
+            logger.debug(
+                "Looking for existing account based on mapped %s", user_id,
             )
 
-        return registered_user_id
+            users = await self._store.get_users_by_id_case_insensitive(user_id)
+            if users:
+                registered_user_id = list(users.keys())[0]
+                logger.info("Grandfathering mapping to %s", registered_user_id)
+                return registered_user_id
+
+            return None
+
+        await self._sso_handler.complete_sso_login_request(
+            self.idp_id,
+            cas_response.username,
+            request,
+            client_redirect_url,
+            cas_response_to_user_attributes,
+            grandfather_existing_users,
+        )