diff options
Diffstat (limited to 'synapse/rest/client/v1/login.py')
-rw-r--r-- | synapse/rest/client/v1/login.py | 84 |
1 files changed, 57 insertions, 27 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index a99dcaab6f..2e3e4f39f3 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -45,8 +45,8 @@ class LoginRestServlet(ClientV1RestServlet): self.idp_redirect_url = hs.config.saml2_idp_redirect_url self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled - self.cas_server_url = hs.config.cas_server_url + self.cas_required_attributes = hs.config.cas_required_attributes self.servername = hs.config.server_name def on_GET(self, request): @@ -125,6 +125,47 @@ class LoginRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def do_cas_login(self, cas_response_body): + user, attributes = self.parse_cas_response(cas_response_body) + + for required_attribute, required_value in self.cas_required_attributes.items(): + # If required attribute was not in CAS Response - Forbidden + if required_attribute not in attributes: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + # Also need to check value + if required_value is not None: + actual_value = attributes[required_attribute] + # If required attribute value does not match expected - Forbidden + if required_value != actual_value: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + user_id = UserID.create(user, self.hs.hostname).to_string() + auth_handler = self.handlers.auth_handler + user_exists = yield auth_handler.does_user_exist(user_id) + if user_exists: + user_id, access_token, refresh_token = ( + yield auth_handler.login_with_cas_user_id(user_id) + ) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "refresh_token": refresh_token, + "home_server": self.hs.hostname, + } + + else: + user_id, access_token = ( + yield self.handlers.registration_handler.register(localpart=user) + ) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "home_server": self.hs.hostname, + } + + defer.returnValue((200, result)) + + def parse_cas_response(self, cas_response_body): root = ET.fromstring(cas_response_body) if not root.tag.endswith("serviceResponse"): raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) @@ -133,33 +174,22 @@ class LoginRestServlet(ClientV1RestServlet): for child in root[0]: if child.tag.endswith("user"): user = child.text - user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if user_exists: - user_id, access_token, refresh_token = ( - yield auth_handler.login_with_cas_user_id(user_id) - ) - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "refresh_token": refresh_token, - "home_server": self.hs.hostname, - } - - else: - user_id, access_token = ( - yield self.handlers.registration_handler.register(localpart=user) - ) - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "home_server": self.hs.hostname, - } - - defer.returnValue((200, result)) + if child.tag.endswith("attributes"): + attributes = {} + for attribute in child: + # ElementTree library expands the namespace in attribute tags + # to the full URL of the namespace. + # See (https://docs.python.org/2/library/xml.etree.elementtree.html) + # We don't care about namespace here and it will always be encased in + # curly braces, so we remove them. + if "}" in attribute.tag: + attributes[attribute.tag.split("}")[1]] = attribute.text + else: + attributes[attribute.tag] = attribute.text + if user is None or attributes is None: + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) + return (user, attributes) class LoginFallbackRestServlet(ClientV1RestServlet): |