diff options
Diffstat (limited to 'synapse/rest/client/v2_alpha')
26 files changed, 1152 insertions, 805 deletions
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 8250ae0ae1..bc11b4dda4 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -32,7 +32,7 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False): Args: path_regex (str): The regex string to match. This should NOT have a ^ - as this will be prefixed. + as this will be prefixed. Returns: SRE_Pattern """ @@ -78,7 +78,7 @@ def interactive_auth_handler(orig): """ def wrapped(*args, **kwargs): - res = defer.maybeDeferred(orig, *args, **kwargs) + res = defer.ensureDeferred(orig(*args, **kwargs)) res.addErrback(_catch_incomplete_interactive_auth) return res diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 2ea515d2f6..d4f721b6b9 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -18,8 +18,6 @@ import logging from six.moves import http_client -from twisted.internet import defer - from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError, ThreepidValidationError from synapse.config.emailconfig import ThreepidBehaviour @@ -32,6 +30,7 @@ from synapse.http.servlet import ( ) from synapse.push.mailer import Mailer, load_jinja2_templates from synapse.util.msisdn import phone_number_to_msisdn +from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import check_3pid_allowed from ._base import client_patterns, interactive_auth_handler @@ -67,11 +66,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): template_text=template_text, ) - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "User password resets have been disabled due to lack of email config" ) raise SynapseError( @@ -84,6 +82,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Extract params from body client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + email = body["email"] send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param @@ -95,25 +95,23 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email ) if existing_user_id is None: + if self.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - # Have the configured identity server handle the request - if not self.hs.config.account_threepid_delegate_email: - logger.warn( - "No upstream email account_threepid_delegate configured on the server to " - "handle this request" - ) - raise SynapseError( - 400, "Password reset by email is not supported on this homeserver" - ) + assert self.hs.config.account_threepid_delegate_email - ret = yield self.identity_handler.requestEmailToken( + # Have the configured identity server handle the request + ret = await self.identity_handler.requestEmailToken( self.hs.config.account_threepid_delegate_email, email, client_secret, @@ -122,7 +120,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): ) else: # Send password reset emails from Synapse - sid = yield self.identity_handler.send_threepid_validation( + sid = await self.identity_handler.send_threepid_validation( email, client_secret, send_attempt, @@ -136,71 +134,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): return 200, ret -class MsisdnPasswordRequestTokenRestServlet(RestServlet): - PATTERNS = client_patterns("/account/password/msisdn/requestToken$") - - def __init__(self, hs): - super(MsisdnPasswordRequestTokenRestServlet, self).__init__() - self.hs = hs - self.datastore = self.hs.get_datastore() - self.identity_handler = hs.get_handlers().identity_handler - - @defer.inlineCallbacks - def on_POST(self, request): - body = parse_json_object_from_request(request) - - assert_params_in_dict( - body, ["client_secret", "country", "phone_number", "send_attempt"] - ) - client_secret = body["client_secret"] - country = body["country"] - phone_number = body["phone_number"] - send_attempt = body["send_attempt"] - next_link = body.get("next_link") # Optional param - - msisdn = phone_number_to_msisdn(country, phone_number) - - if not check_3pid_allowed(self.hs, "msisdn", msisdn): - raise SynapseError( - 403, - "Account phone numbers are not authorized on this server", - Codes.THREEPID_DENIED, - ) - - existing_user_id = yield self.datastore.get_user_id_by_threepid( - "msisdn", msisdn - ) - - if existing_user_id is None: - raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND) - - if not self.hs.config.account_threepid_delegate_msisdn: - logger.warn( - "No upstream msisdn account_threepid_delegate configured on the server to " - "handle this request" - ) - raise SynapseError( - 400, - "Password reset by phone number is not supported on this homeserver", - ) - - ret = yield self.identity_handler.requestMsisdnToken( - self.hs.config.account_threepid_delegate_msisdn, - country, - phone_number, - client_secret, - send_attempt, - next_link, - ) - - return 200, ret - - class PasswordResetSubmitTokenServlet(RestServlet): """Handles 3PID validation token submission""" PATTERNS = client_patterns( - "/password_reset/(?P<medium>[^/]*)/submit_token/*$", releases=(), unstable=True + "/password_reset/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True ) def __init__(self, hs): @@ -214,9 +152,13 @@ class PasswordResetSubmitTokenServlet(RestServlet): self.config = hs.config self.clock = hs.get_clock() self.store = hs.get_datastore() + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + (self.failure_email_template,) = load_jinja2_templates( + self.config.email_template_dir, + [self.config.email_password_reset_template_failure_html], + ) - @defer.inlineCallbacks - def on_GET(self, request, medium): + async def on_GET(self, request, medium): # We currently only handle threepid token submissions for email if medium != "email": raise SynapseError( @@ -224,7 +166,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): ) if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Password reset emails have been disabled due to lack of an email config" ) raise SynapseError( @@ -232,20 +174,21 @@ class PasswordResetSubmitTokenServlet(RestServlet): ) sid = parse_string(request, "sid", required=True) - client_secret = parse_string(request, "client_secret", required=True) token = parse_string(request, "token", required=True) + client_secret = parse_string(request, "client_secret", required=True) + assert_valid_client_secret(client_secret) # Attempt to validate a 3PID session try: # Mark the session as valid - next_link = yield self.store.validate_threepid_session( + next_link = await self.store.validate_threepid_session( sid, client_secret, token, self.clock.time_msec() ) # Perform a 302 redirect if next_link is set if next_link: if next_link.startswith("file:///"): - logger.warn( + logger.warning( "Not redirecting to next_link as it is a local file: address" ) else: @@ -261,34 +204,12 @@ class PasswordResetSubmitTokenServlet(RestServlet): request.setResponseCode(e.code) # Show a failure page with a reason - html_template, = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_password_reset_template_failure_html], - ) - template_vars = {"failure_reason": e.msg} - html = html_template.render(**template_vars) + html = self.failure_email_template.render(**template_vars) request.write(html.encode("utf-8")) finish_request(request) - @defer.inlineCallbacks - def on_POST(self, request, medium): - if medium != "email": - raise SynapseError( - 400, "This medium is currently not supported for password resets" - ) - - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["sid", "client_secret", "token"]) - - valid, _ = yield self.store.validate_threepid_session( - body["sid"], body["client_secret"], body["token"], self.clock.time_msec() - ) - response_code = 200 if valid else 400 - - return response_code, {"success": valid} - class PasswordRestServlet(RestServlet): PATTERNS = client_patterns("/account/password$") @@ -299,13 +220,27 @@ class PasswordRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.datastore = self.hs.get_datastore() + self.password_policy_handler = hs.get_password_policy_handler() self._set_password_handler = hs.get_set_password_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) + # we do basic sanity checks here because the auth layer will store these + # in sessions. Pull out the new password provided to us. + if "new_password" in body: + new_password = body.pop("new_password") + if not isinstance(new_password, str) or len(new_password) > 512: + raise SynapseError(400, "Invalid password") + self.password_policy_handler.validate_password(new_password) + + # If the password is valid, hash it and store it back on the body. + # This ensures that only the hashed password is handled everywhere. + if "new_password_hash" in body: + raise SynapseError(400, "Unexpected property: new_password_hash") + body["new_password_hash"] = await self.auth_handler.hash(new_password) + # there are two possibilities here. Either the user does not have an # access token, and needs to do a password reset; or they have one and # need to validate their identity. @@ -317,17 +252,23 @@ class PasswordRestServlet(RestServlet): # In the second case, we require a password to confirm their identity. if self.auth.has_access_token(request): - requester = yield self.auth.get_user_by_req(request) - params = yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester = await self.auth.get_user_by_req(request) + params = await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + self.hs.get_ip_from_request(request), + "modify your account password", ) user_id = requester.user.to_string() else: requester = None - result, params, _ = yield self.auth_handler.check_auth( - [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]], + result, params, _ = await self.auth_handler.check_auth( + [[LoginType.EMAIL_IDENTITY]], + request, body, self.hs.get_ip_from_request(request), + "modify your account password", ) if LoginType.EMAIL_IDENTITY in result: @@ -340,7 +281,7 @@ class PasswordRestServlet(RestServlet): # (See add_threepid in synapse/handlers/auth.py) threepid["address"] = threepid["address"].lower() # if using email, we must know about the email they're authing with! - threepid_user_id = yield self.datastore.get_user_id_by_threepid( + threepid_user_id = await self.datastore.get_user_id_by_threepid( threepid["medium"], threepid["address"] ) if not threepid_user_id: @@ -350,10 +291,13 @@ class PasswordRestServlet(RestServlet): logger.error("Auth succeeded but no known type! %r", result.keys()) raise SynapseError(500, "", Codes.UNKNOWN) - assert_params_in_dict(params, ["new_password"]) - new_password = params["new_password"] + assert_params_in_dict(params, ["new_password_hash"]) + new_password_hash = params["new_password_hash"] + logout_devices = params.get("logout_devices", True) - yield self._set_password_handler.set_password(user_id, new_password, requester) + await self._set_password_handler.set_password( + user_id, new_password_hash, logout_devices, requester + ) return 200, {} @@ -372,8 +316,7 @@ class DeactivateAccountRestServlet(RestServlet): self._deactivate_account_handler = hs.get_deactivate_account_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) erase = body.get("erase", False) if not isinstance(erase, bool): @@ -383,19 +326,23 @@ class DeactivateAccountRestServlet(RestServlet): Codes.BAD_JSON, ) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) # allow ASes to dectivate their own users if requester.app_service: - yield self._deactivate_account_handler.deactivate_account( + await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase ) return 200, {} - yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + self.hs.get_ip_from_request(request), + "deactivate your account", ) - result = yield self._deactivate_account_handler.deactivate_account( + result = await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase, id_server=body.get("id_server") ) if result: @@ -416,14 +363,37 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): self.identity_handler = hs.get_handlers().identity_handler self.store = self.hs.get_datastore() - @defer.inlineCallbacks - def on_POST(self, request): + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + template_html, template_text = load_jinja2_templates( + self.config.email_template_dir, + [ + self.config.email_add_threepid_template_html, + self.config.email_add_threepid_template_text, + ], + public_baseurl=self.config.public_baseurl, + ) + self.mailer = Mailer( + hs=self.hs, + app_name=self.config.email_app_name, + template_html=template_html, + template_text=template_text, + ) + + async def on_POST(self, request): + if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.local_threepid_handling_disabled_due_to_email_config: + logger.warning( + "Adding emails have been disabled due to lack of an email config" + ) + raise SynapseError( + 400, "Adding an email to your account is disabled on this server" + ) + body = parse_json_object_from_request(request) - assert_params_in_dict( - body, ["id_server", "client_secret", "email", "send_attempt"] - ) - id_server = "https://" + body["id_server"] # Assume https + assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + email = body["email"] send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param @@ -435,16 +405,42 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.store.get_user_id_by_threepid( + existing_user_id = await self.store.get_user_id_by_threepid( "email", body["email"] ) if existing_user_id is not None: + if self.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - ret = yield self.identity_handler.requestEmailToken( - id_server, email, client_secret, send_attempt, next_link - ) + if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + assert self.hs.config.account_threepid_delegate_email + + # Have the configured identity server handle the request + ret = await self.identity_handler.requestEmailToken( + self.hs.config.account_threepid_delegate_email, + email, + client_secret, + send_attempt, + next_link, + ) + else: + # Send threepid validation emails from Synapse + sid = await self.identity_handler.send_threepid_validation( + email, + client_secret, + send_attempt, + self.mailer.send_add_threepid_mail, + next_link, + ) + + # Wrap the session id in a JSON object + ret = {"sid": sid} + return 200, ret @@ -457,15 +453,14 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): self.store = self.hs.get_datastore() self.identity_handler = hs.get_handlers().identity_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict( - body, - ["id_server", "client_secret", "country", "phone_number", "send_attempt"], + body, ["client_secret", "country", "phone_number", "send_attempt"] ) - id_server = "https://" + body["id_server"] # Assume https client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + country = body["country"] phone_number = body["phone_number"] send_attempt = body["send_attempt"] @@ -480,17 +475,156 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.store.get_user_id_by_threepid("msisdn", msisdn) + existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) if existing_user_id is not None: + if self.hs.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) - ret = yield self.identity_handler.requestMsisdnToken( - id_server, country, phone_number, client_secret, send_attempt, next_link + if not self.hs.config.account_threepid_delegate_msisdn: + logger.warning( + "No upstream msisdn account_threepid_delegate configured on the server to " + "handle this request" + ) + raise SynapseError( + 400, + "Adding phone numbers to user account is not supported by this homeserver", + ) + + ret = await self.identity_handler.requestMsisdnToken( + self.hs.config.account_threepid_delegate_msisdn, + country, + phone_number, + client_secret, + send_attempt, + next_link, ) + return 200, ret +class AddThreepidEmailSubmitTokenServlet(RestServlet): + """Handles 3PID validation token submission for adding an email to a user's account""" + + PATTERNS = client_patterns( + "/add_threepid/email/submit_token$", releases=(), unstable=True + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.config = hs.config + self.clock = hs.get_clock() + self.store = hs.get_datastore() + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + (self.failure_email_template,) = load_jinja2_templates( + self.config.email_template_dir, + [self.config.email_add_threepid_template_failure_html], + ) + + async def on_GET(self, request): + if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.local_threepid_handling_disabled_due_to_email_config: + logger.warning( + "Adding emails have been disabled due to lack of an email config" + ) + raise SynapseError( + 400, "Adding an email to your account is disabled on this server" + ) + elif self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + raise SynapseError( + 400, + "This homeserver is not validating threepids. Use an identity server " + "instead.", + ) + + sid = parse_string(request, "sid", required=True) + token = parse_string(request, "token", required=True) + client_secret = parse_string(request, "client_secret", required=True) + assert_valid_client_secret(client_secret) + + # Attempt to validate a 3PID session + try: + # Mark the session as valid + next_link = await self.store.validate_threepid_session( + sid, client_secret, token, self.clock.time_msec() + ) + + # Perform a 302 redirect if next_link is set + if next_link: + if next_link.startswith("file:///"): + logger.warning( + "Not redirecting to next_link as it is a local file: address" + ) + else: + request.setResponseCode(302) + request.setHeader("Location", next_link) + finish_request(request) + return None + + # Otherwise show the success template + html = self.config.email_add_threepid_template_success_html_content + request.setResponseCode(200) + except ThreepidValidationError as e: + request.setResponseCode(e.code) + + # Show a failure page with a reason + template_vars = {"failure_reason": e.msg} + html = self.failure_email_template.render(**template_vars) + + request.write(html.encode("utf-8")) + finish_request(request) + + +class AddThreepidMsisdnSubmitTokenServlet(RestServlet): + """Handles 3PID validation token submission for adding a phone number to a user's + account + """ + + PATTERNS = client_patterns( + "/add_threepid/msisdn/submit_token$", releases=(), unstable=True + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.config = hs.config + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.identity_handler = hs.get_handlers().identity_handler + + async def on_POST(self, request): + if not self.config.account_threepid_delegate_msisdn: + raise SynapseError( + 400, + "This homeserver is not validating phone numbers. Use an identity server " + "instead.", + ) + + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["client_secret", "sid", "token"]) + assert_valid_client_secret(body["client_secret"]) + + # Proxy submit_token request to msisdn threepid delegate + response = await self.identity_handler.proxy_msisdn_submit_token( + self.config.account_threepid_delegate_msisdn, + body["client_secret"], + body["sid"], + body["token"], + ) + return 200, response + + class ThreepidRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid$") @@ -502,16 +636,21 @@ class ThreepidRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.datastore = self.hs.get_datastore() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) - threepids = yield self.datastore.user_get_threepids(requester.user.to_string()) + threepids = await self.datastore.user_get_threepids(requester.user.to_string()) return 200, {"threepids": threepids} - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() body = parse_json_object_from_request(request) threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds") @@ -519,34 +658,111 @@ class ThreepidRestServlet(RestServlet): raise SynapseError( 400, "Missing param three_pid_creds", Codes.MISSING_PARAM ) + assert_params_in_dict(threepid_creds, ["client_secret", "sid"]) + + sid = threepid_creds["sid"] + client_secret = threepid_creds["client_secret"] + assert_valid_client_secret(client_secret) + + validation_session = await self.identity_handler.validate_threepid_session( + client_secret, sid + ) + if validation_session: + await self.auth_handler.add_threepid( + user_id, + validation_session["medium"], + validation_session["address"], + validation_session["validated_at"], + ) + return 200, {} + + raise SynapseError( + 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED + ) + + +class ThreepidAddRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/add$", releases=(), unstable=True) + + def __init__(self, hs): + super(ThreepidAddRestServlet, self).__init__() + self.hs = hs + self.identity_handler = hs.get_handlers().identity_handler + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + + @interactive_auth_handler + async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() + body = parse_json_object_from_request(request) - # Specify None as the identity server to retrieve it from the request body instead - threepid = yield self.identity_handler.threepid_from_creds(None, threepid_creds) + assert_params_in_dict(body, ["client_secret", "sid"]) + sid = body["sid"] + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) - if not threepid: - raise SynapseError(400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED) + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + self.hs.get_ip_from_request(request), + "add a third-party identifier to your account", + ) - for reqd in ["medium", "address", "validated_at"]: - if reqd not in threepid: - logger.warn("Couldn't add 3pid: invalid response from ID server") - raise SynapseError(500, "Invalid response from ID Server") + validation_session = await self.identity_handler.validate_threepid_session( + client_secret, sid + ) + if validation_session: + await self.auth_handler.add_threepid( + user_id, + validation_session["medium"], + validation_session["address"], + validation_session["validated_at"], + ) + return 200, {} - yield self.auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"] + raise SynapseError( + 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED ) - if "bind" in body and body["bind"]: - logger.debug("Binding threepid %s to %s", threepid, user_id) - yield self.identity_handler.bind_threepid(threepid_creds, user_id) + +class ThreepidBindRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/bind$", releases=(), unstable=True) + + def __init__(self, hs): + super(ThreepidBindRestServlet, self).__init__() + self.hs = hs + self.identity_handler = hs.get_handlers().identity_handler + self.auth = hs.get_auth() + + async def on_POST(self, request): + body = parse_json_object_from_request(request) + + assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) + id_server = body["id_server"] + sid = body["sid"] + id_access_token = body.get("id_access_token") # optional + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + await self.identity_handler.bind_threepid( + client_secret, sid, user_id, id_server, id_access_token + ) return 200, {} class ThreepidUnbindRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/unbind$") + PATTERNS = client_patterns("/account/3pid/unbind$", releases=(), unstable=True) def __init__(self, hs): super(ThreepidUnbindRestServlet, self).__init__() @@ -555,12 +771,11 @@ class ThreepidUnbindRestServlet(RestServlet): self.auth = hs.get_auth() self.datastore = self.hs.get_datastore() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): """Unbind the given 3pid from a specific identity server, or identity servers that are known to have this 3pid bound """ - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) @@ -570,7 +785,7 @@ class ThreepidUnbindRestServlet(RestServlet): # Attempt to unbind the threepid from an identity server. If id_server is None, try to # unbind from all identity servers this threepid has been added to in the past - result = yield self.identity_handler.try_unbind_threepid( + result = await self.identity_handler.try_unbind_threepid( requester.user.to_string(), {"address": address, "medium": medium, "id_server": id_server}, ) @@ -582,19 +797,24 @@ class ThreepidDeleteRestServlet(RestServlet): def __init__(self, hs): super(ThreepidDeleteRestServlet, self).__init__() + self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() try: - ret = yield self.auth_handler.delete_threepid( + ret = await self.auth_handler.delete_threepid( user_id, body["medium"], body["address"], body.get("id_server") ) except Exception: @@ -619,22 +839,24 @@ class WhoamiRestServlet(RestServlet): super(WhoamiRestServlet, self).__init__() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) return 200, {"user_id": requester.user.to_string()} def register_servlets(hs, http_server): EmailPasswordRequestTokenRestServlet(hs).register(http_server) - MsisdnPasswordRequestTokenRestServlet(hs).register(http_server) PasswordResetSubmitTokenServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) EmailThreepidRequestTokenRestServlet(hs).register(http_server) MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) + AddThreepidEmailSubmitTokenServlet(hs).register(http_server) + AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server) + ThreepidAddRestServlet(hs).register(http_server) + ThreepidBindRestServlet(hs).register(http_server) ThreepidUnbindRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server) WhoamiRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index f0db204ffa..c1d4cd0caf 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -40,16 +38,19 @@ class AccountDataServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() + self._is_worker = hs.config.worker_app is not None + + async def on_PUT(self, request, user_id, account_data_type): + if self._is_worker: + raise Exception("Cannot handle PUT /account_data on worker") - @defer.inlineCallbacks - def on_PUT(self, request, user_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") body = parse_json_object_from_request(request) - max_id = yield self.store.add_account_data_for_user( + max_id = await self.store.add_account_data_for_user( user_id, account_data_type, body ) @@ -57,13 +58,12 @@ class AccountDataServlet(RestServlet): return 200, {} - @defer.inlineCallbacks - def on_GET(self, request, user_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id, account_data_type): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") - event = yield self.store.get_global_account_data_by_type_for_user( + event = await self.store.get_global_account_data_by_type_for_user( account_data_type, user_id ) @@ -90,10 +90,13 @@ class RoomAccountDataServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() + self._is_worker = hs.config.worker_app is not None + + async def on_PUT(self, request, user_id, room_id, account_data_type): + if self._is_worker: + raise Exception("Cannot handle PUT /account_data on worker") - @defer.inlineCallbacks - def on_PUT(self, request, user_id, room_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") @@ -106,7 +109,7 @@ class RoomAccountDataServlet(RestServlet): " Use /rooms/!roomId:server.name/read_markers", ) - max_id = yield self.store.add_account_data_to_room( + max_id = await self.store.add_account_data_to_room( user_id, room_id, account_data_type, body ) @@ -114,13 +117,12 @@ class RoomAccountDataServlet(RestServlet): return 200, {} - @defer.inlineCallbacks - def on_GET(self, request, user_id, room_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id, room_id, account_data_type): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") - event = yield self.store.get_account_data_for_room_and_type( + event = await self.store.get_account_data_for_room_and_type( user_id, room_id, account_data_type ) diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 33f6a23028..2f10fa64e2 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError, SynapseError from synapse.http.server import finish_request from synapse.http.servlet import RestServlet @@ -45,13 +43,12 @@ class AccountValidityRenewServlet(RestServlet): self.success_html = hs.config.account_validity.account_renewed_html_content self.failure_html = hs.config.account_validity.invalid_token_html_content - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): if b"token" not in request.args: raise SynapseError(400, "Missing renewal token") renewal_token = request.args[b"token"][0] - token_valid = yield self.account_activity_handler.renew_account( + token_valid = await self.account_activity_handler.renew_account( renewal_token.decode("utf8") ) @@ -67,7 +64,6 @@ class AccountValidityRenewServlet(RestServlet): request.setHeader(b"Content-Length", b"%d" % (len(response),)) request.write(response.encode("utf8")) finish_request(request) - defer.returnValue(None) class AccountValiditySendMailServlet(RestServlet): @@ -85,18 +81,17 @@ class AccountValiditySendMailServlet(RestServlet): self.auth = hs.get_auth() self.account_validity = self.hs.config.account_validity - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if not self.account_validity.renew_by_email_enabled: raise AuthError( 403, "Account renewal via email is disabled on this server." ) - requester = yield self.auth.get_user_by_req(request, allow_expired=True) + requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() - yield self.account_activity_handler.send_renewal_email_to_user(user_id) + await self.account_activity_handler.send_renewal_email_to_user(user_id) - defer.returnValue((200, {})) + return 200, {} def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index f21aff39e5..75590ebaeb 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import LoginType from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_API_PREFIX @@ -132,7 +130,22 @@ class AuthRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() - def on_GET(self, request, stagetype): + # SSO configuration. + self._cas_enabled = hs.config.cas_enabled + if self._cas_enabled: + self._cas_handler = hs.get_cas_handler() + self._cas_server_url = hs.config.cas_server_url + self._cas_service_url = hs.config.cas_service_url + self._saml_enabled = hs.config.saml2_enabled + if self._saml_enabled: + self._saml_handler = hs.get_saml_handler() + self._oidc_enabled = hs.config.oidc_enabled + if self._oidc_enabled: + self._oidc_handler = hs.get_oidc_handler() + self._cas_server_url = hs.config.cas_server_url + self._cas_service_url = hs.config.cas_service_url + + async def on_GET(self, request, stagetype): session = parse_string(request, "session") if not session: raise SynapseError(400, "No session supplied") @@ -144,14 +157,6 @@ class AuthRestServlet(RestServlet): % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), "sitekey": self.hs.config.recaptcha_public_key, } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - return None elif stagetype == LoginType.TERMS: html = TERMS_TEMPLATE % { "session": session, @@ -160,19 +165,51 @@ class AuthRestServlet(RestServlet): "myurl": "%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - return None + + elif stagetype == LoginType.SSO: + # Display a confirmation page which prompts the user to + # re-authenticate with their SSO provider. + if self._cas_enabled: + # Generate a request to CAS that redirects back to an endpoint + # to verify the successful authentication. + sso_redirect_url = self._cas_handler.get_redirect_url( + {"session": session}, + ) + + elif self._saml_enabled: + # Some SAML identity providers (e.g. Google) require a + # RelayState parameter on requests. It is not necessary here, so + # pass in a dummy redirect URL (which will never get used). + client_redirect_url = b"unused" + sso_redirect_url = self._saml_handler.handle_redirect_request( + client_redirect_url, session + ) + + elif self._oidc_enabled: + client_redirect_url = b"" + sso_redirect_url = await self._oidc_handler.handle_redirect_request( + request, client_redirect_url, session + ) + + else: + raise SynapseError(400, "Homeserver not configured for SSO.") + + html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) + else: raise SynapseError(404, "Unknown auth stage type") - @defer.inlineCallbacks - def on_POST(self, request, stagetype): + # Render the HTML and return. + html_bytes = html.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + finish_request(request) + return None + + async def on_POST(self, request, stagetype): session = parse_string(request, "session") if not session: @@ -186,7 +223,7 @@ class AuthRestServlet(RestServlet): authdict = {"response": response, "session": session} - success = yield self.auth_handler.add_oob_auth( + success = await self.auth_handler.add_oob_auth( LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request) ) @@ -199,23 +236,10 @@ class AuthRestServlet(RestServlet): % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), "sitekey": self.hs.config.recaptcha_public_key, } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - - return None elif stagetype == LoginType.TERMS: - if ("session" not in request.args or len(request.args["session"])) == 0: - raise SynapseError(400, "No session supplied") - - session = request.args["session"][0] authdict = {"session": session} - success = yield self.auth_handler.add_oob_auth( + success = await self.auth_handler.add_oob_auth( LoginType.TERMS, authdict, self.hs.get_ip_from_request(request) ) @@ -232,17 +256,22 @@ class AuthRestServlet(RestServlet): "myurl": "%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - return None + elif stagetype == LoginType.SSO: + # The SSO fallback workflow should not post here, + raise SynapseError(404, "Fallback SSO auth does not support POST requests.") else: raise SynapseError(404, "Unknown auth stage type") + # Render the HTML and return. + html_bytes = html.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + finish_request(request) + return None + def on_OPTIONS(self, _): return 200, {} diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py index acd58af193..fe9d019c44 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py @@ -14,8 +14,6 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import RestServlet @@ -40,10 +38,9 @@ class CapabilitiesRestServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) - user = yield self.store.get_user_by_id(requester.user.to_string()) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user = await self.store.get_user_by_id(requester.user.to_string()) change_password = bool(user["password_hash"]) response = { diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 26d0235208..c0714fcfb1 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api import errors from synapse.http.servlet import ( RestServlet, @@ -42,10 +40,9 @@ class DevicesRestServlet(RestServlet): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) - devices = yield self.device_handler.get_devices_by_user( + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + devices = await self.device_handler.get_devices_by_user( requester.user.to_string() ) return 200, {"devices": devices} @@ -67,9 +64,8 @@ class DeleteDevicesRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) try: body = parse_json_object_from_request(request) @@ -84,11 +80,15 @@ class DeleteDevicesRestServlet(RestServlet): assert_params_in_dict(body, ["devices"]) - yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + self.hs.get_ip_from_request(request), + "remove device(s) from your account", ) - yield self.device_handler.delete_devices( + await self.device_handler.delete_devices( requester.user.to_string(), body["devices"] ) return 200, {} @@ -108,18 +108,16 @@ class DeviceRestServlet(RestServlet): self.device_handler = hs.get_device_handler() self.auth_handler = hs.get_auth_handler() - @defer.inlineCallbacks - def on_GET(self, request, device_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) - device = yield self.device_handler.get_device( + async def on_GET(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + device = await self.device_handler.get_device( requester.user.to_string(), device_id ) return 200, device @interactive_auth_handler - @defer.inlineCallbacks - def on_DELETE(self, request, device_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, device_id): + requester = await self.auth.get_user_by_req(request) try: body = parse_json_object_from_request(request) @@ -132,19 +130,22 @@ class DeviceRestServlet(RestServlet): else: raise - yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + self.hs.get_ip_from_request(request), + "remove a device from your account", ) - yield self.device_handler.delete_device(requester.user.to_string(), device_id) + await self.device_handler.delete_device(requester.user.to_string(), device_id) return 200, {} - @defer.inlineCallbacks - def on_PUT(self, request, device_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_PUT(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) body = parse_json_object_from_request(request) - yield self.device_handler.update_device( + await self.device_handler.update_device( requester.user.to_string(), device_id, body ) return 200, {} diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index c6ddf24c8d..b28da017cd 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -15,9 +15,7 @@ import logging -from twisted.internet import defer - -from synapse.api.errors import AuthError, Codes, StoreError, SynapseError +from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID @@ -35,10 +33,9 @@ class GetFilterRestServlet(RestServlet): self.auth = hs.get_auth() self.filtering = hs.get_filtering() - @defer.inlineCallbacks - def on_GET(self, request, user_id, filter_id): + async def on_GET(self, request, user_id, filter_id): target_user = UserID.from_string(user_id) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if target_user != requester.user: raise AuthError(403, "Cannot get filters for other users") @@ -52,13 +49,15 @@ class GetFilterRestServlet(RestServlet): raise SynapseError(400, "Invalid filter_id") try: - filter = yield self.filtering.get_user_filter( + filter_collection = await self.filtering.get_user_filter( user_localpart=target_user.localpart, filter_id=filter_id ) + except StoreError as e: + if e.code != 404: + raise + raise NotFoundError("No such filter") - return 200, filter.get_filter_json() - except (KeyError, StoreError): - raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND) + return 200, filter_collection.get_filter_json() class CreateFilterRestServlet(RestServlet): @@ -70,11 +69,10 @@ class CreateFilterRestServlet(RestServlet): self.auth = hs.get_auth() self.filtering = hs.get_filtering() - @defer.inlineCallbacks - def on_POST(self, request, user_id): + async def on_POST(self, request, user_id): target_user = UserID.from_string(user_id) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if target_user != requester.user: raise AuthError(403, "Cannot create filters for other users") @@ -85,7 +83,7 @@ class CreateFilterRestServlet(RestServlet): content = parse_json_object_from_request(request) set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit) - filter_id = yield self.filtering.add_user_filter( + filter_id = await self.filtering.add_user_filter( user_localpart=target_user.localpart, user_filter=content ) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 999a0fa80c..d84a6d7e11 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import GroupID @@ -38,24 +36,22 @@ class GroupServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - group_description = yield self.groups_handler.get_group_profile( + group_description = await self.groups_handler.get_group_profile( group_id, requester_user_id ) return 200, group_description - @defer.inlineCallbacks - def on_POST(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - yield self.groups_handler.update_group_profile( + await self.groups_handler.update_group_profile( group_id, requester_user_id, content ) @@ -74,12 +70,11 @@ class GroupSummaryServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - get_group_summary = yield self.groups_handler.get_group_summary( + get_group_summary = await self.groups_handler.get_group_summary( group_id, requester_user_id ) @@ -106,13 +101,12 @@ class GroupSummaryRoomsCatServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, category_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, category_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_summary_room( + resp = await self.groups_handler.update_group_summary_room( group_id, requester_user_id, room_id=room_id, @@ -122,12 +116,11 @@ class GroupSummaryRoomsCatServlet(RestServlet): return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, category_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, category_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_summary_room( + resp = await self.groups_handler.delete_group_summary_room( group_id, requester_user_id, room_id=room_id, category_id=category_id ) @@ -148,35 +141,32 @@ class GroupCategoryServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id, category_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id, category_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_category( + category = await self.groups_handler.get_group_category( group_id, requester_user_id, category_id=category_id ) return 200, category - @defer.inlineCallbacks - def on_PUT(self, request, group_id, category_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, category_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_category( + resp = await self.groups_handler.update_group_category( group_id, requester_user_id, category_id=category_id, content=content ) return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, category_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, category_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_category( + resp = await self.groups_handler.delete_group_category( group_id, requester_user_id, category_id=category_id ) @@ -195,12 +185,11 @@ class GroupCategoriesServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_categories( + category = await self.groups_handler.get_group_categories( group_id, requester_user_id ) @@ -219,35 +208,32 @@ class GroupRoleServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id, role_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id, role_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_role( + category = await self.groups_handler.get_group_role( group_id, requester_user_id, role_id=role_id ) return 200, category - @defer.inlineCallbacks - def on_PUT(self, request, group_id, role_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, role_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_role( + resp = await self.groups_handler.update_group_role( group_id, requester_user_id, role_id=role_id, content=content ) return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, role_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, role_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_role( + resp = await self.groups_handler.delete_group_role( group_id, requester_user_id, role_id=role_id ) @@ -266,12 +252,11 @@ class GroupRolesServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_roles( + category = await self.groups_handler.get_group_roles( group_id, requester_user_id ) @@ -298,13 +283,12 @@ class GroupSummaryUsersRoleServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, role_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, role_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_summary_user( + resp = await self.groups_handler.update_group_summary_user( group_id, requester_user_id, user_id=user_id, @@ -314,12 +298,11 @@ class GroupSummaryUsersRoleServlet(RestServlet): return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, role_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, role_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_summary_user( + resp = await self.groups_handler.delete_group_summary_user( group_id, requester_user_id, user_id=user_id, role_id=role_id ) @@ -338,12 +321,11 @@ class GroupRoomServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_rooms_in_group( + result = await self.groups_handler.get_rooms_in_group( group_id, requester_user_id ) @@ -362,12 +344,11 @@ class GroupUsersServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_users_in_group( + result = await self.groups_handler.get_users_in_group( group_id, requester_user_id ) @@ -386,12 +367,11 @@ class GroupInvitedUsersServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_invited_users_in_group( + result = await self.groups_handler.get_invited_users_in_group( group_id, requester_user_id ) @@ -409,14 +389,13 @@ class GroupSettingJoinPolicyServlet(RestServlet): self.auth = hs.get_auth() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.set_group_join_policy( + result = await self.groups_handler.set_group_join_policy( group_id, requester_user_id, content ) @@ -436,9 +415,8 @@ class GroupCreateServlet(RestServlet): self.groups_handler = hs.get_groups_local_handler() self.server_name = hs.hostname - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() # TODO: Create group on remote server @@ -446,7 +424,7 @@ class GroupCreateServlet(RestServlet): localpart = content.pop("localpart") group_id = GroupID(localpart, self.server_name).to_string() - result = yield self.groups_handler.create_group( + result = await self.groups_handler.create_group( group_id, requester_user_id, content ) @@ -467,24 +445,22 @@ class GroupAdminRoomsServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.add_room_to_group( + result = await self.groups_handler.add_room_to_group( group_id, requester_user_id, room_id, content ) return 200, result - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.remove_room_from_group( + result = await self.groups_handler.remove_room_from_group( group_id, requester_user_id, room_id ) @@ -506,13 +482,12 @@ class GroupAdminRoomsConfigServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, room_id, config_key): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, room_id, config_key): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.update_room_in_group( + result = await self.groups_handler.update_room_in_group( group_id, requester_user_id, room_id, config_key, content ) @@ -535,14 +510,13 @@ class GroupAdminUsersInviteServlet(RestServlet): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id - @defer.inlineCallbacks - def on_PUT(self, request, group_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) config = content.get("config", {}) - result = yield self.groups_handler.invite( + result = await self.groups_handler.invite( group_id, user_id, requester_user_id, config ) @@ -563,13 +537,12 @@ class GroupAdminUsersKickServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.remove_user_from_group( + result = await self.groups_handler.remove_user_from_group( group_id, user_id, requester_user_id, content ) @@ -588,13 +561,12 @@ class GroupSelfLeaveServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.remove_user_from_group( + result = await self.groups_handler.remove_user_from_group( group_id, requester_user_id, requester_user_id, content ) @@ -613,13 +585,12 @@ class GroupSelfJoinServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.join_group( + result = await self.groups_handler.join_group( group_id, requester_user_id, content ) @@ -638,13 +609,12 @@ class GroupSelfAcceptInviteServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.accept_invite( + result = await self.groups_handler.accept_invite( group_id, requester_user_id, content ) @@ -663,14 +633,13 @@ class GroupSelfUpdatePublicityServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) publicise = content["publicise"] - yield self.store.update_group_publicity(group_id, requester_user_id, publicise) + await self.store.update_group_publicity(group_id, requester_user_id, publicise) return 200, {} @@ -688,11 +657,10 @@ class PublicisedGroupsForUserServlet(RestServlet): self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, user_id): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, user_id): + await self.auth.get_user_by_req(request, allow_guest=True) - result = yield self.groups_handler.get_publicised_groups_for_user(user_id) + result = await self.groups_handler.get_publicised_groups_for_user(user_id) return 200, result @@ -710,14 +678,13 @@ class PublicisedGroupsForUsersServlet(RestServlet): self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) user_ids = content["user_ids"] - result = yield self.groups_handler.bulk_get_publicised_groups(user_ids) + result = await self.groups_handler.bulk_get_publicised_groups(user_ids) return 200, result @@ -734,12 +701,11 @@ class GroupsForUserServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_joined_groups(requester_user_id) + result = await self.groups_handler.get_joined_groups(requester_user_id) return 200, result diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 2e680134a0..24bb090822 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import ( RestServlet, @@ -27,7 +26,7 @@ from synapse.http.servlet import ( from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.types import StreamToken -from ._base import client_patterns +from ._base import client_patterns, interactive_auth_handler logger = logging.getLogger(__name__) @@ -43,7 +42,7 @@ class KeyUploadServlet(RestServlet): "device_id": "<device_id>", "valid_until_ts": <millisecond_timestamp>, "algorithms": [ - "m.olm.curve25519-aes-sha256", + "m.olm.curve25519-aes-sha2", ] "keys": { "<algorithm>:<device_id>": "<key_base64>", @@ -70,9 +69,8 @@ class KeyUploadServlet(RestServlet): self.e2e_keys_handler = hs.get_e2e_keys_handler() @trace(opname="upload_keys") - @defer.inlineCallbacks - def on_POST(self, request, device_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -102,7 +100,7 @@ class KeyUploadServlet(RestServlet): 400, "To upload keys, you must pass device_id when authenticating" ) - result = yield self.e2e_keys_handler.upload_keys_for_user( + result = await self.e2e_keys_handler.upload_keys_for_user( user_id, device_id, body ) return 200, result @@ -126,7 +124,7 @@ class KeyQueryServlet(RestServlet): "device_id": "<device_id>", // Duplicated to be signed "valid_until_ts": <millisecond_timestamp>, "algorithms": [ // List of supported algorithms - "m.olm.curve25519-aes-sha256", + "m.olm.curve25519-aes-sha2", ], "keys": { // Must include a ed25519 signing key "<algorithm>:<key_id>": "<key_base64>", @@ -153,12 +151,12 @@ class KeyQueryServlet(RestServlet): self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - @defer.inlineCallbacks - def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user_id = requester.user.to_string() timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.query_devices(body, timeout) + result = await self.e2e_keys_handler.query_devices(body, timeout, user_id) return 200, result @@ -183,9 +181,8 @@ class KeyChangesServlet(RestServlet): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) from_token_string = parse_string(request, "from") set_tag("from", from_token_string) @@ -198,7 +195,7 @@ class KeyChangesServlet(RestServlet): user_id = requester.user.to_string() - results = yield self.device_handler.get_user_ids_changed(user_id, from_token) + results = await self.device_handler.get_user_ids_changed(user_id, from_token) return 200, results @@ -229,12 +226,100 @@ class OneTimeKeyServlet(RestServlet): self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - @defer.inlineCallbacks - def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout) + result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout) + return 200, result + + +class SigningKeyUploadServlet(RestServlet): + """ + POST /keys/device_signing/upload HTTP/1.1 + Content-Type: application/json + + { + } + """ + + PATTERNS = client_patterns("/keys/device_signing/upload$", releases=()) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(SigningKeyUploadServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.auth_handler = hs.get_auth_handler() + + @interactive_auth_handler + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + self.hs.get_ip_from_request(request), + "add a device signing key to your account", + ) + + result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) + return 200, result + + +class SignaturesUploadServlet(RestServlet): + """ + POST /keys/signatures/upload HTTP/1.1 + Content-Type: application/json + + { + "@alice:example.com": { + "<device_id>": { + "user_id": "<user_id>", + "device_id": "<device_id>", + "algorithms": [ + "m.olm.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2" + ], + "keys": { + "<algorithm>:<device_id>": "<key_base64>", + }, + "signatures": { + "<signing_user_id>": { + "<algorithm>:<signing_key_base64>": "<signature_base64>>" + } + } + } + } + } + """ + + PATTERNS = client_patterns("/keys/signatures/upload$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(SignaturesUploadServlet, self).__init__() + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + result = await self.e2e_keys_handler.upload_signatures_for_device_keys( + user_id, body + ) return 200, result @@ -243,3 +328,5 @@ def register_servlets(hs, http_server): KeyQueryServlet(hs).register(http_server) KeyChangesServlet(hs).register(http_server) OneTimeKeyServlet(hs).register(http_server) + SigningKeyUploadServlet(hs).register(http_server) + SignaturesUploadServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 10c1ad5b07..aa911d75ee 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.http.servlet import RestServlet, parse_integer, parse_string @@ -35,9 +33,8 @@ class NotificationsServlet(RestServlet): self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() from_token = parse_string(request, "from", required=False) @@ -46,16 +43,16 @@ class NotificationsServlet(RestServlet): limit = min(limit, 500) - push_actions = yield self.store.get_push_actions_for_user( + push_actions = await self.store.get_push_actions_for_user( user_id, from_token, limit, only_highlight=(only == "highlight") ) - receipts_by_room = yield self.store.get_receipts_for_user_with_orderings( + receipts_by_room = await self.store.get_receipts_for_user_with_orderings( user_id, "m.read" ) notif_event_ids = [pa["event_id"] for pa in push_actions] - notif_events = yield self.store.get_events(notif_event_ids) + notif_events = await self.store.get_events(notif_event_ids) returned_push_actions = [] @@ -68,7 +65,7 @@ class NotificationsServlet(RestServlet): "actions": pa["actions"], "ts": pa["received_ts"], "event": ( - yield self._event_serializer.serialize_event( + await self._event_serializer.serialize_event( notif_events[pa["event_id"]], self.clock.time_msec(), event_format=format_event_for_client_v2_without_room_id, diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py index b4925c0f59..6ae9a5a8e9 100644 --- a/synapse/rest/client/v2_alpha/openid.py +++ b/synapse/rest/client/v2_alpha/openid.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.util.stringutils import random_string @@ -68,9 +66,8 @@ class IdTokenServlet(RestServlet): self.clock = hs.get_clock() self.server_name = hs.config.server_name - @defer.inlineCallbacks - def on_POST(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, user_id): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot request tokens for other users.") @@ -81,7 +78,7 @@ class IdTokenServlet(RestServlet): token = random_string(24) ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS - yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) + await self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) return ( 200, diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py new file mode 100644 index 0000000000..968403cca4 --- /dev/null +++ b/synapse/rest/client/v2_alpha/password_policy.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.http.servlet import RestServlet + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class PasswordPolicyServlet(RestServlet): + PATTERNS = client_patterns("/password_policy$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(PasswordPolicyServlet, self).__init__() + + self.policy = hs.config.password_policy + self.enabled = hs.config.password_policy_enabled + + def on_GET(self, request): + if not self.enabled or not self.policy: + return (200, {}) + + policy = {} + + for param in [ + "minimum_length", + "require_digit", + "require_symbol", + "require_lowercase", + "require_uppercase", + ]: + if param in self.policy: + policy["m.%s" % param] = self.policy[param] + + return (200, policy) + + +def register_servlets(hs, http_server): + PasswordPolicyServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index b3bf8567e1..67cbc37312 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import RestServlet, parse_json_object_from_request from ._base import client_patterns @@ -34,17 +32,16 @@ class ReadMarkerRestServlet(RestServlet): self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() - @defer.inlineCallbacks - def on_POST(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id): + requester = await self.auth.get_user_by_req(request) - yield self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time(requester.user) body = parse_json_object_from_request(request) read_event_id = body.get("m.read", None) if read_event_id: - yield self.receipts_handler.received_client_receipt( + await self.receipts_handler.received_client_receipt( room_id, "m.read", user_id=requester.user.to_string(), @@ -53,7 +50,7 @@ class ReadMarkerRestServlet(RestServlet): read_marker_event_id = body.get("m.fully_read", None) if read_marker_event_id: - yield self.read_marker_handler.received_client_read_marker( + await self.read_marker_handler.received_client_read_marker( room_id, user_id=requester.user.to_string(), event_id=read_marker_event_id, diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 0dab03d227..92555bd4a9 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet @@ -39,16 +37,15 @@ class ReceiptRestServlet(RestServlet): self.receipts_handler = hs.get_receipts_handler() self.presence_handler = hs.get_presence_handler() - @defer.inlineCallbacks - def on_POST(self, request, room_id, receipt_type, event_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id, receipt_type, event_id): + requester = await self.auth.get_user_by_req(request) if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") - yield self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time(requester.user) - yield self.receipts_handler.received_client_receipt( + await self.receipts_handler.received_client_receipt( room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 5c7a5f3579..b9ffe86b2a 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -16,24 +16,28 @@ import hmac import logging +from typing import List, Union from six import string_types -from twisted.internet import defer - import synapse +import synapse.api.auth import synapse.types from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, - LimitExceededError, SynapseError, ThreepidValidationError, UnrecognizedRequestError, ) +from synapse.config import ConfigError +from synapse.config.captcha import CaptchaConfig +from synapse.config.consent_config import ConsentConfig from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.config.registration import RegistrationConfig from synapse.config.server import is_threepid_reserved +from synapse.handlers.auth import AuthHandler from synapse.http.server import finish_request from synapse.http.servlet import ( RestServlet, @@ -44,6 +48,7 @@ from synapse.http.servlet import ( from synapse.push.mailer import load_jinja2_templates from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter +from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import check_3pid_allowed from ._base import client_patterns, interactive_auth_handler @@ -96,11 +101,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): template_text=template_text, ) - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.hs.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Email registration has been disabled due to lack of email config" ) raise SynapseError( @@ -112,6 +116,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): # Extract params from body client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + email = body["email"] send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param @@ -123,24 +129,23 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", body["email"] ) if existing_user_id is not None: + if self.hs.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - if not self.hs.config.account_threepid_delegate_email: - logger.warn( - "No upstream email account_threepid_delegate configured on the server to " - "handle this request" - ) - raise SynapseError( - 400, "Registration by email is not supported on this homeserver" - ) + assert self.hs.config.account_threepid_delegate_email - ret = yield self.identity_handler.requestEmailToken( + # Have the configured identity server handle the request + ret = await self.identity_handler.requestEmailToken( self.hs.config.account_threepid_delegate_email, email, client_secret, @@ -149,7 +154,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): ) else: # Send registration emails from Synapse - sid = yield self.identity_handler.send_threepid_validation( + sid = await self.identity_handler.send_threepid_validation( email, client_secret, send_attempt, @@ -175,8 +180,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): self.hs = hs self.identity_handler = hs.get_handlers().identity_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict( @@ -197,17 +201,22 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "msisdn", msisdn ) if existing_user_id is not None: + if self.hs.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError( 400, "Phone number is already in use", Codes.THREEPID_IN_USE ) if not self.hs.config.account_threepid_delegate_msisdn: - logger.warn( + logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" ) @@ -215,7 +224,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): 400, "Registration by phone number is not supported on this homeserver" ) - ret = yield self.identity_handler.requestMsisdnToken( + ret = await self.identity_handler.requestMsisdnToken( self.hs.config.account_threepid_delegate_msisdn, country, phone_number, @@ -246,15 +255,26 @@ class RegistrationSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_GET(self, request, medium): + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + (self.failure_email_template,) = load_jinja2_templates( + self.config.email_template_dir, + [self.config.email_registration_template_failure_html], + ) + + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + (self.failure_email_template,) = load_jinja2_templates( + self.config.email_template_dir, + [self.config.email_registration_template_failure_html], + ) + + async def on_GET(self, request, medium): if medium != "email": raise SynapseError( 400, "This medium is currently not supported for registration" ) if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "User registration via email has been disabled due to lack of email config" ) raise SynapseError( @@ -268,14 +288,14 @@ class RegistrationSubmitTokenServlet(RestServlet): # Attempt to validate a 3PID session try: # Mark the session as valid - next_link = yield self.store.validate_threepid_session( + next_link = await self.store.validate_threepid_session( sid, client_secret, token, self.clock.time_msec() ) # Perform a 302 redirect if next_link is set if next_link: if next_link.startswith("file:///"): - logger.warn( + logger.warning( "Not redirecting to next_link as it is a local file: address" ) else: @@ -289,17 +309,11 @@ class RegistrationSubmitTokenServlet(RestServlet): request.setResponseCode(200) except ThreepidValidationError as e: - # Show a failure page with a reason request.setResponseCode(e.code) # Show a failure page with a reason - html_template, = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_registration_template_failure_html], - ) - template_vars = {"failure_reason": e.msg} - html = html_template.render(**template_vars) + html = self.failure_email_template.render(**template_vars) request.write(html.encode("utf-8")) finish_request(request) @@ -332,15 +346,19 @@ class UsernameAvailabilityRestServlet(RestServlet): ), ) - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): + if not self.hs.config.enable_registration: + raise SynapseError( + 403, "Registration has been disabled", errcode=Codes.FORBIDDEN + ) + ip = self.hs.get_ip_from_request(request) with self.ratelimiter.ratelimit(ip) as wait_deferred: - yield wait_deferred + await wait_deferred username = parse_string(request, "username", required=True) - yield self.registration_handler.check_username(username) + await self.registration_handler.check_username(username) return 200, {"available": True} @@ -364,50 +382,46 @@ class RegisterRestServlet(RestServlet): self.room_member_handler = hs.get_room_member_handler() self.macaroon_gen = hs.get_macaroon_generator() self.ratelimiter = hs.get_registration_ratelimiter() + self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() + self._registration_flows = _calculate_registration_flows( + hs.config, self.auth_handler + ) + @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) client_addr = request.getClientIP() - time_now = self.clock.time() - - allowed, time_allowed = self.ratelimiter.can_do_action( - client_addr, - time_now_s=time_now, - rate_hz=self.hs.config.rc_registration.per_second, - burst_count=self.hs.config.rc_registration.burst_count, - update=False, - ) - - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)) - ) + self.ratelimiter.ratelimit(client_addr, update=False) kind = b"user" if b"kind" in request.args: kind = request.args[b"kind"][0] if kind == b"guest": - ret = yield self._do_guest_registration(body, address=client_addr) + ret = await self._do_guest_registration(body, address=client_addr) return ret elif kind != b"user": raise UnrecognizedRequestError( - "Do not understand membership kind: %s" % (kind,) + "Do not understand membership kind: %s" % (kind.decode("utf8"),) ) # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the username/password provided to us. if "password" in body: - if ( - not isinstance(body["password"], string_types) - or len(body["password"]) > 512 - ): + password = body.pop("password") + if not isinstance(password, string_types) or len(password) > 512: raise SynapseError(400, "Invalid password") + self.password_policy_handler.validate_password(password) + + # If the password is valid, hash it and store it back on the body. + # This ensures that only the hashed password is handled everywhere. + if "password_hash" in body: + raise SynapseError(400, "Unexpected property: password_hash") + body["password_hash"] = await self.auth_handler.hash(password) desired_username = None if "username" in body: @@ -420,7 +434,7 @@ class RegisterRestServlet(RestServlet): appservice = None if self.auth.has_access_token(request): - appservice = yield self.auth.get_appservice_by_req(request) + appservice = await self.auth.get_appservice_by_req(request) # fork off as soon as possible for ASes which have completely # different registration flows to normal users @@ -440,7 +454,7 @@ class RegisterRestServlet(RestServlet): access_token = self.auth.get_access_token_from_request(request) if isinstance(desired_username, string_types): - result = yield self._do_appservice_registration( + result = await self._do_appservice_registration( desired_username, access_token, body ) return 200, result # we throw for non 200 responses @@ -460,12 +474,12 @@ class RegisterRestServlet(RestServlet): guest_access_token = body.get("guest_access_token", None) - if "initial_device_display_name" in body and "password" not in body: + if "initial_device_display_name" in body and "password_hash" not in body: # ignore 'initial_device_display_name' if sent without # a password to work around a client bug where it sent # the 'initial_device_display_name' param alone, wiping out # the original registration params - logger.warn("Ignoring initial_device_display_name without password") + logger.warning("Ignoring initial_device_display_name without password") del body["initial_device_display_name"] session_id = self.auth_handler.get_session_id(body) @@ -475,80 +489,23 @@ class RegisterRestServlet(RestServlet): # registered a user for this session, so we could just return the # user here. We carry on and go through the auth checks though, # for paranoia. - registered_user_id = self.auth_handler.get_session_data( + registered_user_id = await self.auth_handler.get_session_data( session_id, "registered_user_id", None ) if desired_username is not None: - yield self.registration_handler.check_username( + await self.registration_handler.check_username( desired_username, guest_access_token=guest_access_token, assigned_user_id=registered_user_id, ) - # FIXME: need a better error than "no auth flow found" for scenarios - # where we required 3PID for registration but the user didn't give one - require_email = "email" in self.hs.config.registrations_require_3pid - require_msisdn = "msisdn" in self.hs.config.registrations_require_3pid - - show_msisdn = True - if self.hs.config.disable_msisdn_registration: - show_msisdn = False - require_msisdn = False - - flows = [] - if self.hs.config.enable_registration_captcha: - # only support 3PIDless registration if no 3PIDs are required - if not require_email and not require_msisdn: - # Also add a dummy flow here, otherwise if a client completes - # recaptcha first we'll assume they were going for this flow - # and complete the request, when they could have been trying to - # complete one of the flows with email/msisdn auth. - flows.extend([[LoginType.RECAPTCHA, LoginType.DUMMY]]) - # only support the email-only flow if we don't require MSISDN 3PIDs - if not require_msisdn: - flows.extend([[LoginType.RECAPTCHA, LoginType.EMAIL_IDENTITY]]) - - if show_msisdn: - # only support the MSISDN-only flow if we don't require email 3PIDs - if not require_email: - flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]]) - # always let users provide both MSISDN & email - flows.extend( - [[LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY]] - ) - else: - # only support 3PIDless registration if no 3PIDs are required - if not require_email and not require_msisdn: - flows.extend([[LoginType.DUMMY]]) - # only support the email-only flow if we don't require MSISDN 3PIDs - if not require_msisdn: - flows.extend([[LoginType.EMAIL_IDENTITY]]) - - if show_msisdn: - # only support the MSISDN-only flow if we don't require email 3PIDs - if not require_email or require_msisdn: - flows.extend([[LoginType.MSISDN]]) - # always let users provide both MSISDN & email - flows.extend([[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]]) - - # Append m.login.terms to all flows if we're requiring consent - if self.hs.config.user_consent_at_registration: - new_flows = [] - for flow in flows: - inserted = False - # m.login.terms should go near the end but before msisdn or email auth - for i, stage in enumerate(flow): - if stage == LoginType.EMAIL_IDENTITY or stage == LoginType.MSISDN: - flow.insert(i, LoginType.TERMS) - inserted = True - break - if not inserted: - flow.append(LoginType.TERMS) - flows.extend(new_flows) - - auth_result, params, session_id = yield self.auth_handler.check_auth( - flows, body, self.hs.get_ip_from_request(request) + auth_result, params, session_id = await self.auth_handler.check_auth( + self._registration_flows, + request, + body, + self.hs.get_ip_from_request(request), + "register a new account", ) # Check that we're not trying to register a denied 3pid. @@ -579,11 +536,11 @@ class RegisterRestServlet(RestServlet): registered = False else: # NB: This may be from the auth handler and NOT from the POST - assert_params_in_dict(params, ["password"]) + assert_params_in_dict(params, ["password_hash"]) desired_username = params.get("username", None) guest_access_token = params.get("guest_access_token", None) - new_password = params.get("password", None) + new_password_hash = params.get("password_hash", None) if desired_username is not None: desired_username = desired_username.lower() @@ -603,7 +560,7 @@ class RegisterRestServlet(RestServlet): medium = auth_result[login_type]["medium"] address = auth_result[login_type]["address"] - existing_user_id = yield self.store.get_user_id_by_threepid( + existing_user_id = await self.store.get_user_id_by_threepid( medium, address ) @@ -614,9 +571,9 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_IN_USE, ) - registered_user_id = yield self.registration_handler.register_user( + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, - password=new_password, + password_hash=new_password_hash, guest_access_token=guest_access_token, threepid=threepid, address=client_addr, @@ -627,22 +584,22 @@ class RegisterRestServlet(RestServlet): if is_threepid_reserved( self.hs.config.mau_limits_reserved_threepids, threepid ): - yield self.store.upsert_monthly_active_user(registered_user_id) + await self.store.upsert_monthly_active_user(registered_user_id) # remember that we've now registered that user account, and with # what user ID (since the user may not have specified) - self.auth_handler.set_session_data( + await self.auth_handler.set_session_data( session_id, "registered_user_id", registered_user_id ) registered = True - return_dict = yield self._create_registration_details( + return_dict = await self._create_registration_details( registered_user_id, params ) if registered: - yield self.registration_handler.post_registration_actions( + await self.registration_handler.post_registration_actions( user_id=registered_user_id, auth_result=auth_result, access_token=return_dict.get("access_token"), @@ -653,15 +610,13 @@ class RegisterRestServlet(RestServlet): def on_OPTIONS(self, _): return 200, {} - @defer.inlineCallbacks - def _do_appservice_registration(self, username, as_token, body): - user_id = yield self.registration_handler.appservice_register( + async def _do_appservice_registration(self, username, as_token, body): + user_id = await self.registration_handler.appservice_register( username, as_token ) - return (yield self._create_registration_details(user_id, body)) + return await self._create_registration_details(user_id, body) - @defer.inlineCallbacks - def _create_registration_details(self, user_id, params): + async def _create_registration_details(self, user_id, params): """Complete registration of newly-registered user Allocates device_id if one was not given; also creates access_token. @@ -677,18 +632,17 @@ class RegisterRestServlet(RestServlet): if not params.get("inhibit_login", False): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=False ) result.update({"access_token": access_token, "device_id": device_id}) return result - @defer.inlineCallbacks - def _do_guest_registration(self, params, address=None): + async def _do_guest_registration(self, params, address=None): if not self.hs.config.allow_guest_access: raise SynapseError(403, "Guest access is disabled") - user_id = yield self.registration_handler.register_user( + user_id = await self.registration_handler.register_user( make_guest=True, address=address ) @@ -696,7 +650,7 @@ class RegisterRestServlet(RestServlet): # we have nowhere to store it. device_id = synapse.api.auth.GUEST_DEVICE_ID initial_display_name = params.get("initial_device_display_name") - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=True ) @@ -711,6 +665,83 @@ class RegisterRestServlet(RestServlet): ) +def _calculate_registration_flows( + # technically `config` has to provide *all* of these interfaces, not just one + config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig], + auth_handler: AuthHandler, +) -> List[List[str]]: + """Get a suitable flows list for registration + + Args: + config: server configuration + auth_handler: authorization handler + + Returns: a list of supported flows + """ + # FIXME: need a better error than "no auth flow found" for scenarios + # where we required 3PID for registration but the user didn't give one + require_email = "email" in config.registrations_require_3pid + require_msisdn = "msisdn" in config.registrations_require_3pid + + show_msisdn = True + show_email = True + + if config.disable_msisdn_registration: + show_msisdn = False + require_msisdn = False + + enabled_auth_types = auth_handler.get_enabled_auth_types() + if LoginType.EMAIL_IDENTITY not in enabled_auth_types: + show_email = False + if require_email: + raise ConfigError( + "Configuration requires email address at registration, but email " + "validation is not configured" + ) + + if LoginType.MSISDN not in enabled_auth_types: + show_msisdn = False + if require_msisdn: + raise ConfigError( + "Configuration requires msisdn at registration, but msisdn " + "validation is not configured" + ) + + flows = [] + + # only support 3PIDless registration if no 3PIDs are required + if not require_email and not require_msisdn: + # Add a dummy step here, otherwise if a client completes + # recaptcha first we'll assume they were going for this flow + # and complete the request, when they could have been trying to + # complete one of the flows with email/msisdn auth. + flows.append([LoginType.DUMMY]) + + # only support the email-only flow if we don't require MSISDN 3PIDs + if show_email and not require_msisdn: + flows.append([LoginType.EMAIL_IDENTITY]) + + # only support the MSISDN-only flow if we don't require email 3PIDs + if show_msisdn and not require_email: + flows.append([LoginType.MSISDN]) + + if show_email and show_msisdn: + # always let users provide both MSISDN & email + flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY]) + + # Prepend m.login.terms to all flows if we're requiring consent + if config.user_consent_at_registration: + for flow in flows: + flow.insert(0, LoginType.TERMS) + + # Prepend recaptcha to all flows if we're requiring captcha + if config.enable_registration_captcha: + for flow in flows: + flow.insert(0, LoginType.RECAPTCHA) + + return flows + + def register_servlets(hs, http_server): EmailRegisterRequestTokenRestServlet(hs).register(http_server) MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 040b37c504..89002ffbff 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -21,8 +21,6 @@ any time to reflect changes in the MSC. import logging -from twisted.internet import defer - from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.http.servlet import ( @@ -86,11 +84,10 @@ class RelationSendServlet(RestServlet): request, self.on_PUT_or_POST, request, *args, **kwargs ) - @defer.inlineCallbacks - def on_PUT_or_POST( + async def on_PUT_or_POST( self, request, room_id, parent_id, relation_type, event_type, txn_id=None ): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = await self.auth.get_user_by_req(request, allow_guest=True) if event_type == EventTypes.Member: # Add relations to a membership is meaningless, so we just deny it @@ -114,7 +111,7 @@ class RelationSendServlet(RestServlet): "sender": requester.user.to_string(), } - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict=event_dict, txn_id=txn_id ) @@ -140,17 +137,18 @@ class RelationPaginationServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() - @defer.inlineCallbacks - def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET( + self, request, room_id, parent_id, relation_type=None, event_type=None + ): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - yield self.auth.check_in_room_or_world_readable( - room_id, requester.user.to_string() + await self.auth.check_user_in_room_or_world_readable( + room_id, requester.user.to_string(), allow_departed_users=True ) # This gets the original event and checks that a) the event exists and # b) the user is allowed to view it. - event = yield self.event_handler.get_event(requester.user, room_id, parent_id) + event = await self.event_handler.get_event(requester.user, room_id, parent_id) limit = parse_integer(request, "limit", default=5) from_token = parse_string(request, "from") @@ -167,7 +165,7 @@ class RelationPaginationServlet(RestServlet): if to_token: to_token = RelationPaginationToken.from_string(to_token) - pagination_chunk = yield self.store.get_relations_for_event( + pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, relation_type=relation_type, event_type=event_type, @@ -176,7 +174,7 @@ class RelationPaginationServlet(RestServlet): to_token=to_token, ) - events = yield self.store.get_events_as_list( + events = await self.store.get_events_as_list( [c["event_id"] for c in pagination_chunk.chunk] ) @@ -184,13 +182,13 @@ class RelationPaginationServlet(RestServlet): # We set bundle_aggregations to False when retrieving the original # event because we want the content before relations were applied to # it. - original_event = yield self._event_serializer.serialize_event( + original_event = await self._event_serializer.serialize_event( event, now, bundle_aggregations=False ) # Similarly, we don't allow relations to be applied to relations, so we # return the original relations without any aggregations on top of them # here. - events = yield self._event_serializer.serialize_events( + events = await self._event_serializer.serialize_events( events, now, bundle_aggregations=False ) @@ -232,17 +230,18 @@ class RelationAggregationPaginationServlet(RestServlet): self.store = hs.get_datastore() self.event_handler = hs.get_event_handler() - @defer.inlineCallbacks - def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET( + self, request, room_id, parent_id, relation_type=None, event_type=None + ): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - yield self.auth.check_in_room_or_world_readable( - room_id, requester.user.to_string() + await self.auth.check_user_in_room_or_world_readable( + room_id, requester.user.to_string(), allow_departed_users=True, ) # This checks that a) the event exists and b) the user is allowed to # view it. - event = yield self.event_handler.get_event(requester.user, room_id, parent_id) + event = await self.event_handler.get_event(requester.user, room_id, parent_id) if relation_type not in (RelationTypes.ANNOTATION, None): raise SynapseError(400, "Relation type must be 'annotation'") @@ -262,7 +261,7 @@ class RelationAggregationPaginationServlet(RestServlet): if to_token: to_token = AggregationPaginationToken.from_string(to_token) - pagination_chunk = yield self.store.get_aggregation_groups_for_event( + pagination_chunk = await self.store.get_aggregation_groups_for_event( event_id=parent_id, event_type=event_type, limit=limit, @@ -311,17 +310,16 @@ class RelationAggregationGroupPaginationServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() - @defer.inlineCallbacks - def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - yield self.auth.check_in_room_or_world_readable( - room_id, requester.user.to_string() + await self.auth.check_user_in_room_or_world_readable( + room_id, requester.user.to_string(), allow_departed_users=True, ) # This checks that a) the event exists and b) the user is allowed to # view it. - yield self.event_handler.get_event(requester.user, room_id, parent_id) + await self.event_handler.get_event(requester.user, room_id, parent_id) if relation_type != RelationTypes.ANNOTATION: raise SynapseError(400, "Relation type must be 'annotation'") @@ -336,7 +334,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): if to_token: to_token = RelationPaginationToken.from_string(to_token) - result = yield self.store.get_relations_for_event( + result = await self.store.get_relations_for_event( event_id=parent_id, relation_type=relation_type, event_type=event_type, @@ -346,12 +344,12 @@ class RelationAggregationGroupPaginationServlet(RestServlet): to_token=to_token, ) - events = yield self.store.get_events_as_list( + events = await self.store.get_events_as_list( [c["event_id"] for c in result.chunk] ) now = self.clock.time_msec() - events = yield self._event_serializer.serialize_events(events, now) + events = await self._event_serializer.serialize_events(events, now) return_value = result.to_dict() return_value["chunk"] = events diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index e7449864cd..f067b5edac 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -18,8 +18,6 @@ import logging from six import string_types from six.moves import http_client -from twisted.internet import defer - from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( RestServlet, @@ -42,9 +40,8 @@ class ReportEventRestServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -63,7 +60,7 @@ class ReportEventRestServlet(RestServlet): Codes.BAD_JSON, ) - yield self.store.add_event_report( + await self.store.add_event_report( room_id=room_id, event_id=event_id, user_id=user_id, diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index df4f44cd36..59529707df 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, @@ -43,8 +41,7 @@ class RoomKeysServlet(RestServlet): self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - @defer.inlineCallbacks - def on_PUT(self, request, room_id, session_id): + async def on_PUT(self, request, room_id, session_id): """ Uploads one or more encrypted E2E room keys for backup purposes. room_id: the ID of the room the keys are for (optional) @@ -123,7 +120,7 @@ class RoomKeysServlet(RestServlet): } } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() body = parse_json_object_from_request(request) version = parse_string(request, "version") @@ -134,11 +131,10 @@ class RoomKeysServlet(RestServlet): if room_id: body = {"rooms": {room_id: body}} - yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) - return 200, {} + ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) + return 200, ret - @defer.inlineCallbacks - def on_GET(self, request, room_id, session_id): + async def on_GET(self, request, room_id, session_id): """ Retrieves one or more encrypted E2E room keys for backup purposes. Symmetric with the PUT version of the API. @@ -190,11 +186,11 @@ class RoomKeysServlet(RestServlet): } } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - version = parse_string(request, "version") + version = parse_string(request, "version", required=True) - room_keys = yield self.e2e_room_keys_handler.get_room_keys( + room_keys = await self.e2e_room_keys_handler.get_room_keys( user_id, version, room_id, session_id ) @@ -220,8 +216,7 @@ class RoomKeysServlet(RestServlet): return 200, room_keys - @defer.inlineCallbacks - def on_DELETE(self, request, room_id, session_id): + async def on_DELETE(self, request, room_id, session_id): """ Deletes one or more encrypted E2E room keys for a user for backup purposes. @@ -235,14 +230,14 @@ class RoomKeysServlet(RestServlet): the version must already have been created via the /change_secret API. """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() version = parse_string(request, "version") - yield self.e2e_room_keys_handler.delete_room_keys( + ret = await self.e2e_room_keys_handler.delete_room_keys( user_id, version, room_id, session_id ) - return 200, {} + return 200, ret class RoomKeysNewVersionServlet(RestServlet): @@ -257,8 +252,7 @@ class RoomKeysNewVersionServlet(RestServlet): self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): """ Create a new backup version for this user's room_keys with the given info. The version is allocated by the server and returned to the user @@ -288,11 +282,11 @@ class RoomKeysNewVersionServlet(RestServlet): "version": 12345 } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() info = parse_json_object_from_request(request) - new_version = yield self.e2e_room_keys_handler.create_version(user_id, info) + new_version = await self.e2e_room_keys_handler.create_version(user_id, info) return 200, {"version": new_version} # we deliberately don't have a PUT /version, as these things really should @@ -311,8 +305,7 @@ class RoomKeysVersionServlet(RestServlet): self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - @defer.inlineCallbacks - def on_GET(self, request, version): + async def on_GET(self, request, version): """ Retrieve the version information about a given version of the user's room_keys backup. If the version part is missing, returns info about the @@ -330,18 +323,17 @@ class RoomKeysVersionServlet(RestServlet): "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() try: - info = yield self.e2e_room_keys_handler.get_version_info(user_id, version) + info = await self.e2e_room_keys_handler.get_version_info(user_id, version) except SynapseError as e: if e.code == 404: raise SynapseError(404, "No backup found", Codes.NOT_FOUND) return 200, info - @defer.inlineCallbacks - def on_DELETE(self, request, version): + async def on_DELETE(self, request, version): """ Delete the information about a given version of the user's room_keys backup. If the version part is missing, deletes the most @@ -354,14 +346,13 @@ class RoomKeysVersionServlet(RestServlet): if version is None: raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND) - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - yield self.e2e_room_keys_handler.delete_version(user_id, version) + await self.e2e_room_keys_handler.delete_version(user_id, version) return 200, {} - @defer.inlineCallbacks - def on_PUT(self, request, version): + async def on_PUT(self, request, version): """ Update the information about a given version of the user's room_keys backup. @@ -375,14 +366,14 @@ class RoomKeysVersionServlet(RestServlet): "ed25519:something": "hijklmnop" } }, - "version": "42" + "version": "12345" } HTTP/1.1 200 OK Content-Type: application/json {} """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() info = parse_json_object_from_request(request) @@ -391,7 +382,7 @@ class RoomKeysVersionServlet(RestServlet): 400, "No version specified to update", Codes.MISSING_PARAM ) - yield self.e2e_room_keys_handler.update_version(user_id, version, info) + await self.e2e_room_keys_handler.update_version(user_id, version, info) return 200, {} diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index d2c3316eb7..f357015a70 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import ( @@ -59,22 +57,22 @@ class RoomUpgradeRestServlet(RestServlet): self._room_creation_handler = hs.get_room_creation_handler() self._auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request, room_id): - requester = yield self._auth.get_user_by_req(request) + async def on_POST(self, request, room_id): + requester = await self._auth.get_user_by_req(request) content = parse_json_object_from_request(request) assert_params_in_dict(content, ("new_version",)) new_version = content["new_version"] - if new_version not in KNOWN_ROOM_VERSIONS: + new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"]) + if new_version is None: raise SynapseError( 400, "Your homeserver does not support this room version", Codes.UNSUPPORTED_ROOM_VERSION, ) - new_room_id = yield self._room_creation_handler.upgrade_room( + new_room_id = await self._room_creation_handler.upgrade_room( requester, room_id, new_version ) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index d90e52ed1a..db829f3098 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -14,8 +14,7 @@ # limitations under the License. import logging - -from twisted.internet import defer +from typing import Tuple from synapse.http import servlet from synapse.http.servlet import parse_json_object_from_request @@ -51,19 +50,18 @@ class SendToDeviceRestServlet(servlet.RestServlet): request, self._put, request, message_type, txn_id ) - @defer.inlineCallbacks - def _put(self, request, message_type, txn_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def _put(self, request, message_type, txn_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) sender_user_id = requester.user.to_string() - yield self.device_message_handler.send_device_message( + await self.device_message_handler.send_device_message( sender_user_id, message_type, content["messages"] ) - response = (200, {}) + response = (200, {}) # type: Tuple[int, dict] return response diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index c98c5a3802..8fa68dd37f 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -18,10 +18,8 @@ import logging from canonicaljson import json -from twisted.internet import defer - from synapse.api.constants import PresenceState -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.events.utils import ( format_event_for_client_v2_without_room_id, @@ -74,7 +72,7 @@ class SyncRestServlet(RestServlet): """ PATTERNS = client_patterns("/sync$") - ALLOWED_PRESENCE = set(["online", "offline", "unavailable"]) + ALLOWED_PRESENCE = {"online", "offline", "unavailable"} def __init__(self, hs): super(SyncRestServlet, self).__init__() @@ -87,8 +85,7 @@ class SyncRestServlet(RestServlet): self._server_notices_sender = hs.get_server_notices_sender() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): if b"from" in request.args: # /events used to use 'from', but /sync uses 'since'. # Lets be helpful and whine if we see a 'from'. @@ -96,7 +93,7 @@ class SyncRestServlet(RestServlet): 400, "'from' is not a valid query parameter. Did you mean 'since'?" ) - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = await self.auth.get_user_by_req(request, allow_guest=True) user = requester.user device_id = requester.device_id @@ -112,32 +109,44 @@ class SyncRestServlet(RestServlet): full_state = parse_boolean(request, "full_state", default=False) logger.debug( - "/sync: user=%r, timeout=%r, since=%r," - " set_presence=%r, filter_id=%r, device_id=%r" - % (user, timeout, since, set_presence, filter_id, device_id) + "/sync: user=%r, timeout=%r, since=%r, " + "set_presence=%r, filter_id=%r, device_id=%r", + user, + timeout, + since, + set_presence, + filter_id, + device_id, ) request_key = (user, timeout, since, filter_id, full_state, device_id) - if filter_id: - if filter_id.startswith("{"): - try: - filter_object = json.loads(filter_id) - set_timeline_upper_limit( - filter_object, self.hs.config.filter_timeline_limit - ) - except Exception: - raise SynapseError(400, "Invalid filter JSON") - self.filtering.check_valid_filter(filter_object) - filter = FilterCollection(filter_object) - else: - filter = yield self.filtering.get_user_filter(user.localpart, filter_id) + if filter_id is None: + filter_collection = DEFAULT_FILTER_COLLECTION + elif filter_id.startswith("{"): + try: + filter_object = json.loads(filter_id) + set_timeline_upper_limit( + filter_object, self.hs.config.filter_timeline_limit + ) + except Exception: + raise SynapseError(400, "Invalid filter JSON") + self.filtering.check_valid_filter(filter_object) + filter_collection = FilterCollection(filter_object) else: - filter = DEFAULT_FILTER_COLLECTION + try: + filter_collection = await self.filtering.get_user_filter( + user.localpart, filter_id + ) + except StoreError as err: + if err.code != 404: + raise + # fix up the description and errcode to be more useful + raise SynapseError(400, "No such filter", errcode=Codes.INVALID_PARAM) sync_config = SyncConfig( user=user, - filter_collection=filter, + filter_collection=filter_collection, is_guest=requester.is_guest, request_key=request_key, device_id=device_id, @@ -149,20 +158,20 @@ class SyncRestServlet(RestServlet): since_token = None # send any outstanding server notices to the user. - yield self._server_notices_sender.on_user_syncing(user.to_string()) + await self._server_notices_sender.on_user_syncing(user.to_string()) affect_presence = set_presence != PresenceState.OFFLINE if affect_presence: - yield self.presence_handler.set_state( + await self.presence_handler.set_state( user, {"presence": set_presence}, True ) - context = yield self.presence_handler.user_syncing( + context = await self.presence_handler.user_syncing( user.to_string(), affect_presence=affect_presence ) with context: - sync_result = yield self.sync_handler.wait_for_sync_for_user( + sync_result = await self.sync_handler.wait_for_sync_for_user( sync_config, since_token=since_token, timeout=timeout, @@ -170,14 +179,13 @@ class SyncRestServlet(RestServlet): ) time_now = self.clock.time_msec() - response_content = yield self.encode_response( - time_now, sync_result, requester.access_token_id, filter + response_content = await self.encode_response( + time_now, sync_result, requester.access_token_id, filter_collection ) return 200, response_content - @defer.inlineCallbacks - def encode_response(self, time_now, sync_result, access_token_id, filter): + async def encode_response(self, time_now, sync_result, access_token_id, filter): if filter.event_format == "client": event_formatter = format_event_for_client_v2_without_room_id elif filter.event_format == "federation": @@ -185,7 +193,7 @@ class SyncRestServlet(RestServlet): else: raise Exception("Unknown event format %s" % (filter.event_format,)) - joined = yield self.encode_joined( + joined = await self.encode_joined( sync_result.joined, time_now, access_token_id, @@ -193,11 +201,11 @@ class SyncRestServlet(RestServlet): event_formatter, ) - invited = yield self.encode_invited( + invited = await self.encode_invited( sync_result.invited, time_now, access_token_id, event_formatter ) - archived = yield self.encode_archived( + archived = await self.encode_archived( sync_result.archived, time_now, access_token_id, @@ -238,8 +246,9 @@ class SyncRestServlet(RestServlet): ] } - @defer.inlineCallbacks - def encode_joined(self, rooms, time_now, token_id, event_fields, event_formatter): + async def encode_joined( + self, rooms, time_now, token_id, event_fields, event_formatter + ): """ Encode the joined rooms in a sync result @@ -260,7 +269,7 @@ class SyncRestServlet(RestServlet): """ joined = {} for room in rooms: - joined[room.room_id] = yield self.encode_room( + joined[room.room_id] = await self.encode_room( room, time_now, token_id, @@ -271,8 +280,7 @@ class SyncRestServlet(RestServlet): return joined - @defer.inlineCallbacks - def encode_invited(self, rooms, time_now, token_id, event_formatter): + async def encode_invited(self, rooms, time_now, token_id, event_formatter): """ Encode the invited rooms in a sync result @@ -292,7 +300,7 @@ class SyncRestServlet(RestServlet): """ invited = {} for room in rooms: - invite = yield self._event_serializer.serialize_event( + invite = await self._event_serializer.serialize_event( room.invite, time_now, token_id=token_id, @@ -307,8 +315,9 @@ class SyncRestServlet(RestServlet): return invited - @defer.inlineCallbacks - def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatter): + async def encode_archived( + self, rooms, time_now, token_id, event_fields, event_formatter + ): """ Encode the archived rooms in a sync result @@ -329,7 +338,7 @@ class SyncRestServlet(RestServlet): """ joined = {} for room in rooms: - joined[room.room_id] = yield self.encode_room( + joined[room.room_id] = await self.encode_room( room, time_now, token_id, @@ -340,8 +349,7 @@ class SyncRestServlet(RestServlet): return joined - @defer.inlineCallbacks - def encode_room( + async def encode_room( self, room, time_now, token_id, joined, only_fields, event_formatter ): """ @@ -382,15 +390,15 @@ class SyncRestServlet(RestServlet): # We've had bug reports that events were coming down under the # wrong room. if event.room_id != room.room_id: - logger.warn( + logger.warning( "Event %r is under room %r instead of %r", event.event_id, room.room_id, event.room_id, ) - serialized_state = yield serialize(state_events) - serialized_timeline = yield serialize(timeline_events) + serialized_state = await serialize(state_events) + serialized_timeline = await serialize(timeline_events) account_data = room.account_data diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index 3b555669a0..a3f12e8a77 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -37,13 +35,12 @@ class TagListServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_GET(self, request, user_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id, room_id): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get tags for other users.") - tags = yield self.store.get_tags_for_room(user_id, room_id) + tags = await self.store.get_tags_for_room(user_id, room_id) return 200, {"tags": tags} @@ -64,27 +61,25 @@ class TagServlet(RestServlet): self.store = hs.get_datastore() self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def on_PUT(self, request, user_id, room_id, tag): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id, room_id, tag): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") body = parse_json_object_from_request(request) - max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) + max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body) self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, user_id, room_id, tag): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, user_id, room_id, tag): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") - max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) + max_id = await self.store.remove_tag_from_room(user_id, room_id, tag) self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 2e8d672471..23709960ad 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import ThirdPartyEntityKind from synapse.http.servlet import RestServlet @@ -35,11 +33,10 @@ class ThirdPartyProtocolsServlet(RestServlet): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) - protocols = yield self.appservice_handler.get_3pe_protocols() + protocols = await self.appservice_handler.get_3pe_protocols() return 200, protocols @@ -52,11 +49,10 @@ class ThirdPartyProtocolServlet(RestServlet): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) - protocols = yield self.appservice_handler.get_3pe_protocols( + protocols = await self.appservice_handler.get_3pe_protocols( only_protocol=protocol ) if protocol in protocols: @@ -74,14 +70,13 @@ class ThirdPartyUserServlet(RestServlet): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) fields = request.args fields.pop(b"access_token", None) - results = yield self.appservice_handler.query_3pe( + results = await self.appservice_handler.query_3pe( ThirdPartyEntityKind.USER, protocol, fields ) @@ -97,14 +92,13 @@ class ThirdPartyLocationServlet(RestServlet): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) fields = request.args fields.pop(b"access_token", None) - results = yield self.appservice_handler.query_3pe( + results = await self.appservice_handler.query_3pe( ThirdPartyEntityKind.LOCATION, protocol, fields ) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 2da0f55811..83f3b6b70a 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet @@ -32,8 +30,7 @@ class TokenRefreshRestServlet(RestServlet): def __init__(self, hs): super(TokenRefreshRestServlet, self).__init__() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): raise AuthError(403, "tokenrefresh is no longer supported.") diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index 2863affbab..bef91a2d3e 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -38,8 +36,7 @@ class UserDirectorySearchRestServlet(RestServlet): self.auth = hs.get_auth() self.user_directory_handler = hs.get_user_directory_handler() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): """Searches for users in directory Returns: @@ -56,7 +53,7 @@ class UserDirectorySearchRestServlet(RestServlet): ] } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() if not self.hs.config.user_directory_search_enabled: @@ -72,7 +69,7 @@ class UserDirectorySearchRestServlet(RestServlet): except Exception: raise SynapseError(400, "`search_term` is required field") - results = yield self.user_directory_handler.search_users( + results = await self.user_directory_handler.search_users( user_id, search_term, limit ) |