From 8822b331114a2f6fdcd5916f0c91991c0acae07e Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 5 Nov 2019 10:56:39 +0000 Subject: Update copyrights --- synapse/rest/client/versions.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index bb30ce3f34..2a477ad22e 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. -- cgit 1.4.1 From 541f1b92d946093fef17ea8b95a7cb595fc5ffc4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 Nov 2019 17:39:16 +0000 Subject: Only do `rc_login` ratelimiting on succesful login. We were doing this in a number of places which meant that some login code paths incremented the counter multiple times. It was also applying ratelimiting to UIA endpoints, which was probably not intentional. In particular, some custom auth modules were calling `check_user_exists`, which incremented the counters, meaning that people would fail to login sometimes. --- synapse/handlers/auth.py | 55 +------------------- synapse/rest/client/v1/login.py | 111 +++++++++++++++++++++++++++++++++------- 2 files changed, 94 insertions(+), 72 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7a0f54ca24..14c6387b6a 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -35,7 +35,6 @@ from synapse.api.errors import ( SynapseError, UserDeactivatedError, ) -from synapse.api.ratelimiting import Ratelimiter from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.logging.context import defer_to_thread @@ -102,9 +101,6 @@ class AuthHandler(BaseHandler): login_types.append(t) self._supported_login_types = login_types - self._account_ratelimiter = Ratelimiter() - self._failed_attempts_ratelimiter = Ratelimiter() - self._clock = self.hs.get_clock() @defer.inlineCallbacks @@ -501,11 +497,8 @@ class AuthHandler(BaseHandler): multiple matches Raises: - LimitExceededError if the ratelimiter's login requests count for this - user is too high too proceed. UserDeactivatedError if a user is found but is deactivated. """ - self.ratelimit_login_per_account(user_id) res = yield self._find_user_id_and_pwd_hash(user_id) if res is not None: return res[0] @@ -572,8 +565,6 @@ class AuthHandler(BaseHandler): StoreError if there was a problem accessing the database SynapseError if there was a problem with the request LoginError if there was an authentication problem. - LimitExceededError if the ratelimiter's login requests count for this - user is too high too proceed. """ if username.startswith("@"): @@ -581,8 +572,6 @@ class AuthHandler(BaseHandler): else: qualified_user_id = UserID(username, self.hs.hostname).to_string() - self.ratelimit_login_per_account(qualified_user_id) - login_type = login_submission.get("type") known_login_type = False @@ -650,15 +639,6 @@ class AuthHandler(BaseHandler): if not known_login_type: raise SynapseError(400, "Unknown login type %s" % login_type) - # unknown username or invalid password. - self._failed_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=True, - ) - # We raise a 403 here, but note that if we're doing user-interactive # login, it turns all LoginErrors into a 401 anyway. raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN) @@ -710,10 +690,6 @@ class AuthHandler(BaseHandler): Returns: Deferred[unicode] the canonical_user_id, or Deferred[None] if unknown user/bad password - - Raises: - LimitExceededError if the ratelimiter's login requests count for this - user is too high too proceed. """ lookupres = yield self._find_user_id_and_pwd_hash(user_id) if not lookupres: @@ -742,7 +718,7 @@ class AuthHandler(BaseHandler): auth_api.validate_macaroon(macaroon, "login", user_id) except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) - self.ratelimit_login_per_account(user_id) + yield self.auth.check_auth_blocking(user_id) return user_id @@ -912,35 +888,6 @@ class AuthHandler(BaseHandler): else: return defer.succeed(False) - def ratelimit_login_per_account(self, user_id): - """Checks whether the process must be stopped because of ratelimiting. - - Checks against two ratelimiters: the generic one for login attempts per - account and the one specific to failed attempts. - - Args: - user_id (unicode): complete @user:id - - Raises: - LimitExceededError if one of the ratelimiters' login requests count - for this user is too high too proceed. - """ - self._failed_attempts_ratelimiter.ratelimit( - user_id.lower(), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=False, - ) - - self._account_ratelimiter.ratelimit( - user_id.lower(), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_account.per_second, - burst_count=self.hs.config.rc_login_account.burst_count, - update=True, - ) - @attr.s class MacaroonGenerator(object): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 24a0ce74f2..abc210da57 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -92,8 +92,11 @@ class LoginRestServlet(RestServlet): self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() + self._clock = hs.get_clock() self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter() + self._account_ratelimiter = Ratelimiter() + self._failed_attempts_ratelimiter = Ratelimiter() def on_GET(self, request): flows = [] @@ -202,6 +205,16 @@ class LoginRestServlet(RestServlet): # (See add_threepid in synapse/handlers/auth.py) address = address.lower() + # We also apply account rate limiting using the 3PID as a key, as + # otherwise using 3PID bypasses the ratelimiting based on user ID. + self._failed_attempts_ratelimiter.ratelimit( + (medium, address), + time_now_s=self._clock.time(), + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + update=False, + ) + # Check for login providers that support 3pid login types ( canonical_user_id, @@ -211,7 +224,8 @@ class LoginRestServlet(RestServlet): ) if canonical_user_id: # Authentication through password provider and 3pid succeeded - result = yield self._register_device_with_callback( + + result = yield self._complete_login( canonical_user_id, login_submission, callback_3pid ) return result @@ -225,6 +239,21 @@ class LoginRestServlet(RestServlet): logger.warning( "unknown 3pid identifier medium %s, address %r", medium, address ) + # We mark that we've failed to log in here, as + # `check_password_provider_3pid` might have returned `None` due + # to an incorrect password, rather than the account not + # existing. + # + # If it returned None but the 3PID was bound then we won't hit + # this code path, which is fine as then the per-user ratelimit + # will kick in below. + self._failed_attempts_ratelimiter.can_do_action( + (medium, address), + time_now_s=self._clock.time(), + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + update=True, + ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) identifier = {"type": "m.id.user", "user": user_id} @@ -236,29 +265,84 @@ class LoginRestServlet(RestServlet): if "user" not in identifier: raise SynapseError(400, "User identifier is missing 'user' key") - canonical_user_id, callback = yield self.auth_handler.validate_login( - identifier["user"], login_submission + if identifier["user"].startswith("@"): + qualified_user_id = identifier["user"] + else: + qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string() + + # Check if we've hit the failed ratelimit (but don't update it) + self._failed_attempts_ratelimiter.ratelimit( + qualified_user_id.lower(), + time_now_s=self._clock.time(), + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + update=False, ) - result = yield self._register_device_with_callback( + try: + canonical_user_id, callback = yield self.auth_handler.validate_login( + identifier["user"], login_submission + ) + except LoginError: + # The user has failed to log in, so we need to update the rate + # limiter. Using `can_do_action` avoids us raising a ratelimit + # exception and masking the LoginError. The actual ratelimiting + # should have happened above. + self._failed_attempts_ratelimiter.can_do_action( + qualified_user_id.lower(), + time_now_s=self._clock.time(), + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + update=True, + ) + raise + + result = yield self._complete_login( canonical_user_id, login_submission, callback ) return result @defer.inlineCallbacks - def _register_device_with_callback(self, user_id, login_submission, callback=None): - """ Registers a device with a given user_id. Optionally run a callback - function after registration has completed. + def _complete_login( + self, user_id, login_submission, callback=None, create_non_existant_users=False + ): + """Called when we've successfully authed the user and now need to + actually login them in (e.g. create devices). This gets called on + all succesful logins. + + Applies the ratelimiting for succesful login attempts against an + account. Args: user_id (str): ID of the user to register. login_submission (dict): Dictionary of login information. callback (func|None): Callback function to run after registration. + create_non_existant_users (bool): Whether to create the user if + they don't exist. Defaults to False. Returns: result (Dict[str,str]): Dictionary of account information after successful registration. """ + + # Before we actually log them in we check if they've already logged in + # too often. This happens here rather than before as we don't + # necessarily know the user before now. + self._account_ratelimiter.ratelimit( + user_id.lower(), + time_now_s=self._clock.time(), + rate_hz=self.hs.config.rc_login_account.per_second, + burst_count=self.hs.config.rc_login_account.burst_count, + update=True, + ) + + if create_non_existant_users: + user_id = yield self.auth_handler.check_user_exists(user_id) + if not user_id: + user_id = yield self.registration_handler.register_user( + localpart=UserID.from_string(user_id).localpart + ) + device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( @@ -285,7 +369,7 @@ class LoginRestServlet(RestServlet): token ) - result = yield self._register_device_with_callback(user_id, login_submission) + result = yield self._complete_login(user_id, login_submission) return result @defer.inlineCallbacks @@ -313,16 +397,7 @@ class LoginRestServlet(RestServlet): raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user_id = UserID(user, self.hs.hostname).to_string() - - registered_user_id = yield self.auth_handler.check_user_exists(user_id) - if not registered_user_id: - registered_user_id = yield self.registration_handler.register_user( - localpart=user - ) - - result = yield self._register_device_with_callback( - registered_user_id, login_submission - ) + result = yield self._complete_login(user_id, login_submission) return result -- cgit 1.4.1 From c7376cdfe3efe05942964efcdf8886d66342383c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 18 Nov 2019 17:10:16 +0000 Subject: Apply suggestions from code review Co-Authored-By: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Co-Authored-By: Brendan Abolivier --- synapse/handlers/auth.py | 4 ++-- synapse/rest/client/v1/login.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 20c62bd780..0955cf9dba 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -135,8 +135,8 @@ class AuthHandler(BaseHandler): AuthError if the client has completed a login flow, and it gives a different user to `requester` - LimitExceededError if the ratelimiter's failed requests count for this - user is too high too proceed + LimitExceededError if the ratelimiter's failed request count for this + user is too high to proceed """ diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index abc210da57..f8d58afb29 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -397,7 +397,7 @@ class LoginRestServlet(RestServlet): raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user_id = UserID(user, self.hs.hostname).to_string() - result = yield self._complete_login(user_id, login_submission) + result = yield self._complete_login(user_id, login_submission, create_non_existant_users=True) return result -- cgit 1.4.1 From 271c322d08a3d13c986d97cbc40e72eef50e92ba Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 20 Nov 2019 09:29:48 +0000 Subject: Lint --- synapse/rest/client/v1/login.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index f8d58afb29..19eb15003d 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -397,7 +397,9 @@ class LoginRestServlet(RestServlet): raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user_id = UserID(user, self.hs.hostname).to_string() - result = yield self._complete_login(user_id, login_submission, create_non_existant_users=True) + result = yield self._complete_login( + user_id, login_submission, create_non_existant_users=True + ) return result -- cgit 1.4.1 From 0d27aba900136514a8801b902f9a8ac69150e2c0 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 27 Nov 2019 16:14:44 -0500 Subject: add etag and count to key backup endpoints (#5858) --- changelog.d/5858.feature | 1 + synapse/handlers/e2e_room_keys.py | 130 +++++++----- synapse/rest/client/v2_alpha/room_keys.py | 8 +- synapse/storage/data_stores/main/e2e_room_keys.py | 226 +++++++++++++++------ .../main/schema/delta/56/room_key_etag.sql | 17 ++ tests/handlers/test_e2e_room_keys.py | 31 +++ tests/storage/test_e2e_room_keys.py | 8 +- 7 files changed, 297 insertions(+), 124 deletions(-) create mode 100644 changelog.d/5858.feature create mode 100644 synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql (limited to 'synapse/rest/client') diff --git a/changelog.d/5858.feature b/changelog.d/5858.feature new file mode 100644 index 0000000000..55ee93051e --- /dev/null +++ b/changelog.d/5858.feature @@ -0,0 +1 @@ +Add etag and count fields to key backup endpoints to help clients guess if there are new keys. diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 0cea445f0d..f1b4424a02 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017, 2018 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -103,14 +104,35 @@ class E2eRoomKeysHandler(object): rooms session_id(string): session ID to delete keys for, for None to delete keys for all sessions + Raises: + NotFoundError: if the backup version does not exist Returns: - A deferred of the deletion transaction + A dict containing the count and etag for the backup version """ # lock for consistency with uploading with (yield self._upload_linearizer.queue(user_id)): + # make sure the backup version exists + try: + version_info = yield self.store.get_e2e_room_keys_version_info( + user_id, version + ) + except StoreError as e: + if e.code == 404: + raise NotFoundError("Unknown backup version") + else: + raise + yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) + version_etag = version_info["etag"] + 1 + yield self.store.update_e2e_room_keys_version( + user_id, version, None, version_etag + ) + + count = yield self.store.count_e2e_room_keys(user_id, version) + return {"etag": str(version_etag), "count": count} + @trace @defer.inlineCallbacks def upload_room_keys(self, user_id, version, room_keys): @@ -138,6 +160,9 @@ class E2eRoomKeysHandler(object): } } + Returns: + A dict containing the count and etag for the backup version + Raises: NotFoundError: if there are no versions defined RoomKeysVersionError: if the uploaded version is not the current version @@ -171,59 +196,62 @@ class E2eRoomKeysHandler(object): else: raise - # go through the room_keys. - # XXX: this should/could be done concurrently, given we're in a lock. + # Fetch any existing room keys for the sessions that have been + # submitted. Then compare them with the submitted keys. If the + # key is new, insert it; if the key should be updated, then update + # it; otherwise, drop it. + existing_keys = yield self.store.get_e2e_room_keys_multi( + user_id, version, room_keys["rooms"] + ) + to_insert = [] # batch the inserts together + changed = False # if anything has changed, we need to update the etag for room_id, room in iteritems(room_keys["rooms"]): - for session_id, session in iteritems(room["sessions"]): - yield self._upload_room_key( - user_id, version, room_id, session_id, session + for session_id, room_key in iteritems(room["sessions"]): + log_kv( + { + "message": "Trying to upload room key", + "room_id": room_id, + "session_id": session_id, + "user_id": user_id, + } ) - - @defer.inlineCallbacks - def _upload_room_key(self, user_id, version, room_id, session_id, room_key): - """Upload a given room_key for a given room and session into a given - version of the backup. Merges the key with any which might already exist. - - Args: - user_id(str): the user whose backup we're setting - version(str): the version ID of the backup we're updating - room_id(str): the ID of the room whose keys we're setting - session_id(str): the session whose room_key we're setting - room_key(dict): the room_key being set - """ - log_kv( - { - "message": "Trying to upload room key", - "room_id": room_id, - "session_id": session_id, - "user_id": user_id, - } - ) - # get the room_key for this particular row - current_room_key = None - try: - current_room_key = yield self.store.get_e2e_room_key( - user_id, version, room_id, session_id - ) - except StoreError as e: - if e.code == 404: - log_kv( - { - "message": "Room key not found.", - "room_id": room_id, - "user_id": user_id, - } + current_room_key = existing_keys.get(room_id, {}).get(session_id) + if current_room_key: + if self._should_replace_room_key(current_room_key, room_key): + log_kv({"message": "Replacing room key."}) + # updates are done one at a time in the DB, so send + # updates right away rather than batching them up, + # like we do with the inserts + yield self.store.update_e2e_room_key( + user_id, version, room_id, session_id, room_key + ) + changed = True + else: + log_kv({"message": "Not replacing room_key."}) + else: + log_kv( + { + "message": "Room key not found.", + "room_id": room_id, + "user_id": user_id, + } + ) + log_kv({"message": "Replacing room key."}) + to_insert.append((room_id, session_id, room_key)) + changed = True + + if len(to_insert): + yield self.store.add_e2e_room_keys(user_id, version, to_insert) + + version_etag = version_info["etag"] + if changed: + version_etag = version_etag + 1 + yield self.store.update_e2e_room_keys_version( + user_id, version, None, version_etag ) - else: - raise - if self._should_replace_room_key(current_room_key, room_key): - log_kv({"message": "Replacing room key."}) - yield self.store.set_e2e_room_key( - user_id, version, room_id, session_id, room_key - ) - else: - log_kv({"message": "Not replacing room_key."}) + count = yield self.store.count_e2e_room_keys(user_id, version) + return {"etag": str(version_etag), "count": count} @staticmethod def _should_replace_room_key(current_room_key, room_key): @@ -314,6 +342,8 @@ class E2eRoomKeysHandler(object): raise NotFoundError("Unknown backup version") else: raise + + res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"]) return res @trace diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index d596786430..d83ac8e3c5 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -134,8 +134,8 @@ class RoomKeysServlet(RestServlet): if room_id: body = {"rooms": {room_id: body}} - yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) - return 200, {} + ret = yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) + return 200, ret @defer.inlineCallbacks def on_GET(self, request, room_id, session_id): @@ -239,10 +239,10 @@ class RoomKeysServlet(RestServlet): user_id = requester.user.to_string() version = parse_string(request, "version") - yield self.e2e_room_keys_handler.delete_room_keys( + ret = yield self.e2e_room_keys_handler.delete_room_keys( user_id, version, room_id, session_id ) - return 200, {} + return 200, ret class RoomKeysNewVersionServlet(RestServlet): diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index 1cbbae5b63..113224fd7c 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,49 +25,8 @@ from synapse.storage._base import SQLBaseStore class EndToEndRoomKeyStore(SQLBaseStore): @defer.inlineCallbacks - def get_e2e_room_key(self, user_id, version, room_id, session_id): - """Get the encrypted E2E room key for a given session from a given - backup version of room_keys. We only store the 'best' room key for a given - session at a given time, as determined by the handler. - - Args: - user_id(str): the user whose backup we're querying - version(str): the version ID of the backup for the set of keys we're querying - room_id(str): the ID of the room whose keys we're querying. - This is a bit redundant as it's implied by the session_id, but - we include for consistency with the rest of the API. - session_id(str): the session whose room_key we're querying. - - Returns: - A deferred dict giving the session_data and message metadata for - this room key. - """ - - row = yield self._simple_select_one( - table="e2e_room_keys", - keyvalues={ - "user_id": user_id, - "version": version, - "room_id": room_id, - "session_id": session_id, - }, - retcols=( - "first_message_index", - "forwarded_count", - "is_verified", - "session_data", - ), - desc="get_e2e_room_key", - ) - - row["session_data"] = json.loads(row["session_data"]) - - return row - - @defer.inlineCallbacks - def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): - """Replaces or inserts the encrypted E2E room key for a given session in - a given backup + def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + """Replaces the encrypted E2E room key for a given session in a given backup Args: user_id(str): the user whose backup we're setting @@ -78,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): StoreError """ - yield self._simple_upsert( + yield self._simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, @@ -86,21 +46,51 @@ class EndToEndRoomKeyStore(SQLBaseStore): "room_id": room_id, "session_id": session_id, }, - values={ + updatevalues={ "first_message_index": room_key["first_message_index"], "forwarded_count": room_key["forwarded_count"], "is_verified": room_key["is_verified"], "session_data": json.dumps(room_key["session_data"]), }, - lock=False, + desc="update_e2e_room_key", ) - log_kv( - { - "message": "Set room key", - "room_id": room_id, - "session_id": session_id, - "room_key": room_key, - } + + @defer.inlineCallbacks + def add_e2e_room_keys(self, user_id, version, room_keys): + """Bulk add room keys to a given backup. + + Args: + user_id (str): the user whose backup we're adding to + version (str): the version ID of the backup for the set of keys we're adding to + room_keys (iterable[(str, str, dict)]): the keys to add, in the form + (roomID, sessionID, keyData) + """ + + values = [] + for (room_id, session_id, room_key) in room_keys: + values.append( + { + "user_id": user_id, + "version": version, + "room_id": room_id, + "session_id": session_id, + "first_message_index": room_key["first_message_index"], + "forwarded_count": room_key["forwarded_count"], + "is_verified": room_key["is_verified"], + "session_data": json.dumps(room_key["session_data"]), + } + ) + log_kv( + { + "message": "Set room key", + "room_id": room_id, + "session_id": session_id, + "room_key": room_key, + } + ) + + yield self._simple_insert_many( + table="e2e_room_keys", values=values, desc="add_e2e_room_keys" ) @trace @@ -110,11 +100,11 @@ class EndToEndRoomKeyStore(SQLBaseStore): room, or a given session. Args: - user_id(str): the user whose backup we're querying - version(str): the version ID of the backup for the set of keys we're querying - room_id(str): Optional. the ID of the room whose keys we're querying, if any. + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup for the set of keys we're querying + room_id (str): Optional. the ID of the room whose keys we're querying, if any. If not specified, we return the keys for all the rooms in the backup. - session_id(str): Optional. the session whose room_key we're querying, if any. + session_id (str): Optional. the session whose room_key we're querying, if any. If specified, we also require the room_id to be specified. If not specified, we return all the keys in this version of the backup (or for the specified room) @@ -162,6 +152,95 @@ class EndToEndRoomKeyStore(SQLBaseStore): return sessions + def get_e2e_room_keys_multi(self, user_id, version, room_keys): + """Get multiple room keys at a time. The difference between this function and + get_e2e_room_keys is that this function can be used to retrieve + multiple specific keys at a time, whereas get_e2e_room_keys is used for + getting all the keys in a backup version, all the keys for a room, or a + specific key. + + Args: + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup we're querying about + room_keys (dict[str, dict[str, iterable[str]]]): a map from + room ID -> {"session": [session ids]} indicating the session IDs + that we want to query + + Returns: + Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key + """ + + return self.runInteraction( + "get_e2e_room_keys_multi", + self._get_e2e_room_keys_multi_txn, + user_id, + version, + room_keys, + ) + + @staticmethod + def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys): + if not room_keys: + return {} + + where_clauses = [] + params = [user_id, version] + for room_id, room in room_keys.items(): + sessions = list(room["sessions"]) + if not sessions: + continue + params.append(room_id) + params.extend(sessions) + where_clauses.append( + "(room_id = ? AND session_id IN (%s))" + % (",".join(["?" for _ in sessions]),) + ) + + # check if we're actually querying something + if not where_clauses: + return {} + + sql = """ + SELECT room_id, session_id, first_message_index, forwarded_count, + is_verified, session_data + FROM e2e_room_keys + WHERE user_id = ? AND version = ? AND (%s) + """ % ( + " OR ".join(where_clauses) + ) + + txn.execute(sql, params) + + ret = {} + + for row in txn: + room_id = row[0] + session_id = row[1] + ret.setdefault(room_id, {}) + ret[room_id][session_id] = { + "first_message_index": row[2], + "forwarded_count": row[3], + "is_verified": row[4], + "session_data": json.loads(row[5]), + } + + return ret + + def count_e2e_room_keys(self, user_id, version): + """Get the number of keys in a backup version. + + Args: + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup we're querying about + """ + + return self._simple_select_one_onecol( + table="e2e_room_keys", + keyvalues={"user_id": user_id, "version": version}, + retcol="COUNT(*)", + desc="count_e2e_room_keys", + ) + @trace @defer.inlineCallbacks def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): @@ -219,6 +298,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): version(str) algorithm(str) auth_data(object): opaque dict supplied by the client + etag(int): tag of the keys in the backup """ def _get_e2e_room_keys_version_info_txn(txn): @@ -236,10 +316,12 @@ class EndToEndRoomKeyStore(SQLBaseStore): txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, - retcols=("version", "algorithm", "auth_data"), + retcols=("version", "algorithm", "auth_data", "etag"), ) result["auth_data"] = json.loads(result["auth_data"]) result["version"] = str(result["version"]) + if result["etag"] is None: + result["etag"] = 0 return result return self.runInteraction( @@ -288,21 +370,33 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - def update_e2e_room_keys_version(self, user_id, version, info): + def update_e2e_room_keys_version( + self, user_id, version, info=None, version_etag=None + ): """Update a given backup version Args: user_id(str): the user whose backup version we're updating version(str): the version ID of the backup version we're updating - info(dict): the new backup version info to store + info (dict): the new backup version info to store. If None, then + the backup version info is not updated + version_etag (Optional[int]): etag of the keys in the backup. If + None, then the etag is not updated """ + updatevalues = {} - return self._simple_update( - table="e2e_room_keys_versions", - keyvalues={"user_id": user_id, "version": version}, - updatevalues={"auth_data": json.dumps(info["auth_data"])}, - desc="update_e2e_room_keys_version", - ) + if info is not None and "auth_data" in info: + updatevalues["auth_data"] = json.dumps(info["auth_data"]) + if version_etag is not None: + updatevalues["etag"] = version_etag + + if updatevalues: + return self._simple_update( + table="e2e_room_keys_versions", + keyvalues={"user_id": user_id, "version": version}, + updatevalues=updatevalues, + desc="update_e2e_room_keys_version", + ) @trace def delete_e2e_room_keys_version(self, user_id, version=None): diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql new file mode 100644 index 0000000000..7d70dd071e --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- store the current etag of backup version +ALTER TABLE e2e_room_keys_versions ADD COLUMN etag BIGINT; diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 0bb96674a2..70f172eb02 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # Copyright 2017 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -94,23 +95,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + version_etag = res["etag"] + del res["etag"] self.assertDictEqual( res, { "version": "1", "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", + "count": 0, }, ) # check we can retrieve it as a specific version res = yield self.handler.get_version_info(self.local_user, "1") + self.assertEqual(res["etag"], version_etag) + del res["etag"] self.assertDictEqual( res, { "version": "1", "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", + "count": 0, }, ) @@ -126,12 +133,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] self.assertDictEqual( res, { "version": "2", "algorithm": "m.megolm_backup.v1", "auth_data": "second_version_auth_data", + "count": 0, }, ) @@ -158,12 +167,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] self.assertDictEqual( res, { "algorithm": "m.megolm_backup.v1", "auth_data": "revised_first_version_auth_data", "version": version, + "count": 0, }, ) @@ -207,12 +218,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] # etag is opaque, so don't test its contents self.assertDictEqual( res, { "algorithm": "m.megolm_backup.v1", "auth_data": "revised_first_version_auth_data", "version": version, + "count": 0, }, ) @@ -409,6 +422,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): yield self.handler.upload_room_keys(self.local_user, version, room_keys) + # get the etag to compare to future versions + res = yield self.handler.get_version_info(self.local_user) + backup_etag = res["etag"] + self.assertEqual(res["count"], 1) + new_room_keys = copy.deepcopy(room_keys) new_room_key = new_room_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"] @@ -423,6 +441,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): "SSBBTSBBIEZJU0gK", ) + # the etag should be the same since the session did not change + res = yield self.handler.get_version_info(self.local_user) + self.assertEqual(res["etag"], backup_etag) + # test that marking the session as verified however /does/ replace it new_room_key["is_verified"] = True yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) @@ -432,6 +454,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) + # the etag should NOT be equal now, since the key changed + res = yield self.handler.get_version_info(self.local_user) + self.assertNotEqual(res["etag"], backup_etag) + backup_etag = res["etag"] + # test that a session with a higher forwarded_count doesn't replace one # with a lower forwarding count new_room_key["forwarded_count"] = 2 @@ -443,6 +470,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) + # the etag should be the same since the session did not change + res = yield self.handler.get_version_info(self.local_user) + self.assertEqual(res["etag"], backup_etag) + # TODO: check edge cases as well as the common variations here @defer.inlineCallbacks diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index d128fde441..35dafbb904 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -39,8 +39,8 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.store.set_e2e_room_key( - "user_id", version1, "room", "session", room_key + self.store.add_e2e_room_keys( + "user_id", version1, [("room", "session", room_key)] ) ) @@ -51,8 +51,8 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.store.set_e2e_room_key( - "user_id", version2, "room", "session", room_key + self.store.add_e2e_room_keys( + "user_id", version2, [("room", "session", room_key)] ) ) -- cgit 1.4.1 From 69d8fb83c6da35d7e1f04fa3afba0fd5406bd9d9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Nov 2019 11:02:04 +0000 Subject: MSC2367 Allow reason field on all member events --- synapse/rest/client/v1/room.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 86bbcc0eea..711d4ad304 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -714,7 +714,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): target = UserID.from_string(content["user_id"]) event_content = None - if "reason" in content and membership_action in ["kick", "ban"]: + if "reason" in content: event_content = {"reason": content["reason"]} await self.room_member_handler.update_membership( -- cgit 1.4.1 From 23ea5721259059d50b80083bb7240a8cb56cf297 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 29 Nov 2019 13:51:14 +0000 Subject: Add User-Interactive Auth to /account/3pid/add (#6119) --- changelog.d/6119.feature | 1 + synapse/rest/client/v2_alpha/account.py | 5 +++++ 2 files changed, 6 insertions(+) create mode 100644 changelog.d/6119.feature (limited to 'synapse/rest/client') diff --git a/changelog.d/6119.feature b/changelog.d/6119.feature new file mode 100644 index 0000000000..1492e83c5a --- /dev/null +++ b/changelog.d/6119.feature @@ -0,0 +1 @@ +Require User-Interactive Authentication for `/account/3pid/add`, meaning the user's password will be required to add a third-party ID to their account. \ No newline at end of file diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index f26eae794c..ad674239ab 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -642,6 +642,7 @@ class ThreepidAddRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) @@ -652,6 +653,10 @@ class ThreepidAddRestServlet(RestServlet): client_secret = body["client_secret"] sid = body["sid"] + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request) + ) + validation_session = yield self.identity_handler.validate_threepid_session( client_secret, sid ) -- cgit 1.4.1 From 1a0997bbd5f9134b21b1105c6b14e09480f193a9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 5 Dec 2019 15:53:06 +0000 Subject: Port rest/v1 to async/await --- synapse/app/synchrotron.py | 2 +- synapse/rest/client/v1/directory.py | 53 +++++++++++++----------------- synapse/rest/client/v1/events.py | 18 ++++------ synapse/rest/client/v1/initial_sync.py | 8 ++--- synapse/rest/client/v1/login.py | 60 +++++++++++++++------------------- synapse/rest/client/v1/logout.py | 20 +++++------- synapse/rest/client/v1/presence.py | 18 ++++------ synapse/rest/client/v1/profile.py | 48 ++++++++++++--------------- synapse/rest/client/v1/push_rule.py | 24 ++++++-------- synapse/rest/client/v1/pusher.py | 27 +++++++-------- synapse/rest/client/v1/voip.py | 7 ++-- 11 files changed, 118 insertions(+), 167 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index b14da09f47..288ee64b42 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -151,7 +151,7 @@ class SynchrotronPresence(object): def set_state(self, user, state, ignore_status_msg=False): # TODO Hows this supposed to work? - pass + return defer.succeed(None) get_states = __func__(PresenceHandler.get_states) get_state = __func__(PresenceHandler.get_state) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 4ea3666874..5934b1fe8b 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import ( AuthError, Codes, @@ -47,17 +45,15 @@ class ClientDirectoryServer(RestServlet): self.handlers = hs.get_handlers() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_alias): + async def on_GET(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) dir_handler = self.handlers.directory_handler - res = yield dir_handler.get_association(room_alias) + res = await dir_handler.get_association(room_alias) return 200, res - @defer.inlineCallbacks - def on_PUT(self, request, room_alias): + async def on_PUT(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) content = parse_json_object_from_request(request) @@ -77,26 +73,25 @@ class ClientDirectoryServer(RestServlet): # TODO(erikj): Check types. - room = yield self.store.get_room(room_id) + room = await self.store.get_room(room_id) if room is None: raise SynapseError(400, "Room does not exist") - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) - yield self.handlers.directory_handler.create_association( + await self.handlers.directory_handler.create_association( requester, room_alias, room_id, servers ) return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, room_alias): + async def on_DELETE(self, request, room_alias): dir_handler = self.handlers.directory_handler try: - service = yield self.auth.get_appservice_by_req(request) + service = await self.auth.get_appservice_by_req(request) room_alias = RoomAlias.from_string(room_alias) - yield dir_handler.delete_appservice_association(service, room_alias) + await dir_handler.delete_appservice_association(service, room_alias) logger.info( "Application service at %s deleted alias %s", service.url, @@ -107,12 +102,12 @@ class ClientDirectoryServer(RestServlet): # fallback to default user behaviour if they aren't an AS pass - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user = requester.user room_alias = RoomAlias.from_string(room_alias) - yield dir_handler.delete_association(requester, room_alias) + await dir_handler.delete_association(requester, room_alias) logger.info( "User %s deleted alias %s", user.to_string(), room_alias.to_string() @@ -130,32 +125,29 @@ class ClientDirectoryListServer(RestServlet): self.handlers = hs.get_handlers() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - room = yield self.store.get_room(room_id) + async def on_GET(self, request, room_id): + room = await self.store.get_room(room_id) if room is None: raise NotFoundError("Unknown room") return 200, {"visibility": "public" if room["is_public"] else "private"} - @defer.inlineCallbacks - def on_PUT(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, room_id): + requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) visibility = content.get("visibility", "public") - yield self.handlers.directory_handler.edit_published_room_list( + await self.handlers.directory_handler.edit_published_room_list( requester, room_id, visibility ) return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, room_id): + requester = await self.auth.get_user_by_req(request) - yield self.handlers.directory_handler.edit_published_room_list( + await self.handlers.directory_handler.edit_published_room_list( requester, room_id, "private" ) @@ -181,15 +173,14 @@ class ClientAppserviceDirectoryListServer(RestServlet): def on_DELETE(self, request, network_id, room_id): return self._edit(request, network_id, room_id, "private") - @defer.inlineCallbacks - def _edit(self, request, network_id, room_id, visibility): - requester = yield self.auth.get_user_by_req(request) + async def _edit(self, request, network_id, room_id, visibility): + requester = await self.auth.get_user_by_req(request) if not requester.app_service: raise AuthError( 403, "Only appservices can edit the appservice published room list" ) - yield self.handlers.directory_handler.edit_published_appservice_room_list( + await self.handlers.directory_handler.edit_published_appservice_room_list( requester.app_service.id, network_id, room_id, visibility ) diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 6651b4cf07..4beb617733 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -16,8 +16,6 @@ """This module contains REST servlets to do with event streaming, /events.""" import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet from synapse.rest.client.v2_alpha._base import client_patterns @@ -36,9 +34,8 @@ class EventStreamRestServlet(RestServlet): self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) is_guest = requester.is_guest room_id = None if is_guest: @@ -57,7 +54,7 @@ class EventStreamRestServlet(RestServlet): as_client_event = b"raw" not in request.args - chunk = yield self.event_stream_handler.get_stream( + chunk = await self.event_stream_handler.get_stream( requester.user.to_string(), pagin_config, timeout=timeout, @@ -83,14 +80,13 @@ class EventRestServlet(RestServlet): self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks - def on_GET(self, request, event_id): - requester = yield self.auth.get_user_by_req(request) - event = yield self.event_handler.get_event(requester.user, None, event_id) + async def on_GET(self, request, event_id): + requester = await self.auth.get_user_by_req(request) + event = await self.event_handler.get_event(requester.user, None, event_id) time_now = self.clock.time_msec() if event: - event = yield self._event_serializer.serialize_event(event, time_now) + event = await self._event_serializer.serialize_event(event, time_now) return 200, event else: return 404, "Event not found." diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 2da3cd7511..910b3b4eeb 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_boolean from synapse.rest.client.v2_alpha._base import client_patterns @@ -29,13 +28,12 @@ class InitialSyncRestServlet(RestServlet): self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) as_client_event = b"raw" not in request.args pagination_config = PaginationConfig.from_request(request) include_archived = parse_boolean(request, "archived", default=False) - content = yield self.initial_sync_handler.snapshot_all_rooms( + content = await self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), pagin_config=pagination_config, as_client_event=as_client_event, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 19eb15003d..ff9c978fe7 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -18,7 +18,6 @@ import xml.etree.ElementTree as ET from six.moves import urllib -from twisted.internet import defer from twisted.web.client import PartialDownloadError from synapse.api.errors import Codes, LoginError, SynapseError @@ -130,8 +129,7 @@ class LoginRestServlet(RestServlet): def on_OPTIONS(self, request): return 200, {} - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): self._address_ratelimiter.ratelimit( request.getClientIP(), time_now_s=self.hs.clock.time(), @@ -145,11 +143,11 @@ class LoginRestServlet(RestServlet): if self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE ): - result = yield self.do_jwt_login(login_submission) + result = await self.do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: - result = yield self.do_token_login(login_submission) + result = await self.do_token_login(login_submission) else: - result = yield self._do_other_login(login_submission) + result = await self._do_other_login(login_submission) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -158,8 +156,7 @@ class LoginRestServlet(RestServlet): result["well_known"] = well_known_data return 200, result - @defer.inlineCallbacks - def _do_other_login(self, login_submission): + async def _do_other_login(self, login_submission): """Handle non-token/saml/jwt logins Args: @@ -219,20 +216,20 @@ class LoginRestServlet(RestServlet): ( canonical_user_id, callback_3pid, - ) = yield self.auth_handler.check_password_provider_3pid( + ) = await self.auth_handler.check_password_provider_3pid( medium, address, login_submission["password"] ) if canonical_user_id: # Authentication through password provider and 3pid succeeded - result = yield self._complete_login( + result = await self._complete_login( canonical_user_id, login_submission, callback_3pid ) return result # No password providers were able to handle this 3pid # Check local store - user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + user_id = await self.hs.get_datastore().get_user_id_by_threepid( medium, address ) if not user_id: @@ -280,7 +277,7 @@ class LoginRestServlet(RestServlet): ) try: - canonical_user_id, callback = yield self.auth_handler.validate_login( + canonical_user_id, callback = await self.auth_handler.validate_login( identifier["user"], login_submission ) except LoginError: @@ -297,13 +294,12 @@ class LoginRestServlet(RestServlet): ) raise - result = yield self._complete_login( + result = await self._complete_login( canonical_user_id, login_submission, callback ) return result - @defer.inlineCallbacks - def _complete_login( + async def _complete_login( self, user_id, login_submission, callback=None, create_non_existant_users=False ): """Called when we've successfully authed the user and now need to @@ -337,15 +333,15 @@ class LoginRestServlet(RestServlet): ) if create_non_existant_users: - user_id = yield self.auth_handler.check_user_exists(user_id) + user_id = await self.auth_handler.check_user_exists(user_id) if not user_id: - user_id = yield self.registration_handler.register_user( + user_id = await self.registration_handler.register_user( localpart=UserID.from_string(user_id).localpart ) device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name ) @@ -357,23 +353,21 @@ class LoginRestServlet(RestServlet): } if callback is not None: - yield callback(result) + await callback(result) return result - @defer.inlineCallbacks - def do_token_login(self, login_submission): + async def do_token_login(self, login_submission): token = login_submission["token"] auth_handler = self.auth_handler - user_id = yield auth_handler.validate_short_term_login_token_and_get_user_id( + user_id = await auth_handler.validate_short_term_login_token_and_get_user_id( token ) - result = yield self._complete_login(user_id, login_submission) + result = await self._complete_login(user_id, login_submission) return result - @defer.inlineCallbacks - def do_jwt_login(self, login_submission): + async def do_jwt_login(self, login_submission): token = login_submission.get("token", None) if token is None: raise LoginError( @@ -397,7 +391,7 @@ class LoginRestServlet(RestServlet): raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user_id = UserID(user, self.hs.hostname).to_string() - result = yield self._complete_login( + result = await self._complete_login( user_id, login_submission, create_non_existant_users=True ) return result @@ -460,8 +454,7 @@ class CasTicketServlet(RestServlet): self._sso_auth_handler = SSOAuthHandler(hs) self._http_client = hs.get_proxied_http_client() - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): client_redirect_url = parse_string(request, "redirectUrl", required=True) uri = self.cas_server_url + "/proxyValidate" args = { @@ -469,12 +462,12 @@ class CasTicketServlet(RestServlet): "service": self.cas_service_url, } try: - body = yield self._http_client.get_raw(uri, args) + body = await self._http_client.get_raw(uri, args) except PartialDownloadError as pde: # Twisted raises this error if the connection is closed, # even if that's being used old-http style to signal end-of-data body = pde.response - result = yield self.handle_cas_response(request, body, client_redirect_url) + result = await self.handle_cas_response(request, body, client_redirect_url) return result def handle_cas_response(self, request, cas_response_body, client_redirect_url): @@ -555,8 +548,7 @@ class SSOAuthHandler(object): self._registration_handler = hs.get_registration_handler() self._macaroon_gen = hs.get_macaroon_generator() - @defer.inlineCallbacks - def on_successful_auth( + async def on_successful_auth( self, username, request, client_redirect_url, user_display_name=None ): """Called once the user has successfully authenticated with the SSO. @@ -582,9 +574,9 @@ class SSOAuthHandler(object): """ localpart = map_username_to_mxid_localpart(username) user_id = UserID(localpart, self._hostname).to_string() - registered_user_id = yield self._auth_handler.check_user_exists(user_id) + registered_user_id = await self._auth_handler.check_user_exists(user_id) if not registered_user_id: - registered_user_id = yield self._registration_handler.register_user( + registered_user_id = await self._registration_handler.register_user( localpart=localpart, default_display_name=user_display_name ) diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 4785a34d75..1cf3caf832 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import RestServlet from synapse.rest.client.v2_alpha._base import client_patterns @@ -35,17 +33,16 @@ class LogoutRestServlet(RestServlet): def on_OPTIONS(self, request): return 200, {} - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) if requester.device_id is None: # the acccess token wasn't associated with a device. # Just delete the access token access_token = self.auth.get_access_token_from_request(request) - yield self._auth_handler.delete_access_token(access_token) + await self._auth_handler.delete_access_token(access_token) else: - yield self._device_handler.delete_device( + await self._device_handler.delete_device( requester.user.to_string(), requester.device_id ) @@ -64,17 +61,16 @@ class LogoutAllRestServlet(RestServlet): def on_OPTIONS(self, request): return 200, {} - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() # first delete all of the user's devices - yield self._device_handler.delete_all_devices_for_user(user_id) + await self._device_handler.delete_all_devices_for_user(user_id) # .. and then delete any access tokens which weren't associated with # devices. - yield self._auth_handler.delete_access_tokens_for_user(user_id) + await self._auth_handler.delete_access_tokens_for_user(user_id) return 200, {} diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 0153525cef..eec16f8ad8 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -19,8 +19,6 @@ import logging from six import string_types -from twisted.internet import defer - from synapse.api.errors import AuthError, SynapseError from synapse.handlers.presence import format_user_presence_state from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -40,27 +38,25 @@ class PresenceStatusRestServlet(RestServlet): self.clock = hs.get_clock() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id): + requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if requester.user != user: - allowed = yield self.presence_handler.is_visible( + allowed = await self.presence_handler.is_visible( observed_user=user, observer_user=requester.user ) if not allowed: raise AuthError(403, "You are not allowed to see their presence.") - state = yield self.presence_handler.get_state(target_user=user) + state = await self.presence_handler.get_state(target_user=user) state = format_user_presence_state(state, self.clock.time_msec()) return 200, state - @defer.inlineCallbacks - def on_PUT(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if requester.user != user: @@ -86,7 +82,7 @@ class PresenceStatusRestServlet(RestServlet): raise SynapseError(400, "Unable to parse state") if self.hs.config.use_presence: - yield self.presence_handler.set_state(user, state) + await self.presence_handler.set_state(user, state) return 200, {} diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index bbce2e2b71..1eac8a44c5 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -14,7 +14,6 @@ # limitations under the License. """ This module contains REST servlets to do with profile: /profile/ """ -from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.rest.client.v2_alpha._base import client_patterns @@ -30,19 +29,18 @@ class ProfileDisplaynameRestServlet(RestServlet): self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): requester_user = None if self.hs.config.require_auth_for_profile_requests: - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) requester_user = requester.user user = UserID.from_string(user_id) - yield self.profile_handler.check_profile_query_allowed(user, requester_user) + await self.profile_handler.check_profile_query_allowed(user, requester_user) - displayname = yield self.profile_handler.get_displayname(user) + displayname = await self.profile_handler.get_displayname(user) ret = {} if displayname is not None: @@ -50,11 +48,10 @@ class ProfileDisplaynameRestServlet(RestServlet): return 200, ret - @defer.inlineCallbacks - def on_PUT(self, request, user_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) user = UserID.from_string(user_id) - is_admin = yield self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester.user) content = parse_json_object_from_request(request) @@ -63,7 +60,7 @@ class ProfileDisplaynameRestServlet(RestServlet): except Exception: return 400, "Unable to parse name" - yield self.profile_handler.set_displayname(user, requester, new_name, is_admin) + await self.profile_handler.set_displayname(user, requester, new_name, is_admin) return 200, {} @@ -80,19 +77,18 @@ class ProfileAvatarURLRestServlet(RestServlet): self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): requester_user = None if self.hs.config.require_auth_for_profile_requests: - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) requester_user = requester.user user = UserID.from_string(user_id) - yield self.profile_handler.check_profile_query_allowed(user, requester_user) + await self.profile_handler.check_profile_query_allowed(user, requester_user) - avatar_url = yield self.profile_handler.get_avatar_url(user) + avatar_url = await self.profile_handler.get_avatar_url(user) ret = {} if avatar_url is not None: @@ -100,11 +96,10 @@ class ProfileAvatarURLRestServlet(RestServlet): return 200, ret - @defer.inlineCallbacks - def on_PUT(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) - is_admin = yield self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester.user) content = parse_json_object_from_request(request) try: @@ -112,7 +107,7 @@ class ProfileAvatarURLRestServlet(RestServlet): except Exception: return 400, "Unable to parse name" - yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin) + await self.profile_handler.set_avatar_url(user, requester, new_name, is_admin) return 200, {} @@ -129,20 +124,19 @@ class ProfileRestServlet(RestServlet): self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): requester_user = None if self.hs.config.require_auth_for_profile_requests: - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) requester_user = requester.user user = UserID.from_string(user_id) - yield self.profile_handler.check_profile_query_allowed(user, requester_user) + await self.profile_handler.check_profile_query_allowed(user, requester_user) - displayname = yield self.profile_handler.get_displayname(user) - avatar_url = yield self.profile_handler.get_avatar_url(user) + displayname = await self.profile_handler.get_displayname(user) + avatar_url = await self.profile_handler.get_avatar_url(user) ret = {} if displayname is not None: diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 9f8c3d09e3..4f74600239 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer from synapse.api.errors import ( NotFoundError, @@ -46,8 +45,7 @@ class PushRuleRestServlet(RestServlet): self.notifier = hs.get_notifier() self._is_worker = hs.config.worker_app is not None - @defer.inlineCallbacks - def on_PUT(self, request, path): + async def on_PUT(self, request, path): if self._is_worker: raise Exception("Cannot handle PUT /push_rules on worker") @@ -57,7 +55,7 @@ class PushRuleRestServlet(RestServlet): except InvalidRuleException as e: raise SynapseError(400, str(e)) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: raise SynapseError(400, "rule_id may not contain slashes") @@ -67,7 +65,7 @@ class PushRuleRestServlet(RestServlet): user_id = requester.user.to_string() if "attr" in spec: - yield self.set_rule_attr(user_id, spec, content) + await self.set_rule_attr(user_id, spec, content) self.notify_user(user_id) return 200, {} @@ -91,7 +89,7 @@ class PushRuleRestServlet(RestServlet): after = _namespaced_rule_id(spec, after) try: - yield self.store.add_push_rule( + await self.store.add_push_rule( user_id=user_id, rule_id=_namespaced_rule_id_from_spec(spec), priority_class=priority_class, @@ -108,20 +106,19 @@ class PushRuleRestServlet(RestServlet): return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, path): + async def on_DELETE(self, request, path): if self._is_worker: raise Exception("Cannot handle DELETE /push_rules on worker") spec = _rule_spec_from_path([x for x in path.split("/")]) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() namespaced_rule_id = _namespaced_rule_id_from_spec(spec) try: - yield self.store.delete_push_rule(user_id, namespaced_rule_id) + await self.store.delete_push_rule(user_id, namespaced_rule_id) self.notify_user(user_id) return 200, {} except StoreError as e: @@ -130,15 +127,14 @@ class PushRuleRestServlet(RestServlet): else: raise - @defer.inlineCallbacks - def on_GET(self, request, path): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, path): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is # is probably not going to make a whole lot of difference - rules = yield self.store.get_push_rules_for_user(user_id) + rules = await self.store.get_push_rules_for_user(user_id) rules = format_push_rules_for_user(requester.user, rules) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 41660682d9..0791866f55 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import Codes, StoreError, SynapseError from synapse.http.server import finish_request from synapse.http.servlet import ( @@ -39,12 +37,11 @@ class PushersRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) user = requester.user - pushers = yield self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) + pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) allowed_keys = [ "app_display_name", @@ -78,9 +75,8 @@ class PushersSetRestServlet(RestServlet): self.notifier = hs.get_notifier() self.pusher_pool = self.hs.get_pusherpool() - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) user = requester.user content = parse_json_object_from_request(request) @@ -91,7 +87,7 @@ class PushersSetRestServlet(RestServlet): and "kind" in content and content["kind"] is None ): - yield self.pusher_pool.remove_pusher( + await self.pusher_pool.remove_pusher( content["app_id"], content["pushkey"], user_id=user.to_string() ) return 200, {} @@ -117,14 +113,14 @@ class PushersSetRestServlet(RestServlet): append = content["append"] if not append: - yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( + await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( app_id=content["app_id"], pushkey=content["pushkey"], not_user_id=user.to_string(), ) try: - yield self.pusher_pool.add_pusher( + await self.pusher_pool.add_pusher( user_id=user.to_string(), access_token=requester.access_token_id, kind=content["kind"], @@ -164,16 +160,15 @@ class PushersRemoveRestServlet(RestServlet): self.auth = hs.get_auth() self.pusher_pool = self.hs.get_pusherpool() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, rights="delete_pusher") + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, rights="delete_pusher") user = requester.user app_id = parse_string(request, "app_id", required=True) pushkey = parse_string(request, "pushkey", required=True) try: - yield self.pusher_pool.remove_pusher( + await self.pusher_pool.remove_pusher( app_id=app_id, pushkey=pushkey, user_id=user.to_string() ) except StoreError as se: diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 2afdbb89e5..747d46eac2 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -17,8 +17,6 @@ import base64 import hashlib import hmac -from twisted.internet import defer - from synapse.http.servlet import RestServlet from synapse.rest.client.v2_alpha._base import client_patterns @@ -31,9 +29,8 @@ class VoipRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req( + async def on_GET(self, request): + requester = await self.auth.get_user_by_req( request, self.hs.config.turn_allow_guests ) -- cgit 1.4.1 From 9c41ba4c5fe6e554bc885a1bef8ed51d14ec0f3e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 5 Dec 2019 16:46:37 +0000 Subject: Port rest.client.v2 --- synapse/rest/client/v2_alpha/_base.py | 2 +- synapse/rest/client/v2_alpha/account.py | 119 +++++------ synapse/rest/client/v2_alpha/account_data.py | 30 ++- synapse/rest/client/v2_alpha/account_validity.py | 17 +- synapse/rest/client/v2_alpha/auth.py | 9 +- synapse/rest/client/v2_alpha/capabilities.py | 9 +- synapse/rest/client/v2_alpha/devices.py | 41 ++-- synapse/rest/client/v2_alpha/filter.py | 16 +- synapse/rest/client/v2_alpha/groups.py | 226 +++++++++------------ synapse/rest/client/v2_alpha/keys.py | 46 ++--- synapse/rest/client/v2_alpha/notifications.py | 15 +- synapse/rest/client/v2_alpha/openid.py | 9 +- synapse/rest/client/v2_alpha/register.py | 72 +++---- synapse/rest/client/v2_alpha/relations.py | 56 +++-- synapse/rest/client/v2_alpha/report_event.py | 9 +- synapse/rest/client/v2_alpha/room_keys.py | 51 ++--- .../client/v2_alpha/room_upgrade_rest_servlet.py | 9 +- synapse/rest/client/v2_alpha/sendtodevice.py | 9 +- synapse/rest/client/v2_alpha/sync.py | 54 +++-- synapse/rest/client/v2_alpha/tags.py | 23 +-- synapse/rest/client/v2_alpha/thirdparty.py | 30 ++- synapse/rest/client/v2_alpha/tokenrefresh.py | 5 +- synapse/rest/client/v2_alpha/user_directory.py | 9 +- 23 files changed, 361 insertions(+), 505 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 8250ae0ae1..2a3f4dd58f 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -78,7 +78,7 @@ def interactive_auth_handler(orig): """ def wrapped(*args, **kwargs): - res = defer.maybeDeferred(orig, *args, **kwargs) + res = defer.ensureDeferred(orig(*args, **kwargs)) res.addErrback(_catch_incomplete_interactive_auth) return res diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index ad674239ab..fc240f5cf8 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -18,8 +18,6 @@ import logging from six.moves import http_client -from twisted.internet import defer - from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError, ThreepidValidationError from synapse.config.emailconfig import ThreepidBehaviour @@ -67,8 +65,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): template_text=template_text, ) - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -95,7 +92,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email ) @@ -106,7 +103,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): assert self.hs.config.account_threepid_delegate_email # Have the configured identity server handle the request - ret = yield self.identity_handler.requestEmailToken( + ret = await self.identity_handler.requestEmailToken( self.hs.config.account_threepid_delegate_email, email, client_secret, @@ -115,7 +112,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): ) else: # Send password reset emails from Synapse - sid = yield self.identity_handler.send_threepid_validation( + sid = await self.identity_handler.send_threepid_validation( email, client_secret, send_attempt, @@ -153,8 +150,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): [self.config.email_password_reset_template_failure_html], ) - @defer.inlineCallbacks - def on_GET(self, request, medium): + async def on_GET(self, request, medium): # We currently only handle threepid token submissions for email if medium != "email": raise SynapseError( @@ -176,7 +172,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): # Attempt to validate a 3PID session try: # Mark the session as valid - next_link = yield self.store.validate_threepid_session( + next_link = await self.store.validate_threepid_session( sid, client_secret, token, self.clock.time_msec() ) @@ -218,8 +214,7 @@ class PasswordRestServlet(RestServlet): self._set_password_handler = hs.get_set_password_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) # there are two possibilities here. Either the user does not have an @@ -233,14 +228,14 @@ class PasswordRestServlet(RestServlet): # In the second case, we require a password to confirm their identity. if self.auth.has_access_token(request): - requester = yield self.auth.get_user_by_req(request) - params = yield self.auth_handler.validate_user_via_ui_auth( + requester = await self.auth.get_user_by_req(request) + params = await self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) ) user_id = requester.user.to_string() else: requester = None - result, params, _ = yield self.auth_handler.check_auth( + result, params, _ = await self.auth_handler.check_auth( [[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request) ) @@ -254,7 +249,7 @@ class PasswordRestServlet(RestServlet): # (See add_threepid in synapse/handlers/auth.py) threepid["address"] = threepid["address"].lower() # if using email, we must know about the email they're authing with! - threepid_user_id = yield self.datastore.get_user_id_by_threepid( + threepid_user_id = await self.datastore.get_user_id_by_threepid( threepid["medium"], threepid["address"] ) if not threepid_user_id: @@ -267,7 +262,7 @@ class PasswordRestServlet(RestServlet): assert_params_in_dict(params, ["new_password"]) new_password = params["new_password"] - yield self._set_password_handler.set_password(user_id, new_password, requester) + await self._set_password_handler.set_password(user_id, new_password, requester) return 200, {} @@ -286,8 +281,7 @@ class DeactivateAccountRestServlet(RestServlet): self._deactivate_account_handler = hs.get_deactivate_account_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) erase = body.get("erase", False) if not isinstance(erase, bool): @@ -297,19 +291,19 @@ class DeactivateAccountRestServlet(RestServlet): Codes.BAD_JSON, ) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) # allow ASes to dectivate their own users if requester.app_service: - yield self._deactivate_account_handler.deactivate_account( + await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase ) return 200, {} - yield self.auth_handler.validate_user_via_ui_auth( + await self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) ) - result = yield self._deactivate_account_handler.deactivate_account( + result = await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase, id_server=body.get("id_server") ) if result: @@ -346,8 +340,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): template_text=template_text, ) - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -371,7 +364,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.store.get_user_id_by_threepid( + existing_user_id = await self.store.get_user_id_by_threepid( "email", body["email"] ) @@ -382,7 +375,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): assert self.hs.config.account_threepid_delegate_email # Have the configured identity server handle the request - ret = yield self.identity_handler.requestEmailToken( + ret = await self.identity_handler.requestEmailToken( self.hs.config.account_threepid_delegate_email, email, client_secret, @@ -391,7 +384,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): ) else: # Send threepid validation emails from Synapse - sid = yield self.identity_handler.send_threepid_validation( + sid = await self.identity_handler.send_threepid_validation( email, client_secret, send_attempt, @@ -414,8 +407,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): self.store = self.hs.get_datastore() self.identity_handler = hs.get_handlers().identity_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict( body, ["client_secret", "country", "phone_number", "send_attempt"] @@ -435,7 +427,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.store.get_user_id_by_threepid("msisdn", msisdn) + existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) if existing_user_id is not None: raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) @@ -450,7 +442,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): "Adding phone numbers to user account is not supported by this homeserver", ) - ret = yield self.identity_handler.requestMsisdnToken( + ret = await self.identity_handler.requestMsisdnToken( self.hs.config.account_threepid_delegate_msisdn, country, phone_number, @@ -484,8 +476,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): [self.config.email_add_threepid_template_failure_html], ) - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -508,7 +499,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): # Attempt to validate a 3PID session try: # Mark the session as valid - next_link = yield self.store.validate_threepid_session( + next_link = await self.store.validate_threepid_session( sid, client_secret, token, self.clock.time_msec() ) @@ -558,8 +549,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): self.store = hs.get_datastore() self.identity_handler = hs.get_handlers().identity_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if not self.config.account_threepid_delegate_msisdn: raise SynapseError( 400, @@ -571,7 +561,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): assert_params_in_dict(body, ["client_secret", "sid", "token"]) # Proxy submit_token request to msisdn threepid delegate - response = yield self.identity_handler.proxy_msisdn_submit_token( + response = await self.identity_handler.proxy_msisdn_submit_token( self.config.account_threepid_delegate_msisdn, body["client_secret"], body["sid"], @@ -591,17 +581,15 @@ class ThreepidRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.datastore = self.hs.get_datastore() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) - threepids = yield self.datastore.user_get_threepids(requester.user.to_string()) + threepids = await self.datastore.user_get_threepids(requester.user.to_string()) return 200, {"threepids": threepids} - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -615,11 +603,11 @@ class ThreepidRestServlet(RestServlet): client_secret = threepid_creds["client_secret"] sid = threepid_creds["sid"] - validation_session = yield self.identity_handler.validate_threepid_session( + validation_session = await self.identity_handler.validate_threepid_session( client_secret, sid ) if validation_session: - yield self.auth_handler.add_threepid( + await self.auth_handler.add_threepid( user_id, validation_session["medium"], validation_session["address"], @@ -643,9 +631,8 @@ class ThreepidAddRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -653,15 +640,15 @@ class ThreepidAddRestServlet(RestServlet): client_secret = body["client_secret"] sid = body["sid"] - yield self.auth_handler.validate_user_via_ui_auth( + await self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) ) - validation_session = yield self.identity_handler.validate_threepid_session( + validation_session = await self.identity_handler.validate_threepid_session( client_secret, sid ) if validation_session: - yield self.auth_handler.add_threepid( + await self.auth_handler.add_threepid( user_id, validation_session["medium"], validation_session["address"], @@ -683,8 +670,7 @@ class ThreepidBindRestServlet(RestServlet): self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) @@ -693,10 +679,10 @@ class ThreepidBindRestServlet(RestServlet): client_secret = body["client_secret"] id_access_token = body.get("id_access_token") # optional - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() - yield self.identity_handler.bind_threepid( + await self.identity_handler.bind_threepid( client_secret, sid, user_id, id_server, id_access_token ) @@ -713,12 +699,11 @@ class ThreepidUnbindRestServlet(RestServlet): self.auth = hs.get_auth() self.datastore = self.hs.get_datastore() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): """Unbind the given 3pid from a specific identity server, or identity servers that are known to have this 3pid bound """ - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) @@ -728,7 +713,7 @@ class ThreepidUnbindRestServlet(RestServlet): # Attempt to unbind the threepid from an identity server. If id_server is None, try to # unbind from all identity servers this threepid has been added to in the past - result = yield self.identity_handler.try_unbind_threepid( + result = await self.identity_handler.try_unbind_threepid( requester.user.to_string(), {"address": address, "medium": medium, "id_server": id_server}, ) @@ -743,16 +728,15 @@ class ThreepidDeleteRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() try: - ret = yield self.auth_handler.delete_threepid( + ret = await self.auth_handler.delete_threepid( user_id, body["medium"], body["address"], body.get("id_server") ) except Exception: @@ -777,9 +761,8 @@ class WhoamiRestServlet(RestServlet): super(WhoamiRestServlet, self).__init__() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) return 200, {"user_id": requester.user.to_string()} diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index f0db204ffa..64eb7fec3b 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -41,15 +39,14 @@ class AccountDataServlet(RestServlet): self.store = hs.get_datastore() self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def on_PUT(self, request, user_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id, account_data_type): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") body = parse_json_object_from_request(request) - max_id = yield self.store.add_account_data_for_user( + max_id = await self.store.add_account_data_for_user( user_id, account_data_type, body ) @@ -57,13 +54,12 @@ class AccountDataServlet(RestServlet): return 200, {} - @defer.inlineCallbacks - def on_GET(self, request, user_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id, account_data_type): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") - event = yield self.store.get_global_account_data_by_type_for_user( + event = await self.store.get_global_account_data_by_type_for_user( account_data_type, user_id ) @@ -91,9 +87,8 @@ class RoomAccountDataServlet(RestServlet): self.store = hs.get_datastore() self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def on_PUT(self, request, user_id, room_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id, room_id, account_data_type): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") @@ -106,7 +101,7 @@ class RoomAccountDataServlet(RestServlet): " Use /rooms/!roomId:server.name/read_markers", ) - max_id = yield self.store.add_account_data_to_room( + max_id = await self.store.add_account_data_to_room( user_id, room_id, account_data_type, body ) @@ -114,13 +109,12 @@ class RoomAccountDataServlet(RestServlet): return 200, {} - @defer.inlineCallbacks - def on_GET(self, request, user_id, room_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id, room_id, account_data_type): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") - event = yield self.store.get_account_data_for_room_and_type( + event = await self.store.get_account_data_for_room_and_type( user_id, room_id, account_data_type ) diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 33f6a23028..2f10fa64e2 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError, SynapseError from synapse.http.server import finish_request from synapse.http.servlet import RestServlet @@ -45,13 +43,12 @@ class AccountValidityRenewServlet(RestServlet): self.success_html = hs.config.account_validity.account_renewed_html_content self.failure_html = hs.config.account_validity.invalid_token_html_content - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): if b"token" not in request.args: raise SynapseError(400, "Missing renewal token") renewal_token = request.args[b"token"][0] - token_valid = yield self.account_activity_handler.renew_account( + token_valid = await self.account_activity_handler.renew_account( renewal_token.decode("utf8") ) @@ -67,7 +64,6 @@ class AccountValidityRenewServlet(RestServlet): request.setHeader(b"Content-Length", b"%d" % (len(response),)) request.write(response.encode("utf8")) finish_request(request) - defer.returnValue(None) class AccountValiditySendMailServlet(RestServlet): @@ -85,18 +81,17 @@ class AccountValiditySendMailServlet(RestServlet): self.auth = hs.get_auth() self.account_validity = self.hs.config.account_validity - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if not self.account_validity.renew_by_email_enabled: raise AuthError( 403, "Account renewal via email is disabled on this server." ) - requester = yield self.auth.get_user_by_req(request, allow_expired=True) + requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() - yield self.account_activity_handler.send_renewal_email_to_user(user_id) + await self.account_activity_handler.send_renewal_email_to_user(user_id) - defer.returnValue((200, {})) + return 200, {} def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index f21aff39e5..7a256b6ecb 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import LoginType from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_API_PREFIX @@ -171,8 +169,7 @@ class AuthRestServlet(RestServlet): else: raise SynapseError(404, "Unknown auth stage type") - @defer.inlineCallbacks - def on_POST(self, request, stagetype): + async def on_POST(self, request, stagetype): session = parse_string(request, "session") if not session: @@ -186,7 +183,7 @@ class AuthRestServlet(RestServlet): authdict = {"response": response, "session": session} - success = yield self.auth_handler.add_oob_auth( + success = await self.auth_handler.add_oob_auth( LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request) ) @@ -215,7 +212,7 @@ class AuthRestServlet(RestServlet): session = request.args["session"][0] authdict = {"session": session} - success = yield self.auth_handler.add_oob_auth( + success = await self.auth_handler.add_oob_auth( LoginType.TERMS, authdict, self.hs.get_ip_from_request(request) ) diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py index acd58af193..fe9d019c44 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py @@ -14,8 +14,6 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import RestServlet @@ -40,10 +38,9 @@ class CapabilitiesRestServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) - user = yield self.store.get_user_by_id(requester.user.to_string()) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user = await self.store.get_user_by_id(requester.user.to_string()) change_password = bool(user["password_hash"]) response = { diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 26d0235208..94ff73f384 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api import errors from synapse.http.servlet import ( RestServlet, @@ -42,10 +40,9 @@ class DevicesRestServlet(RestServlet): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) - devices = yield self.device_handler.get_devices_by_user( + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + devices = await self.device_handler.get_devices_by_user( requester.user.to_string() ) return 200, {"devices": devices} @@ -67,9 +64,8 @@ class DeleteDevicesRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) try: body = parse_json_object_from_request(request) @@ -84,11 +80,11 @@ class DeleteDevicesRestServlet(RestServlet): assert_params_in_dict(body, ["devices"]) - yield self.auth_handler.validate_user_via_ui_auth( + await self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) ) - yield self.device_handler.delete_devices( + await self.device_handler.delete_devices( requester.user.to_string(), body["devices"] ) return 200, {} @@ -108,18 +104,16 @@ class DeviceRestServlet(RestServlet): self.device_handler = hs.get_device_handler() self.auth_handler = hs.get_auth_handler() - @defer.inlineCallbacks - def on_GET(self, request, device_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) - device = yield self.device_handler.get_device( + async def on_GET(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + device = await self.device_handler.get_device( requester.user.to_string(), device_id ) return 200, device @interactive_auth_handler - @defer.inlineCallbacks - def on_DELETE(self, request, device_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, device_id): + requester = await self.auth.get_user_by_req(request) try: body = parse_json_object_from_request(request) @@ -132,19 +126,18 @@ class DeviceRestServlet(RestServlet): else: raise - yield self.auth_handler.validate_user_via_ui_auth( + await self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) ) - yield self.device_handler.delete_device(requester.user.to_string(), device_id) + await self.device_handler.delete_device(requester.user.to_string(), device_id) return 200, {} - @defer.inlineCallbacks - def on_PUT(self, request, device_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_PUT(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) body = parse_json_object_from_request(request) - yield self.device_handler.update_device( + await self.device_handler.update_device( requester.user.to_string(), device_id, body ) return 200, {} diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 17a8bc7366..b28da017cd 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID @@ -35,10 +33,9 @@ class GetFilterRestServlet(RestServlet): self.auth = hs.get_auth() self.filtering = hs.get_filtering() - @defer.inlineCallbacks - def on_GET(self, request, user_id, filter_id): + async def on_GET(self, request, user_id, filter_id): target_user = UserID.from_string(user_id) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if target_user != requester.user: raise AuthError(403, "Cannot get filters for other users") @@ -52,7 +49,7 @@ class GetFilterRestServlet(RestServlet): raise SynapseError(400, "Invalid filter_id") try: - filter_collection = yield self.filtering.get_user_filter( + filter_collection = await self.filtering.get_user_filter( user_localpart=target_user.localpart, filter_id=filter_id ) except StoreError as e: @@ -72,11 +69,10 @@ class CreateFilterRestServlet(RestServlet): self.auth = hs.get_auth() self.filtering = hs.get_filtering() - @defer.inlineCallbacks - def on_POST(self, request, user_id): + async def on_POST(self, request, user_id): target_user = UserID.from_string(user_id) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if target_user != requester.user: raise AuthError(403, "Cannot create filters for other users") @@ -87,7 +83,7 @@ class CreateFilterRestServlet(RestServlet): content = parse_json_object_from_request(request) set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit) - filter_id = yield self.filtering.add_user_filter( + filter_id = await self.filtering.add_user_filter( user_localpart=target_user.localpart, user_filter=content ) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 999a0fa80c..d84a6d7e11 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import GroupID @@ -38,24 +36,22 @@ class GroupServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - group_description = yield self.groups_handler.get_group_profile( + group_description = await self.groups_handler.get_group_profile( group_id, requester_user_id ) return 200, group_description - @defer.inlineCallbacks - def on_POST(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - yield self.groups_handler.update_group_profile( + await self.groups_handler.update_group_profile( group_id, requester_user_id, content ) @@ -74,12 +70,11 @@ class GroupSummaryServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - get_group_summary = yield self.groups_handler.get_group_summary( + get_group_summary = await self.groups_handler.get_group_summary( group_id, requester_user_id ) @@ -106,13 +101,12 @@ class GroupSummaryRoomsCatServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, category_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, category_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_summary_room( + resp = await self.groups_handler.update_group_summary_room( group_id, requester_user_id, room_id=room_id, @@ -122,12 +116,11 @@ class GroupSummaryRoomsCatServlet(RestServlet): return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, category_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, category_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_summary_room( + resp = await self.groups_handler.delete_group_summary_room( group_id, requester_user_id, room_id=room_id, category_id=category_id ) @@ -148,35 +141,32 @@ class GroupCategoryServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id, category_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id, category_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_category( + category = await self.groups_handler.get_group_category( group_id, requester_user_id, category_id=category_id ) return 200, category - @defer.inlineCallbacks - def on_PUT(self, request, group_id, category_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, category_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_category( + resp = await self.groups_handler.update_group_category( group_id, requester_user_id, category_id=category_id, content=content ) return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, category_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, category_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_category( + resp = await self.groups_handler.delete_group_category( group_id, requester_user_id, category_id=category_id ) @@ -195,12 +185,11 @@ class GroupCategoriesServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_categories( + category = await self.groups_handler.get_group_categories( group_id, requester_user_id ) @@ -219,35 +208,32 @@ class GroupRoleServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id, role_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id, role_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_role( + category = await self.groups_handler.get_group_role( group_id, requester_user_id, role_id=role_id ) return 200, category - @defer.inlineCallbacks - def on_PUT(self, request, group_id, role_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, role_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_role( + resp = await self.groups_handler.update_group_role( group_id, requester_user_id, role_id=role_id, content=content ) return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, role_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, role_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_role( + resp = await self.groups_handler.delete_group_role( group_id, requester_user_id, role_id=role_id ) @@ -266,12 +252,11 @@ class GroupRolesServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_roles( + category = await self.groups_handler.get_group_roles( group_id, requester_user_id ) @@ -298,13 +283,12 @@ class GroupSummaryUsersRoleServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, role_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, role_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_summary_user( + resp = await self.groups_handler.update_group_summary_user( group_id, requester_user_id, user_id=user_id, @@ -314,12 +298,11 @@ class GroupSummaryUsersRoleServlet(RestServlet): return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, role_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, role_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_summary_user( + resp = await self.groups_handler.delete_group_summary_user( group_id, requester_user_id, user_id=user_id, role_id=role_id ) @@ -338,12 +321,11 @@ class GroupRoomServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_rooms_in_group( + result = await self.groups_handler.get_rooms_in_group( group_id, requester_user_id ) @@ -362,12 +344,11 @@ class GroupUsersServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_users_in_group( + result = await self.groups_handler.get_users_in_group( group_id, requester_user_id ) @@ -386,12 +367,11 @@ class GroupInvitedUsersServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_invited_users_in_group( + result = await self.groups_handler.get_invited_users_in_group( group_id, requester_user_id ) @@ -409,14 +389,13 @@ class GroupSettingJoinPolicyServlet(RestServlet): self.auth = hs.get_auth() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.set_group_join_policy( + result = await self.groups_handler.set_group_join_policy( group_id, requester_user_id, content ) @@ -436,9 +415,8 @@ class GroupCreateServlet(RestServlet): self.groups_handler = hs.get_groups_local_handler() self.server_name = hs.hostname - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() # TODO: Create group on remote server @@ -446,7 +424,7 @@ class GroupCreateServlet(RestServlet): localpart = content.pop("localpart") group_id = GroupID(localpart, self.server_name).to_string() - result = yield self.groups_handler.create_group( + result = await self.groups_handler.create_group( group_id, requester_user_id, content ) @@ -467,24 +445,22 @@ class GroupAdminRoomsServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.add_room_to_group( + result = await self.groups_handler.add_room_to_group( group_id, requester_user_id, room_id, content ) return 200, result - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.remove_room_from_group( + result = await self.groups_handler.remove_room_from_group( group_id, requester_user_id, room_id ) @@ -506,13 +482,12 @@ class GroupAdminRoomsConfigServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, room_id, config_key): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, room_id, config_key): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.update_room_in_group( + result = await self.groups_handler.update_room_in_group( group_id, requester_user_id, room_id, config_key, content ) @@ -535,14 +510,13 @@ class GroupAdminUsersInviteServlet(RestServlet): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id - @defer.inlineCallbacks - def on_PUT(self, request, group_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) config = content.get("config", {}) - result = yield self.groups_handler.invite( + result = await self.groups_handler.invite( group_id, user_id, requester_user_id, config ) @@ -563,13 +537,12 @@ class GroupAdminUsersKickServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.remove_user_from_group( + result = await self.groups_handler.remove_user_from_group( group_id, user_id, requester_user_id, content ) @@ -588,13 +561,12 @@ class GroupSelfLeaveServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.remove_user_from_group( + result = await self.groups_handler.remove_user_from_group( group_id, requester_user_id, requester_user_id, content ) @@ -613,13 +585,12 @@ class GroupSelfJoinServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.join_group( + result = await self.groups_handler.join_group( group_id, requester_user_id, content ) @@ -638,13 +609,12 @@ class GroupSelfAcceptInviteServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.accept_invite( + result = await self.groups_handler.accept_invite( group_id, requester_user_id, content ) @@ -663,14 +633,13 @@ class GroupSelfUpdatePublicityServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) publicise = content["publicise"] - yield self.store.update_group_publicity(group_id, requester_user_id, publicise) + await self.store.update_group_publicity(group_id, requester_user_id, publicise) return 200, {} @@ -688,11 +657,10 @@ class PublicisedGroupsForUserServlet(RestServlet): self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, user_id): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, user_id): + await self.auth.get_user_by_req(request, allow_guest=True) - result = yield self.groups_handler.get_publicised_groups_for_user(user_id) + result = await self.groups_handler.get_publicised_groups_for_user(user_id) return 200, result @@ -710,14 +678,13 @@ class PublicisedGroupsForUsersServlet(RestServlet): self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) user_ids = content["user_ids"] - result = yield self.groups_handler.bulk_get_publicised_groups(user_ids) + result = await self.groups_handler.bulk_get_publicised_groups(user_ids) return 200, result @@ -734,12 +701,11 @@ class GroupsForUserServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_joined_groups(requester_user_id) + result = await self.groups_handler.get_joined_groups(requester_user_id) return 200, result diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 341567ae21..f7ed4daf90 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import ( RestServlet, @@ -71,9 +69,8 @@ class KeyUploadServlet(RestServlet): self.e2e_keys_handler = hs.get_e2e_keys_handler() @trace(opname="upload_keys") - @defer.inlineCallbacks - def on_POST(self, request, device_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -103,7 +100,7 @@ class KeyUploadServlet(RestServlet): 400, "To upload keys, you must pass device_id when authenticating" ) - result = yield self.e2e_keys_handler.upload_keys_for_user( + result = await self.e2e_keys_handler.upload_keys_for_user( user_id, device_id, body ) return 200, result @@ -154,13 +151,12 @@ class KeyQueryServlet(RestServlet): self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.query_devices(body, timeout, user_id) + result = await self.e2e_keys_handler.query_devices(body, timeout, user_id) return 200, result @@ -185,9 +181,8 @@ class KeyChangesServlet(RestServlet): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) from_token_string = parse_string(request, "from") set_tag("from", from_token_string) @@ -200,7 +195,7 @@ class KeyChangesServlet(RestServlet): user_id = requester.user.to_string() - results = yield self.device_handler.get_user_ids_changed(user_id, from_token) + results = await self.device_handler.get_user_ids_changed(user_id, from_token) return 200, results @@ -231,12 +226,11 @@ class OneTimeKeyServlet(RestServlet): self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - @defer.inlineCallbacks - def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout) + result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout) return 200, result @@ -263,17 +257,16 @@ class SigningKeyUploadServlet(RestServlet): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) - yield self.auth_handler.validate_user_via_ui_auth( + await self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) ) - result = yield self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) + result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) return 200, result @@ -315,13 +308,12 @@ class SignaturesUploadServlet(RestServlet): self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.upload_signatures_for_device_keys( + result = await self.e2e_keys_handler.upload_signatures_for_device_keys( user_id, body ) return 200, result diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 10c1ad5b07..aa911d75ee 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.http.servlet import RestServlet, parse_integer, parse_string @@ -35,9 +33,8 @@ class NotificationsServlet(RestServlet): self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() from_token = parse_string(request, "from", required=False) @@ -46,16 +43,16 @@ class NotificationsServlet(RestServlet): limit = min(limit, 500) - push_actions = yield self.store.get_push_actions_for_user( + push_actions = await self.store.get_push_actions_for_user( user_id, from_token, limit, only_highlight=(only == "highlight") ) - receipts_by_room = yield self.store.get_receipts_for_user_with_orderings( + receipts_by_room = await self.store.get_receipts_for_user_with_orderings( user_id, "m.read" ) notif_event_ids = [pa["event_id"] for pa in push_actions] - notif_events = yield self.store.get_events(notif_event_ids) + notif_events = await self.store.get_events(notif_event_ids) returned_push_actions = [] @@ -68,7 +65,7 @@ class NotificationsServlet(RestServlet): "actions": pa["actions"], "ts": pa["received_ts"], "event": ( - yield self._event_serializer.serialize_event( + await self._event_serializer.serialize_event( notif_events[pa["event_id"]], self.clock.time_msec(), event_format=format_event_for_client_v2_without_room_id, diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py index b4925c0f59..6ae9a5a8e9 100644 --- a/synapse/rest/client/v2_alpha/openid.py +++ b/synapse/rest/client/v2_alpha/openid.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.util.stringutils import random_string @@ -68,9 +66,8 @@ class IdTokenServlet(RestServlet): self.clock = hs.get_clock() self.server_name = hs.config.server_name - @defer.inlineCallbacks - def on_POST(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, user_id): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot request tokens for other users.") @@ -81,7 +78,7 @@ class IdTokenServlet(RestServlet): token = random_string(24) ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS - yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) + await self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) return ( 200, diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 91db923814..66de16a1fa 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -20,8 +20,6 @@ from typing import List, Union from six import string_types -from twisted.internet import defer - import synapse import synapse.types from synapse.api.constants import LoginType @@ -102,8 +100,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): template_text=template_text, ) - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.hs.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -129,7 +126,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", body["email"] ) @@ -140,7 +137,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): assert self.hs.config.account_threepid_delegate_email # Have the configured identity server handle the request - ret = yield self.identity_handler.requestEmailToken( + ret = await self.identity_handler.requestEmailToken( self.hs.config.account_threepid_delegate_email, email, client_secret, @@ -149,7 +146,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): ) else: # Send registration emails from Synapse - sid = yield self.identity_handler.send_threepid_validation( + sid = await self.identity_handler.send_threepid_validation( email, client_secret, send_attempt, @@ -175,8 +172,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): self.hs = hs self.identity_handler = hs.get_handlers().identity_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict( @@ -197,7 +193,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "msisdn", msisdn ) @@ -215,7 +211,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): 400, "Registration by phone number is not supported on this homeserver" ) - ret = yield self.identity_handler.requestMsisdnToken( + ret = await self.identity_handler.requestMsisdnToken( self.hs.config.account_threepid_delegate_msisdn, country, phone_number, @@ -258,8 +254,7 @@ class RegistrationSubmitTokenServlet(RestServlet): [self.config.email_registration_template_failure_html], ) - @defer.inlineCallbacks - def on_GET(self, request, medium): + async def on_GET(self, request, medium): if medium != "email": raise SynapseError( 400, "This medium is currently not supported for registration" @@ -280,7 +275,7 @@ class RegistrationSubmitTokenServlet(RestServlet): # Attempt to validate a 3PID session try: # Mark the session as valid - next_link = yield self.store.validate_threepid_session( + next_link = await self.store.validate_threepid_session( sid, client_secret, token, self.clock.time_msec() ) @@ -338,8 +333,7 @@ class UsernameAvailabilityRestServlet(RestServlet): ), ) - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): if not self.hs.config.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN @@ -347,11 +341,11 @@ class UsernameAvailabilityRestServlet(RestServlet): ip = self.hs.get_ip_from_request(request) with self.ratelimiter.ratelimit(ip) as wait_deferred: - yield wait_deferred + await wait_deferred username = parse_string(request, "username", required=True) - yield self.registration_handler.check_username(username) + await self.registration_handler.check_username(username) return 200, {"available": True} @@ -382,8 +376,7 @@ class RegisterRestServlet(RestServlet): ) @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) client_addr = request.getClientIP() @@ -408,7 +401,7 @@ class RegisterRestServlet(RestServlet): kind = request.args[b"kind"][0] if kind == b"guest": - ret = yield self._do_guest_registration(body, address=client_addr) + ret = await self._do_guest_registration(body, address=client_addr) return ret elif kind != b"user": raise UnrecognizedRequestError( @@ -435,7 +428,7 @@ class RegisterRestServlet(RestServlet): appservice = None if self.auth.has_access_token(request): - appservice = yield self.auth.get_appservice_by_req(request) + appservice = await self.auth.get_appservice_by_req(request) # fork off as soon as possible for ASes which have completely # different registration flows to normal users @@ -455,7 +448,7 @@ class RegisterRestServlet(RestServlet): access_token = self.auth.get_access_token_from_request(request) if isinstance(desired_username, string_types): - result = yield self._do_appservice_registration( + result = await self._do_appservice_registration( desired_username, access_token, body ) return 200, result # we throw for non 200 responses @@ -495,13 +488,13 @@ class RegisterRestServlet(RestServlet): ) if desired_username is not None: - yield self.registration_handler.check_username( + await self.registration_handler.check_username( desired_username, guest_access_token=guest_access_token, assigned_user_id=registered_user_id, ) - auth_result, params, session_id = yield self.auth_handler.check_auth( + auth_result, params, session_id = await self.auth_handler.check_auth( self._registration_flows, body, self.hs.get_ip_from_request(request) ) @@ -557,7 +550,7 @@ class RegisterRestServlet(RestServlet): medium = auth_result[login_type]["medium"] address = auth_result[login_type]["address"] - existing_user_id = yield self.store.get_user_id_by_threepid( + existing_user_id = await self.store.get_user_id_by_threepid( medium, address ) @@ -568,7 +561,7 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_IN_USE, ) - registered_user_id = yield self.registration_handler.register_user( + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, password=new_password, guest_access_token=guest_access_token, @@ -581,7 +574,7 @@ class RegisterRestServlet(RestServlet): if is_threepid_reserved( self.hs.config.mau_limits_reserved_threepids, threepid ): - yield self.store.upsert_monthly_active_user(registered_user_id) + await self.store.upsert_monthly_active_user(registered_user_id) # remember that we've now registered that user account, and with # what user ID (since the user may not have specified) @@ -591,12 +584,12 @@ class RegisterRestServlet(RestServlet): registered = True - return_dict = yield self._create_registration_details( + return_dict = await self._create_registration_details( registered_user_id, params ) if registered: - yield self.registration_handler.post_registration_actions( + await self.registration_handler.post_registration_actions( user_id=registered_user_id, auth_result=auth_result, access_token=return_dict.get("access_token"), @@ -607,15 +600,13 @@ class RegisterRestServlet(RestServlet): def on_OPTIONS(self, _): return 200, {} - @defer.inlineCallbacks - def _do_appservice_registration(self, username, as_token, body): - user_id = yield self.registration_handler.appservice_register( + async def _do_appservice_registration(self, username, as_token, body): + user_id = await self.registration_handler.appservice_register( username, as_token ) - return (yield self._create_registration_details(user_id, body)) + return await self._create_registration_details(user_id, body) - @defer.inlineCallbacks - def _create_registration_details(self, user_id, params): + async def _create_registration_details(self, user_id, params): """Complete registration of newly-registered user Allocates device_id if one was not given; also creates access_token. @@ -631,18 +622,17 @@ class RegisterRestServlet(RestServlet): if not params.get("inhibit_login", False): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=False ) result.update({"access_token": access_token, "device_id": device_id}) return result - @defer.inlineCallbacks - def _do_guest_registration(self, params, address=None): + async def _do_guest_registration(self, params, address=None): if not self.hs.config.allow_guest_access: raise SynapseError(403, "Guest access is disabled") - user_id = yield self.registration_handler.register_user( + user_id = await self.registration_handler.register_user( make_guest=True, address=address ) @@ -650,7 +640,7 @@ class RegisterRestServlet(RestServlet): # we have nowhere to store it. device_id = synapse.api.auth.GUEST_DEVICE_ID initial_display_name = params.get("initial_device_display_name") - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=True ) diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 040b37c504..9be9a34b91 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -21,8 +21,6 @@ any time to reflect changes in the MSC. import logging -from twisted.internet import defer - from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.http.servlet import ( @@ -86,11 +84,10 @@ class RelationSendServlet(RestServlet): request, self.on_PUT_or_POST, request, *args, **kwargs ) - @defer.inlineCallbacks - def on_PUT_or_POST( + async def on_PUT_or_POST( self, request, room_id, parent_id, relation_type, event_type, txn_id=None ): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = await self.auth.get_user_by_req(request, allow_guest=True) if event_type == EventTypes.Member: # Add relations to a membership is meaningless, so we just deny it @@ -114,7 +111,7 @@ class RelationSendServlet(RestServlet): "sender": requester.user.to_string(), } - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict=event_dict, txn_id=txn_id ) @@ -140,17 +137,18 @@ class RelationPaginationServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() - @defer.inlineCallbacks - def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET( + self, request, room_id, parent_id, relation_type=None, event_type=None + ): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - yield self.auth.check_in_room_or_world_readable( + await self.auth.check_in_room_or_world_readable( room_id, requester.user.to_string() ) # This gets the original event and checks that a) the event exists and # b) the user is allowed to view it. - event = yield self.event_handler.get_event(requester.user, room_id, parent_id) + event = await self.event_handler.get_event(requester.user, room_id, parent_id) limit = parse_integer(request, "limit", default=5) from_token = parse_string(request, "from") @@ -167,7 +165,7 @@ class RelationPaginationServlet(RestServlet): if to_token: to_token = RelationPaginationToken.from_string(to_token) - pagination_chunk = yield self.store.get_relations_for_event( + pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, relation_type=relation_type, event_type=event_type, @@ -176,7 +174,7 @@ class RelationPaginationServlet(RestServlet): to_token=to_token, ) - events = yield self.store.get_events_as_list( + events = await self.store.get_events_as_list( [c["event_id"] for c in pagination_chunk.chunk] ) @@ -184,13 +182,13 @@ class RelationPaginationServlet(RestServlet): # We set bundle_aggregations to False when retrieving the original # event because we want the content before relations were applied to # it. - original_event = yield self._event_serializer.serialize_event( + original_event = await self._event_serializer.serialize_event( event, now, bundle_aggregations=False ) # Similarly, we don't allow relations to be applied to relations, so we # return the original relations without any aggregations on top of them # here. - events = yield self._event_serializer.serialize_events( + events = await self._event_serializer.serialize_events( events, now, bundle_aggregations=False ) @@ -232,17 +230,18 @@ class RelationAggregationPaginationServlet(RestServlet): self.store = hs.get_datastore() self.event_handler = hs.get_event_handler() - @defer.inlineCallbacks - def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET( + self, request, room_id, parent_id, relation_type=None, event_type=None + ): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - yield self.auth.check_in_room_or_world_readable( + await self.auth.check_in_room_or_world_readable( room_id, requester.user.to_string() ) # This checks that a) the event exists and b) the user is allowed to # view it. - event = yield self.event_handler.get_event(requester.user, room_id, parent_id) + event = await self.event_handler.get_event(requester.user, room_id, parent_id) if relation_type not in (RelationTypes.ANNOTATION, None): raise SynapseError(400, "Relation type must be 'annotation'") @@ -262,7 +261,7 @@ class RelationAggregationPaginationServlet(RestServlet): if to_token: to_token = AggregationPaginationToken.from_string(to_token) - pagination_chunk = yield self.store.get_aggregation_groups_for_event( + pagination_chunk = await self.store.get_aggregation_groups_for_event( event_id=parent_id, event_type=event_type, limit=limit, @@ -311,17 +310,16 @@ class RelationAggregationGroupPaginationServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() - @defer.inlineCallbacks - def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - yield self.auth.check_in_room_or_world_readable( + await self.auth.check_in_room_or_world_readable( room_id, requester.user.to_string() ) # This checks that a) the event exists and b) the user is allowed to # view it. - yield self.event_handler.get_event(requester.user, room_id, parent_id) + await self.event_handler.get_event(requester.user, room_id, parent_id) if relation_type != RelationTypes.ANNOTATION: raise SynapseError(400, "Relation type must be 'annotation'") @@ -336,7 +334,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): if to_token: to_token = RelationPaginationToken.from_string(to_token) - result = yield self.store.get_relations_for_event( + result = await self.store.get_relations_for_event( event_id=parent_id, relation_type=relation_type, event_type=event_type, @@ -346,12 +344,12 @@ class RelationAggregationGroupPaginationServlet(RestServlet): to_token=to_token, ) - events = yield self.store.get_events_as_list( + events = await self.store.get_events_as_list( [c["event_id"] for c in result.chunk] ) now = self.clock.time_msec() - events = yield self._event_serializer.serialize_events(events, now) + events = await self._event_serializer.serialize_events(events, now) return_value = result.to_dict() return_value["chunk"] = events diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index e7449864cd..f067b5edac 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -18,8 +18,6 @@ import logging from six import string_types from six.moves import http_client -from twisted.internet import defer - from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( RestServlet, @@ -42,9 +40,8 @@ class ReportEventRestServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -63,7 +60,7 @@ class ReportEventRestServlet(RestServlet): Codes.BAD_JSON, ) - yield self.store.add_event_report( + await self.store.add_event_report( room_id=room_id, event_id=event_id, user_id=user_id, diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index d83ac8e3c5..38952a1d27 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, @@ -43,8 +41,7 @@ class RoomKeysServlet(RestServlet): self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - @defer.inlineCallbacks - def on_PUT(self, request, room_id, session_id): + async def on_PUT(self, request, room_id, session_id): """ Uploads one or more encrypted E2E room keys for backup purposes. room_id: the ID of the room the keys are for (optional) @@ -123,7 +120,7 @@ class RoomKeysServlet(RestServlet): } } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() body = parse_json_object_from_request(request) version = parse_string(request, "version") @@ -134,11 +131,10 @@ class RoomKeysServlet(RestServlet): if room_id: body = {"rooms": {room_id: body}} - ret = yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) + ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) return 200, ret - @defer.inlineCallbacks - def on_GET(self, request, room_id, session_id): + async def on_GET(self, request, room_id, session_id): """ Retrieves one or more encrypted E2E room keys for backup purposes. Symmetric with the PUT version of the API. @@ -190,11 +186,11 @@ class RoomKeysServlet(RestServlet): } } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() version = parse_string(request, "version") - room_keys = yield self.e2e_room_keys_handler.get_room_keys( + room_keys = await self.e2e_room_keys_handler.get_room_keys( user_id, version, room_id, session_id ) @@ -220,8 +216,7 @@ class RoomKeysServlet(RestServlet): return 200, room_keys - @defer.inlineCallbacks - def on_DELETE(self, request, room_id, session_id): + async def on_DELETE(self, request, room_id, session_id): """ Deletes one or more encrypted E2E room keys for a user for backup purposes. @@ -235,11 +230,11 @@ class RoomKeysServlet(RestServlet): the version must already have been created via the /change_secret API. """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() version = parse_string(request, "version") - ret = yield self.e2e_room_keys_handler.delete_room_keys( + ret = await self.e2e_room_keys_handler.delete_room_keys( user_id, version, room_id, session_id ) return 200, ret @@ -257,8 +252,7 @@ class RoomKeysNewVersionServlet(RestServlet): self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): """ Create a new backup version for this user's room_keys with the given info. The version is allocated by the server and returned to the user @@ -288,11 +282,11 @@ class RoomKeysNewVersionServlet(RestServlet): "version": 12345 } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() info = parse_json_object_from_request(request) - new_version = yield self.e2e_room_keys_handler.create_version(user_id, info) + new_version = await self.e2e_room_keys_handler.create_version(user_id, info) return 200, {"version": new_version} # we deliberately don't have a PUT /version, as these things really should @@ -311,8 +305,7 @@ class RoomKeysVersionServlet(RestServlet): self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - @defer.inlineCallbacks - def on_GET(self, request, version): + async def on_GET(self, request, version): """ Retrieve the version information about a given version of the user's room_keys backup. If the version part is missing, returns info about the @@ -330,18 +323,17 @@ class RoomKeysVersionServlet(RestServlet): "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() try: - info = yield self.e2e_room_keys_handler.get_version_info(user_id, version) + info = await self.e2e_room_keys_handler.get_version_info(user_id, version) except SynapseError as e: if e.code == 404: raise SynapseError(404, "No backup found", Codes.NOT_FOUND) return 200, info - @defer.inlineCallbacks - def on_DELETE(self, request, version): + async def on_DELETE(self, request, version): """ Delete the information about a given version of the user's room_keys backup. If the version part is missing, deletes the most @@ -354,14 +346,13 @@ class RoomKeysVersionServlet(RestServlet): if version is None: raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND) - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - yield self.e2e_room_keys_handler.delete_version(user_id, version) + await self.e2e_room_keys_handler.delete_version(user_id, version) return 200, {} - @defer.inlineCallbacks - def on_PUT(self, request, version): + async def on_PUT(self, request, version): """ Update the information about a given version of the user's room_keys backup. @@ -382,7 +373,7 @@ class RoomKeysVersionServlet(RestServlet): Content-Type: application/json {} """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() info = parse_json_object_from_request(request) @@ -391,7 +382,7 @@ class RoomKeysVersionServlet(RestServlet): 400, "No version specified to update", Codes.MISSING_PARAM ) - yield self.e2e_room_keys_handler.update_version(user_id, version, info) + await self.e2e_room_keys_handler.update_version(user_id, version, info) return 200, {} diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index d2c3316eb7..ca97330797 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import ( @@ -59,9 +57,8 @@ class RoomUpgradeRestServlet(RestServlet): self._room_creation_handler = hs.get_room_creation_handler() self._auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request, room_id): - requester = yield self._auth.get_user_by_req(request) + async def on_POST(self, request, room_id): + requester = await self._auth.get_user_by_req(request) content = parse_json_object_from_request(request) assert_params_in_dict(content, ("new_version",)) @@ -74,7 +71,7 @@ class RoomUpgradeRestServlet(RestServlet): Codes.UNSUPPORTED_ROOM_VERSION, ) - new_room_id = yield self._room_creation_handler.upgrade_room( + new_room_id = await self._room_creation_handler.upgrade_room( requester, room_id, new_version ) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index d90e52ed1a..501b52fb6c 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http import servlet from synapse.http.servlet import parse_json_object_from_request from synapse.logging.opentracing import set_tag, trace @@ -51,15 +49,14 @@ class SendToDeviceRestServlet(servlet.RestServlet): request, self._put, request, message_type, txn_id ) - @defer.inlineCallbacks - def _put(self, request, message_type, txn_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def _put(self, request, message_type, txn_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) sender_user_id = requester.user.to_string() - yield self.device_message_handler.send_device_message( + await self.device_message_handler.send_device_message( sender_user_id, message_type, content["messages"] ) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index ccd8b17b23..d8292ce29f 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -18,8 +18,6 @@ import logging from canonicaljson import json -from twisted.internet import defer - from synapse.api.constants import PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection @@ -87,8 +85,7 @@ class SyncRestServlet(RestServlet): self._server_notices_sender = hs.get_server_notices_sender() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): if b"from" in request.args: # /events used to use 'from', but /sync uses 'since'. # Lets be helpful and whine if we see a 'from'. @@ -96,7 +93,7 @@ class SyncRestServlet(RestServlet): 400, "'from' is not a valid query parameter. Did you mean 'since'?" ) - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = await self.auth.get_user_by_req(request, allow_guest=True) user = requester.user device_id = requester.device_id @@ -138,7 +135,7 @@ class SyncRestServlet(RestServlet): filter_collection = FilterCollection(filter_object) else: try: - filter_collection = yield self.filtering.get_user_filter( + filter_collection = await self.filtering.get_user_filter( user.localpart, filter_id ) except StoreError as err: @@ -161,20 +158,20 @@ class SyncRestServlet(RestServlet): since_token = None # send any outstanding server notices to the user. - yield self._server_notices_sender.on_user_syncing(user.to_string()) + await self._server_notices_sender.on_user_syncing(user.to_string()) affect_presence = set_presence != PresenceState.OFFLINE if affect_presence: - yield self.presence_handler.set_state( + await self.presence_handler.set_state( user, {"presence": set_presence}, True ) - context = yield self.presence_handler.user_syncing( + context = await self.presence_handler.user_syncing( user.to_string(), affect_presence=affect_presence ) with context: - sync_result = yield self.sync_handler.wait_for_sync_for_user( + sync_result = await self.sync_handler.wait_for_sync_for_user( sync_config, since_token=since_token, timeout=timeout, @@ -182,14 +179,13 @@ class SyncRestServlet(RestServlet): ) time_now = self.clock.time_msec() - response_content = yield self.encode_response( + response_content = await self.encode_response( time_now, sync_result, requester.access_token_id, filter_collection ) return 200, response_content - @defer.inlineCallbacks - def encode_response(self, time_now, sync_result, access_token_id, filter): + async def encode_response(self, time_now, sync_result, access_token_id, filter): if filter.event_format == "client": event_formatter = format_event_for_client_v2_without_room_id elif filter.event_format == "federation": @@ -197,7 +193,7 @@ class SyncRestServlet(RestServlet): else: raise Exception("Unknown event format %s" % (filter.event_format,)) - joined = yield self.encode_joined( + joined = await self.encode_joined( sync_result.joined, time_now, access_token_id, @@ -205,11 +201,11 @@ class SyncRestServlet(RestServlet): event_formatter, ) - invited = yield self.encode_invited( + invited = await self.encode_invited( sync_result.invited, time_now, access_token_id, event_formatter ) - archived = yield self.encode_archived( + archived = await self.encode_archived( sync_result.archived, time_now, access_token_id, @@ -250,8 +246,9 @@ class SyncRestServlet(RestServlet): ] } - @defer.inlineCallbacks - def encode_joined(self, rooms, time_now, token_id, event_fields, event_formatter): + async def encode_joined( + self, rooms, time_now, token_id, event_fields, event_formatter + ): """ Encode the joined rooms in a sync result @@ -272,7 +269,7 @@ class SyncRestServlet(RestServlet): """ joined = {} for room in rooms: - joined[room.room_id] = yield self.encode_room( + joined[room.room_id] = await self.encode_room( room, time_now, token_id, @@ -283,8 +280,7 @@ class SyncRestServlet(RestServlet): return joined - @defer.inlineCallbacks - def encode_invited(self, rooms, time_now, token_id, event_formatter): + async def encode_invited(self, rooms, time_now, token_id, event_formatter): """ Encode the invited rooms in a sync result @@ -304,7 +300,7 @@ class SyncRestServlet(RestServlet): """ invited = {} for room in rooms: - invite = yield self._event_serializer.serialize_event( + invite = await self._event_serializer.serialize_event( room.invite, time_now, token_id=token_id, @@ -319,8 +315,9 @@ class SyncRestServlet(RestServlet): return invited - @defer.inlineCallbacks - def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatter): + async def encode_archived( + self, rooms, time_now, token_id, event_fields, event_formatter + ): """ Encode the archived rooms in a sync result @@ -341,7 +338,7 @@ class SyncRestServlet(RestServlet): """ joined = {} for room in rooms: - joined[room.room_id] = yield self.encode_room( + joined[room.room_id] = await self.encode_room( room, time_now, token_id, @@ -352,8 +349,7 @@ class SyncRestServlet(RestServlet): return joined - @defer.inlineCallbacks - def encode_room( + async def encode_room( self, room, time_now, token_id, joined, only_fields, event_formatter ): """ @@ -401,8 +397,8 @@ class SyncRestServlet(RestServlet): event.room_id, ) - serialized_state = yield serialize(state_events) - serialized_timeline = yield serialize(timeline_events) + serialized_state = await serialize(state_events) + serialized_timeline = await serialize(timeline_events) account_data = room.account_data diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index 3b555669a0..a3f12e8a77 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -37,13 +35,12 @@ class TagListServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_GET(self, request, user_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id, room_id): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get tags for other users.") - tags = yield self.store.get_tags_for_room(user_id, room_id) + tags = await self.store.get_tags_for_room(user_id, room_id) return 200, {"tags": tags} @@ -64,27 +61,25 @@ class TagServlet(RestServlet): self.store = hs.get_datastore() self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def on_PUT(self, request, user_id, room_id, tag): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id, room_id, tag): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") body = parse_json_object_from_request(request) - max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) + max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body) self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, user_id, room_id, tag): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, user_id, room_id, tag): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") - max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) + max_id = await self.store.remove_tag_from_room(user_id, room_id, tag) self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 2e8d672471..23709960ad 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import ThirdPartyEntityKind from synapse.http.servlet import RestServlet @@ -35,11 +33,10 @@ class ThirdPartyProtocolsServlet(RestServlet): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) - protocols = yield self.appservice_handler.get_3pe_protocols() + protocols = await self.appservice_handler.get_3pe_protocols() return 200, protocols @@ -52,11 +49,10 @@ class ThirdPartyProtocolServlet(RestServlet): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) - protocols = yield self.appservice_handler.get_3pe_protocols( + protocols = await self.appservice_handler.get_3pe_protocols( only_protocol=protocol ) if protocol in protocols: @@ -74,14 +70,13 @@ class ThirdPartyUserServlet(RestServlet): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) fields = request.args fields.pop(b"access_token", None) - results = yield self.appservice_handler.query_3pe( + results = await self.appservice_handler.query_3pe( ThirdPartyEntityKind.USER, protocol, fields ) @@ -97,14 +92,13 @@ class ThirdPartyLocationServlet(RestServlet): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) fields = request.args fields.pop(b"access_token", None) - results = yield self.appservice_handler.query_3pe( + results = await self.appservice_handler.query_3pe( ThirdPartyEntityKind.LOCATION, protocol, fields ) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 2da0f55811..83f3b6b70a 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet @@ -32,8 +30,7 @@ class TokenRefreshRestServlet(RestServlet): def __init__(self, hs): super(TokenRefreshRestServlet, self).__init__() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): raise AuthError(403, "tokenrefresh is no longer supported.") diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index 2863affbab..bef91a2d3e 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -38,8 +36,7 @@ class UserDirectorySearchRestServlet(RestServlet): self.auth = hs.get_auth() self.user_directory_handler = hs.get_user_directory_handler() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): """Searches for users in directory Returns: @@ -56,7 +53,7 @@ class UserDirectorySearchRestServlet(RestServlet): ] } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() if not self.hs.config.user_directory_search_enabled: @@ -72,7 +69,7 @@ class UserDirectorySearchRestServlet(RestServlet): except Exception: raise SynapseError(400, "`search_term` is required field") - results = yield self.user_directory_handler.search_users( + results = await self.user_directory_handler.search_users( user_id, search_term, limit ) -- cgit 1.4.1