diff options
author | Richard van der Hoff <richard@matrix.org> | 2019-06-26 22:34:41 +0100 |
---|---|---|
committer | Richard van der Hoff <richard@matrix.org> | 2019-06-26 22:34:41 +0100 |
commit | a4daa899ec4cd195fc10936f68df5c78314b366c (patch) | |
tree | 35e88ff388b0f7652773a79930b732aa04f16bde /synapse/rest/client/v1/login.py | |
parent | changelog (diff) | |
parent | Improve docs on choosing server_name (#5558) (diff) | |
download | synapse-a4daa899ec4cd195fc10936f68df5c78314b366c.tar.xz |
Merge branch 'develop' into rav/saml2_client
Diffstat (limited to 'synapse/rest/client/v1/login.py')
-rw-r--r-- | synapse/rest/client/v1/login.py | 130 |
1 files changed, 57 insertions, 73 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 1a886cbbbf..a31d277935 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -44,10 +44,7 @@ def login_submission_legacy_convert(submission): to a typed object. """ if "user" in submission: - submission["identifier"] = { - "type": "m.id.user", - "user": submission["user"], - } + submission["identifier"] = {"type": "m.id.user", "user": submission["user"]} del submission["user"] if "medium" in submission and "address" in submission: @@ -73,11 +70,7 @@ def login_id_thirdparty_from_phone(identifier): msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"]) - return { - "type": "m.id.thirdparty", - "medium": "msisdn", - "address": msisdn, - } + return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn} class LoginRestServlet(RestServlet): @@ -124,9 +117,9 @@ class LoginRestServlet(RestServlet): # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - flows.extend(( - {"type": t} for t in self.auth_handler.get_supported_login_types() - )) + flows.extend( + ({"type": t} for t in self.auth_handler.get_supported_login_types()) + ) return (200, {"flows": flows}) @@ -136,7 +129,8 @@ class LoginRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): self._address_ratelimiter.ratelimit( - request.getClientIP(), time_now_s=self.hs.clock.time(), + request.getClientIP(), + time_now_s=self.hs.clock.time(), rate_hz=self.hs.config.rc_login_address.per_second, burst_count=self.hs.config.rc_login_address.burst_count, update=True, @@ -144,8 +138,9 @@ class LoginRestServlet(RestServlet): login_submission = parse_json_object_from_request(request) try: - if self.jwt_enabled and (login_submission["type"] == - LoginRestServlet.JWT_TYPE): + if self.jwt_enabled and ( + login_submission["type"] == LoginRestServlet.JWT_TYPE + ): result = yield self.do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: result = yield self.do_token_login(login_submission) @@ -174,10 +169,10 @@ class LoginRestServlet(RestServlet): # field) logger.info( "Got login request with identifier: %r, medium: %r, address: %r, user: %r", - login_submission.get('identifier'), - login_submission.get('medium'), - login_submission.get('address'), - login_submission.get('user'), + login_submission.get("identifier"), + login_submission.get("medium"), + login_submission.get("address"), + login_submission.get("user"), ) login_submission_legacy_convert(login_submission) @@ -194,13 +189,13 @@ class LoginRestServlet(RestServlet): # convert threepid identifiers to user IDs if identifier["type"] == "m.id.thirdparty": - address = identifier.get('address') - medium = identifier.get('medium') + address = identifier.get("address") + medium = identifier.get("medium") if medium is None or address is None: raise SynapseError(400, "Invalid thirdparty identifier") - if medium == 'email': + if medium == "email": # For emails, transform the address to lowercase. # We store all email addreses as lowercase in the DB. # (See add_threepid in synapse/handlers/auth.py) @@ -209,34 +204,28 @@ class LoginRestServlet(RestServlet): # Check for login providers that support 3pid login types canonical_user_id, callback_3pid = ( yield self.auth_handler.check_password_provider_3pid( - medium, - address, - login_submission["password"], + medium, address, login_submission["password"] ) ) if canonical_user_id: # Authentication through password provider and 3pid succeeded result = yield self._register_device_with_callback( - canonical_user_id, login_submission, callback_3pid, + canonical_user_id, login_submission, callback_3pid ) defer.returnValue(result) # No password providers were able to handle this 3pid # Check local store user_id = yield self.hs.get_datastore().get_user_id_by_threepid( - medium, address, + medium, address ) if not user_id: logger.warn( - "unknown 3pid identifier medium %s, address %r", - medium, address, + "unknown 3pid identifier medium %s, address %r", medium, address ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) - identifier = { - "type": "m.id.user", - "user": user_id, - } + identifier = {"type": "m.id.user", "user": user_id} # by this point, the identifier should be an m.id.user: if it's anything # else, we haven't understood it. @@ -246,22 +235,16 @@ class LoginRestServlet(RestServlet): raise SynapseError(400, "User identifier is missing 'user' key") canonical_user_id, callback = yield self.auth_handler.validate_login( - identifier["user"], - login_submission, + identifier["user"], login_submission ) result = yield self._register_device_with_callback( - canonical_user_id, login_submission, callback, + canonical_user_id, login_submission, callback ) defer.returnValue(result) @defer.inlineCallbacks - def _register_device_with_callback( - self, - user_id, - login_submission, - callback=None, - ): + def _register_device_with_callback(self, user_id, login_submission, callback=None): """ Registers a device with a given user_id. Optionally run a callback function after registration has completed. @@ -277,7 +260,7 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, + user_id, device_id, initial_display_name ) result = { @@ -294,7 +277,7 @@ class LoginRestServlet(RestServlet): @defer.inlineCallbacks def do_token_login(self, login_submission): - token = login_submission['token'] + token = login_submission["token"] auth_handler = self.auth_handler user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) @@ -303,7 +286,7 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, + user_id, device_id, initial_display_name ) result = { @@ -320,15 +303,16 @@ class LoginRestServlet(RestServlet): token = login_submission.get("token", None) if token is None: raise LoginError( - 401, "Token field for JWT is missing", - errcode=Codes.UNAUTHORIZED + 401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED ) import jwt from jwt.exceptions import InvalidTokenError try: - payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm]) + payload = jwt.decode( + token, self.jwt_secret, algorithms=[self.jwt_algorithm] + ) except jwt.ExpiredSignatureError: raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED) except InvalidTokenError: @@ -346,7 +330,7 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - registered_user_id, device_id, initial_display_name, + registered_user_id, device_id, initial_display_name ) result = { @@ -362,7 +346,7 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - registered_user_id, device_id, initial_display_name, + registered_user_id, device_id, initial_display_name ) result = { @@ -376,6 +360,7 @@ class LoginRestServlet(RestServlet): class BaseSsoRedirectServlet(RestServlet): """Common base class for /login/sso/redirect impls""" + PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) def on_GET(self, request): @@ -401,21 +386,20 @@ class BaseSsoRedirectServlet(RestServlet): raise NotImplementedError() -class CasRedirectServlet(RestServlet): +class CasRedirectServlet(BaseSsoRedirectServlet): def __init__(self, hs): super(CasRedirectServlet, self).__init__() - self.cas_server_url = hs.config.cas_server_url.encode('ascii') - self.cas_service_url = hs.config.cas_service_url.encode('ascii') + self.cas_server_url = hs.config.cas_server_url.encode("ascii") + self.cas_service_url = hs.config.cas_service_url.encode("ascii") def get_sso_url(self, client_redirect_url): - client_redirect_url_param = urllib.parse.urlencode({ - b"redirectUrl": client_redirect_url - }).encode('ascii') - hs_redirect_url = (self.cas_service_url + - b"/_matrix/client/r0/login/cas/ticket") - service_param = urllib.parse.urlencode({ - b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param) - }).encode('ascii') + client_redirect_url_param = urllib.parse.urlencode( + {b"redirectUrl": client_redirect_url} + ).encode("ascii") + hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket" + service_param = urllib.parse.urlencode( + {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)} + ).encode("ascii") return b"%s/login?%s" % (self.cas_server_url, service_param) @@ -436,7 +420,7 @@ class CasTicketServlet(RestServlet): uri = self.cas_server_url + "/proxyValidate" args = { "ticket": parse_string(request, "ticket", required=True), - "service": self.cas_service_url + "service": self.cas_service_url, } try: body = yield self._http_client.get_raw(uri, args) @@ -463,7 +447,7 @@ class CasTicketServlet(RestServlet): raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) return self._sso_auth_handler.on_successful_auth( - user, request, client_redirect_url, + user, request, client_redirect_url ) def parse_cas_response(self, cas_response_body): @@ -473,7 +457,7 @@ class CasTicketServlet(RestServlet): root = ET.fromstring(cas_response_body) if not root.tag.endswith("serviceResponse"): raise Exception("root of CAS response is not serviceResponse") - success = (root[0].tag.endswith("authenticationSuccess")) + success = root[0].tag.endswith("authenticationSuccess") for child in root[0]: if child.tag.endswith("user"): user = child.text @@ -491,11 +475,11 @@ class CasTicketServlet(RestServlet): raise Exception("CAS response does not contain user") except Exception: logger.error("Error parsing CAS response", exc_info=1) - raise LoginError(401, "Invalid CAS response", - errcode=Codes.UNAUTHORIZED) + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) if not success: - raise LoginError(401, "Unsuccessful CAS response", - errcode=Codes.UNAUTHORIZED) + raise LoginError( + 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED + ) return user, attributes @@ -507,11 +491,11 @@ class SAMLRedirectServlet(BaseSsoRedirectServlet): def get_sso_url(self, client_redirect_url): reqid, info = self._saml_client.prepare_for_authenticate( - relay_state=client_redirect_url, + relay_state=client_redirect_url ) - for key, value in info['headers']: - if key == 'Location': + for key, value in info["headers"]: + if key == "Location": return value # this shouldn't happen! @@ -526,6 +510,7 @@ class SSOAuthHandler(object): Args: hs (synapse.server.HomeServer) """ + def __init__(self, hs): self._hostname = hs.hostname self._auth_handler = hs.get_auth_handler() @@ -534,8 +519,7 @@ class SSOAuthHandler(object): @defer.inlineCallbacks def on_successful_auth( - self, username, request, client_redirect_url, - user_display_name=None, + self, username, request, client_redirect_url, user_display_name=None ): """Called once the user has successfully authenticated with the SSO. |