diff options
Diffstat (limited to 'synapse/rest/client')
-rw-r--r-- | synapse/rest/client/v1/login.py | 145 |
1 files changed, 144 insertions, 1 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 4ea06c1434..5a2cedacb0 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -22,6 +22,7 @@ from base import ClientV1RestServlet, client_path_pattern import simplejson as json import urllib +import urlparse import logging from saml2 import BINDING_HTTP_POST @@ -39,6 +40,7 @@ class LoginRestServlet(ClientV1RestServlet): PASS_TYPE = "m.login.password" SAML2_TYPE = "m.login.saml2" CAS_TYPE = "m.login.cas" + TOKEN_TYPE = "m.login.token" def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) @@ -58,6 +60,7 @@ class LoginRestServlet(ClientV1RestServlet): flows.append({"type": LoginRestServlet.CAS_TYPE}) if self.password_enabled: flows.append({"type": LoginRestServlet.PASS_TYPE}) + flows.append({"type": LoginRestServlet.TOKEN_TYPE}) return (200, {"flows": flows}) def on_OPTIONS(self, request): @@ -83,6 +86,7 @@ class LoginRestServlet(ClientV1RestServlet): "uri": "%s%s" % (self.idp_redirect_url, relay_state) } defer.returnValue((200, result)) + # TODO Delete this after all CAS clients switch to token login instead elif self.cas_enabled and (login_submission["type"] == LoginRestServlet.CAS_TYPE): # TODO: get this from the homeserver rather than creating a new one for @@ -96,6 +100,9 @@ class LoginRestServlet(ClientV1RestServlet): body = yield http_client.get_raw(uri, args) result = yield self.do_cas_login(body) defer.returnValue(result) + elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: + result = yield self.do_token_login(login_submission) + defer.returnValue(result) else: raise SynapseError(400, "Bad login type.") except KeyError: @@ -132,6 +139,26 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) @defer.inlineCallbacks + def do_token_login(self, login_submission): + token = login_submission['token'] + auth_handler = self.handlers.auth_handler + user_id = ( + yield auth_handler.validate_short_term_login_token_and_get_user_id(token) + ) + user_id, access_token, refresh_token = ( + yield auth_handler.login_with_user_id(user_id) + ) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "refresh_token": refresh_token, + "home_server": self.hs.hostname, + } + + defer.returnValue((200, result)) + + # TODO Delete this after all CAS clients switch to token login instead + @defer.inlineCallbacks def do_cas_login(self, cas_response_body): user, attributes = self.parse_cas_response(cas_response_body) @@ -152,7 +179,7 @@ class LoginRestServlet(ClientV1RestServlet): user_exists = yield auth_handler.does_user_exist(user_id) if user_exists: user_id, access_token, refresh_token = ( - yield auth_handler.login_with_cas_user_id(user_id) + yield auth_handler.login_with_user_id(user_id) ) result = { "user_id": user_id, # may have changed @@ -173,6 +200,7 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) + # TODO Delete this after all CAS clients switch to token login instead def parse_cas_response(self, cas_response_body): root = ET.fromstring(cas_response_body) if not root.tag.endswith("serviceResponse"): @@ -243,6 +271,7 @@ class SAML2RestServlet(ClientV1RestServlet): defer.returnValue((200, {"status": "not_authenticated"})) +# TODO Delete this after all CAS clients switch to token login instead class CasRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/login/cas") @@ -254,6 +283,118 @@ class CasRestServlet(ClientV1RestServlet): return (200, {"serverUrl": self.cas_server_url}) +class CasRedirectServlet(ClientV1RestServlet): + PATTERN = client_path_pattern("/login/cas/redirect") + + def __init__(self, hs): + super(CasRedirectServlet, self).__init__(hs) + self.cas_server_url = hs.config.cas_server_url + self.cas_service_url = hs.config.cas_service_url + + def on_GET(self, request): + args = request.args + if "redirectUrl" not in args: + return (400, "Redirect URL not specified for CAS auth") + clientRedirectUrlParam = urllib.urlencode({ + "redirectUrl": args["redirectUrl"][0] + }) + hsRedirectUrl = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket" + serviceParam = urllib.urlencode({ + "service": "%s?%s" % (hsRedirectUrl, clientRedirectUrlParam) + }) + request.redirect("%s?%s" % (self.cas_server_url, serviceParam)) + request.finish() + defer.returnValue(None) + + +class CasTicketServlet(ClientV1RestServlet): + PATTERN = client_path_pattern("/login/cas/ticket") + + def __init__(self, hs): + super(CasTicketServlet, self).__init__(hs) + self.cas_server_url = hs.config.cas_server_url + self.cas_service_url = hs.config.cas_service_url + self.cas_required_attributes = hs.config.cas_required_attributes + + @defer.inlineCallbacks + def on_GET(self, request): + clientRedirectUrl = request.args["redirectUrl"][0] + # TODO: get this from the homeserver rather than creating a new one for + # each request + http_client = SimpleHttpClient(self.hs) + uri = self.cas_server_url + "/proxyValidate" + args = { + "ticket": request.args["ticket"], + "service": self.cas_service_url + } + body = yield http_client.get_raw(uri, args) + result = yield self.handle_cas_response(request, body, clientRedirectUrl) + defer.returnValue(result) + + @defer.inlineCallbacks + def handle_cas_response(self, request, cas_response_body, clientRedirectUrl): + user, attributes = self.parse_cas_response(cas_response_body) + + for required_attribute, required_value in self.cas_required_attributes.items(): + # If required attribute was not in CAS Response - Forbidden + if required_attribute not in attributes: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + # Also need to check value + if required_value is not None: + actual_value = attributes[required_attribute] + # If required attribute value does not match expected - Forbidden + if required_value != actual_value: + raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + + user_id = UserID.create(user, self.hs.hostname).to_string() + auth_handler = self.handlers.auth_handler + user_exists = yield auth_handler.does_user_exist(user_id) + if not user_exists: + user_id, ignored = ( + yield self.handlers.registration_handler.register(localpart=user) + ) + + login_token = auth_handler.generate_short_term_login_token(user_id) + redirectUrl = self.add_login_token_to_redirect_url(clientRedirectUrl, login_token) + request.redirect(redirectUrl) + request.finish() + defer.returnValue(None) + + def add_login_token_to_redirect_url(self, url, token): + url_parts = list(urlparse.urlparse(url)) + query = dict(urlparse.parse_qsl(url_parts[4])) + query.update({"loginToken": token}) + url_parts[4] = urllib.urlencode(query) + return urlparse.urlunparse(url_parts) + + def parse_cas_response(self, cas_response_body): + root = ET.fromstring(cas_response_body) + if not root.tag.endswith("serviceResponse"): + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) + if not root[0].tag.endswith("authenticationSuccess"): + raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) + for child in root[0]: + if child.tag.endswith("user"): + user = child.text + if child.tag.endswith("attributes"): + attributes = {} + for attribute in child: + # ElementTree library expands the namespace in attribute tags + # to the full URL of the namespace. + # See (https://docs.python.org/2/library/xml.etree.elementtree.html) + # We don't care about namespace here and it will always be encased in + # curly braces, so we remove them. + if "}" in attribute.tag: + attributes[attribute.tag.split("}")[1]] = attribute.text + else: + attributes[attribute.tag] = attribute.text + if user is None or attributes is None: + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) + + return (user, attributes) + + def _parse_json(request): try: content = json.loads(request.content.read()) @@ -269,5 +410,7 @@ def register_servlets(hs, http_server): if hs.config.saml2_enabled: SAML2RestServlet(hs).register(http_server) if hs.config.cas_enabled: + CasRedirectServlet(hs).register(http_server) + CasTicketServlet(hs).register(http_server) CasRestServlet(hs).register(http_server) # TODO PasswordResetRestServlet(hs).register(http_server) |