From 1c1242acba9694a3a4b1eb3b14ec0bac11ee4ff8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 Mar 2020 07:39:34 -0400 Subject: Validate that the session is not modified during UI-Auth (#7068) --- synapse/handlers/auth.py | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) (limited to 'synapse/handlers/auth.py') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7860f9625e..2ce1425dfa 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -125,7 +125,11 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def validate_user_via_ui_auth( - self, requester: Requester, request_body: Dict[str, Any], clientip: str + self, + requester: Requester, + request: SynapseRequest, + request_body: Dict[str, Any], + clientip: str, ): """ Checks that the user is who they claim to be, via a UI auth. @@ -137,6 +141,8 @@ class AuthHandler(BaseHandler): Args: requester: The user, as given by the access token + request: The request sent by the client. + request_body: The body of the request sent by the client clientip: The IP address of the client. @@ -172,7 +178,9 @@ class AuthHandler(BaseHandler): flows = [[login_type] for login_type in self._supported_login_types] try: - result, params, _ = yield self.check_auth(flows, request_body, clientip) + result, params, _ = yield self.check_auth( + flows, request, request_body, clientip + ) except LoginError: # Update the ratelimite to say we failed (`can_do_action` doesn't raise). self._failed_uia_attempts_ratelimiter.can_do_action( @@ -211,7 +219,11 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def check_auth( - self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str + self, + flows: List[List[str]], + request: SynapseRequest, + clientdict: Dict[str, Any], + clientip: str, ): """ Takes a dictionary sent by the client in the login / registration @@ -231,6 +243,8 @@ class AuthHandler(BaseHandler): strings representing auth-types. At least one full flow must be completed in order for auth to be successful. + request: The request sent by the client. + clientdict: The dictionary from the client root level, not the 'auth' key: this method prompts for auth if none is sent. @@ -270,13 +284,27 @@ class AuthHandler(BaseHandler): # email auth link on there). It's probably too open to abuse # because it lets unauthenticated clients store arbitrary objects # on a homeserver. - # Revisit: Assumimg the REST APIs do sensible validation, the data + # Revisit: Assuming the REST APIs do sensible validation, the data # isn't arbintrary. session["clientdict"] = clientdict self._save_session(session) elif "clientdict" in session: clientdict = session["clientdict"] + # Ensure that the queried operation does not vary between stages of + # the UI authentication session. This is done by generating a stable + # comparator based on the URI, method, and body (minus the auth dict) + # and storing it during the initial query. Subsequent queries ensure + # that this comparator has not changed. + comparator = (request.uri, request.method, clientdict) + if "ui_auth" not in session: + session["ui_auth"] = comparator + elif session["ui_auth"] != comparator: + raise SynapseError( + 403, + "Requested operation has changed during the UI authentication session.", + ) + if not authdict: raise InteractiveAuthIncompleteError( self._auth_dict_for_flows(flows, session) @@ -322,6 +350,7 @@ class AuthHandler(BaseHandler): creds, list(clientdict), ) + return creds, clientdict, session["id"] ret = self._auth_dict_for_flows(flows, session) -- cgit 1.5.1 From b9930d24a05e47c36845d8607b12a45eea889be0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 1 Apr 2020 08:48:00 -0400 Subject: Support SAML in the user interactive authentication workflow. (#7102) --- CHANGES.md | 8 ++ changelog.d/7102.feature | 1 + synapse/api/constants.py | 1 + synapse/handlers/auth.py | 116 +++++++++++++++++++++++++++- synapse/handlers/saml_handler.py | 51 +++++++++--- synapse/res/templates/sso_auth_confirm.html | 14 ++++ synapse/rest/client/v2_alpha/account.py | 19 ++++- synapse/rest/client/v2_alpha/auth.py | 42 +++++----- synapse/rest/client/v2_alpha/devices.py | 12 ++- synapse/rest/client/v2_alpha/keys.py | 6 +- synapse/rest/client/v2_alpha/register.py | 1 + 11 files changed, 227 insertions(+), 44 deletions(-) create mode 100644 changelog.d/7102.feature create mode 100644 synapse/res/templates/sso_auth_confirm.html (limited to 'synapse/handlers/auth.py') diff --git a/CHANGES.md b/CHANGES.md index f794c585b7..b997af1630 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,11 @@ +Next version +============ + +* A new template (`sso_auth_confirm.html`) was added to Synapse. If your Synapse + is configured to use SSO and a custom `sso_redirect_confirm_template_dir` + configuration then this template will need to be duplicated into that + directory. + Synapse 1.12.0 (2020-03-23) =========================== diff --git a/changelog.d/7102.feature b/changelog.d/7102.feature new file mode 100644 index 0000000000..01057aa396 --- /dev/null +++ b/changelog.d/7102.feature @@ -0,0 +1 @@ +Support SSO in the user interactive authentication workflow. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index cc8577552b..fda2c2e5bb 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -61,6 +61,7 @@ class LoginType(object): MSISDN = "m.login.msisdn" RECAPTCHA = "m.login.recaptcha" TERMS = "m.login.terms" + SSO = "org.matrix.login.sso" DUMMY = "m.login.dummy" # Only for C/S API v1 diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 2ce1425dfa..7c09d15a72 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -53,6 +53,31 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) +SUCCESS_TEMPLATE = """ + + +Success! + + + + + +
+

Thank you

+

You may now close this window and return to the application

+
+ + +""" + + class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 @@ -91,6 +116,7 @@ class AuthHandler(BaseHandler): self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.password_enabled + self._saml2_enabled = hs.config.saml2_enabled # we keep this as a list despite the O(N^2) implication so that we can # keep PASSWORD first and avoid confusing clients which pick the first @@ -106,6 +132,13 @@ class AuthHandler(BaseHandler): if t not in login_types: login_types.append(t) self._supported_login_types = login_types + # Login types and UI Auth types have a heavy overlap, but are not + # necessarily identical. Login types have SSO (and other login types) + # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET. + ui_auth_types = login_types.copy() + if self._saml2_enabled: + ui_auth_types.append(LoginType.SSO) + self._supported_ui_auth_types = ui_auth_types # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. @@ -113,10 +146,21 @@ class AuthHandler(BaseHandler): self._clock = self.hs.get_clock() - # Load the SSO redirect confirmation page HTML template + # Load the SSO HTML templates. + + # The following template is shown to the user during a client login via SSO, + # after the SSO completes and before redirecting them back to their client. + # It notifies the user they are about to give access to their matrix account + # to the client. self._sso_redirect_confirm_template = load_jinja2_templates( hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"], )[0] + # The following template is shown during user interactive authentication + # in the fallback auth scenario. It notifies the user that they are + # authenticating for an operation to occur on their account. + self._sso_auth_confirm_template = load_jinja2_templates( + hs.config.sso_redirect_confirm_template_dir, ["sso_auth_confirm.html"], + )[0] self._server_name = hs.config.server_name @@ -130,6 +174,7 @@ class AuthHandler(BaseHandler): request: SynapseRequest, request_body: Dict[str, Any], clientip: str, + description: str, ): """ Checks that the user is who they claim to be, via a UI auth. @@ -147,6 +192,9 @@ class AuthHandler(BaseHandler): clientip: The IP address of the client. + description: A human readable string to be displayed to the user that + describes the operation happening on their account. + Returns: defer.Deferred[dict]: the parameters for this request (which may have been given only in a previous call). @@ -175,11 +223,11 @@ class AuthHandler(BaseHandler): ) # build a list of supported flows - flows = [[login_type] for login_type in self._supported_login_types] + flows = [[login_type] for login_type in self._supported_ui_auth_types] try: result, params, _ = yield self.check_auth( - flows, request, request_body, clientip + flows, request, request_body, clientip, description ) except LoginError: # Update the ratelimite to say we failed (`can_do_action` doesn't raise). @@ -193,7 +241,7 @@ class AuthHandler(BaseHandler): raise # find the completed login type - for login_type in self._supported_login_types: + for login_type in self._supported_ui_auth_types: if login_type not in result: continue @@ -224,6 +272,7 @@ class AuthHandler(BaseHandler): request: SynapseRequest, clientdict: Dict[str, Any], clientip: str, + description: str, ): """ Takes a dictionary sent by the client in the login / registration @@ -250,6 +299,9 @@ class AuthHandler(BaseHandler): clientip: The IP address of the client. + description: A human readable string to be displayed to the user that + describes the operation happening on their account. + Returns: defer.Deferred[dict, dict, str]: a deferred tuple of (creds, params, session_id). @@ -299,12 +351,18 @@ class AuthHandler(BaseHandler): comparator = (request.uri, request.method, clientdict) if "ui_auth" not in session: session["ui_auth"] = comparator + self._save_session(session) elif session["ui_auth"] != comparator: raise SynapseError( 403, "Requested operation has changed during the UI authentication session.", ) + # Add a human readable description to the session. + if "description" not in session: + session["description"] = description + self._save_session(session) + if not authdict: raise InteractiveAuthIncompleteError( self._auth_dict_for_flows(flows, session) @@ -991,6 +1049,56 @@ class AuthHandler(BaseHandler): else: return defer.succeed(False) + def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: + """ + Get the HTML for the SSO redirect confirmation page. + + Args: + redirect_url: The URL to redirect to the SSO provider. + session_id: The user interactive authentication session ID. + + Returns: + The HTML to render. + """ + session = self._get_session_info(session_id) + # Get the human readable operation of what is occurring, falling back to + # a generic message if it isn't available for some reason. + description = session.get("description", "modify your account") + return self._sso_auth_confirm_template.render( + description=description, redirect_url=redirect_url, + ) + + def complete_sso_ui_auth( + self, registered_user_id: str, session_id: str, request: SynapseRequest, + ): + """Having figured out a mxid for this user, complete the HTTP request + + Args: + registered_user_id: The registered user ID to complete SSO login for. + request: The request to complete. + client_redirect_url: The URL to which to redirect the user at the end of the + process. + """ + # Mark the stage of the authentication as successful. + sess = self._get_session_info(session_id) + if "creds" not in sess: + sess["creds"] = {} + creds = sess["creds"] + + # Save the user who authenticated with SSO, this will be used to ensure + # that the account be modified is also the person who logged in. + creds[LoginType.SSO] = registered_user_id + self._save_session(sess) + + # Render the HTML and return. + html_bytes = SUCCESS_TEMPLATE.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + finish_request(request) + def complete_sso_login( self, registered_user_id: str, diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index dc04b53f43..4741c82f61 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import re -from typing import Tuple +from typing import Optional, Tuple import attr import saml2 @@ -44,11 +44,15 @@ class Saml2SessionData: # time the session was created, in milliseconds creation_time = attr.ib() + # The user interactive authentication session ID associated with this SAML + # session (or None if this SAML session is for an initial login). + ui_auth_session_id = attr.ib(type=Optional[str], default=None) class SamlHandler: def __init__(self, hs): self._saml_client = Saml2Client(hs.config.saml2_sp_config) + self._auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() @@ -77,12 +81,14 @@ class SamlHandler: self._error_html_content = hs.config.saml2_error_html_content - def handle_redirect_request(self, client_redirect_url): + def handle_redirect_request(self, client_redirect_url, ui_auth_session_id=None): """Handle an incoming request to /login/sso/redirect Args: client_redirect_url (bytes): the URL that we should redirect the client to when everything is done + ui_auth_session_id (Optional[str]): The session ID of the ongoing UI Auth (or + None if this is a login). Returns: bytes: URL to redirect to @@ -92,7 +98,9 @@ class SamlHandler: ) now = self._clock.time_msec() - self._outstanding_requests_dict[reqid] = Saml2SessionData(creation_time=now) + self._outstanding_requests_dict[reqid] = Saml2SessionData( + creation_time=now, ui_auth_session_id=ui_auth_session_id, + ) for key, value in info["headers"]: if key == "Location": @@ -119,7 +127,9 @@ class SamlHandler: self.expire_sessions() try: - user_id = await self._map_saml_response_to_user(resp_bytes, relay_state) + user_id, current_session = await self._map_saml_response_to_user( + resp_bytes, relay_state + ) except RedirectException: # Raise the exception as per the wishes of the SAML module response raise @@ -137,9 +147,28 @@ class SamlHandler: finish_request(request) return - self._auth_handler.complete_sso_login(user_id, request, relay_state) + # Complete the interactive auth session or the login. + if current_session and current_session.ui_auth_session_id: + self._auth_handler.complete_sso_ui_auth( + user_id, current_session.ui_auth_session_id, request + ) + + else: + self._auth_handler.complete_sso_login(user_id, request, relay_state) + + async def _map_saml_response_to_user( + self, resp_bytes: str, client_redirect_url: str + ) -> Tuple[str, Optional[Saml2SessionData]]: + """ + Given a sample response, retrieve the cached session and user for it. - async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url): + Args: + resp_bytes: The SAML response. + client_redirect_url: The redirect URL passed in by the client. + + Returns: + Tuple of the user ID and SAML session associated with this response. + """ try: saml2_auth = self._saml_client.parse_authn_request_response( resp_bytes, @@ -167,7 +196,9 @@ class SamlHandler: logger.info("SAML2 mapped attributes: %s", saml2_auth.ava) - self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None) + current_session = self._outstanding_requests_dict.pop( + saml2_auth.in_response_to, None + ) remote_user_id = self._user_mapping_provider.get_remote_user_id( saml2_auth, client_redirect_url @@ -188,7 +219,7 @@ class SamlHandler: ) if registered_user_id is not None: logger.info("Found existing mapping %s", registered_user_id) - return registered_user_id + return registered_user_id, current_session # backwards-compatibility hack: see if there is an existing user with a # suitable mapping from the uid @@ -213,7 +244,7 @@ class SamlHandler: await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) - return registered_user_id + return registered_user_id, current_session # Map saml response to user attributes using the configured mapping provider for i in range(1000): @@ -260,7 +291,7 @@ class SamlHandler: await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) - return registered_user_id + return registered_user_id, current_session def expire_sessions(self): expire_before = self._clock.time_msec() - self._saml2_session_lifetime diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html new file mode 100644 index 0000000000..0d9de9d465 --- /dev/null +++ b/synapse/res/templates/sso_auth_confirm.html @@ -0,0 +1,14 @@ + + + Authentication + + +
+

+ A client is trying to {{ description | e }}. To confirm this action, + re-authenticate with single sign-on. + If you did not expect this, your account may be compromised! +

+
+ + diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index f80b5e40ea..31435b1e1c 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -234,7 +234,11 @@ class PasswordRestServlet(RestServlet): if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req(request) params = await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, self.hs.get_ip_from_request(request), + requester, + request, + body, + self.hs.get_ip_from_request(request), + "modify your account password", ) user_id = requester.user.to_string() else: @@ -244,6 +248,7 @@ class PasswordRestServlet(RestServlet): request, body, self.hs.get_ip_from_request(request), + "modify your account password", ) if LoginType.EMAIL_IDENTITY in result: @@ -311,7 +316,11 @@ class DeactivateAccountRestServlet(RestServlet): return 200, {} await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, self.hs.get_ip_from_request(request), + requester, + request, + body, + self.hs.get_ip_from_request(request), + "deactivate your account", ) result = await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase, id_server=body.get("id_server") @@ -669,7 +678,11 @@ class ThreepidAddRestServlet(RestServlet): assert_valid_client_secret(client_secret) await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, self.hs.get_ip_from_request(request), + requester, + request, + body, + self.hs.get_ip_from_request(request), + "add a third-party identifier to your account", ) validation_session = await self.identity_handler.validate_threepid_session( diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 85cf5a14c6..1787562b90 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -18,6 +18,7 @@ import logging from synapse.api.constants import LoginType from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_API_PREFIX +from synapse.handlers.auth import SUCCESS_TEMPLATE from synapse.http.server import finish_request from synapse.http.servlet import RestServlet, parse_string @@ -89,30 +90,6 @@ TERMS_TEMPLATE = """ """ -SUCCESS_TEMPLATE = """ - - -Success! - - - - - -
-

Thank you

-

You may now close this window and return to the application

-
- - -""" - class AuthRestServlet(RestServlet): """ @@ -130,6 +107,11 @@ class AuthRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() + # SSO configuration. + self._saml_enabled = hs.config.saml2_enabled + if self._saml_enabled: + self._saml_handler = hs.get_saml_handler() + def on_GET(self, request, stagetype): session = parse_string(request, "session") if not session: @@ -150,6 +132,15 @@ class AuthRestServlet(RestServlet): "myurl": "%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), } + + elif stagetype == LoginType.SSO and self._saml_enabled: + # Display a confirmation page which prompts the user to + # re-authenticate with their SSO provider. + client_redirect_url = "" + sso_redirect_url = self._saml_handler.handle_redirect_request( + client_redirect_url, session + ) + html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) else: raise SynapseError(404, "Unknown auth stage type") @@ -210,6 +201,9 @@ class AuthRestServlet(RestServlet): "myurl": "%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), } + elif stagetype == LoginType.SSO: + # The SSO fallback workflow should not post here, + raise SynapseError(404, "Fallback SSO auth does not support POST requests.") else: raise SynapseError(404, "Unknown auth stage type") diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 119d979052..c0714fcfb1 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -81,7 +81,11 @@ class DeleteDevicesRestServlet(RestServlet): assert_params_in_dict(body, ["devices"]) await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, self.hs.get_ip_from_request(request), + requester, + request, + body, + self.hs.get_ip_from_request(request), + "remove device(s) from your account", ) await self.device_handler.delete_devices( @@ -127,7 +131,11 @@ class DeviceRestServlet(RestServlet): raise await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, self.hs.get_ip_from_request(request), + requester, + request, + body, + self.hs.get_ip_from_request(request), + "remove a device from your account", ) await self.device_handler.delete_device(requester.user.to_string(), device_id) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 5eb7ef35a4..8f41a3edbf 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -263,7 +263,11 @@ class SigningKeyUploadServlet(RestServlet): body = parse_json_object_from_request(request) await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, self.hs.get_ip_from_request(request), + requester, + request, + body, + self.hs.get_ip_from_request(request), + "add a device signing key to your account", ) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 66fc8ec179..431ecf4f84 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -505,6 +505,7 @@ class RegisterRestServlet(RestServlet): request, body, self.hs.get_ip_from_request(request), + "register a new account", ) # Check that we're not trying to register a denied 3pid. -- cgit 1.5.1 From 694d8bed0e56366f080a49db0f930d635ca6cdf4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 3 Apr 2020 15:35:05 -0400 Subject: Support CAS in UI Auth flows. (#7186) --- changelog.d/7186.feature | 1 + synapse/handlers/auth.py | 4 +- synapse/handlers/cas_handler.py | 161 +++++++++++++++++++---------------- synapse/rest/client/v1/login.py | 20 ++++- synapse/rest/client/v2_alpha/auth.py | 28 ++++-- 5 files changed, 131 insertions(+), 83 deletions(-) create mode 100644 changelog.d/7186.feature (limited to 'synapse/handlers/auth.py') diff --git a/changelog.d/7186.feature b/changelog.d/7186.feature new file mode 100644 index 0000000000..01057aa396 --- /dev/null +++ b/changelog.d/7186.feature @@ -0,0 +1 @@ +Support SSO in the user interactive authentication workflow. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7c09d15a72..892adb00b9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -116,7 +116,7 @@ class AuthHandler(BaseHandler): self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.password_enabled - self._saml2_enabled = hs.config.saml2_enabled + self._sso_enabled = hs.config.saml2_enabled or hs.config.cas_enabled # we keep this as a list despite the O(N^2) implication so that we can # keep PASSWORD first and avoid confusing clients which pick the first @@ -136,7 +136,7 @@ class AuthHandler(BaseHandler): # necessarily identical. Login types have SSO (and other login types) # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET. ui_auth_types = login_types.copy() - if self._saml2_enabled: + if self._sso_enabled: ui_auth_types.append(LoginType.SSO) self._supported_ui_auth_types = ui_auth_types diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index f8dc274b78..d977badf35 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -15,7 +15,7 @@ import logging import xml.etree.ElementTree as ET -from typing import AnyStr, Dict, Optional, Tuple +from typing import Dict, Optional, Tuple from six.moves import urllib @@ -48,26 +48,47 @@ class CasHandler: self._http_client = hs.get_proxied_http_client() - def _build_service_param(self, client_redirect_url: AnyStr) -> str: + def _build_service_param(self, args: Dict[str, str]) -> str: + """ + Generates a value to use as the "service" parameter when redirecting or + querying the CAS service. + + Args: + args: Additional arguments to include in the final redirect URL. + + Returns: + The URL to use as a "service" parameter. + """ return "%s%s?%s" % ( self._cas_service_url, "/_matrix/client/r0/login/cas/ticket", - urllib.parse.urlencode({"redirectUrl": client_redirect_url}), + urllib.parse.urlencode(args), ) - async def _handle_cas_response( - self, request: SynapseRequest, cas_response_body: str, client_redirect_url: str - ) -> None: + async def _validate_ticket( + self, ticket: str, service_args: Dict[str, str] + ) -> Tuple[str, Optional[str]]: """ - Retrieves the user and display name from the CAS response and continues with the authentication. + Validate a CAS ticket with the server, parse the response, and return the user and display name. Args: - request: The original client request. - cas_response_body: The response from the CAS server. - client_redirect_url: The URl to redirect the client to when - everything is done. + ticket: The CAS ticket from the client. + service_args: Additional arguments to include in the service URL. + Should be the same as those passed to `get_redirect_url`. """ - user, attributes = self._parse_cas_response(cas_response_body) + uri = self._cas_server_url + "/proxyValidate" + args = { + "ticket": ticket, + "service": self._build_service_param(service_args), + } + try: + body = await self._http_client.get_raw(uri, args) + except PartialDownloadError as pde: + # Twisted raises this error if the connection is closed, + # even if that's being used old-http style to signal end-of-data + body = pde.response + + user, attributes = self._parse_cas_response(body) displayname = attributes.pop(self._cas_displayname_attribute, None) for required_attribute, required_value in self._cas_required_attributes.items(): @@ -82,7 +103,7 @@ class CasHandler: if required_value != actual_value: raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - await self._on_successful_auth(user, request, client_redirect_url, displayname) + return user, displayname def _parse_cas_response( self, cas_response_body: str @@ -127,78 +148,74 @@ class CasHandler: ) return user, attributes - async def _on_successful_auth( - self, - username: str, - request: SynapseRequest, - client_redirect_url: str, - user_display_name: Optional[str] = None, - ) -> None: - """Called once the user has successfully authenticated with the SSO. - - Registers the user if necessary, and then returns a redirect (with - a login token) to the client. + def get_redirect_url(self, service_args: Dict[str, str]) -> str: + """ + Generates a URL for the CAS server where the client should be redirected. Args: - username: the remote user id. We'll map this onto - something sane for a MXID localpath. + service_args: Additional arguments to include in the final redirect URL. - request: the incoming request from the browser. We'll - respond to it with a redirect. + Returns: + The URL to redirect the client to. + """ + args = urllib.parse.urlencode( + {"service": self._build_service_param(service_args)} + ) - client_redirect_url: the redirect_url the client gave us when - it first started the process. + return "%s/login?%s" % (self._cas_server_url, args) - user_display_name: if set, and we have to register a new user, - we will set their displayname to this. + async def handle_ticket( + self, + request: SynapseRequest, + ticket: str, + client_redirect_url: Optional[str], + session: Optional[str], + ) -> None: """ - localpart = map_username_to_mxid_localpart(username) - user_id = UserID(localpart, self._hostname).to_string() - registered_user_id = await self._auth_handler.check_user_exists(user_id) - if not registered_user_id: - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=user_display_name - ) + Called once the user has successfully authenticated with the SSO. + Validates a CAS ticket sent by the client and completes the auth process. - self._auth_handler.complete_sso_login( - registered_user_id, request, client_redirect_url - ) + If the user interactive authentication session is provided, marks the + UI Auth session as complete, then returns an HTML page notifying the + user they are done. - def handle_redirect_request(self, client_redirect_url: bytes) -> bytes: - """ - Generates a URL to the CAS server where the client should be redirected. + Otherwise, this registers the user if necessary, and then returns a + redirect (with a login token) to the client. Args: - client_redirect_url: The final URL the client should go to after the - user has negotiated SSO. + request: the incoming request from the browser. We'll + respond to it with a redirect or an HTML page. - Returns: - The URL to redirect to. - """ - args = urllib.parse.urlencode( - {"service": self._build_service_param(client_redirect_url)} - ) + ticket: The CAS ticket provided by the client. - return ("%s/login?%s" % (self._cas_server_url, args)).encode("ascii") + client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given. + This should be the same as the redirectUrl from the original `/login/sso/redirect` request. - async def handle_ticket_request( - self, request: SynapseRequest, client_redirect_url: str, ticket: str - ) -> None: + session: The session parameter from the `/cas/ticket` HTTP request, if given. + This should be the UI Auth session id. """ - Validates a CAS ticket sent by the client for login/registration. + args = {} + if client_redirect_url: + args["redirectUrl"] = client_redirect_url + if session: + args["session"] = session + username, user_display_name = await self._validate_ticket(ticket, args) - On a successful request, writes a redirect to the request. - """ - uri = self._cas_server_url + "/proxyValidate" - args = { - "ticket": ticket, - "service": self._build_service_param(client_redirect_url), - } - try: - body = await self._http_client.get_raw(uri, args) - except PartialDownloadError as pde: - # Twisted raises this error if the connection is closed, - # even if that's being used old-http style to signal end-of-data - body = pde.response + localpart = map_username_to_mxid_localpart(username) + user_id = UserID(localpart, self._hostname).to_string() + registered_user_id = await self._auth_handler.check_user_exists(user_id) - await self._handle_cas_response(request, body, client_redirect_url) + if session: + self._auth_handler.complete_sso_ui_auth( + registered_user_id, session, request, + ) + + else: + if not registered_user_id: + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, default_display_name=user_display_name + ) + + self._auth_handler.complete_sso_login( + registered_user_id, request, client_redirect_url + ) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 59593cbf6e..4de2f97d06 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -425,7 +425,9 @@ class CasRedirectServlet(BaseSSORedirectServlet): self._cas_handler = hs.get_cas_handler() def get_sso_url(self, client_redirect_url: bytes) -> bytes: - return self._cas_handler.handle_redirect_request(client_redirect_url) + return self._cas_handler.get_redirect_url( + {"redirectUrl": client_redirect_url} + ).encode("ascii") class CasTicketServlet(RestServlet): @@ -436,10 +438,20 @@ class CasTicketServlet(RestServlet): self._cas_handler = hs.get_cas_handler() async def on_GET(self, request: SynapseRequest) -> None: - client_redirect_url = parse_string(request, "redirectUrl", required=True) + client_redirect_url = parse_string(request, "redirectUrl") ticket = parse_string(request, "ticket", required=True) - await self._cas_handler.handle_ticket_request( - request, client_redirect_url, ticket + + # Maybe get a session ID (if this ticket is from user interactive + # authentication). + session = parse_string(request, "session") + + # Either client_redirect_url or session must be provided. + if not client_redirect_url and not session: + message = "Missing string query parameter redirectUrl or session" + raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + + await self._cas_handler.handle_ticket( + request, ticket, client_redirect_url, session ) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 1787562b90..13f9604407 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -111,6 +111,11 @@ class AuthRestServlet(RestServlet): self._saml_enabled = hs.config.saml2_enabled if self._saml_enabled: self._saml_handler = hs.get_saml_handler() + self._cas_enabled = hs.config.cas_enabled + if self._cas_enabled: + self._cas_handler = hs.get_cas_handler() + self._cas_server_url = hs.config.cas_server_url + self._cas_service_url = hs.config.cas_service_url def on_GET(self, request, stagetype): session = parse_string(request, "session") @@ -133,14 +138,27 @@ class AuthRestServlet(RestServlet): % (CLIENT_API_PREFIX, LoginType.TERMS), } - elif stagetype == LoginType.SSO and self._saml_enabled: + elif stagetype == LoginType.SSO: # Display a confirmation page which prompts the user to # re-authenticate with their SSO provider. - client_redirect_url = "" - sso_redirect_url = self._saml_handler.handle_redirect_request( - client_redirect_url, session - ) + if self._cas_enabled: + # Generate a request to CAS that redirects back to an endpoint + # to verify the successful authentication. + sso_redirect_url = self._cas_handler.get_redirect_url( + {"session": session}, + ) + + elif self._saml_enabled: + client_redirect_url = "" + sso_redirect_url = self._saml_handler.handle_redirect_request( + client_redirect_url, session + ) + + else: + raise SynapseError(400, "Homeserver not configured for SSO.") + html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) + else: raise SynapseError(404, "Unknown auth stage type") -- cgit 1.5.1 From b85d7652ff084fee997e0bb44ecd46c2789abbdd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 9 Apr 2020 13:28:13 -0400 Subject: Do not allow a deactivated user to login via SSO. (#7240) --- changelog.d/7240.bugfix | 1 + synapse/config/sso.py | 7 ++++ synapse/handlers/auth.py | 34 +++++++++++++++--- synapse/handlers/cas_handler.py | 2 +- synapse/handlers/saml_handler.py | 2 +- synapse/module_api/__init__.py | 22 +++++++++++- synapse/res/templates/sso_account_deactivated.html | 10 ++++++ tests/rest/client/v1/test_login.py | 42 ++++++++++++++++++++-- 8 files changed, 110 insertions(+), 10 deletions(-) create mode 100644 changelog.d/7240.bugfix create mode 100644 synapse/res/templates/sso_account_deactivated.html (limited to 'synapse/handlers/auth.py') diff --git a/changelog.d/7240.bugfix b/changelog.d/7240.bugfix new file mode 100644 index 0000000000..83b18d3e11 --- /dev/null +++ b/changelog.d/7240.bugfix @@ -0,0 +1 @@ +Do not allow a deactivated user to login via SSO. diff --git a/synapse/config/sso.py b/synapse/config/sso.py index ec3dca9efc..686678a3b7 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Any, Dict import pkg_resources @@ -36,6 +37,12 @@ class SSOConfig(Config): template_dir = pkg_resources.resource_filename("synapse", "res/templates",) self.sso_redirect_confirm_template_dir = template_dir + self.sso_account_deactivated_template = self.read_file( + os.path.join( + self.sso_redirect_confirm_template_dir, "sso_account_deactivated.html" + ), + "sso_account_deactivated_template", + ) self.sso_client_whitelist = sso_config.get("client_whitelist") or [] diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 892adb00b9..fbfbd44a2e 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -161,6 +161,9 @@ class AuthHandler(BaseHandler): self._sso_auth_confirm_template = load_jinja2_templates( hs.config.sso_redirect_confirm_template_dir, ["sso_auth_confirm.html"], )[0] + self._sso_account_deactivated_template = ( + hs.config.sso_account_deactivated_template + ) self._server_name = hs.config.server_name @@ -644,9 +647,6 @@ class AuthHandler(BaseHandler): Returns: defer.Deferred: (unicode) canonical_user_id, or None if zero or multiple matches - - Raises: - UserDeactivatedError if a user is found but is deactivated. """ res = yield self._find_user_id_and_pwd_hash(user_id) if res is not None: @@ -1099,7 +1099,7 @@ class AuthHandler(BaseHandler): request.write(html_bytes) finish_request(request) - def complete_sso_login( + async def complete_sso_login( self, registered_user_id: str, request: SynapseRequest, @@ -1113,6 +1113,32 @@ class AuthHandler(BaseHandler): client_redirect_url: The URL to which to redirect the user at the end of the process. """ + # If the account has been deactivated, do not proceed with the login + # flow. + deactivated = await self.store.get_user_deactivated_status(registered_user_id) + if deactivated: + html = self._sso_account_deactivated_template.encode("utf-8") + + request.setResponseCode(403) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html),)) + request.write(html) + finish_request(request) + return + + self._complete_sso_login(registered_user_id, request, client_redirect_url) + + def _complete_sso_login( + self, + registered_user_id: str, + request: SynapseRequest, + client_redirect_url: str, + ): + """ + The synchronous portion of complete_sso_login. + + This exists purely for backwards compatibility of synapse.module_api.ModuleApi. + """ # Create a login token login_token = self.macaroon_gen.generate_short_term_login_token( registered_user_id diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index d977badf35..5cb3f9d133 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -216,6 +216,6 @@ class CasHandler: localpart=localpart, default_display_name=user_display_name ) - self._auth_handler.complete_sso_login( + await self._auth_handler.complete_sso_login( registered_user_id, request, client_redirect_url ) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 4741c82f61..7c9454b504 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -154,7 +154,7 @@ class SamlHandler: ) else: - self._auth_handler.complete_sso_login(user_id, request, relay_state) + await self._auth_handler.complete_sso_login(user_id, request, relay_state) async def _map_saml_response_to_user( self, resp_bytes: str, client_redirect_url: str diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index c7fffd72f2..afc3598e11 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -220,6 +220,26 @@ class ModuleApi(object): want their access token sent to `client_redirect_url`, or redirect them to that URL with a token directly if the URL matches with one of the whitelisted clients. + This is deprecated in favor of complete_sso_login_async. + + Args: + registered_user_id: The MXID that has been registered as a previous step of + of this SSO login. + request: The request to respond to. + client_redirect_url: The URL to which to offer to redirect the user (or to + redirect them directly if whitelisted). + """ + self._auth_handler._complete_sso_login( + registered_user_id, request, client_redirect_url, + ) + + async def complete_sso_login_async( + self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str + ): + """Complete a SSO login by redirecting the user to a page to confirm whether they + want their access token sent to `client_redirect_url`, or redirect them to that + URL with a token directly if the URL matches with one of the whitelisted clients. + Args: registered_user_id: The MXID that has been registered as a previous step of of this SSO login. @@ -227,6 +247,6 @@ class ModuleApi(object): client_redirect_url: The URL to which to offer to redirect the user (or to redirect them directly if whitelisted). """ - self._auth_handler.complete_sso_login( + await self._auth_handler.complete_sso_login( registered_user_id, request, client_redirect_url, ) diff --git a/synapse/res/templates/sso_account_deactivated.html b/synapse/res/templates/sso_account_deactivated.html new file mode 100644 index 0000000000..4eb8db9fb4 --- /dev/null +++ b/synapse/res/templates/sso_account_deactivated.html @@ -0,0 +1,10 @@ + + + + + SSO account deactivated + + +

This account has been deactivated.

+ + diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index aed8853d6e..1856c7ffd5 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -257,7 +257,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.code, 200, channel.result) -class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): +class CASTestCase(unittest.HomeserverTestCase): servlets = [ login.register_servlets, @@ -274,6 +274,9 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): "service_url": "https://matrix.goodserver.com:8448", } + cas_user_id = "username" + self.user_id = "@%s:test" % cas_user_id + async def get_raw(uri, args): """Return an example response payload from a call to the `/proxyValidate` endpoint of a CAS server, copied from @@ -282,10 +285,11 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): This needs to be returned by an async function (as opposed to set as the mock's return value) because the corresponding Synapse code awaits on it. """ - return """ + return ( + """ - username + %s PGTIOU-84678-8a9d... https://proxy2/pgtUrl @@ -294,6 +298,8 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): """ + % cas_user_id + ) mocked_http_client = Mock(spec=["get_raw"]) mocked_http_client.get_raw.side_effect = get_raw @@ -304,6 +310,9 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): return self.hs + def prepare(self, reactor, clock, hs): + self.deactivate_account_handler = hs.get_deactivate_account_handler() + def test_cas_redirect_confirm(self): """Tests that the SSO login flow serves a confirmation page before redirecting a user to the redirect URL. @@ -370,3 +379,30 @@ class CASRedirectConfirmTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 302) location_headers = channel.headers.getRawHeaders("Location") self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) + + @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) + def test_deactivated_user(self): + """Logging in as a deactivated account should error.""" + redirect_url = "https://legit-site.com/" + + # First login (to create the user). + self._test_redirect(redirect_url) + + # Deactivate the account. + self.get_success( + self.deactivate_account_handler.deactivate_account(self.user_id, False) + ) + + # Request the CAS ticket. + cas_ticket_url = ( + "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" + % (urllib.parse.quote(redirect_url)) + ) + + # Get Synapse to call the fake CAS and serve the template. + request, channel = self.make_request("GET", cas_ticket_url) + self.render(request) + + # Because the user is deactivated they are served an error template. + self.assertEqual(channel.code, 403) + self.assertIn(b"SSO account deactivated", channel.result["body"]) -- cgit 1.5.1 From eed7c5b89eee6951ac17861b1695817470bace36 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 15 Apr 2020 12:40:18 -0400 Subject: Convert auth handler to async/await (#7261) --- changelog.d/7261.misc | 1 + synapse/handlers/auth.py | 173 ++++++++++++++++++--------------------- synapse/handlers/device.py | 12 ++- synapse/handlers/register.py | 28 +++++-- synapse/handlers/set_password.py | 13 ++- synapse/module_api/__init__.py | 6 +- tests/api/test_auth.py | 64 +++++++++------ tests/handlers/test_auth.py | 80 +++++++++++------- tests/handlers/test_register.py | 4 +- tests/utils.py | 13 ++- 10 files changed, 224 insertions(+), 170 deletions(-) create mode 100644 changelog.d/7261.misc (limited to 'synapse/handlers/auth.py') diff --git a/changelog.d/7261.misc b/changelog.d/7261.misc new file mode 100644 index 0000000000..88165f0105 --- /dev/null +++ b/changelog.d/7261.misc @@ -0,0 +1 @@ +Convert auth handler to async/await. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index fbfbd44a2e..0aae929ecc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,14 +18,12 @@ import logging import time import unicodedata import urllib.parse -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import attr import bcrypt # type: ignore[import] import pymacaroons -from twisted.internet import defer - import synapse.util.stringutils as stringutils from synapse.api.constants import LoginType from synapse.api.errors import ( @@ -170,15 +168,14 @@ class AuthHandler(BaseHandler): # cast to tuple for use with str.startswith self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) - @defer.inlineCallbacks - def validate_user_via_ui_auth( + async def validate_user_via_ui_auth( self, requester: Requester, request: SynapseRequest, request_body: Dict[str, Any], clientip: str, description: str, - ): + ) -> dict: """ Checks that the user is who they claim to be, via a UI auth. @@ -199,7 +196,7 @@ class AuthHandler(BaseHandler): describes the operation happening on their account. Returns: - defer.Deferred[dict]: the parameters for this request (which may + The parameters for this request (which may have been given only in a previous call). Raises: @@ -229,7 +226,7 @@ class AuthHandler(BaseHandler): flows = [[login_type] for login_type in self._supported_ui_auth_types] try: - result, params, _ = yield self.check_auth( + result, params, _ = await self.check_auth( flows, request, request_body, clientip, description ) except LoginError: @@ -268,15 +265,14 @@ class AuthHandler(BaseHandler): """ return self.checkers.keys() - @defer.inlineCallbacks - def check_auth( + async def check_auth( self, flows: List[List[str]], request: SynapseRequest, clientdict: Dict[str, Any], clientip: str, description: str, - ): + ) -> Tuple[dict, dict, str]: """ Takes a dictionary sent by the client in the login / registration protocol and handles the User-Interactive Auth flow. @@ -306,8 +302,7 @@ class AuthHandler(BaseHandler): describes the operation happening on their account. Returns: - defer.Deferred[dict, dict, str]: a deferred tuple of - (creds, params, session_id). + A tuple of (creds, params, session_id). 'creds' contains the authenticated credentials of each stage. @@ -380,7 +375,7 @@ class AuthHandler(BaseHandler): if "type" in authdict: login_type = authdict["type"] # type: str try: - result = yield self._check_auth_dict(authdict, clientip) + result = await self._check_auth_dict(authdict, clientip) if result: creds[login_type] = result self._save_session(session) @@ -419,8 +414,9 @@ class AuthHandler(BaseHandler): ret.update(errordict) raise InteractiveAuthIncompleteError(ret) - @defer.inlineCallbacks - def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str): + async def add_oob_auth( + self, stagetype: str, authdict: Dict[str, Any], clientip: str + ) -> bool: """ Adds the result of out-of-band authentication into an existing auth session. Currently used for adding the result of fallback auth. @@ -435,7 +431,7 @@ class AuthHandler(BaseHandler): sess["creds"] = {} creds = sess["creds"] - result = yield self.checkers[stagetype].check_auth(authdict, clientip) + result = await self.checkers[stagetype].check_auth(authdict, clientip) if result: creds[stagetype] = result self._save_session(sess) @@ -489,8 +485,9 @@ class AuthHandler(BaseHandler): sess = self._get_session_info(session_id) return sess.setdefault("serverdict", {}).get(key, default) - @defer.inlineCallbacks - def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str): + async def _check_auth_dict( + self, authdict: Dict[str, Any], clientip: str + ) -> Union[Dict[str, Any], str]: """Attempt to validate the auth dict provided by a client Args: @@ -498,7 +495,7 @@ class AuthHandler(BaseHandler): clientip: IP address of the client Returns: - Deferred: result of the stage verification. + Result of the stage verification. Raises: StoreError if there was a problem accessing the database @@ -508,7 +505,7 @@ class AuthHandler(BaseHandler): login_type = authdict["type"] checker = self.checkers.get(login_type) if checker is not None: - res = yield checker.check_auth(authdict, clientip=clientip) + res = await checker.check_auth(authdict, clientip=clientip) return res # build a v1-login-style dict out of the authdict and fall back to the @@ -518,7 +515,7 @@ class AuthHandler(BaseHandler): if user_id is None: raise SynapseError(400, "", Codes.MISSING_PARAM) - (canonical_id, callback) = yield self.validate_login(user_id, authdict) + (canonical_id, callback) = await self.validate_login(user_id, authdict) return canonical_id def _get_params_recaptcha(self) -> dict: @@ -584,8 +581,7 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] - @defer.inlineCallbacks - def get_access_token_for_user_id( + async def get_access_token_for_user_id( self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int] ): """ @@ -615,10 +611,10 @@ class AuthHandler(BaseHandler): ) logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry) - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) access_token = self.macaroon_gen.generate_access_token(user_id) - yield self.store.add_access_token_to_user( + await self.store.add_access_token_to_user( user_id, access_token, device_id, valid_until_ms ) @@ -628,15 +624,14 @@ class AuthHandler(BaseHandler): # device, so we double-check it here. if device_id is not None: try: - yield self.store.get_device(user_id, device_id) + await self.store.get_device(user_id, device_id) except StoreError: - yield self.store.delete_access_token(access_token) + await self.store.delete_access_token(access_token) raise StoreError(400, "Login raced against device deletion") return access_token - @defer.inlineCallbacks - def check_user_exists(self, user_id: str): + async def check_user_exists(self, user_id: str) -> Optional[str]: """ Checks to see if a user with the given id exists. Will check case insensitively, but return None if there are multiple inexact matches. @@ -645,25 +640,25 @@ class AuthHandler(BaseHandler): user_id: complete @user:id Returns: - defer.Deferred: (unicode) canonical_user_id, or None if zero or - multiple matches + The canonical_user_id, or None if zero or multiple matches """ - res = yield self._find_user_id_and_pwd_hash(user_id) + res = await self._find_user_id_and_pwd_hash(user_id) if res is not None: return res[0] return None - @defer.inlineCallbacks - def _find_user_id_and_pwd_hash(self, user_id: str): + async def _find_user_id_and_pwd_hash( + self, user_id: str + ) -> Optional[Tuple[str, str]]: """Checks to see if a user with the given id exists. Will check case insensitively, but will return None if there are multiple inexact matches. Returns: - tuple: A 2-tuple of `(canonical_user_id, password_hash)` - None: if there is not exactly one match + A 2-tuple of `(canonical_user_id, password_hash)` or `None` + if there is not exactly one match """ - user_infos = yield self.store.get_users_by_id_case_insensitive(user_id) + user_infos = await self.store.get_users_by_id_case_insensitive(user_id) result = None if not user_infos: @@ -696,8 +691,9 @@ class AuthHandler(BaseHandler): """ return self._supported_login_types - @defer.inlineCallbacks - def validate_login(self, username: str, login_submission: Dict[str, Any]): + async def validate_login( + self, username: str, login_submission: Dict[str, Any] + ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: """Authenticates the user for the /login API Also used by the user-interactive auth flow to validate @@ -708,7 +704,7 @@ class AuthHandler(BaseHandler): login_submission: the whole of the login submission (including 'type' and other relevant fields) Returns: - Deferred[str, func]: canonical user id, and optional callback + A tuple of the canonical user id, and optional callback to be called once the access token and device id are issued Raises: StoreError if there was a problem accessing the database @@ -737,7 +733,7 @@ class AuthHandler(BaseHandler): for provider in self.password_providers: if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD: known_login_type = True - is_valid = yield provider.check_password(qualified_user_id, password) + is_valid = await provider.check_password(qualified_user_id, password) if is_valid: return qualified_user_id, None @@ -769,7 +765,7 @@ class AuthHandler(BaseHandler): % (login_type, missing_fields), ) - result = yield provider.check_auth(username, login_type, login_dict) + result = await provider.check_auth(username, login_type, login_dict) if result: if isinstance(result, str): result = (result, None) @@ -778,8 +774,8 @@ class AuthHandler(BaseHandler): if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: known_login_type = True - canonical_user_id = yield self._check_local_password( - qualified_user_id, password + canonical_user_id = await self._check_local_password( + qualified_user_id, password # type: ignore ) if canonical_user_id: @@ -792,8 +788,9 @@ class AuthHandler(BaseHandler): # login, it turns all LoginErrors into a 401 anyway. raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN) - @defer.inlineCallbacks - def check_password_provider_3pid(self, medium: str, address: str, password: str): + async def check_password_provider_3pid( + self, medium: str, address: str, password: str + ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]: """Check if a password provider is able to validate a thirdparty login Args: @@ -802,9 +799,8 @@ class AuthHandler(BaseHandler): password: The password of the user. Returns: - Deferred[(str|None, func|None)]: A tuple of `(user_id, - callback)`. If authentication is successful, `user_id` is a `str` - containing the authenticated, canonical user ID. `callback` is + A tuple of `(user_id, callback)`. If authentication is successful, + `user_id`is the authenticated, canonical user ID. `callback` is then either a function to be later run after the server has completed login/registration, or `None`. If authentication was unsuccessful, `user_id` and `callback` are both `None`. @@ -816,7 +812,7 @@ class AuthHandler(BaseHandler): # success, to a str (which is the user_id) or a tuple of # (user_id, callback_func), where callback_func should be run # after we've finished everything else - result = yield provider.check_3pid_auth(medium, address, password) + result = await provider.check_3pid_auth(medium, address, password) if result: # Check if the return value is a str or a tuple if isinstance(result, str): @@ -826,8 +822,7 @@ class AuthHandler(BaseHandler): return None, None - @defer.inlineCallbacks - def _check_local_password(self, user_id: str, password: str): + async def _check_local_password(self, user_id: str, password: str) -> Optional[str]: """Authenticate a user against the local password database. user_id is checked case insensitively, but will return None if there are @@ -837,28 +832,26 @@ class AuthHandler(BaseHandler): user_id: complete @user:id password: the provided password Returns: - Deferred[unicode] the canonical_user_id, or Deferred[None] if - unknown user/bad password + The canonical_user_id, or None if unknown user/bad password """ - lookupres = yield self._find_user_id_and_pwd_hash(user_id) + lookupres = await self._find_user_id_and_pwd_hash(user_id) if not lookupres: return None (user_id, password_hash) = lookupres # If the password hash is None, the account has likely been deactivated if not password_hash: - deactivated = yield self.store.get_user_deactivated_status(user_id) + deactivated = await self.store.get_user_deactivated_status(user_id) if deactivated: raise UserDeactivatedError("This account has been deactivated") - result = yield self.validate_hash(password, password_hash) + result = await self.validate_hash(password, password_hash) if not result: logger.warning("Failed password login for user %s", user_id) return None return user_id - @defer.inlineCallbacks - def validate_short_term_login_token_and_get_user_id(self, login_token: str): + async def validate_short_term_login_token_and_get_user_id(self, login_token: str): auth_api = self.hs.get_auth() user_id = None try: @@ -868,26 +861,23 @@ class AuthHandler(BaseHandler): except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) return user_id - @defer.inlineCallbacks - def delete_access_token(self, access_token: str): + async def delete_access_token(self, access_token: str): """Invalidate a single access token Args: access_token: access token to be deleted - Returns: - Deferred """ - user_info = yield self.auth.get_user_by_access_token(access_token) - yield self.store.delete_access_token(access_token) + user_info = await self.auth.get_user_by_access_token(access_token) + await self.store.delete_access_token(access_token) # see if any of our auth providers want to know about this for provider in self.password_providers: if hasattr(provider, "on_logged_out"): - yield provider.on_logged_out( + await provider.on_logged_out( user_id=str(user_info["user"]), device_id=user_info["device_id"], access_token=access_token, @@ -895,12 +885,11 @@ class AuthHandler(BaseHandler): # delete pushers associated with this access token if user_info["token_id"] is not None: - yield self.hs.get_pusherpool().remove_pushers_by_access_token( + await self.hs.get_pusherpool().remove_pushers_by_access_token( str(user_info["user"]), (user_info["token_id"],) ) - @defer.inlineCallbacks - def delete_access_tokens_for_user( + async def delete_access_tokens_for_user( self, user_id: str, except_token_id: Optional[str] = None, @@ -914,10 +903,8 @@ class AuthHandler(BaseHandler): device_id: ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted - Returns: - Deferred """ - tokens_and_devices = yield self.store.user_delete_access_tokens( + tokens_and_devices = await self.store.user_delete_access_tokens( user_id, except_token_id=except_token_id, device_id=device_id ) @@ -925,17 +912,18 @@ class AuthHandler(BaseHandler): for provider in self.password_providers: if hasattr(provider, "on_logged_out"): for token, token_id, device_id in tokens_and_devices: - yield provider.on_logged_out( + await provider.on_logged_out( user_id=user_id, device_id=device_id, access_token=token ) # delete pushers associated with the access tokens - yield self.hs.get_pusherpool().remove_pushers_by_access_token( + await self.hs.get_pusherpool().remove_pushers_by_access_token( user_id, (token_id for _, token_id, _ in tokens_and_devices) ) - @defer.inlineCallbacks - def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int): + async def add_threepid( + self, user_id: str, medium: str, address: str, validated_at: int + ): # check if medium has a valid value if medium not in ["email", "msisdn"]: raise SynapseError( @@ -956,14 +944,13 @@ class AuthHandler(BaseHandler): if medium == "email": address = address.lower() - yield self.store.user_add_threepid( + await self.store.user_add_threepid( user_id, medium, address, validated_at, self.hs.get_clock().time_msec() ) - @defer.inlineCallbacks - def delete_threepid( + async def delete_threepid( self, user_id: str, medium: str, address: str, id_server: Optional[str] = None - ): + ) -> bool: """Attempts to unbind the 3pid on the identity servers and deletes it from the local database. @@ -976,7 +963,7 @@ class AuthHandler(BaseHandler): identity server specified when binding (if known). Returns: - Deferred[bool]: Returns True if successfully unbound the 3pid on + Returns True if successfully unbound the 3pid on the identity server, False if identity server doesn't support the unbind API. """ @@ -986,11 +973,11 @@ class AuthHandler(BaseHandler): address = address.lower() identity_handler = self.hs.get_handlers().identity_handler - result = yield identity_handler.try_unbind_threepid( + result = await identity_handler.try_unbind_threepid( user_id, {"medium": medium, "address": address, "id_server": id_server} ) - yield self.store.user_delete_threepid(user_id, medium, address) + await self.store.user_delete_threepid(user_id, medium, address) return result def _save_session(self, session: Dict[str, Any]) -> None: @@ -1000,14 +987,14 @@ class AuthHandler(BaseHandler): session["last_used"] = self.hs.get_clock().time_msec() self.sessions[session["id"]] = session - def hash(self, password: str): + async def hash(self, password: str) -> str: """Computes a secure hash of password. Args: password: Password to hash. Returns: - Deferred(unicode): Hashed password. + Hashed password. """ def _do_hash(): @@ -1019,9 +1006,11 @@ class AuthHandler(BaseHandler): bcrypt.gensalt(self.bcrypt_rounds), ).decode("ascii") - return defer_to_thread(self.hs.get_reactor(), _do_hash) + return await defer_to_thread(self.hs.get_reactor(), _do_hash) - def validate_hash(self, password: str, stored_hash: bytes): + async def validate_hash( + self, password: str, stored_hash: Union[bytes, str] + ) -> bool: """Validates that self.hash(password) == stored_hash. Args: @@ -1029,7 +1018,7 @@ class AuthHandler(BaseHandler): stored_hash: Expected hash value. Returns: - Deferred(bool): Whether self.hash(password) == stored_hash. + Whether self.hash(password) == stored_hash. """ def _do_validate_hash(): @@ -1045,9 +1034,9 @@ class AuthHandler(BaseHandler): if not isinstance(stored_hash, bytes): stored_hash = stored_hash.encode("ascii") - return defer_to_thread(self.hs.get_reactor(), _do_validate_hash) + return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash) else: - return defer.succeed(False) + return False def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: """ diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 993499f446..9bd941b5a0 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -338,8 +338,10 @@ class DeviceHandler(DeviceWorkerHandler): else: raise - yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id + yield defer.ensureDeferred( + self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id + ) ) yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) @@ -391,8 +393,10 @@ class DeviceHandler(DeviceWorkerHandler): # Delete access tokens and e2e keys for each device. Not optimised as it is not # considered as part of a critical path. for device_id in device_ids: - yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id + yield defer.ensureDeferred( + self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id + ) ) yield self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 7ffc194f0c..3a65b46ecd 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -166,7 +166,9 @@ class RegistrationHandler(BaseHandler): yield self.auth.check_auth_blocking(threepid=threepid) password_hash = None if password: - password_hash = yield self._auth_handler.hash(password) + password_hash = yield defer.ensureDeferred( + self._auth_handler.hash(password) + ) if localpart is not None: yield self.check_username(localpart, guest_access_token=guest_access_token) @@ -540,8 +542,10 @@ class RegistrationHandler(BaseHandler): user_id, ["guest = true"] ) else: - access_token = yield self._auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, valid_until_ms=valid_until_ms + access_token = yield defer.ensureDeferred( + self._auth_handler.get_access_token_for_user_id( + user_id, device_id=device_id, valid_until_ms=valid_until_ms + ) ) return (device_id, access_token) @@ -617,8 +621,13 @@ class RegistrationHandler(BaseHandler): logger.info("Can't add incomplete 3pid") return - yield self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"] + yield defer.ensureDeferred( + self._auth_handler.add_threepid( + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], + ) ) # And we add an email pusher for them by default, but only @@ -670,6 +679,11 @@ class RegistrationHandler(BaseHandler): return None raise - yield self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"] + yield defer.ensureDeferred( + self._auth_handler.add_threepid( + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], + ) ) diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 7d1263caf2..63d8f9aa0d 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -15,8 +15,6 @@ import logging from typing import Optional -from twisted.internet import defer - from synapse.api.errors import Codes, StoreError, SynapseError from synapse.types import Requester @@ -34,8 +32,7 @@ class SetPasswordHandler(BaseHandler): self._device_handler = hs.get_device_handler() self._password_policy_handler = hs.get_password_policy_handler() - @defer.inlineCallbacks - def set_password( + async def set_password( self, user_id: str, new_password: str, @@ -46,10 +43,10 @@ class SetPasswordHandler(BaseHandler): raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) self._password_policy_handler.validate_password(new_password) - password_hash = yield self._auth_handler.hash(new_password) + password_hash = await self._auth_handler.hash(new_password) try: - yield self.store.user_set_password_hash(user_id, password_hash) + await self.store.user_set_password_hash(user_id, password_hash) except StoreError as e: if e.code == 404: raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) @@ -61,12 +58,12 @@ class SetPasswordHandler(BaseHandler): except_access_token_id = requester.access_token_id if requester else None # First delete all of their other devices. - yield self._device_handler.delete_all_devices_for_user( + await self._device_handler.delete_all_devices_for_user( user_id, except_device_id=except_device_id ) # and now delete any access tokens which weren't associated with # devices (or were associated with this device). - yield self._auth_handler.delete_access_tokens_for_user( + await self._auth_handler.delete_access_tokens_for_user( user_id, except_token_id=except_access_token_id ) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index afc3598e11..d678c0eb9b 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -86,7 +86,7 @@ class ModuleApi(object): Deferred[str|None]: Canonical (case-corrected) user_id, or None if the user is not registered. """ - return self._auth_handler.check_user_exists(user_id) + return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id)) @defer.inlineCallbacks def register(self, localpart, displayname=None, emails=[]): @@ -196,7 +196,9 @@ class ModuleApi(object): yield self._hs.get_device_handler().delete_device(user_id, device_id) else: # no associated device. Just delete the access token. - yield self._auth_handler.delete_access_token(access_token) + yield defer.ensureDeferred( + self._auth_handler.delete_access_token(access_token) + ) def run_db_interaction(self, desc, func, *args, **kwargs): """Run a function with a database connection diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 6121efcfa9..cc0b10e7f6 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -68,7 +68,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): @@ -105,7 +105,7 @@ class AuthTestCase(unittest.TestCase): request.getClientIP.return_value = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) @defer.inlineCallbacks @@ -125,7 +125,7 @@ class AuthTestCase(unittest.TestCase): request.getClientIP.return_value = "192.168.10.10" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_appservice_valid_token_bad_ip(self): @@ -188,7 +188,7 @@ class AuthTestCase(unittest.TestCase): request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals( requester.user.to_string(), masquerading_user_id.decode("utf8") ) @@ -225,7 +225,9 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) - user_info = yield self.auth.get_user_by_access_token(macaroon.serialize()) + user_info = yield defer.ensureDeferred( + self.auth.get_user_by_access_token(macaroon.serialize()) + ) user = user_info["user"] self.assertEqual(UserID.from_string(user_id), user) @@ -250,7 +252,9 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("guest = true") serialized = macaroon.serialize() - user_info = yield self.auth.get_user_by_access_token(serialized) + user_info = yield defer.ensureDeferred( + self.auth.get_user_by_access_token(serialized) + ) user = user_info["user"] is_guest = user_info["is_guest"] self.assertEqual(UserID.from_string(user_id), user) @@ -260,10 +264,13 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cannot_use_regular_token_as_guest(self): USER_ID = "@percy:matrix.org" - self.store.add_access_token_to_user = Mock() + self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None)) + self.store.get_device = Mock(return_value=defer.succeed(None)) - token = yield self.hs.handlers.auth_handler.get_access_token_for_user_id( - USER_ID, "DEVICE", valid_until_ms=None + token = yield defer.ensureDeferred( + self.hs.handlers.auth_handler.get_access_token_for_user_id( + USER_ID, "DEVICE", valid_until_ms=None + ) ) self.store.add_access_token_to_user.assert_called_with( USER_ID, token, "DEVICE", None @@ -286,7 +293,9 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args[b"access_token"] = [token.encode("ascii")] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = yield defer.ensureDeferred( + self.auth.get_user_by_req(request, allow_guest=True) + ) self.assertEqual(UserID.from_string(USER_ID), requester.user) self.assertFalse(requester.is_guest) @@ -301,7 +310,9 @@ class AuthTestCase(unittest.TestCase): request.requestHeaders.getRawHeaders = mock_getRawHeaders() with self.assertRaises(InvalidClientCredentialsError) as cm: - yield self.auth.get_user_by_req(request, allow_guest=True) + yield defer.ensureDeferred( + self.auth.get_user_by_req(request, allow_guest=True) + ) self.assertEqual(401, cm.exception.code) self.assertEqual("Guest access token used for regular user", cm.exception.msg) @@ -316,7 +327,7 @@ class AuthTestCase(unittest.TestCase): small_number_of_users = 1 # Ensure no error thrown - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.hs.config.limit_usage_by_mau = True @@ -325,7 +336,7 @@ class AuthTestCase(unittest.TestCase): ) with self.assertRaises(ResourceLimitError) as e: - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) @@ -334,7 +345,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(small_number_of_users) ) - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) @defer.inlineCallbacks def test_blocking_mau__depending_on_user_type(self): @@ -343,15 +354,19 @@ class AuthTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) # Support users allowed - yield self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT) + yield defer.ensureDeferred( + self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT) + ) self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) # Bots not allowed with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking(user_type=UserTypes.BOT) + yield defer.ensureDeferred( + self.auth.check_auth_blocking(user_type=UserTypes.BOT) + ) self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) # Real users not allowed with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) @defer.inlineCallbacks def test_reserved_threepid(self): @@ -362,21 +377,22 @@ class AuthTestCase(unittest.TestCase): unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} self.hs.config.mau_limits_reserved_threepids = [threepid] - yield self.store.register_user(user_id="user1", password_hash=None) with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking(threepid=unknown_threepid) + yield defer.ensureDeferred( + self.auth.check_auth_blocking(threepid=unknown_threepid) + ) - yield self.auth.check_auth_blocking(threepid=threepid) + yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid)) @defer.inlineCallbacks def test_hs_disabled(self): self.hs.config.hs_disabled = True self.hs.config.hs_disabled_message = "Reason for being disabled" with self.assertRaises(ResourceLimitError) as e: - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) @@ -393,7 +409,7 @@ class AuthTestCase(unittest.TestCase): self.hs.config.hs_disabled = True self.hs.config.hs_disabled_message = "Reason for being disabled" with self.assertRaises(ResourceLimitError) as e: - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) @@ -404,4 +420,4 @@ class AuthTestCase(unittest.TestCase): user = "@user:server" self.hs.config.server_notices_mxid = user self.hs.config.hs_disabled_message = "Reason for being disabled" - yield self.auth.check_auth_blocking(user) + yield defer.ensureDeferred(self.auth.check_auth_blocking(user)) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index b03103d96f..52c4ac8b11 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -82,16 +82,16 @@ class AuthTestCase(unittest.TestCase): self.hs.clock.now = 1000 token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) - user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - token + user_id = yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id(token) ) self.assertEqual("a_user", user_id) # when we advance the clock, the token should be rejected self.hs.clock.now = 6000 with self.assertRaises(synapse.api.errors.AuthError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - token + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id(token) ) @defer.inlineCallbacks @@ -99,8 +99,10 @@ class AuthTestCase(unittest.TestCase): token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) macaroon = pymacaroons.Macaroon.deserialize(token) - user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - macaroon.serialize() + user_id = yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) ) self.assertEqual("a_user", user_id) @@ -109,20 +111,26 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("user_id = b_user") with self.assertRaises(synapse.api.errors.AuthError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - macaroon.serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) ) @defer.inlineCallbacks def test_mau_limits_disabled(self): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) @defer.inlineCallbacks @@ -133,16 +141,20 @@ class AuthTestCase(unittest.TestCase): ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) @defer.inlineCallbacks @@ -154,16 +166,20 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(self.hs.config.max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) # If in monthly active cohort self.hs.get_datastore().user_last_seen_monthly_active = Mock( @@ -172,8 +188,10 @@ class AuthTestCase(unittest.TestCase): self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) @@ -181,8 +199,10 @@ class AuthTestCase(unittest.TestCase): self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) @defer.inlineCallbacks @@ -193,15 +213,19 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(self.small_number_of_users) ) # Ensure does not raise exception - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) ) - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) def _get_macaroon(self): diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e7b638dbfe..f1dc51d6c9 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -294,7 +294,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): create_profile_with_displayname=user.localpart, ) else: - yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) + yield defer.ensureDeferred( + self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) + ) yield self.store.add_access_token_to_user( user_id=user_id, token=token, device_id=None, valid_until_ms=None diff --git a/tests/utils.py b/tests/utils.py index 968d109f77..2079e0143d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -332,10 +332,15 @@ def setup_test_homeserver( # Need to let the HS build an auth handler and then mess with it # because AuthHandler's constructor requires the HS, so we can't make one # beforehand and pass it in to the HS's constructor (chicken / egg) - hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode("utf8")).hexdigest() - hs.get_auth_handler().validate_hash = ( - lambda p, h: hashlib.md5(p.encode("utf8")).hexdigest() == h - ) + async def hash(p): + return hashlib.md5(p.encode("utf8")).hexdigest() + + hs.get_auth_handler().hash = hash + + async def validate_hash(p, h): + return hashlib.md5(p.encode("utf8")).hexdigest() == h + + hs.get_auth_handler().validate_hash = validate_hash fed = kargs.get("resource_for_federation", None) if fed: -- cgit 1.5.1 From 054c231e58eb8a93ff04a81341190aa3b6bcb9f7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 17 Apr 2020 13:34:55 -0400 Subject: Use a template for the SSO success page to allow for customization. (#7279) --- CHANGES.md | 9 +++--- changelog.d/7279.feature | 1 + synapse/config/sso.py | 6 ++++ synapse/handlers/auth.py | 44 ++++++++--------------------- synapse/res/templates/sso_auth_success.html | 18 ++++++++++++ synapse/rest/client/v2_alpha/auth.py | 25 +++++++++++++++- 6 files changed, 66 insertions(+), 37 deletions(-) create mode 100644 changelog.d/7279.feature create mode 100644 synapse/res/templates/sso_auth_success.html (limited to 'synapse/handlers/auth.py') diff --git a/CHANGES.md b/CHANGES.md index 6f25b26a55..b41a627cb8 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,10 +1,11 @@ Next version ============ -* Two new templates (`sso_auth_confirm.html` and `sso_account_deactivated.html`) - were added to Synapse. If your Synapse is configured to use SSO and a custom - `sso_redirect_confirm_template_dir` configuration then these templates will - need to be duplicated into that directory. +* New templates (`sso_auth_confirm.html`, `sso_auth_success.html`, and + `sso_account_deactivated.html`) were added to Synapse. If your Synapse is + configured to use SSO and a custom `sso_redirect_confirm_template_dir` + configuration then these templates will need to be duplicated into that + directory. * Plugins using the `complete_sso_login` method of `synapse.module_api.ModuleApi` should update to using the async/await version `complete_sso_login_async` which diff --git a/changelog.d/7279.feature b/changelog.d/7279.feature new file mode 100644 index 0000000000..9aed075474 --- /dev/null +++ b/changelog.d/7279.feature @@ -0,0 +1 @@ + Support SSO in the user interactive authentication workflow. diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 686678a3b7..6cd37d4324 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -43,6 +43,12 @@ class SSOConfig(Config): ), "sso_account_deactivated_template", ) + self.sso_auth_success_template = self.read_file( + os.path.join( + self.sso_redirect_confirm_template_dir, "sso_auth_success.html" + ), + "sso_auth_success_template", + ) self.sso_client_whitelist = sso_config.get("client_whitelist") or [] diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0aae929ecc..bda279ab8b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -51,31 +51,6 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) -SUCCESS_TEMPLATE = """ - - -Success! - - - - - -
-

Thank you

-

You may now close this window and return to the application

-
- - -""" - - class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 @@ -159,6 +134,11 @@ class AuthHandler(BaseHandler): self._sso_auth_confirm_template = load_jinja2_templates( hs.config.sso_redirect_confirm_template_dir, ["sso_auth_confirm.html"], )[0] + # The following template is shown after a successful user interactive + # authentication session. It tells the user they can close the window. + self._sso_auth_success_template = hs.config.sso_auth_success_template + # The following template is shown during the SSO authentication process if + # the account is deactivated. self._sso_account_deactivated_template = ( hs.config.sso_account_deactivated_template ) @@ -1080,7 +1060,7 @@ class AuthHandler(BaseHandler): self._save_session(sess) # Render the HTML and return. - html_bytes = SUCCESS_TEMPLATE.encode("utf8") + html_bytes = self._sso_auth_success_template.encode("utf-8") request.setResponseCode(200) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) @@ -1106,12 +1086,12 @@ class AuthHandler(BaseHandler): # flow. deactivated = await self.store.get_user_deactivated_status(registered_user_id) if deactivated: - html = self._sso_account_deactivated_template.encode("utf-8") + html_bytes = self._sso_account_deactivated_template.encode("utf-8") request.setResponseCode(403) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html),)) - request.write(html) + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + request.write(html_bytes) finish_request(request) return @@ -1153,7 +1133,7 @@ class AuthHandler(BaseHandler): # URL we redirect users to. redirect_url_no_params = client_redirect_url.split("?")[0] - html = self._sso_redirect_confirm_template.render( + html_bytes = self._sso_redirect_confirm_template.render( display_url=redirect_url_no_params, redirect_url=redirect_url, server_name=self._server_name, @@ -1161,8 +1141,8 @@ class AuthHandler(BaseHandler): request.setResponseCode(200) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html),)) - request.write(html) + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + request.write(html_bytes) finish_request(request) @staticmethod diff --git a/synapse/res/templates/sso_auth_success.html b/synapse/res/templates/sso_auth_success.html new file mode 100644 index 0000000000..03f1419467 --- /dev/null +++ b/synapse/res/templates/sso_auth_success.html @@ -0,0 +1,18 @@ + + + Authentication Successful + + + +
+

Thank you

+

You may now close this window and return to the application

+
+ + diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 13f9604407..11599f5005 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -18,7 +18,6 @@ import logging from synapse.api.constants import LoginType from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_API_PREFIX -from synapse.handlers.auth import SUCCESS_TEMPLATE from synapse.http.server import finish_request from synapse.http.servlet import RestServlet, parse_string @@ -90,6 +89,30 @@ TERMS_TEMPLATE = """ """ +SUCCESS_TEMPLATE = """ + + +Success! + + + + + +
+

Thank you

+

You may now close this window and return to the application

+
+ + +""" + class AuthRestServlet(RestServlet): """ -- cgit 1.5.1 From f5ea8b48bd57c19dae126a5cc631dc79cf5ce332 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 20 Apr 2020 08:54:42 -0400 Subject: Reject unknown UI auth sessions (instead of silently generating a new one) (#7268) --- changelog.d/7268.bugfix | 1 + synapse/handlers/auth.py | 159 ++++++++++++++++++++++++++++------------------- 2 files changed, 95 insertions(+), 65 deletions(-) create mode 100644 changelog.d/7268.bugfix (limited to 'synapse/handlers/auth.py') diff --git a/changelog.d/7268.bugfix b/changelog.d/7268.bugfix new file mode 100644 index 0000000000..ab280da18e --- /dev/null +++ b/changelog.d/7268.bugfix @@ -0,0 +1 @@ +Reject unknown session IDs during user interactive authentication instead of silently creating a new session. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index bda279ab8b..dbe165ce1e 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -257,10 +257,6 @@ class AuthHandler(BaseHandler): Takes a dictionary sent by the client in the login / registration protocol and handles the User-Interactive Auth flow. - As a side effect, this function fills in the 'creds' key on the user's - session with a map, which maps each auth-type (str) to the relevant - identity authenticated by that auth-type (mostly str, but for captcha, bool). - If no auth flows have been completed successfully, raises an InteractiveAuthIncompleteError. To handle this, you can use synapse.rest.client.v2_alpha._base.interactive_auth_handler as a @@ -304,50 +300,47 @@ class AuthHandler(BaseHandler): del clientdict["auth"] if "session" in authdict: sid = authdict["session"] - session = self._get_session_info(sid) - - if len(clientdict) > 0: - # This was designed to allow the client to omit the parameters - # and just supply the session in subsequent calls so it split - # auth between devices by just sharing the session, (eg. so you - # could continue registration from your phone having clicked the - # email auth link on there). It's probably too open to abuse - # because it lets unauthenticated clients store arbitrary objects - # on a homeserver. - # Revisit: Assuming the REST APIs do sensible validation, the data - # isn't arbintrary. - session["clientdict"] = clientdict - self._save_session(session) - elif "clientdict" in session: - clientdict = session["clientdict"] - - # Ensure that the queried operation does not vary between stages of - # the UI authentication session. This is done by generating a stable - # comparator based on the URI, method, and body (minus the auth dict) - # and storing it during the initial query. Subsequent queries ensure - # that this comparator has not changed. - comparator = (request.uri, request.method, clientdict) - if "ui_auth" not in session: - session["ui_auth"] = comparator - self._save_session(session) - elif session["ui_auth"] != comparator: - raise SynapseError( - 403, - "Requested operation has changed during the UI authentication session.", + + # If there's no session ID, create a new session. + if not sid: + session = self._create_session( + clientdict, (request.uri, request.method, clientdict), description ) + session_id = session["id"] - # Add a human readable description to the session. - if "description" not in session: - session["description"] = description - self._save_session(session) + else: + session = self._get_session_info(sid) + session_id = sid + + if not clientdict: + # This was designed to allow the client to omit the parameters + # and just supply the session in subsequent calls so it split + # auth between devices by just sharing the session, (eg. so you + # could continue registration from your phone having clicked the + # email auth link on there). It's probably too open to abuse + # because it lets unauthenticated clients store arbitrary objects + # on a homeserver. + # Revisit: Assuming the REST APIs do sensible validation, the data + # isn't arbitrary. + clientdict = session["clientdict"] + + # Ensure that the queried operation does not vary between stages of + # the UI authentication session. This is done by generating a stable + # comparator based on the URI, method, and body (minus the auth dict) + # and storing it during the initial query. Subsequent queries ensure + # that this comparator has not changed. + comparator = (request.uri, request.method, clientdict) + if session["ui_auth"] != comparator: + raise SynapseError( + 403, + "Requested operation has changed during the UI authentication session.", + ) if not authdict: raise InteractiveAuthIncompleteError( - self._auth_dict_for_flows(flows, session) + self._auth_dict_for_flows(flows, session_id) ) - if "creds" not in session: - session["creds"] = {} creds = session["creds"] # check auth type currently being presented @@ -387,9 +380,9 @@ class AuthHandler(BaseHandler): list(clientdict), ) - return creds, clientdict, session["id"] + return creds, clientdict, session_id - ret = self._auth_dict_for_flows(flows, session) + ret = self._auth_dict_for_flows(flows, session_id) ret["completed"] = list(creds) ret.update(errordict) raise InteractiveAuthIncompleteError(ret) @@ -407,8 +400,6 @@ class AuthHandler(BaseHandler): raise LoginError(400, "", Codes.MISSING_PARAM) sess = self._get_session_info(authdict["session"]) - if "creds" not in sess: - sess["creds"] = {} creds = sess["creds"] result = await self.checkers[stagetype].check_auth(authdict, clientip) @@ -448,7 +439,7 @@ class AuthHandler(BaseHandler): value: The data to store """ sess = self._get_session_info(session_id) - sess.setdefault("serverdict", {})[key] = value + sess["serverdict"][key] = value self._save_session(sess) def get_session_data( @@ -463,7 +454,7 @@ class AuthHandler(BaseHandler): default: Value to return if the key has not been set """ sess = self._get_session_info(session_id) - return sess.setdefault("serverdict", {}).get(key, default) + return sess["serverdict"].get(key, default) async def _check_auth_dict( self, authdict: Dict[str, Any], clientip: str @@ -519,7 +510,7 @@ class AuthHandler(BaseHandler): } def _auth_dict_for_flows( - self, flows: List[List[str]], session: Dict[str, Any] + self, flows: List[List[str]], session_id: str, ) -> Dict[str, Any]: public_flows = [] for f in flows: @@ -538,29 +529,72 @@ class AuthHandler(BaseHandler): params[stage] = get_params[stage]() return { - "session": session["id"], + "session": session_id, "flows": [{"stages": f} for f in public_flows], "params": params, } - def _get_session_info(self, session_id: Optional[str]) -> dict: + def _create_session( + self, + clientdict: Dict[str, Any], + ui_auth: Tuple[bytes, bytes, Dict[str, Any]], + description: str, + ) -> dict: """ - Gets or creates a session given a session ID. + Creates a new user interactive authentication session. The session can be used to track data across multiple requests, e.g. for interactive authentication. - """ - if session_id not in self.sessions: - session_id = None - if not session_id: - # create a new session - while session_id is None or session_id in self.sessions: - session_id = stringutils.random_string(24) - self.sessions[session_id] = {"id": session_id} + Each session has the following keys: + + id: + A unique identifier for this session. Passed back to the client + and returned for each stage. + clientdict: + The dictionary from the client root level, not the 'auth' key. + ui_auth: + A tuple which is checked at each stage of the authentication to + ensure that the asked for operation has not changed. + creds: + A map, which maps each auth-type (str) to the relevant identity + authenticated by that auth-type (mostly str, but for captcha, bool). + serverdict: + A map of data that is stored server-side and cannot be modified + by the client. + description: + A string description of the operation that the current + authentication is authorising. + Returns: + The newly created session. + """ + session_id = None + while session_id is None or session_id in self.sessions: + session_id = stringutils.random_string(24) + + self.sessions[session_id] = { + "id": session_id, + "clientdict": clientdict, + "ui_auth": ui_auth, + "creds": {}, + "serverdict": {}, + "description": description, + } return self.sessions[session_id] + def _get_session_info(self, session_id: str) -> dict: + """ + Gets a session given a session ID. + + The session can be used to track data across multiple requests, e.g. for + interactive authentication. + """ + try: + return self.sessions[session_id] + except KeyError: + raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) + async def get_access_token_for_user_id( self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int] ): @@ -1030,11 +1064,8 @@ class AuthHandler(BaseHandler): The HTML to render. """ session = self._get_session_info(session_id) - # Get the human readable operation of what is occurring, falling back to - # a generic message if it isn't available for some reason. - description = session.get("description", "modify your account") return self._sso_auth_confirm_template.render( - description=description, redirect_url=redirect_url, + description=session["description"], redirect_url=redirect_url, ) def complete_sso_ui_auth( @@ -1050,8 +1081,6 @@ class AuthHandler(BaseHandler): """ # Mark the stage of the authentication as successful. sess = self._get_session_info(session_id) - if "creds" not in sess: - sess["creds"] = {} creds = sess["creds"] # Save the user who authenticated with SSO, this will be used to ensure -- cgit 1.5.1 From 627b0f5f2753e6910adb7a877541d50f5936b8a5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Apr 2020 13:47:49 -0400 Subject: Persist user interactive authentication sessions (#7302) By persisting the user interactive authentication sessions to the database, this fixes situations where a user hits different works throughout their auth session and also allows sessions to persist through restarts of Synapse. --- changelog.d/7302.bugfix | 1 + synapse/app/generic_worker.py | 2 + synapse/handlers/auth.py | 175 +++++-------- synapse/handlers/cas_handler.py | 2 +- synapse/handlers/saml_handler.py | 2 +- synapse/rest/client/v2_alpha/auth.py | 4 +- synapse/rest/client/v2_alpha/register.py | 4 +- synapse/storage/data_stores/main/__init__.py | 2 + .../main/schema/delta/58/03persist_ui_auth.sql | 36 +++ synapse/storage/data_stores/main/ui_auth.py | 279 +++++++++++++++++++++ synapse/storage/engines/sqlite.py | 1 + tests/rest/client/v2_alpha/test_auth.py | 40 +++ tests/utils.py | 8 +- tox.ini | 3 +- 14 files changed, 434 insertions(+), 125 deletions(-) create mode 100644 changelog.d/7302.bugfix create mode 100644 synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql create mode 100644 synapse/storage/data_stores/main/ui_auth.py (limited to 'synapse/handlers/auth.py') diff --git a/changelog.d/7302.bugfix b/changelog.d/7302.bugfix new file mode 100644 index 0000000000..820646d1f9 --- /dev/null +++ b/changelog.d/7302.bugfix @@ -0,0 +1 @@ +Persist user interactive authentication sessions across workers and Synapse restarts. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index d125327f08..0ace7b787d 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -127,6 +127,7 @@ from synapse.storage.data_stores.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, ) from synapse.storage.data_stores.main.presence import UserPresenceState +from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer @@ -439,6 +440,7 @@ class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly # rather than going via the correct worker. UserDirectoryStore, + UIAuthWorkerStore, SlavedDeviceInboxStore, SlavedDeviceStore, SlavedReceiptsStore, diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index dbe165ce1e..7613e5b6ab 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -41,10 +41,10 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.http.server import finish_request from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi from synapse.push.mailer import load_jinja2_templates from synapse.types import Requester, UserID -from synapse.util.caches.expiringcache import ExpiringCache from ._base import BaseHandler @@ -69,15 +69,6 @@ class AuthHandler(BaseHandler): self.bcrypt_rounds = hs.config.bcrypt_rounds - # This is not a cache per se, but a store of all current sessions that - # expire after N hours - self.sessions = ExpiringCache( - cache_name="register_sessions", - clock=hs.get_clock(), - expiry_ms=self.SESSION_EXPIRE_MS, - reset_expiry_on_get=True, - ) - account_handler = ModuleApi(hs, self) self.password_providers = [ module(config=config, account_handler=account_handler) @@ -119,6 +110,15 @@ class AuthHandler(BaseHandler): self._clock = self.hs.get_clock() + # Expire old UI auth sessions after a period of time. + if hs.config.worker_app is None: + self._clock.looping_call( + run_as_background_process, + 5 * 60 * 1000, + "expire_old_sessions", + self._expire_old_sessions, + ) + # Load the SSO HTML templates. # The following template is shown to the user during a client login via SSO, @@ -301,16 +301,21 @@ class AuthHandler(BaseHandler): if "session" in authdict: sid = authdict["session"] + # Convert the URI and method to strings. + uri = request.uri.decode("utf-8") + method = request.uri.decode("utf-8") + # If there's no session ID, create a new session. if not sid: - session = self._create_session( - clientdict, (request.uri, request.method, clientdict), description + session = await self.store.create_ui_auth_session( + clientdict, uri, method, description ) - session_id = session["id"] else: - session = self._get_session_info(sid) - session_id = sid + try: + session = await self.store.get_ui_auth_session(sid) + except StoreError: + raise SynapseError(400, "Unknown session ID: %s" % (sid,)) if not clientdict: # This was designed to allow the client to omit the parameters @@ -322,15 +327,15 @@ class AuthHandler(BaseHandler): # on a homeserver. # Revisit: Assuming the REST APIs do sensible validation, the data # isn't arbitrary. - clientdict = session["clientdict"] + clientdict = session.clientdict # Ensure that the queried operation does not vary between stages of # the UI authentication session. This is done by generating a stable # comparator based on the URI, method, and body (minus the auth dict) # and storing it during the initial query. Subsequent queries ensure # that this comparator has not changed. - comparator = (request.uri, request.method, clientdict) - if session["ui_auth"] != comparator: + comparator = (uri, method, clientdict) + if (session.uri, session.method, session.clientdict) != comparator: raise SynapseError( 403, "Requested operation has changed during the UI authentication session.", @@ -338,11 +343,9 @@ class AuthHandler(BaseHandler): if not authdict: raise InteractiveAuthIncompleteError( - self._auth_dict_for_flows(flows, session_id) + self._auth_dict_for_flows(flows, session.session_id) ) - creds = session["creds"] - # check auth type currently being presented errordict = {} # type: Dict[str, Any] if "type" in authdict: @@ -350,8 +353,9 @@ class AuthHandler(BaseHandler): try: result = await self._check_auth_dict(authdict, clientip) if result: - creds[login_type] = result - self._save_session(session) + await self.store.mark_ui_auth_stage_complete( + session.session_id, login_type, result + ) except LoginError as e: if login_type == LoginType.EMAIL_IDENTITY: # riot used to have a bug where it would request a new @@ -367,6 +371,7 @@ class AuthHandler(BaseHandler): # so that the client can have another go. errordict = e.error_dict() + creds = await self.store.get_completed_ui_auth_stages(session.session_id) for f in flows: if len(set(f) - set(creds)) == 0: # it's very useful to know what args are stored, but this can @@ -380,9 +385,9 @@ class AuthHandler(BaseHandler): list(clientdict), ) - return creds, clientdict, session_id + return creds, clientdict, session.session_id - ret = self._auth_dict_for_flows(flows, session_id) + ret = self._auth_dict_for_flows(flows, session.session_id) ret["completed"] = list(creds) ret.update(errordict) raise InteractiveAuthIncompleteError(ret) @@ -399,13 +404,11 @@ class AuthHandler(BaseHandler): if "session" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) - sess = self._get_session_info(authdict["session"]) - creds = sess["creds"] - result = await self.checkers[stagetype].check_auth(authdict, clientip) if result: - creds[stagetype] = result - self._save_session(sess) + await self.store.mark_ui_auth_stage_complete( + authdict["session"], stagetype, result + ) return True return False @@ -427,7 +430,7 @@ class AuthHandler(BaseHandler): sid = authdict["session"] return sid - def set_session_data(self, session_id: str, key: str, value: Any) -> None: + async def set_session_data(self, session_id: str, key: str, value: Any) -> None: """ Store a key-value pair into the sessions data associated with this request. This data is stored server-side and cannot be modified by @@ -438,11 +441,12 @@ class AuthHandler(BaseHandler): key: The key to store the data under value: The data to store """ - sess = self._get_session_info(session_id) - sess["serverdict"][key] = value - self._save_session(sess) + try: + await self.store.set_ui_auth_session_data(session_id, key, value) + except StoreError: + raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) - def get_session_data( + async def get_session_data( self, session_id: str, key: str, default: Optional[Any] = None ) -> Any: """ @@ -453,8 +457,18 @@ class AuthHandler(BaseHandler): key: The key to store the data under default: Value to return if the key has not been set """ - sess = self._get_session_info(session_id) - return sess["serverdict"].get(key, default) + try: + return await self.store.get_ui_auth_session_data(session_id, key, default) + except StoreError: + raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) + + async def _expire_old_sessions(self): + """ + Invalidate any user interactive authentication sessions that have expired. + """ + now = self._clock.time_msec() + expiration_time = now - self.SESSION_EXPIRE_MS + await self.store.delete_old_ui_auth_sessions(expiration_time) async def _check_auth_dict( self, authdict: Dict[str, Any], clientip: str @@ -534,67 +548,6 @@ class AuthHandler(BaseHandler): "params": params, } - def _create_session( - self, - clientdict: Dict[str, Any], - ui_auth: Tuple[bytes, bytes, Dict[str, Any]], - description: str, - ) -> dict: - """ - Creates a new user interactive authentication session. - - The session can be used to track data across multiple requests, e.g. for - interactive authentication. - - Each session has the following keys: - - id: - A unique identifier for this session. Passed back to the client - and returned for each stage. - clientdict: - The dictionary from the client root level, not the 'auth' key. - ui_auth: - A tuple which is checked at each stage of the authentication to - ensure that the asked for operation has not changed. - creds: - A map, which maps each auth-type (str) to the relevant identity - authenticated by that auth-type (mostly str, but for captcha, bool). - serverdict: - A map of data that is stored server-side and cannot be modified - by the client. - description: - A string description of the operation that the current - authentication is authorising. - Returns: - The newly created session. - """ - session_id = None - while session_id is None or session_id in self.sessions: - session_id = stringutils.random_string(24) - - self.sessions[session_id] = { - "id": session_id, - "clientdict": clientdict, - "ui_auth": ui_auth, - "creds": {}, - "serverdict": {}, - "description": description, - } - - return self.sessions[session_id] - - def _get_session_info(self, session_id: str) -> dict: - """ - Gets a session given a session ID. - - The session can be used to track data across multiple requests, e.g. for - interactive authentication. - """ - try: - return self.sessions[session_id] - except KeyError: - raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) - async def get_access_token_for_user_id( self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int] ): @@ -994,13 +947,6 @@ class AuthHandler(BaseHandler): await self.store.user_delete_threepid(user_id, medium, address) return result - def _save_session(self, session: Dict[str, Any]) -> None: - """Update the last used time on the session to now and add it back to the session store.""" - # TODO: Persistent storage - logger.debug("Saving session %s", session) - session["last_used"] = self.hs.get_clock().time_msec() - self.sessions[session["id"]] = session - async def hash(self, password: str) -> str: """Computes a secure hash of password. @@ -1052,7 +998,7 @@ class AuthHandler(BaseHandler): else: return False - def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: + async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: """ Get the HTML for the SSO redirect confirmation page. @@ -1063,12 +1009,15 @@ class AuthHandler(BaseHandler): Returns: The HTML to render. """ - session = self._get_session_info(session_id) + try: + session = await self.store.get_ui_auth_session(session_id) + except StoreError: + raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) return self._sso_auth_confirm_template.render( - description=session["description"], redirect_url=redirect_url, + description=session.description, redirect_url=redirect_url, ) - def complete_sso_ui_auth( + async def complete_sso_ui_auth( self, registered_user_id: str, session_id: str, request: SynapseRequest, ): """Having figured out a mxid for this user, complete the HTTP request @@ -1080,13 +1029,11 @@ class AuthHandler(BaseHandler): process. """ # Mark the stage of the authentication as successful. - sess = self._get_session_info(session_id) - creds = sess["creds"] - # Save the user who authenticated with SSO, this will be used to ensure # that the account be modified is also the person who logged in. - creds[LoginType.SSO] = registered_user_id - self._save_session(sess) + await self.store.mark_ui_auth_stage_complete( + session_id, LoginType.SSO, registered_user_id + ) # Render the HTML and return. html_bytes = self._sso_auth_success_template.encode("utf-8") diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 5cb3f9d133..64aaa1335c 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -206,7 +206,7 @@ class CasHandler: registered_user_id = await self._auth_handler.check_user_exists(user_id) if session: - self._auth_handler.complete_sso_ui_auth( + await self._auth_handler.complete_sso_ui_auth( registered_user_id, session, request, ) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 7c9454b504..96f2dd36ad 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -149,7 +149,7 @@ class SamlHandler: # Complete the interactive auth session or the login. if current_session and current_session.ui_auth_session_id: - self._auth_handler.complete_sso_ui_auth( + await self._auth_handler.complete_sso_ui_auth( user_id, current_session.ui_auth_session_id, request ) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 11599f5005..24dd3d3e96 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -140,7 +140,7 @@ class AuthRestServlet(RestServlet): self._cas_server_url = hs.config.cas_server_url self._cas_service_url = hs.config.cas_service_url - def on_GET(self, request, stagetype): + async def on_GET(self, request, stagetype): session = parse_string(request, "session") if not session: raise SynapseError(400, "No session supplied") @@ -180,7 +180,7 @@ class AuthRestServlet(RestServlet): else: raise SynapseError(400, "Homeserver not configured for SSO.") - html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) + html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) else: raise SynapseError(404, "Unknown auth stage type") diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index d1b5c49989..af08cc6cce 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -499,7 +499,7 @@ class RegisterRestServlet(RestServlet): # registered a user for this session, so we could just return the # user here. We carry on and go through the auth checks though, # for paranoia. - registered_user_id = self.auth_handler.get_session_data( + registered_user_id = await self.auth_handler.get_session_data( session_id, "registered_user_id", None ) @@ -598,7 +598,7 @@ class RegisterRestServlet(RestServlet): # remember that we've now registered that user account, and with # what user ID (since the user may not have specified) - self.auth_handler.set_session_data( + await self.auth_handler.set_session_data( session_id, "registered_user_id", registered_user_id ) diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index bd7c3a00ea..ceba10882c 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -66,6 +66,7 @@ from .stats import StatsStore from .stream import StreamStore from .tags import TagsStore from .transactions import TransactionStore +from .ui_auth import UIAuthStore from .user_directory import UserDirectoryStore from .user_erasure_store import UserErasureStore @@ -112,6 +113,7 @@ class DataStore( StatsStore, RelationsStore, CacheInvalidationStore, + UIAuthStore, ): def __init__(self, database: Database, db_conn, hs): self.hs = hs diff --git a/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql b/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql new file mode 100644 index 0000000000..dcb593fc2d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql @@ -0,0 +1,36 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS ui_auth_sessions( + session_id TEXT NOT NULL, -- The session ID passed to the client. + creation_time BIGINT NOT NULL, -- The time this session was created (epoch time in milliseconds). + serverdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data added by Synapse. + clientdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data from the client. + uri TEXT NOT NULL, -- The URI the UI authentication session is using. + method TEXT NOT NULL, -- The HTTP method the UI authentication session is using. + -- The clientdict, uri, and method make up an tuple that must be immutable + -- throughout the lifetime of the UI Auth session. + description TEXT NOT NULL, -- A human readable description of the operation which caused the UI Auth flow to occur. + UNIQUE (session_id) +); + +CREATE TABLE IF NOT EXISTS ui_auth_sessions_credentials( + session_id TEXT NOT NULL, -- The corresponding UI Auth session. + stage_type TEXT NOT NULL, -- The stage type. + result TEXT NOT NULL, -- The result of the stage verification, stored as JSON. + UNIQUE (session_id, stage_type), + FOREIGN KEY (session_id) + REFERENCES ui_auth_sessions (session_id) +); diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py new file mode 100644 index 0000000000..c8eebc9378 --- /dev/null +++ b/synapse/storage/data_stores/main/ui_auth.py @@ -0,0 +1,279 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from typing import Any, Dict, Optional, Union + +import attr + +import synapse.util.stringutils as stringutils +from synapse.api.errors import StoreError +from synapse.storage._base import SQLBaseStore +from synapse.types import JsonDict + + +@attr.s +class UIAuthSessionData: + session_id = attr.ib(type=str) + # The dictionary from the client root level, not the 'auth' key. + clientdict = attr.ib(type=JsonDict) + # The URI and method the session was intiatied with. These are checked at + # each stage of the authentication to ensure that the asked for operation + # has not changed. + uri = attr.ib(type=str) + method = attr.ib(type=str) + # A string description of the operation that the current authentication is + # authorising. + description = attr.ib(type=str) + + +class UIAuthWorkerStore(SQLBaseStore): + """ + Manage user interactive authentication sessions. + """ + + async def create_ui_auth_session( + self, clientdict: JsonDict, uri: str, method: str, description: str, + ) -> UIAuthSessionData: + """ + Creates a new user interactive authentication session. + + The session can be used to track the stages necessary to authenticate a + user across multiple HTTP requests. + + Args: + clientdict: + The dictionary from the client root level, not the 'auth' key. + uri: + The URI this session was initiated with, this is checked at each + stage of the authentication to ensure that the asked for + operation has not changed. + method: + The method this session was initiated with, this is checked at each + stage of the authentication to ensure that the asked for + operation has not changed. + description: + A string description of the operation that the current + authentication is authorising. + Returns: + The newly created session. + Raises: + StoreError if a unique session ID cannot be generated. + """ + # The clientdict gets stored as JSON. + clientdict_json = json.dumps(clientdict) + + # autogen a session ID and try to create it. We may clash, so just + # try a few times till one goes through, giving up eventually. + attempts = 0 + while attempts < 5: + session_id = stringutils.random_string(24) + + try: + await self.db.simple_insert( + table="ui_auth_sessions", + values={ + "session_id": session_id, + "clientdict": clientdict_json, + "uri": uri, + "method": method, + "description": description, + "serverdict": "{}", + "creation_time": self.hs.get_clock().time_msec(), + }, + desc="create_ui_auth_session", + ) + return UIAuthSessionData( + session_id, clientdict, uri, method, description + ) + except self.db.engine.module.IntegrityError: + attempts += 1 + raise StoreError(500, "Couldn't generate a session ID.") + + async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData: + """Retrieve a UI auth session. + + Args: + session_id: The ID of the session. + Returns: + A dict containing the device information. + Raises: + StoreError if the session is not found. + """ + result = await self.db.simple_select_one( + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("clientdict", "uri", "method", "description"), + desc="get_ui_auth_session", + ) + + result["clientdict"] = json.loads(result["clientdict"]) + + return UIAuthSessionData(session_id, **result) + + async def mark_ui_auth_stage_complete( + self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict], + ): + """ + Mark a session stage as completed. + + Args: + session_id: The ID of the corresponding session. + stage_type: The completed stage type. + result: The result of the stage verification. + Raises: + StoreError if the session cannot be found. + """ + # Add (or update) the results of the current stage to the database. + # + # Note that we need to allow for the same stage to complete multiple + # times here so that registration is idempotent. + try: + await self.db.simple_upsert( + table="ui_auth_sessions_credentials", + keyvalues={"session_id": session_id, "stage_type": stage_type}, + values={"result": json.dumps(result)}, + desc="mark_ui_auth_stage_complete", + ) + except self.db.engine.module.IntegrityError: + raise StoreError(400, "Unknown session ID: %s" % (session_id,)) + + async def get_completed_ui_auth_stages( + self, session_id: str + ) -> Dict[str, Union[str, bool, JsonDict]]: + """ + Retrieve the completed stages of a UI authentication session. + + Args: + session_id: The ID of the session. + Returns: + The completed stages mapped to the result of the verification of + that auth-type. + """ + results = {} + for row in await self.db.simple_select_list( + table="ui_auth_sessions_credentials", + keyvalues={"session_id": session_id}, + retcols=("stage_type", "result"), + desc="get_completed_ui_auth_stages", + ): + results[row["stage_type"]] = json.loads(row["result"]) + + return results + + async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any): + """ + Store a key-value pair into the sessions data associated with this + request. This data is stored server-side and cannot be modified by + the client. + + Args: + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + value: The data to store + Raises: + StoreError if the session cannot be found. + """ + await self.db.runInteraction( + "set_ui_auth_session_data", + self._set_ui_auth_session_data_txn, + session_id, + key, + value, + ) + + def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any): + # Get the current value. + result = self.db.simple_select_one_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("serverdict",), + ) + + # Update it and add it back to the database. + serverdict = json.loads(result["serverdict"]) + serverdict[key] = value + + self.db.simple_update_one_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + updatevalues={"serverdict": json.dumps(serverdict)}, + ) + + async def get_ui_auth_session_data( + self, session_id: str, key: str, default: Optional[Any] = None + ) -> Any: + """ + Retrieve data stored with set_session_data + + Args: + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + default: Value to return if the key has not been set + Raises: + StoreError if the session cannot be found. + """ + result = await self.db.simple_select_one( + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("serverdict",), + desc="get_ui_auth_session_data", + ) + + serverdict = json.loads(result["serverdict"]) + + return serverdict.get(key, default) + + +class UIAuthStore(UIAuthWorkerStore): + def delete_old_ui_auth_sessions(self, expiration_time: int): + """ + Remove sessions which were last used earlier than the expiration time. + + Args: + expiration_time: The latest time that is still considered valid. + This is an epoch time in milliseconds. + + """ + return self.db.runInteraction( + "delete_old_ui_auth_sessions", + self._delete_old_ui_auth_sessions_txn, + expiration_time, + ) + + def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int): + # Get the expired sessions. + sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" + txn.execute(sql, [expiration_time]) + session_ids = [r[0] for r in txn.fetchall()] + + # Delete the corresponding completed credentials. + self.db.simple_delete_many_txn( + txn, + table="ui_auth_sessions_credentials", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) + + # Finally, delete the sessions. + self.db.simple_delete_many_txn( + txn, + table="ui_auth_sessions", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 3bc2e8b986..215a949442 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -85,6 +85,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): prepare_database(db_conn, self, config=None) db_conn.create_function("rank", 1, _rank) + db_conn.execute("PRAGMA foreign_keys = ON;") def is_deadlock(self, error): return False diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 624bf5ada2..587be7b2e7 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -181,3 +181,43 @@ class FallbackAuthTests(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(channel.code, 403) + + def test_complete_operation_unknown_session(self): + """ + Attempting to mark an invalid session as complete should error. + """ + + # Make the initial request to register. (Later on a different password + # will be used.) + request, channel = self.make_request( + "POST", + "register", + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) + self.render(request) + + # Returns a 401 as per the spec + self.assertEqual(request.code, 401) + # Grab the session + session = channel.json_body["session"] + # Assert our configured public key is being given + self.assertEqual( + channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" + ) + + request, channel = self.make_request( + "GET", "auth/m.login.recaptcha/fallback/web?session=" + session + ) + self.render(request) + self.assertEqual(request.code, 200) + + # Attempt to complete an unknown session, which should return an error. + unknown_session = session + "unknown" + request, channel = self.make_request( + "POST", + "auth/m.login.recaptcha/fallback/web?session=" + + unknown_session + + "&g-recaptcha-response=a", + ) + self.render(request) + self.assertEqual(request.code, 400) diff --git a/tests/utils.py b/tests/utils.py index 037cb134f0..f9be62b499 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -512,8 +512,8 @@ class MockClock(object): return t - def looping_call(self, function, interval): - self.loopers.append([function, interval / 1000.0, self.now]) + def looping_call(self, function, interval, *args, **kwargs): + self.loopers.append([function, interval / 1000.0, self.now, args, kwargs]) def cancel_call_later(self, timer, ignore_errs=False): if timer[2]: @@ -543,9 +543,9 @@ class MockClock(object): self.timers.append(t) for looped in self.loopers: - func, interval, last = looped + func, interval, last, args, kwargs = looped if last + interval < self.now: - func() + func(*args, **kwargs) looped[2] = self.now def advance_time_msec(self, ms): diff --git a/tox.ini b/tox.ini index 2630857436..eccc44e436 100644 --- a/tox.ini +++ b/tox.ini @@ -200,8 +200,9 @@ commands = mypy \ synapse/replication \ synapse/rest \ synapse/spam_checker_api \ - synapse/storage/engines \ + synapse/storage/data_stores/main/ui_auth.py \ synapse/storage/database.py \ + synapse/storage/engines \ synapse/streams \ synapse/util/caches/stream_change_cache.py \ tests/replication/tcp/streams \ -- cgit 1.5.1 From 0ad6d28b0dec06d5e7478984280b4e81ef0f0256 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 8 May 2020 16:08:58 -0400 Subject: Rework UI Auth session validation for registration (#7455) Be less strict about validation of UI authentication sessions during registration to match client expecations. --- changelog.d/7455.bugfix | 1 + synapse/handlers/auth.py | 54 +++-- synapse/rest/client/v2_alpha/register.py | 1 + synapse/storage/data_stores/main/ui_auth.py | 21 ++ tests/rest/client/v2_alpha/test_auth.py | 304 ++++++++++++++++++++-------- tox.ini | 1 + 6 files changed, 280 insertions(+), 102 deletions(-) create mode 100644 changelog.d/7455.bugfix (limited to 'synapse/handlers/auth.py') diff --git a/changelog.d/7455.bugfix b/changelog.d/7455.bugfix new file mode 100644 index 0000000000..d1693a7f22 --- /dev/null +++ b/changelog.d/7455.bugfix @@ -0,0 +1 @@ +Ensure that a user inteactive authentication session is tied to a single request. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7613e5b6ab..9c71702371 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -252,6 +252,7 @@ class AuthHandler(BaseHandler): clientdict: Dict[str, Any], clientip: str, description: str, + validate_clientdict: bool = True, ) -> Tuple[dict, dict, str]: """ Takes a dictionary sent by the client in the login / registration @@ -277,6 +278,10 @@ class AuthHandler(BaseHandler): description: A human readable string to be displayed to the user that describes the operation happening on their account. + validate_clientdict: Whether to validate that the operation happening + on the account has not changed. If this is false, + the client dict is persisted instead of validated. + Returns: A tuple of (creds, params, session_id). @@ -317,30 +322,51 @@ class AuthHandler(BaseHandler): except StoreError: raise SynapseError(400, "Unknown session ID: %s" % (sid,)) + # If the client provides parameters, update what is persisted, + # otherwise use whatever was last provided. + # + # This was designed to allow the client to omit the parameters + # and just supply the session in subsequent calls so it split + # auth between devices by just sharing the session, (eg. so you + # could continue registration from your phone having clicked the + # email auth link on there). It's probably too open to abuse + # because it lets unauthenticated clients store arbitrary objects + # on a homeserver. + # + # Revisit: Assuming the REST APIs do sensible validation, the data + # isn't arbitrary. + # + # Note that the registration endpoint explicitly removes the + # "initial_device_display_name" parameter if it is provided + # without a "password" parameter. See the changes to + # synapse.rest.client.v2_alpha.register.RegisterRestServlet.on_POST + # in commit 544722bad23fc31056b9240189c3cbbbf0ffd3f9. if not clientdict: - # This was designed to allow the client to omit the parameters - # and just supply the session in subsequent calls so it split - # auth between devices by just sharing the session, (eg. so you - # could continue registration from your phone having clicked the - # email auth link on there). It's probably too open to abuse - # because it lets unauthenticated clients store arbitrary objects - # on a homeserver. - # Revisit: Assuming the REST APIs do sensible validation, the data - # isn't arbitrary. clientdict = session.clientdict # Ensure that the queried operation does not vary between stages of # the UI authentication session. This is done by generating a stable - # comparator based on the URI, method, and body (minus the auth dict) - # and storing it during the initial query. Subsequent queries ensure - # that this comparator has not changed. - comparator = (uri, method, clientdict) - if (session.uri, session.method, session.clientdict) != comparator: + # comparator based on the URI, method, and client dict (minus the + # auth dict) and storing it during the initial query. Subsequent + # queries ensure that this comparator has not changed. + if validate_clientdict: + session_comparator = (session.uri, session.method, session.clientdict) + comparator = (uri, method, clientdict) + else: + session_comparator = (session.uri, session.method) # type: ignore + comparator = (uri, method) # type: ignore + + if session_comparator != comparator: raise SynapseError( 403, "Requested operation has changed during the UI authentication session.", ) + # For backwards compatibility the registration endpoint persists + # changes to the client dict instead of validating them. + if not validate_clientdict: + await self.store.set_ui_auth_clientdict(sid, clientdict) + if not authdict: raise InteractiveAuthIncompleteError( self._auth_dict_for_flows(flows, session.session_id) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index af08cc6cce..e77dd6bf92 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -516,6 +516,7 @@ class RegisterRestServlet(RestServlet): body, self.hs.get_ip_from_request(request), "register a new account", + validate_clientdict=False, ) # Check that we're not trying to register a denied 3pid. diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py index c8eebc9378..1d8ee22fb1 100644 --- a/synapse/storage/data_stores/main/ui_auth.py +++ b/synapse/storage/data_stores/main/ui_auth.py @@ -172,6 +172,27 @@ class UIAuthWorkerStore(SQLBaseStore): return results + async def set_ui_auth_clientdict( + self, session_id: str, clientdict: JsonDict + ) -> None: + """ + Store an updated clientdict for a given session ID. + + Args: + session_id: The ID of this session as returned from check_auth + clientdict: + The dictionary from the client root level, not the 'auth' key. + """ + # The clientdict gets stored as JSON. + clientdict_json = json.dumps(clientdict) + + self.db.simple_update_one( + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + updatevalues={"clientdict": clientdict_json}, + desc="set_ui_auth_client_dict", + ) + async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any): """ Store a key-value pair into the sessions data associated with this diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 587be7b2e7..a56c50a5b7 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -12,16 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import List, Union from twisted.internet.defer import succeed import synapse.rest.admin from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.rest.client.v2_alpha import auth, register +from synapse.http.site import SynapseRequest +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import auth, devices, register +from synapse.types import JsonDict from tests import unittest +from tests.server import FakeChannel class DummyRecaptchaChecker(UserInteractiveAuthChecker): @@ -34,11 +38,15 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker): return succeed(True) +class DummyPasswordChecker(UserInteractiveAuthChecker): + def check_auth(self, authdict, clientip): + return succeed(authdict["identifier"]["user"]) + + class FallbackAuthTests(unittest.HomeserverTestCase): servlets = [ auth.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, register.register_servlets, ] hijack_auth = False @@ -59,79 +67,84 @@ class FallbackAuthTests(unittest.HomeserverTestCase): auth_handler = hs.get_auth_handler() auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker - @unittest.INFO - def test_fallback_captcha(self): - + def register(self, expected_response: int, body: JsonDict) -> FakeChannel: + """Make a register request.""" request, channel = self.make_request( - "POST", - "register", - {"username": "user", "type": "m.login.password", "password": "bar"}, - ) + "POST", "register", body + ) # type: SynapseRequest, FakeChannel self.render(request) - # Returns a 401 as per the spec - self.assertEqual(request.code, 401) - # Grab the session - session = channel.json_body["session"] - # Assert our configured public key is being given - self.assertEqual( - channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" - ) + self.assertEqual(request.code, expected_response) + return channel + + def recaptcha( + self, session: str, expected_post_response: int, post_session: str = None + ) -> None: + """Get and respond to a fallback recaptcha. Returns the second request.""" + if post_session is None: + post_session = session request, channel = self.make_request( "GET", "auth/m.login.recaptcha/fallback/web?session=" + session - ) + ) # type: SynapseRequest, FakeChannel self.render(request) self.assertEqual(request.code, 200) request, channel = self.make_request( "POST", "auth/m.login.recaptcha/fallback/web?session=" - + session + + post_session + "&g-recaptcha-response=a", ) self.render(request) - self.assertEqual(request.code, 200) + self.assertEqual(request.code, expected_post_response) # The recaptcha handler is called with the response given attempts = self.recaptcha_checker.recaptcha_attempts self.assertEqual(len(attempts), 1) self.assertEqual(attempts[0][0]["response"], "a") - # also complete the dummy auth - request, channel = self.make_request( - "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} + @unittest.INFO + def test_fallback_captcha(self): + """Ensure that fallback auth via a captcha works.""" + # Returns a 401 as per the spec + channel = self.register( + 401, {"username": "user", "type": "m.login.password", "password": "bar"}, ) - self.render(request) + + # Grab the session + session = channel.json_body["session"] + # Assert our configured public key is being given + self.assertEqual( + channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" + ) + + # Complete the recaptcha step. + self.recaptcha(session, 200) + + # also complete the dummy auth + self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) # Now we should have fulfilled a complete auth flow, including # the recaptcha fallback step, we can then send a # request to the register API with the session in the authdict. - request, channel = self.make_request( - "POST", "register", {"auth": {"session": session}} - ) - self.render(request) - self.assertEqual(channel.code, 200) + channel = self.register(200, {"auth": {"session": session}}) # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") - def test_cannot_change_operation(self): + def test_legacy_registration(self): """ - The initial requested operation cannot be modified during the user interactive authentication session. + Registration allows the parameters to vary through the process. """ # Make the initial request to register. (Later on a different password # will be used.) - request, channel = self.make_request( - "POST", - "register", - {"username": "user", "type": "m.login.password", "password": "bar"}, + # Returns a 401 as per the spec + channel = self.register( + 401, {"username": "user", "type": "m.login.password", "password": "bar"}, ) - self.render(request) - # Returns a 401 as per the spec - self.assertEqual(request.code, 401) # Grab the session session = channel.json_body["session"] # Assert our configured public key is being given @@ -139,65 +152,39 @@ class FallbackAuthTests(unittest.HomeserverTestCase): channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" ) - request, channel = self.make_request( - "GET", "auth/m.login.recaptcha/fallback/web?session=" + session - ) - self.render(request) - self.assertEqual(request.code, 200) - - request, channel = self.make_request( - "POST", - "auth/m.login.recaptcha/fallback/web?session=" - + session - + "&g-recaptcha-response=a", - ) - self.render(request) - self.assertEqual(request.code, 200) - - # The recaptcha handler is called with the response given - attempts = self.recaptcha_checker.recaptcha_attempts - self.assertEqual(len(attempts), 1) - self.assertEqual(attempts[0][0]["response"], "a") + # Complete the recaptcha step. + self.recaptcha(session, 200) # also complete the dummy auth - request, channel = self.make_request( - "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} - ) - self.render(request) + self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) # Now we should have fulfilled a complete auth flow, including # the recaptcha fallback step. Make the initial request again, but - # with a different password. This causes the request to fail since the - # operaiton was modified during the ui auth session. - request, channel = self.make_request( - "POST", - "register", + # with a changed password. This still completes. + channel = self.register( + 200, { "username": "user", "type": "m.login.password", - "password": "foo", # Note this doesn't match the original request. + "password": "foo", # Note that this is different. "auth": {"session": session}, }, ) - self.render(request) - self.assertEqual(channel.code, 403) + + # We're given a registered user. + self.assertEqual(channel.json_body["user_id"], "@user:test") def test_complete_operation_unknown_session(self): """ Attempting to mark an invalid session as complete should error. """ - # Make the initial request to register. (Later on a different password # will be used.) - request, channel = self.make_request( - "POST", - "register", - {"username": "user", "type": "m.login.password", "password": "bar"}, + # Returns a 401 as per the spec + channel = self.register( + 401, {"username": "user", "type": "m.login.password", "password": "bar"} ) - self.render(request) - # Returns a 401 as per the spec - self.assertEqual(request.code, 401) # Grab the session session = channel.json_body["session"] # Assert our configured public key is being given @@ -205,19 +192,160 @@ class FallbackAuthTests(unittest.HomeserverTestCase): channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" ) + # Attempt to complete the recaptcha step with an unknown session. + # This results in an error. + self.recaptcha(session, 400, session + "unknown") + + +class UIAuthTests(unittest.HomeserverTestCase): + servlets = [ + auth.register_servlets, + devices.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + register.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + auth_handler = hs.get_auth_handler() + auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs) + + self.user_pass = "pass" + self.user = self.register_user("test", self.user_pass) + self.user_tok = self.login("test", self.user_pass) + + def get_device_ids(self) -> List[str]: + # Get the list of devices so one can be deleted. request, channel = self.make_request( - "GET", "auth/m.login.recaptcha/fallback/web?session=" + session - ) + "GET", "devices", access_token=self.user_tok, + ) # type: SynapseRequest, FakeChannel self.render(request) + + # Get the ID of the device. self.assertEqual(request.code, 200) + return [d["device_id"] for d in channel.json_body["devices"]] - # Attempt to complete an unknown session, which should return an error. - unknown_session = session + "unknown" + def delete_device( + self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b"" + ) -> FakeChannel: + """Delete an individual device.""" request, channel = self.make_request( - "POST", - "auth/m.login.recaptcha/fallback/web?session=" - + unknown_session - + "&g-recaptcha-response=a", - ) + "DELETE", "devices/" + device, body, access_token=self.user_tok + ) # type: SynapseRequest, FakeChannel self.render(request) - self.assertEqual(request.code, 400) + + # Ensure the response is sane. + self.assertEqual(request.code, expected_response) + + return channel + + def delete_devices(self, expected_response: int, body: JsonDict) -> FakeChannel: + """Delete 1 or more devices.""" + # Note that this uses the delete_devices endpoint so that we can modify + # the payload half-way through some tests. + request, channel = self.make_request( + "POST", "delete_devices", body, access_token=self.user_tok, + ) # type: SynapseRequest, FakeChannel + self.render(request) + + # Ensure the response is sane. + self.assertEqual(request.code, expected_response) + + return channel + + def test_ui_auth(self): + """ + Test user interactive authentication outside of registration. + """ + device_id = self.get_device_ids()[0] + + # Attempt to delete this device. + # Returns a 401 as per the spec + channel = self.delete_device(device_id, 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow. + self.delete_device( + device_id, + 200, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + def test_cannot_change_body(self): + """ + The initial requested client dict cannot be modified during the user interactive authentication session. + """ + # Create a second login. + self.login("test", self.user_pass) + + device_ids = self.get_device_ids() + self.assertEqual(len(device_ids), 2) + + # Attempt to delete the first device. + # Returns a 401 as per the spec + channel = self.delete_devices(401, {"devices": [device_ids[0]]}) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow, but try to delete the + # second device. This results in an error. + self.delete_devices( + 403, + { + "devices": [device_ids[1]], + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + def test_cannot_change_uri(self): + """ + The initial requested URI cannot be modified during the user interactive authentication session. + """ + # Create a second login. + self.login("test", self.user_pass) + + device_ids = self.get_device_ids() + self.assertEqual(len(device_ids), 2) + + # Attempt to delete the first device. + # Returns a 401 as per the spec + channel = self.delete_device(device_ids[0], 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow, but try to delete the + # second device. This results in an error. + self.delete_device( + device_ids[1], + 403, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) diff --git a/tox.ini b/tox.ini index eccc44e436..8aef52021d 100644 --- a/tox.ini +++ b/tox.ini @@ -207,6 +207,7 @@ commands = mypy \ synapse/util/caches/stream_change_cache.py \ tests/replication/tcp/streams \ tests/test_utils \ + tests/rest/client/v2_alpha/test_auth.py \ tests/util/test_stream_change_cache.py # To find all folders that pass mypy you run: -- cgit 1.5.1 From 5d64fefd6c7790dac0209c6c32cdb97cd6cd8820 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 13 May 2020 14:26:44 -0400 Subject: Do not validate that the client dict is stable during UI Auth. (#7483) This backs out some of the validation for the client dictionary and logs if this changes during a user interactive authentication session instead. --- changelog.d/7483.bugfix | 1 + synapse/handlers/auth.py | 37 +++++++++++---------- synapse/rest/client/v2_alpha/register.py | 1 - tests/rest/client/v2_alpha/test_auth.py | 55 ++++++-------------------------- 4 files changed, 29 insertions(+), 65 deletions(-) create mode 100644 changelog.d/7483.bugfix (limited to 'synapse/handlers/auth.py') diff --git a/changelog.d/7483.bugfix b/changelog.d/7483.bugfix new file mode 100644 index 0000000000..e1bc324617 --- /dev/null +++ b/changelog.d/7483.bugfix @@ -0,0 +1 @@ +Restore compatibility with non-compliant clients during the user interactive authentication process. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 9c71702371..5c20e29171 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -252,7 +252,6 @@ class AuthHandler(BaseHandler): clientdict: Dict[str, Any], clientip: str, description: str, - validate_clientdict: bool = True, ) -> Tuple[dict, dict, str]: """ Takes a dictionary sent by the client in the login / registration @@ -278,10 +277,6 @@ class AuthHandler(BaseHandler): description: A human readable string to be displayed to the user that describes the operation happening on their account. - validate_clientdict: Whether to validate that the operation happening - on the account has not changed. If this is false, - the client dict is persisted instead of validated. - Returns: A tuple of (creds, params, session_id). @@ -346,26 +341,30 @@ class AuthHandler(BaseHandler): # Ensure that the queried operation does not vary between stages of # the UI authentication session. This is done by generating a stable - # comparator based on the URI, method, and client dict (minus the - # auth dict) and storing it during the initial query. Subsequent + # comparator and storing it during the initial query. Subsequent # queries ensure that this comparator has not changed. - if validate_clientdict: - session_comparator = (session.uri, session.method, session.clientdict) - comparator = (uri, method, clientdict) - else: - session_comparator = (session.uri, session.method) # type: ignore - comparator = (uri, method) # type: ignore - - if session_comparator != comparator: + # + # The comparator is based on the requested URI and HTTP method. The + # client dict (minus the auth dict) should also be checked, but some + # clients are not spec compliant, just warn for now if the client + # dict changes. + if (session.uri, session.method) != (uri, method): raise SynapseError( 403, "Requested operation has changed during the UI authentication session.", ) - # For backwards compatibility the registration endpoint persists - # changes to the client dict instead of validating them. - if not validate_clientdict: - await self.store.set_ui_auth_clientdict(sid, clientdict) + if session.clientdict != clientdict: + logger.warning( + "Requested operation has changed during the UI " + "authentication session. A future version of Synapse " + "will remove this capability." + ) + + # For backwards compatibility, changes to the client dict are + # persisted as clients modify them throughout their user interactive + # authentication flow. + await self.store.set_ui_auth_clientdict(sid, clientdict) if not authdict: raise InteractiveAuthIncompleteError( diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index e77dd6bf92..af08cc6cce 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -516,7 +516,6 @@ class RegisterRestServlet(RestServlet): body, self.hs.get_ip_from_request(request), "register a new account", - validate_clientdict=False, ) # Check that we're not trying to register a denied 3pid. diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index a56c50a5b7..293ccfba2b 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -133,47 +133,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase): # We're given a registered user. self.assertEqual(channel.json_body["user_id"], "@user:test") - def test_legacy_registration(self): - """ - Registration allows the parameters to vary through the process. - """ - - # Make the initial request to register. (Later on a different password - # will be used.) - # Returns a 401 as per the spec - channel = self.register( - 401, {"username": "user", "type": "m.login.password", "password": "bar"}, - ) - - # Grab the session - session = channel.json_body["session"] - # Assert our configured public key is being given - self.assertEqual( - channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" - ) - - # Complete the recaptcha step. - self.recaptcha(session, 200) - - # also complete the dummy auth - self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) - - # Now we should have fulfilled a complete auth flow, including - # the recaptcha fallback step. Make the initial request again, but - # with a changed password. This still completes. - channel = self.register( - 200, - { - "username": "user", - "type": "m.login.password", - "password": "foo", # Note that this is different. - "auth": {"session": session}, - }, - ) - - # We're given a registered user. - self.assertEqual(channel.json_body["user_id"], "@user:test") - def test_complete_operation_unknown_session(self): """ Attempting to mark an invalid session as complete should error. @@ -282,9 +241,15 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) - def test_cannot_change_body(self): + def test_can_change_body(self): """ - The initial requested client dict cannot be modified during the user interactive authentication session. + The client dict can be modified during the user interactive authentication session. + + Note that it is not spec compliant to modify the client dict during a + user interactive authentication session, but many clients currently do. + + When Synapse is updated to be spec compliant, the call to re-use the + session ID should be rejected. """ # Create a second login. self.login("test", self.user_pass) @@ -302,9 +267,9 @@ class UIAuthTests(unittest.HomeserverTestCase): self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) # Make another request providing the UI auth flow, but try to delete the - # second device. This results in an error. + # second device. self.delete_devices( - 403, + 200, { "devices": [device_ids[1]], "auth": { -- cgit 1.5.1