summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/cas_handler.py215
-rw-r--r--synapse/handlers/sso.py9
2 files changed, 158 insertions, 66 deletions
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index f4ea0a9767..e9891e1316 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -13,13 +13,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import urllib
-from typing import TYPE_CHECKING, Dict, Optional, Tuple
+import urllib.parse
+from typing import TYPE_CHECKING, Dict, Optional
 from xml.etree import ElementTree as ET
 
+import attr
+
 from twisted.web.client import PartialDownloadError
 
-from synapse.api.errors import Codes, LoginError
+from synapse.api.errors import HttpResponseException
 from synapse.http.site import SynapseRequest
 from synapse.types import UserID, map_username_to_mxid_localpart
 
@@ -29,6 +31,26 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class CasError(Exception):
+    """Used to catch errors when validating the CAS ticket.
+    """
+
+    def __init__(self, error, error_description=None):
+        self.error = error
+        self.error_description = error_description
+
+    def __str__(self):
+        if self.error_description:
+            return "{}: {}".format(self.error, self.error_description)
+        return self.error
+
+
+@attr.s(slots=True, frozen=True)
+class CasResponse:
+    username = attr.ib(type=str)
+    attributes = attr.ib(type=Dict[str, Optional[str]])
+
+
 class CasHandler:
     """
     Utility class for to handle the response from a CAS SSO service.
@@ -50,6 +72,8 @@ class CasHandler:
 
         self._http_client = hs.get_proxied_http_client()
 
+        self._sso_handler = hs.get_sso_handler()
+
     def _build_service_param(self, args: Dict[str, str]) -> str:
         """
         Generates a value to use as the "service" parameter when redirecting or
@@ -69,14 +93,20 @@ class CasHandler:
 
     async def _validate_ticket(
         self, ticket: str, service_args: Dict[str, str]
-    ) -> Tuple[str, Optional[str]]:
+    ) -> CasResponse:
         """
-        Validate a CAS ticket with the server, parse the response, and return the user and display name.
+        Validate a CAS ticket with the server, and return the parsed the response.
 
         Args:
             ticket: The CAS ticket from the client.
             service_args: Additional arguments to include in the service URL.
                 Should be the same as those passed to `get_redirect_url`.
+
+        Raises:
+            CasError: If there's an error parsing the CAS response.
+
+        Returns:
+            The parsed CAS response.
         """
         uri = self._cas_server_url + "/proxyValidate"
         args = {
@@ -89,66 +119,65 @@ class CasHandler:
             # Twisted raises this error if the connection is closed,
             # even if that's being used old-http style to signal end-of-data
             body = pde.response
+        except HttpResponseException as e:
+            description = (
+                (
+                    'Authorization server responded with a "{status}" error '
+                    "while exchanging the authorization code."
+                ).format(status=e.code),
+            )
+            raise CasError("server_error", description) from e
 
-        user, attributes = self._parse_cas_response(body)
-        displayname = attributes.pop(self._cas_displayname_attribute, None)
-
-        for required_attribute, required_value in self._cas_required_attributes.items():
-            # If required attribute was not in CAS Response - Forbidden
-            if required_attribute not in attributes:
-                raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
-            # Also need to check value
-            if required_value is not None:
-                actual_value = attributes[required_attribute]
-                # If required attribute value does not match expected - Forbidden
-                if required_value != actual_value:
-                    raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
-        return user, displayname
+        return self._parse_cas_response(body)
 
-    def _parse_cas_response(
-        self, cas_response_body: bytes
-    ) -> Tuple[str, Dict[str, Optional[str]]]:
+    def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
         """
         Retrieve the user and other parameters from the CAS response.
 
         Args:
             cas_response_body: The response from the CAS query.
 
+        Raises:
+            CasError: If there's an error parsing the CAS response.
+
         Returns:
-            A tuple of the user and a mapping of other attributes.
+            The parsed CAS response.
         """
+
+        # Ensure the response is valid.
+        root = ET.fromstring(cas_response_body)
+        if not root.tag.endswith("serviceResponse"):
+            raise CasError(
+                "missing_service_response",
+                "root of CAS response is not serviceResponse",
+            )
+
+        success = root[0].tag.endswith("authenticationSuccess")
+        if not success:
+            raise CasError("unsucessful_response", "Unsuccessful CAS response")
+
+        # Iterate through the nodes and pull out the user and any extra attributes.
         user = None
         attributes = {}
-        try:
-            root = ET.fromstring(cas_response_body)
-            if not root.tag.endswith("serviceResponse"):
-                raise Exception("root of CAS response is not serviceResponse")
-            success = root[0].tag.endswith("authenticationSuccess")
-            for child in root[0]:
-                if child.tag.endswith("user"):
-                    user = child.text
-                if child.tag.endswith("attributes"):
-                    for attribute in child:
-                        # ElementTree library expands the namespace in
-                        # attribute tags to the full URL of the namespace.
-                        # We don't care about namespace here and it will always
-                        # be encased in curly braces, so we remove them.
-                        tag = attribute.tag
-                        if "}" in tag:
-                            tag = tag.split("}")[1]
-                        attributes[tag] = attribute.text
-            if user is None:
-                raise Exception("CAS response does not contain user")
-        except Exception:
-            logger.exception("Error parsing CAS response")
-            raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
-        if not success:
-            raise LoginError(
-                401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
-            )
-        return user, attributes
+        for child in root[0]:
+            if child.tag.endswith("user"):
+                user = child.text
+            if child.tag.endswith("attributes"):
+                for attribute in child:
+                    # ElementTree library expands the namespace in
+                    # attribute tags to the full URL of the namespace.
+                    # We don't care about namespace here and it will always
+                    # be encased in curly braces, so we remove them.
+                    tag = attribute.tag
+                    if "}" in tag:
+                        tag = tag.split("}")[1]
+                    attributes[tag] = attribute.text
+
+        # Ensure a user was found.
+        if user is None:
+            raise CasError("no_user", "CAS response does not contain user")
+
+        return CasResponse(user, attributes)
 
     def get_redirect_url(self, service_args: Dict[str, str]) -> str:
         """
@@ -201,7 +230,68 @@ class CasHandler:
             args["redirectUrl"] = client_redirect_url
         if session:
             args["session"] = session
-        username, user_display_name = await self._validate_ticket(ticket, args)
+
+        try:
+            cas_response = await self._validate_ticket(ticket, args)
+        except CasError as e:
+            logger.exception("Could not validate ticket")
+            self._sso_handler.render_error(request, e.error, e.error_description, 401)
+            return
+
+        await self._handle_cas_response(
+            request, cas_response, client_redirect_url, session
+        )
+
+    async def _handle_cas_response(
+        self,
+        request: SynapseRequest,
+        cas_response: CasResponse,
+        client_redirect_url: Optional[str],
+        session: Optional[str],
+    ) -> None:
+        """Handle a CAS response to a ticket request.
+
+        Assumes that the response has been validated. Maps the user onto an MXID,
+        registering them if necessary, and returns a response to the browser.
+
+        Args:
+            request: the incoming request from the browser. We'll respond to it with an
+                HTML page or a redirect
+
+            cas_response: The parsed CAS response.
+
+            client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
+                This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
+
+            session: The session parameter from the `/cas/ticket` HTTP request, if given.
+                This should be the UI Auth session id.
+        """
+
+        # Ensure that the attributes of the logged in user meet the required
+        # attributes.
+        for required_attribute, required_value in self._cas_required_attributes.items():
+            # If required attribute was not in CAS Response - Forbidden
+            if required_attribute not in cas_response.attributes:
+                self._sso_handler.render_error(
+                    request,
+                    "unauthorised",
+                    "You are not authorised to log in here.",
+                    401,
+                )
+                return
+
+            # Also need to check value
+            if required_value is not None:
+                actual_value = cas_response.attributes[required_attribute]
+                # If required attribute value does not match expected - Forbidden
+                if required_value != actual_value:
+                    self._sso_handler.render_error(
+                        request,
+                        "unauthorised",
+                        "You are not authorised to log in here.",
+                        401,
+                    )
+                    return
 
         # Pull out the user-agent and IP from the request.
         user_agent = request.get_user_agent("")
@@ -209,7 +299,7 @@ class CasHandler:
 
         # Get the matrix ID from the CAS username.
         user_id = await self._map_cas_user_to_matrix_user(
-            username, user_display_name, user_agent, ip_address
+            cas_response, user_agent, ip_address
         )
 
         if session:
@@ -225,18 +315,13 @@ class CasHandler:
             )
 
     async def _map_cas_user_to_matrix_user(
-        self,
-        remote_user_id: str,
-        display_name: Optional[str],
-        user_agent: str,
-        ip_address: str,
+        self, cas_response: CasResponse, user_agent: str, ip_address: str,
     ) -> str:
         """
         Given a CAS username, retrieve the user ID for it and possibly register the user.
 
         Args:
-            remote_user_id: The username from the CAS response.
-            display_name: The display name from the CAS response.
+            cas_response: The parsed CAS response.
             user_agent: The user agent of the client making the request.
             ip_address: The IP address of the client making the request.
 
@@ -244,15 +329,17 @@ class CasHandler:
              The user ID associated with this response.
         """
 
-        localpart = map_username_to_mxid_localpart(remote_user_id)
+        localpart = map_username_to_mxid_localpart(cas_response.username)
         user_id = UserID(localpart, self._hostname).to_string()
         registered_user_id = await self._auth_handler.check_user_exists(user_id)
 
+        displayname = cas_response.attributes.get(self._cas_displayname_attribute, None)
+
         # If the user does not exist, register it.
         if not registered_user_id:
             registered_user_id = await self._registration_handler.register_user(
                 localpart=localpart,
-                default_display_name=display_name,
+                default_display_name=displayname,
                 user_agent_ips=[(user_agent, ip_address)],
             )
 
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 548b02211b..b0a8c8c7d2 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -101,7 +101,11 @@ class SsoHandler:
         self._username_mapping_sessions = {}  # type: Dict[str, UsernameMappingSession]
 
     def render_error(
-        self, request, error: str, error_description: Optional[str] = None
+        self,
+        request: Request,
+        error: str,
+        error_description: Optional[str] = None,
+        code: int = 400,
     ) -> None:
         """Renders the error template and responds with it.
 
@@ -113,11 +117,12 @@ class SsoHandler:
                 We'll respond with an HTML page describing the error.
             error: A technical identifier for this error.
             error_description: A human-readable description of the error.
+            code: The integer error code (an HTTP response code)
         """
         html = self._error_template.render(
             error=error, error_description=error_description
         )
-        respond_with_html(request, 400, html)
+        respond_with_html(request, code, html)
 
     async def get_sso_user_by_remote_user_id(
         self, auth_provider_id: str, remote_user_id: str