summary refs log tree commit diff
path: root/synapse/rest/client/v1/login.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v1/login.py')
-rw-r--r--synapse/rest/client/v1/login.py178
1 files changed, 24 insertions, 154 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d0d4999795..4de2f97d06 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,11 +14,6 @@
 # limitations under the License.
 
 import logging
-import xml.etree.ElementTree as ET
-
-from six.moves import urllib
-
-from twisted.web.client import PartialDownloadError
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
@@ -28,10 +23,10 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
     parse_string,
 )
-from synapse.push.mailer import load_jinja2_templates
+from synapse.http.site import SynapseRequest
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.well_known import WellKnownBuilder
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import UserID
 from synapse.util.msisdn import phone_number_to_msisdn
 
 logger = logging.getLogger(__name__)
@@ -402,7 +397,7 @@ class BaseSSORedirectServlet(RestServlet):
 
     PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
 
-    def on_GET(self, request):
+    def on_GET(self, request: SynapseRequest):
         args = request.args
         if b"redirectUrl" not in args:
             return 400, "Redirect URL not specified for SSO auth"
@@ -411,15 +406,15 @@ class BaseSSORedirectServlet(RestServlet):
         request.redirect(sso_url)
         finish_request(request)
 
-    def get_sso_url(self, client_redirect_url):
+    def get_sso_url(self, client_redirect_url: bytes) -> bytes:
         """Get the URL to redirect to, to perform SSO auth
 
         Args:
-            client_redirect_url (bytes): the URL that we should redirect the
+            client_redirect_url: the URL that we should redirect the
                 client to when everything is done
 
         Returns:
-            bytes: URL to redirect to
+            URL to redirect to
         """
         # to be implemented by subclasses
         raise NotImplementedError()
@@ -427,19 +422,12 @@ class BaseSSORedirectServlet(RestServlet):
 
 class CasRedirectServlet(BaseSSORedirectServlet):
     def __init__(self, hs):
-        super(CasRedirectServlet, self).__init__()
-        self.cas_server_url = hs.config.cas_server_url.encode("ascii")
-        self.cas_service_url = hs.config.cas_service_url.encode("ascii")
+        self._cas_handler = hs.get_cas_handler()
 
-    def get_sso_url(self, client_redirect_url):
-        client_redirect_url_param = urllib.parse.urlencode(
-            {b"redirectUrl": client_redirect_url}
+    def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+        return self._cas_handler.get_redirect_url(
+            {"redirectUrl": client_redirect_url}
         ).encode("ascii")
-        hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
-        service_param = urllib.parse.urlencode(
-            {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
-        ).encode("ascii")
-        return b"%s/login?%s" % (self.cas_server_url, service_param)
 
 
 class CasTicketServlet(RestServlet):
@@ -447,81 +435,25 @@ class CasTicketServlet(RestServlet):
 
     def __init__(self, hs):
         super(CasTicketServlet, self).__init__()
-        self.cas_server_url = hs.config.cas_server_url
-        self.cas_service_url = hs.config.cas_service_url
-        self.cas_displayname_attribute = hs.config.cas_displayname_attribute
-        self.cas_required_attributes = hs.config.cas_required_attributes
-        self._sso_auth_handler = SSOAuthHandler(hs)
-        self._http_client = hs.get_proxied_http_client()
-
-    async def on_GET(self, request):
-        client_redirect_url = parse_string(request, "redirectUrl", required=True)
-        uri = self.cas_server_url + "/proxyValidate"
-        args = {
-            "ticket": parse_string(request, "ticket", required=True),
-            "service": self.cas_service_url,
-        }
-        try:
-            body = await self._http_client.get_raw(uri, args)
-        except PartialDownloadError as pde:
-            # Twisted raises this error if the connection is closed,
-            # even if that's being used old-http style to signal end-of-data
-            body = pde.response
-        result = await self.handle_cas_response(request, body, client_redirect_url)
-        return result
+        self._cas_handler = hs.get_cas_handler()
 
-    def handle_cas_response(self, request, cas_response_body, client_redirect_url):
-        user, attributes = self.parse_cas_response(cas_response_body)
-        displayname = attributes.pop(self.cas_displayname_attribute, None)
+    async def on_GET(self, request: SynapseRequest) -> None:
+        client_redirect_url = parse_string(request, "redirectUrl")
+        ticket = parse_string(request, "ticket", required=True)
 
-        for required_attribute, required_value in self.cas_required_attributes.items():
-            # If required attribute was not in CAS Response - Forbidden
-            if required_attribute not in attributes:
-                raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+        # Maybe get a session ID (if this ticket is from user interactive
+        # authentication).
+        session = parse_string(request, "session")
 
-            # Also need to check value
-            if required_value is not None:
-                actual_value = attributes[required_attribute]
-                # If required attribute value does not match expected - Forbidden
-                if required_value != actual_value:
-                    raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+        # Either client_redirect_url or session must be provided.
+        if not client_redirect_url and not session:
+            message = "Missing string query parameter redirectUrl or session"
+            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
 
-        return self._sso_auth_handler.on_successful_auth(
-            user, request, client_redirect_url, displayname
+        await self._cas_handler.handle_ticket(
+            request, ticket, client_redirect_url, session
         )
 
-    def parse_cas_response(self, cas_response_body):
-        user = None
-        attributes = {}
-        try:
-            root = ET.fromstring(cas_response_body)
-            if not root.tag.endswith("serviceResponse"):
-                raise Exception("root of CAS response is not serviceResponse")
-            success = root[0].tag.endswith("authenticationSuccess")
-            for child in root[0]:
-                if child.tag.endswith("user"):
-                    user = child.text
-                if child.tag.endswith("attributes"):
-                    for attribute in child:
-                        # ElementTree library expands the namespace in
-                        # attribute tags to the full URL of the namespace.
-                        # We don't care about namespace here and it will always
-                        # be encased in curly braces, so we remove them.
-                        tag = attribute.tag
-                        if "}" in tag:
-                            tag = tag.split("}")[1]
-                        attributes[tag] = attribute.text
-            if user is None:
-                raise Exception("CAS response does not contain user")
-        except Exception:
-            logger.exception("Error parsing CAS response")
-            raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
-        if not success:
-            raise LoginError(
-                401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
-            )
-        return user, attributes
-
 
 class SAMLRedirectServlet(BaseSSORedirectServlet):
     PATTERNS = client_patterns("/login/sso/redirect", v1=True)
@@ -529,72 +461,10 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
     def __init__(self, hs):
         self._saml_handler = hs.get_saml_handler()
 
-    def get_sso_url(self, client_redirect_url):
+    def get_sso_url(self, client_redirect_url: bytes) -> bytes:
         return self._saml_handler.handle_redirect_request(client_redirect_url)
 
 
-class SSOAuthHandler(object):
-    """
-    Utility class for Resources and Servlets which handle the response from a SSO
-    service
-
-    Args:
-        hs (synapse.server.HomeServer)
-    """
-
-    def __init__(self, hs):
-        self._hostname = hs.hostname
-        self._auth_handler = hs.get_auth_handler()
-        self._registration_handler = hs.get_registration_handler()
-        self._macaroon_gen = hs.get_macaroon_generator()
-
-        # Load the redirect page HTML template
-        self._template = load_jinja2_templates(
-            hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
-        )[0]
-
-        self._server_name = hs.config.server_name
-
-        # cast to tuple for use with str.startswith
-        self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
-
-    async def on_successful_auth(
-        self, username, request, client_redirect_url, user_display_name=None
-    ):
-        """Called once the user has successfully authenticated with the SSO.
-
-        Registers the user if necessary, and then returns a redirect (with
-        a login token) to the client.
-
-        Args:
-            username (unicode|bytes): the remote user id. We'll map this onto
-                something sane for a MXID localpath.
-
-            request (SynapseRequest): the incoming request from the browser. We'll
-                respond to it with a redirect.
-
-            client_redirect_url (unicode): the redirect_url the client gave us when
-                it first started the process.
-
-            user_display_name (unicode|None): if set, and we have to register a new user,
-                we will set their displayname to this.
-
-        Returns:
-            Deferred[none]: Completes once we have handled the request.
-        """
-        localpart = map_username_to_mxid_localpart(username)
-        user_id = UserID(localpart, self._hostname).to_string()
-        registered_user_id = await self._auth_handler.check_user_exists(user_id)
-        if not registered_user_id:
-            registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart, default_display_name=user_display_name
-            )
-
-        self._auth_handler.complete_sso_login(
-            registered_user_id, request, client_redirect_url
-        )
-
-
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
     if hs.config.cas_enabled: