summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/rest/client/v1/login.py158
1 files changed, 33 insertions, 125 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d8c76a3465..b31e27f7b3 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -54,10 +54,6 @@ class LoginRestServlet(ClientV1RestServlet):
         self.jwt_secret = hs.config.jwt_secret
         self.jwt_algorithm = hs.config.jwt_algorithm
         self.cas_enabled = hs.config.cas_enabled
-        self.cas_server_url = hs.config.cas_server_url
-        self.cas_required_attributes = hs.config.cas_required_attributes
-        self.servername = hs.config.server_name
-        self.http_client = hs.get_simple_http_client()
         self.auth_handler = self.hs.get_auth_handler()
         self.device_handler = self.hs.get_device_handler()
 
@@ -110,17 +106,6 @@ class LoginRestServlet(ClientV1RestServlet):
                                        LoginRestServlet.JWT_TYPE):
                 result = yield self.do_jwt_login(login_submission)
                 defer.returnValue(result)
-            # TODO Delete this after all CAS clients switch to token login instead
-            elif self.cas_enabled and (login_submission["type"] ==
-                                       LoginRestServlet.CAS_TYPE):
-                uri = "%s/proxyValidate" % (self.cas_server_url,)
-                args = {
-                    "ticket": login_submission["ticket"],
-                    "service": login_submission["service"]
-                }
-                body = yield self.http_client.get_raw(uri, args)
-                result = yield self.do_cas_login(body)
-                defer.returnValue(result)
             elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
                 result = yield self.do_token_login(login_submission)
                 defer.returnValue(result)
@@ -191,51 +176,6 @@ class LoginRestServlet(ClientV1RestServlet):
 
         defer.returnValue((200, result))
 
-    # TODO Delete this after all CAS clients switch to token login instead
-    @defer.inlineCallbacks
-    def do_cas_login(self, cas_response_body):
-        user, attributes = self.parse_cas_response(cas_response_body)
-
-        for required_attribute, required_value in self.cas_required_attributes.items():
-            # If required attribute was not in CAS Response - Forbidden
-            if required_attribute not in attributes:
-                raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
-            # Also need to check value
-            if required_value is not None:
-                actual_value = attributes[required_attribute]
-                # If required attribute value does not match expected - Forbidden
-                if required_value != actual_value:
-                    raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
-        user_id = UserID.create(user, self.hs.hostname).to_string()
-        auth_handler = self.auth_handler
-        registered_user_id = yield auth_handler.check_user_exists(user_id)
-        if registered_user_id:
-            access_token, refresh_token = (
-                yield auth_handler.get_login_tuple_for_user_id(
-                    registered_user_id
-                )
-            )
-            result = {
-                "user_id": registered_user_id,  # may have changed
-                "access_token": access_token,
-                "refresh_token": refresh_token,
-                "home_server": self.hs.hostname,
-            }
-
-        else:
-            user_id, access_token = (
-                yield self.handlers.registration_handler.register(localpart=user)
-            )
-            result = {
-                "user_id": user_id,  # may have changed
-                "access_token": access_token,
-                "home_server": self.hs.hostname,
-            }
-
-        defer.returnValue((200, result))
-
     @defer.inlineCallbacks
     def do_jwt_login(self, login_submission):
         token = login_submission.get("token", None)
@@ -293,33 +233,6 @@ class LoginRestServlet(ClientV1RestServlet):
 
         defer.returnValue((200, result))
 
-    # TODO Delete this after all CAS clients switch to token login instead
-    def parse_cas_response(self, cas_response_body):
-        root = ET.fromstring(cas_response_body)
-        if not root.tag.endswith("serviceResponse"):
-            raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
-        if not root[0].tag.endswith("authenticationSuccess"):
-            raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
-        for child in root[0]:
-            if child.tag.endswith("user"):
-                user = child.text
-            if child.tag.endswith("attributes"):
-                attributes = {}
-                for attribute in child:
-                    # ElementTree library expands the namespace in attribute tags
-                    # to the full URL of the namespace.
-                    # See (https://docs.python.org/2/library/xml.etree.elementtree.html)
-                    # We don't care about namespace here and it will always be encased in
-                    # curly braces, so we remove them.
-                    if "}" in attribute.tag:
-                        attributes[attribute.tag.split("}")[1]] = attribute.text
-                    else:
-                        attributes[attribute.tag] = attribute.text
-        if user is None or attributes is None:
-            raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
-
-        return (user, attributes)
-
     def _register_device(self, user_id, login_submission):
         """Register a device for a user.
 
@@ -384,18 +297,6 @@ class SAML2RestServlet(ClientV1RestServlet):
         defer.returnValue((200, {"status": "not_authenticated"}))
 
 
-# TODO Delete this after all CAS clients switch to token login instead
-class CasRestServlet(ClientV1RestServlet):
-    PATTERNS = client_path_patterns("/login/cas", releases=())
-
-    def __init__(self, hs):
-        super(CasRestServlet, self).__init__(hs)
-        self.cas_server_url = hs.config.cas_server_url
-
-    def on_GET(self, request):
-        return (200, {"serverUrl": self.cas_server_url})
-
-
 class CasRedirectServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
 
@@ -480,30 +381,39 @@ class CasTicketServlet(ClientV1RestServlet):
         return urlparse.urlunparse(url_parts)
 
     def parse_cas_response(self, cas_response_body):
-        root = ET.fromstring(cas_response_body)
-        if not root.tag.endswith("serviceResponse"):
-            raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
-        if not root[0].tag.endswith("authenticationSuccess"):
-            raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
-        for child in root[0]:
-            if child.tag.endswith("user"):
-                user = child.text
-            if child.tag.endswith("attributes"):
-                attributes = {}
-                for attribute in child:
-                    # ElementTree library expands the namespace in attribute tags
-                    # to the full URL of the namespace.
-                    # See (https://docs.python.org/2/library/xml.etree.elementtree.html)
-                    # We don't care about namespace here and it will always be encased in
-                    # curly braces, so we remove them.
-                    if "}" in attribute.tag:
-                        attributes[attribute.tag.split("}")[1]] = attribute.text
-                    else:
-                        attributes[attribute.tag] = attribute.text
-        if user is None or attributes is None:
-            raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
-
-        return (user, attributes)
+        user = None
+        attributes = None
+        try:
+            root = ET.fromstring(cas_response_body)
+            if not root.tag.endswith("serviceResponse"):
+                raise Exception("root of CAS response is not serviceResponse")
+            success = (root[0].tag.endswith("authenticationSuccess"))
+            for child in root[0]:
+                if child.tag.endswith("user"):
+                    user = child.text
+                if child.tag.endswith("attributes"):
+                    attributes = {}
+                    for attribute in child:
+                        # ElementTree library expands the namespace in
+                        # attribute tags to the full URL of the namespace.
+                        # We don't care about namespace here and it will always
+                        # be encased in curly braces, so we remove them.
+                        tag = attribute.tag
+                        if "}" in tag:
+                            tag = tag.split("}")[1]
+                        attributes[tag] = attribute.text
+            if user is None:
+                raise Exception("CAS response does not contain user")
+            if attributes is None:
+                raise Exception("CAS response does not contain attributes")
+        except Exception:
+            logger.error("Error parsing CAS response", exc_info=1)
+            raise LoginError(401, "Invalid CAS response",
+                             errcode=Codes.UNAUTHORIZED)
+        if not success:
+            raise LoginError(401, "Unsuccessful CAS response",
+                             errcode=Codes.UNAUTHORIZED)
+        return user, attributes
 
 
 def register_servlets(hs, http_server):
@@ -513,5 +423,3 @@ def register_servlets(hs, http_server):
     if hs.config.cas_enabled:
         CasRedirectServlet(hs).register(http_server)
         CasTicketServlet(hs).register(http_server)
-        CasRestServlet(hs).register(http_server)
-    # TODO PasswordResetRestServlet(hs).register(http_server)