From 62b1ce85398f52e7d6137e77083294d0c90af459 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Sun, 5 Jul 2020 16:32:02 +0100 Subject: isort 5 compatibility (#7786) The CI appears to use the latest version of isort, which is a problem when isort gets a major version bump. Rather than try to pin the version, I've done the necessary to make isort5 happy with synapse. --- tests/handlers/test_e2e_keys.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'tests/handlers/test_e2e_keys.py') diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 6c1dc72bd1..1acf287ca4 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -14,11 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import mock -import signedjson.key as key -import signedjson.sign as sign +from signedjson import key as key, sign as sign from twisted.internet import defer -- cgit 1.5.1 From b11450dedc59b117ad23426b47f2465c459ea62a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 15 Jul 2020 08:48:58 -0400 Subject: Convert E2E key and room key handlers to async/await. (#7851) --- changelog.d/7851.misc | 1 + synapse/handlers/e2e_keys.py | 147 ++++++-------- synapse/handlers/e2e_room_keys.py | 75 ++++--- tests/handlers/test_e2e_keys.py | 286 ++++++++++++++++----------- tests/handlers/test_e2e_room_keys.py | 373 +++++++++++++++++++++++------------ 5 files changed, 521 insertions(+), 361 deletions(-) create mode 100644 changelog.d/7851.misc (limited to 'tests/handlers/test_e2e_keys.py') diff --git a/changelog.d/7851.misc b/changelog.d/7851.misc new file mode 100644 index 0000000000..e5cf540edf --- /dev/null +++ b/changelog.d/7851.misc @@ -0,0 +1 @@ +Convert E2E keys and room keys handlers to async/await. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index a7e60cbc26..361dd64cd2 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -77,8 +77,7 @@ class E2eKeysHandler(object): ) @trace - @defer.inlineCallbacks - def query_devices(self, query_body, timeout, from_user_id): + async def query_devices(self, query_body, timeout, from_user_id): """ Handle a device key query from a client { @@ -124,7 +123,7 @@ class E2eKeysHandler(object): failures = {} results = {} if local_query: - local_result = yield self.query_local_devices(local_query) + local_result = await self.query_local_devices(local_query) for user_id, keys in local_result.items(): if user_id in local_query: results[user_id] = keys @@ -142,7 +141,7 @@ class E2eKeysHandler(object): ( user_ids_not_in_cache, remote_results, - ) = yield self.store.get_user_devices_from_cache(query_list) + ) = await self.store.get_user_devices_from_cache(query_list) for user_id, devices in remote_results.items(): user_devices = results.setdefault(user_id, {}) for device_id, device in devices.items(): @@ -161,14 +160,13 @@ class E2eKeysHandler(object): r[user_id] = remote_queries[user_id] # Get cached cross-signing keys - cross_signing_keys = yield self.get_cross_signing_keys_from_cache( + cross_signing_keys = await self.get_cross_signing_keys_from_cache( device_keys_query, from_user_id ) # Now fetch any devices that we don't have in our cache @trace - @defer.inlineCallbacks - def do_remote_query(destination): + async def do_remote_query(destination): """This is called when we are querying the device list of a user on a remote homeserver and their device list is not in the device list cache. If we share a room with this user and we're not querying for @@ -192,7 +190,7 @@ class E2eKeysHandler(object): if device_list: continue - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: continue @@ -201,11 +199,11 @@ class E2eKeysHandler(object): # done an initial sync on the device list so we do it now. try: if self._is_master: - user_devices = yield self.device_handler.device_list_updater.user_device_resync( + user_devices = await self.device_handler.device_list_updater.user_device_resync( user_id ) else: - user_devices = yield self._user_device_resync_client( + user_devices = await self._user_device_resync_client( user_id=user_id ) @@ -227,7 +225,7 @@ class E2eKeysHandler(object): destination_query.pop(user_id) try: - remote_result = yield self.federation.query_client_keys( + remote_result = await self.federation.query_client_keys( destination, {"device_keys": destination_query}, timeout=timeout ) @@ -251,7 +249,7 @@ class E2eKeysHandler(object): set_tag("error", True) set_tag("reason", failure) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background(do_remote_query, destination) @@ -267,8 +265,7 @@ class E2eKeysHandler(object): return ret - @defer.inlineCallbacks - def get_cross_signing_keys_from_cache(self, query, from_user_id): + async def get_cross_signing_keys_from_cache(self, query, from_user_id): """Get cross-signing keys for users from the database Args: @@ -289,7 +286,7 @@ class E2eKeysHandler(object): user_ids = list(query) - keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) + keys = await self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) for user_id, user_info in keys.items(): if user_info is None: @@ -315,8 +312,7 @@ class E2eKeysHandler(object): } @trace - @defer.inlineCallbacks - def query_local_devices(self, query): + async def query_local_devices(self, query): """Get E2E device keys for local users Args: @@ -354,7 +350,7 @@ class E2eKeysHandler(object): # make sure that each queried user appears in the result dict result_dict[user_id] = {} - results = yield self.store.get_e2e_device_keys(local_query) + results = await self.store.get_e2e_device_keys(local_query) # Build the result structure for user_id, device_keys in results.items(): @@ -364,16 +360,15 @@ class E2eKeysHandler(object): log_kv(results) return result_dict - @defer.inlineCallbacks - def on_federation_query_client_keys(self, query_body): + async def on_federation_query_client_keys(self, query_body): """ Handle a device key query from a federated server """ device_keys_query = query_body.get("device_keys", {}) - res = yield self.query_local_devices(device_keys_query) + res = await self.query_local_devices(device_keys_query) ret = {"device_keys": res} # add in the cross-signing keys - cross_signing_keys = yield self.get_cross_signing_keys_from_cache( + cross_signing_keys = await self.get_cross_signing_keys_from_cache( device_keys_query, None ) @@ -382,8 +377,7 @@ class E2eKeysHandler(object): return ret @trace - @defer.inlineCallbacks - def claim_one_time_keys(self, query, timeout): + async def claim_one_time_keys(self, query, timeout): local_query = [] remote_queries = {} @@ -399,7 +393,7 @@ class E2eKeysHandler(object): set_tag("local_key_query", local_query) set_tag("remote_key_query", remote_queries) - results = yield self.store.claim_e2e_one_time_keys(local_query) + results = await self.store.claim_e2e_one_time_keys(local_query) json_result = {} failures = {} @@ -411,12 +405,11 @@ class E2eKeysHandler(object): } @trace - @defer.inlineCallbacks - def claim_client_keys(destination): + async def claim_client_keys(destination): set_tag("destination", destination) device_keys = remote_queries[destination] try: - remote_result = yield self.federation.claim_client_keys( + remote_result = await self.federation.claim_client_keys( destination, {"one_time_keys": device_keys}, timeout=timeout ) for user_id, keys in remote_result["one_time_keys"].items(): @@ -429,7 +422,7 @@ class E2eKeysHandler(object): set_tag("error", True) set_tag("reason", failure) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background(claim_client_keys, destination) @@ -454,9 +447,8 @@ class E2eKeysHandler(object): log_kv({"one_time_keys": json_result, "failures": failures}) return {"one_time_keys": json_result, "failures": failures} - @defer.inlineCallbacks @tag_args - def upload_keys_for_user(self, user_id, device_id, keys): + async def upload_keys_for_user(self, user_id, device_id, keys): time_now = self.clock.time_msec() @@ -477,12 +469,12 @@ class E2eKeysHandler(object): } ) # TODO: Sign the JSON with the server key - changed = yield self.store.set_e2e_device_keys( + changed = await self.store.set_e2e_device_keys( user_id, device_id, time_now, device_keys ) if changed: # Only notify about device updates *if* the keys actually changed - yield self.device_handler.notify_device_update(user_id, [device_id]) + await self.device_handler.notify_device_update(user_id, [device_id]) else: log_kv({"message": "Not updating device_keys for user", "user_id": user_id}) one_time_keys = keys.get("one_time_keys", None) @@ -494,7 +486,7 @@ class E2eKeysHandler(object): "device_id": device_id, } ) - yield self._upload_one_time_keys_for_user( + await self._upload_one_time_keys_for_user( user_id, device_id, time_now, one_time_keys ) else: @@ -507,15 +499,14 @@ class E2eKeysHandler(object): # old access_token without an associated device_id. Either way, we # need to double-check the device is registered to avoid ending up with # keys without a corresponding device. - yield self.device_handler.check_device_registered(user_id, device_id) + await self.device_handler.check_device_registered(user_id, device_id) - result = yield self.store.count_e2e_one_time_keys(user_id, device_id) + result = await self.store.count_e2e_one_time_keys(user_id, device_id) set_tag("one_time_key_counts", result) return {"one_time_key_counts": result} - @defer.inlineCallbacks - def _upload_one_time_keys_for_user( + async def _upload_one_time_keys_for_user( self, user_id, device_id, time_now, one_time_keys ): logger.info( @@ -533,7 +524,7 @@ class E2eKeysHandler(object): key_list.append((algorithm, key_id, key_obj)) # First we check if we have already persisted any of the keys. - existing_key_map = yield self.store.get_e2e_one_time_keys( + existing_key_map = await self.store.get_e2e_one_time_keys( user_id, device_id, [k_id for _, k_id, _ in key_list] ) @@ -556,10 +547,9 @@ class E2eKeysHandler(object): ) log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys}) - yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) + await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) - @defer.inlineCallbacks - def upload_signing_keys_for_user(self, user_id, keys): + async def upload_signing_keys_for_user(self, user_id, keys): """Upload signing keys for cross-signing Args: @@ -574,7 +564,7 @@ class E2eKeysHandler(object): _check_cross_signing_key(master_key, user_id, "master") else: - master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") + master_key = await self.store.get_e2e_cross_signing_key(user_id, "master") # if there is no master key, then we can't do anything, because all the # other cross-signing keys need to be signed by the master key @@ -613,10 +603,10 @@ class E2eKeysHandler(object): # if everything checks out, then store the keys and send notifications deviceids = [] if "master_key" in keys: - yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + await self.store.set_e2e_cross_signing_key(user_id, "master", master_key) deviceids.append(master_verify_key.version) if "self_signing_key" in keys: - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user_id, "self_signing", self_signing_key ) try: @@ -626,23 +616,22 @@ class E2eKeysHandler(object): except ValueError: raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM) if "user_signing_key" in keys: - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user_id, "user_signing", user_signing_key ) # the signature stream matches the semantics that we want for # user-signing key updates: only the user themselves is notified of # their own user-signing key updates - yield self.device_handler.notify_user_signature_update(user_id, [user_id]) + await self.device_handler.notify_user_signature_update(user_id, [user_id]) # master key and self-signing key updates match the semantics of device # list updates: all users who share an encrypted room are notified if len(deviceids): - yield self.device_handler.notify_device_update(user_id, deviceids) + await self.device_handler.notify_device_update(user_id, deviceids) return {} - @defer.inlineCallbacks - def upload_signatures_for_device_keys(self, user_id, signatures): + async def upload_signatures_for_device_keys(self, user_id, signatures): """Upload device signatures for cross-signing Args: @@ -667,13 +656,13 @@ class E2eKeysHandler(object): self_signatures = signatures.get(user_id, {}) other_signatures = {k: v for k, v in signatures.items() if k != user_id} - self_signature_list, self_failures = yield self._process_self_signatures( + self_signature_list, self_failures = await self._process_self_signatures( user_id, self_signatures ) signature_list.extend(self_signature_list) failures.update(self_failures) - other_signature_list, other_failures = yield self._process_other_signatures( + other_signature_list, other_failures = await self._process_other_signatures( user_id, other_signatures ) signature_list.extend(other_signature_list) @@ -681,21 +670,20 @@ class E2eKeysHandler(object): # store the signature, and send the appropriate notifications for sync logger.debug("upload signature failures: %r", failures) - yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list) + await self.store.store_e2e_cross_signing_signatures(user_id, signature_list) self_device_ids = [item.target_device_id for item in self_signature_list] if self_device_ids: - yield self.device_handler.notify_device_update(user_id, self_device_ids) + await self.device_handler.notify_device_update(user_id, self_device_ids) signed_users = [item.target_user_id for item in other_signature_list] if signed_users: - yield self.device_handler.notify_user_signature_update( + await self.device_handler.notify_user_signature_update( user_id, signed_users ) return {"failures": failures} - @defer.inlineCallbacks - def _process_self_signatures(self, user_id, signatures): + async def _process_self_signatures(self, user_id, signatures): """Process uploaded signatures of the user's own keys. Signatures of the user's own keys from this API come in two forms: @@ -728,7 +716,7 @@ class E2eKeysHandler(object): _, self_signing_key_id, self_signing_verify_key, - ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing") + ) = await self._get_e2e_cross_signing_verify_key(user_id, "self_signing") # get our master key, since we may have received a signature of it. # We need to fetch it here so that we know what its key ID is, so @@ -738,12 +726,12 @@ class E2eKeysHandler(object): master_key, _, master_verify_key, - ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master") + ) = await self._get_e2e_cross_signing_verify_key(user_id, "master") # fetch our stored devices. This is used to 1. verify # signatures on the master key, and 2. to compare with what # was sent if the device was signed - devices = yield self.store.get_e2e_device_keys([(user_id, None)]) + devices = await self.store.get_e2e_device_keys([(user_id, None)]) if user_id not in devices: raise NotFoundError("No device keys found") @@ -853,8 +841,7 @@ class E2eKeysHandler(object): return master_key_signature_list - @defer.inlineCallbacks - def _process_other_signatures(self, user_id, signatures): + async def _process_other_signatures(self, user_id, signatures): """Process uploaded signatures of other users' keys. These will be the target user's master keys, signed by the uploading user's user-signing key. @@ -882,7 +869,7 @@ class E2eKeysHandler(object): user_signing_key, user_signing_key_id, user_signing_verify_key, - ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing") + ) = await self._get_e2e_cross_signing_verify_key(user_id, "user_signing") except SynapseError as e: failure = _exception_to_failure(e) for user, devicemap in signatures.items(): @@ -905,7 +892,7 @@ class E2eKeysHandler(object): master_key, master_key_id, _, - ) = yield self._get_e2e_cross_signing_verify_key( + ) = await self._get_e2e_cross_signing_verify_key( target_user, "master", user_id ) @@ -958,8 +945,7 @@ class E2eKeysHandler(object): return signature_list, failures - @defer.inlineCallbacks - def _get_e2e_cross_signing_verify_key( + async def _get_e2e_cross_signing_verify_key( self, user_id: str, key_type: str, from_user_id: str = None ): """Fetch locally or remotely query for a cross-signing public key. @@ -983,7 +969,7 @@ class E2eKeysHandler(object): SynapseError: if `user_id` is invalid """ user = UserID.from_string(user_id) - key = yield self.store.get_e2e_cross_signing_key( + key = await self.store.get_e2e_cross_signing_key( user_id, key_type, from_user_id ) @@ -1009,15 +995,14 @@ class E2eKeysHandler(object): key, key_id, verify_key, - ) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type) + ) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type) if key is None: raise NotFoundError("No %s key found for %s" % (key_type, user_id)) return key, key_id, verify_key - @defer.inlineCallbacks - def _retrieve_cross_signing_keys_for_remote_user( + async def _retrieve_cross_signing_keys_for_remote_user( self, user: UserID, desired_key_type: str, ): """Queries cross-signing keys for a remote user and saves them to the database @@ -1035,7 +1020,7 @@ class E2eKeysHandler(object): If the key cannot be retrieved, all values in the tuple will instead be None. """ try: - remote_result = yield self.federation.query_user_devices( + remote_result = await self.federation.query_user_devices( user.domain, user.to_string() ) except Exception as e: @@ -1101,14 +1086,14 @@ class E2eKeysHandler(object): desired_key_id = key_id # At the same time, store this key in the db for subsequent queries - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user.to_string(), key_type, key_content ) # Notify clients that new devices for this user have been discovered if retrieved_device_ids: # XXX is this necessary? - yield self.device_handler.notify_device_update( + await self.device_handler.notify_device_update( user.to_string(), retrieved_device_ids ) @@ -1250,8 +1235,7 @@ class SigningKeyEduUpdater(object): iterable=True, ) - @defer.inlineCallbacks - def incoming_signing_key_update(self, origin, edu_content): + async def incoming_signing_key_update(self, origin, edu_content): """Called on incoming signing key update from federation. Responsible for parsing the EDU and adding to pending updates list. @@ -1268,7 +1252,7 @@ class SigningKeyEduUpdater(object): logger.warning("Got signing key update edu for %r from %r", user_id, origin) return - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: # We don't share any rooms with this user. Ignore update, as we # probably won't get any further updates. @@ -1278,10 +1262,9 @@ class SigningKeyEduUpdater(object): (master_key, self_signing_key) ) - yield self._handle_signing_key_updates(user_id) + await self._handle_signing_key_updates(user_id) - @defer.inlineCallbacks - def _handle_signing_key_updates(self, user_id): + async def _handle_signing_key_updates(self, user_id): """Actually handle pending updates. Args: @@ -1291,7 +1274,7 @@ class SigningKeyEduUpdater(object): device_handler = self.e2e_keys_handler.device_handler device_list_updater = device_handler.device_list_updater - with (yield self._remote_edu_linearizer.queue(user_id)): + with (await self._remote_edu_linearizer.queue(user_id)): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: # This can happen since we batch updates @@ -1302,9 +1285,9 @@ class SigningKeyEduUpdater(object): logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - new_device_ids = yield device_list_updater.process_cross_signing_key_update( + new_device_ids = await device_list_updater.process_cross_signing_key_update( user_id, master_key, self_signing_key, ) device_ids = device_ids + new_device_ids - yield device_handler.notify_device_update(user_id, device_ids) + await device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index f55470a707..0bb983dc28 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import ( Codes, NotFoundError, @@ -50,8 +48,7 @@ class E2eRoomKeysHandler(object): self._upload_linearizer = Linearizer("upload_room_keys_lock") @trace - @defer.inlineCallbacks - def get_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. See EndToEndRoomKeyStore.get_e2e_room_keys for full details. @@ -71,17 +68,17 @@ class E2eRoomKeysHandler(object): # we deliberately take the lock to get keys so that changing the version # works atomically - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): # make sure the backup version exists try: - yield self.store.get_e2e_room_keys_version_info(user_id, version) + await self.store.get_e2e_room_keys_version_info(user_id, version) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown backup version") else: raise - results = yield self.store.get_e2e_room_keys( + results = await self.store.get_e2e_room_keys( user_id, version, room_id, session_id ) @@ -89,8 +86,7 @@ class E2eRoomKeysHandler(object): return results @trace - @defer.inlineCallbacks - def delete_room_keys(self, user_id, version, room_id=None, session_id=None): + async def delete_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. See EndToEndRoomKeyStore.delete_e2e_room_keys for full details. @@ -109,10 +105,10 @@ class E2eRoomKeysHandler(object): """ # lock for consistency with uploading - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): # make sure the backup version exists try: - version_info = yield self.store.get_e2e_room_keys_version_info( + version_info = await self.store.get_e2e_room_keys_version_info( user_id, version ) except StoreError as e: @@ -121,19 +117,18 @@ class E2eRoomKeysHandler(object): else: raise - yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) + await self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) version_etag = version_info["etag"] + 1 - yield self.store.update_e2e_room_keys_version( + await self.store.update_e2e_room_keys_version( user_id, version, None, version_etag ) - count = yield self.store.count_e2e_room_keys(user_id, version) + count = await self.store.count_e2e_room_keys(user_id, version) return {"etag": str(version_etag), "count": count} @trace - @defer.inlineCallbacks - def upload_room_keys(self, user_id, version, room_keys): + async def upload_room_keys(self, user_id, version, room_keys): """Bulk upload a list of room keys into a given backup version, asserting that the given version is the current backup version. room_keys are merged into the current backup as described in RoomKeysServlet.on_PUT(). @@ -169,11 +164,11 @@ class E2eRoomKeysHandler(object): # TODO: Validate the JSON to make sure it has the right keys. # XXX: perhaps we should use a finer grained lock here? - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): # Check that the version we're trying to upload is the current version try: - version_info = yield self.store.get_e2e_room_keys_version_info(user_id) + version_info = await self.store.get_e2e_room_keys_version_info(user_id) except StoreError as e: if e.code == 404: raise NotFoundError("Version '%s' not found" % (version,)) @@ -183,7 +178,7 @@ class E2eRoomKeysHandler(object): if version_info["version"] != version: # Check that the version we're trying to upload actually exists try: - version_info = yield self.store.get_e2e_room_keys_version_info( + version_info = await self.store.get_e2e_room_keys_version_info( user_id, version ) # if we get this far, the version must exist @@ -198,7 +193,7 @@ class E2eRoomKeysHandler(object): # submitted. Then compare them with the submitted keys. If the # key is new, insert it; if the key should be updated, then update # it; otherwise, drop it. - existing_keys = yield self.store.get_e2e_room_keys_multi( + existing_keys = await self.store.get_e2e_room_keys_multi( user_id, version, room_keys["rooms"] ) to_insert = [] # batch the inserts together @@ -227,7 +222,7 @@ class E2eRoomKeysHandler(object): # updates are done one at a time in the DB, so send # updates right away rather than batching them up, # like we do with the inserts - yield self.store.update_e2e_room_key( + await self.store.update_e2e_room_key( user_id, version, room_id, session_id, room_key ) changed = True @@ -246,16 +241,16 @@ class E2eRoomKeysHandler(object): changed = True if len(to_insert): - yield self.store.add_e2e_room_keys(user_id, version, to_insert) + await self.store.add_e2e_room_keys(user_id, version, to_insert) version_etag = version_info["etag"] if changed: version_etag = version_etag + 1 - yield self.store.update_e2e_room_keys_version( + await self.store.update_e2e_room_keys_version( user_id, version, None, version_etag ) - count = yield self.store.count_e2e_room_keys(user_id, version) + count = await self.store.count_e2e_room_keys(user_id, version) return {"etag": str(version_etag), "count": count} @staticmethod @@ -291,8 +286,7 @@ class E2eRoomKeysHandler(object): return True @trace - @defer.inlineCallbacks - def create_version(self, user_id, version_info): + async def create_version(self, user_id, version_info): """Create a new backup version. This automatically becomes the new backup version for the user's keys; previous backups will no longer be writeable to. @@ -313,14 +307,13 @@ class E2eRoomKeysHandler(object): # TODO: Validate the JSON to make sure it has the right keys. # lock everyone out until we've switched version - with (yield self._upload_linearizer.queue(user_id)): - new_version = yield self.store.create_e2e_room_keys_version( + with (await self._upload_linearizer.queue(user_id)): + new_version = await self.store.create_e2e_room_keys_version( user_id, version_info ) return new_version - @defer.inlineCallbacks - def get_version_info(self, user_id, version=None): + async def get_version_info(self, user_id, version=None): """Get the info about a given version of the user's backup Args: @@ -339,22 +332,21 @@ class E2eRoomKeysHandler(object): } """ - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): try: - res = yield self.store.get_e2e_room_keys_version_info(user_id, version) + res = await self.store.get_e2e_room_keys_version_info(user_id, version) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown backup version") else: raise - res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"]) + res["count"] = await self.store.count_e2e_room_keys(user_id, res["version"]) res["etag"] = str(res["etag"]) return res @trace - @defer.inlineCallbacks - def delete_version(self, user_id, version=None): + async def delete_version(self, user_id, version=None): """Deletes a given version of the user's e2e_room_keys backup Args: @@ -364,9 +356,9 @@ class E2eRoomKeysHandler(object): NotFoundError: if this backup version doesn't exist """ - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): try: - yield self.store.delete_e2e_room_keys_version(user_id, version) + await self.store.delete_e2e_room_keys_version(user_id, version) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown backup version") @@ -374,8 +366,7 @@ class E2eRoomKeysHandler(object): raise @trace - @defer.inlineCallbacks - def update_version(self, user_id, version, version_info): + async def update_version(self, user_id, version, version_info): """Update the info about a given version of the user's backup Args: @@ -393,9 +384,9 @@ class E2eRoomKeysHandler(object): raise SynapseError( 400, "Version in body does not match", Codes.INVALID_PARAM ) - with (yield self._upload_linearizer.queue(user_id)): + with (await self._upload_linearizer.queue(user_id)): try: - old_info = yield self.store.get_e2e_room_keys_version_info( + old_info = await self.store.get_e2e_room_keys_version_info( user_id, version ) except StoreError as e: @@ -406,7 +397,7 @@ class E2eRoomKeysHandler(object): if old_info["algorithm"] != version_info["algorithm"]: raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM) - yield self.store.update_e2e_room_keys_version( + await self.store.update_e2e_room_keys_version( user_id, version, version_info ) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 1acf287ca4..cdd093ffa8 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -46,7 +46,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): """If the user has no devices, we expect an empty list. """ local_user = "@boris:" + self.hs.hostname - res = yield self.handler.query_local_devices({local_user: None}) + res = yield defer.ensureDeferred( + self.handler.query_local_devices({local_user: None}) + ) self.assertDictEqual(res, {local_user: {}}) @defer.inlineCallbacks @@ -60,15 +62,19 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "alg2:k3": {"key": "key3"}, } - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) # we should be able to change the signature without a problem keys["alg2:k2"]["signatures"]["k1"] = "sig2" - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) @@ -84,44 +90,56 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "alg2:k3": {"key": "key3"}, } - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) try: - yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} + ) ) self.fail("No error when changing string key") except errors.SynapseError: pass try: - yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} + ) ) self.fail("No error when replacing dict key with string") except errors.SynapseError: pass try: - yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"one_time_keys": {"alg1:k1": {"key": "key"}}}, + ) ) self.fail("No error when replacing string key with dict") except errors.SynapseError: pass try: - yield self.handler.upload_keys_for_user( - local_user, - device_id, - { - "one_time_keys": { - "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}} - } - }, + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, + device_id, + { + "one_time_keys": { + "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}} + } + }, + ) ) self.fail("No error when replacing dict key") except errors.SynapseError: @@ -133,13 +151,17 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_id = "xyz" keys = {"alg1:k1": "key1"} - res = yield self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + res = yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": keys} + ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}}) - res2 = yield self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + res2 = yield defer.ensureDeferred( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) ) self.assertEqual( res2, @@ -163,7 +185,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield self.handler.upload_signing_keys_for_user(local_user, keys1) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys1) + ) keys2 = { "master_key": { @@ -175,10 +199,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield self.handler.upload_signing_keys_for_user(local_user, keys2) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys2) + ) - devices = yield self.handler.query_devices( - {"device_keys": {local_user: []}}, 0, local_user + devices = yield defer.ensureDeferred( + self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) @@ -215,7 +241,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0", ) - yield self.handler.upload_signing_keys_for_user(local_user, keys1) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys1) + ) # upload two device keys, which will be signed later by the self-signing key device_key_1 = { @@ -245,18 +273,24 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "signatures": {local_user: {"ed25519:def": "base64+signature"}}, } - yield self.handler.upload_keys_for_user( - local_user, "abc", {"device_keys": device_key_1} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, "abc", {"device_keys": device_key_1} + ) ) - yield self.handler.upload_keys_for_user( - local_user, "def", {"device_keys": device_key_2} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, "def", {"device_keys": device_key_2} + ) ) # sign the first device key and upload it del device_key_1["signatures"] sign.sign_json(device_key_1, local_user, signing_key) - yield self.handler.upload_signatures_for_device_keys( - local_user, {local_user: {"abc": device_key_1}} + yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1}} + ) ) # sign the second device key and upload both device keys. The server @@ -264,14 +298,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase): # signature for it del device_key_2["signatures"] sign.sign_json(device_key_2, local_user, signing_key) - yield self.handler.upload_signatures_for_device_keys( - local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} + yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} + ) ) device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" - devices = yield self.handler.query_devices( - {"device_keys": {local_user: []}}, 0, local_user + devices = yield defer.ensureDeferred( + self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) ) del devices["device_keys"][local_user]["abc"]["unsigned"] del devices["device_keys"][local_user]["def"]["unsigned"] @@ -292,7 +328,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield self.handler.upload_signing_keys_for_user(local_user, keys1) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, keys1) + ) res = None try: @@ -305,7 +343,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): res = e.code self.assertEqual(res, 400) - res = yield self.handler.query_local_devices({local_user: None}) + res = yield defer.ensureDeferred( + self.handler.query_local_devices({local_user: None}) + ) self.assertDictEqual(res, {local_user: {}}) @defer.inlineCallbacks @@ -331,8 +371,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA" ) - yield self.handler.upload_keys_for_user( - local_user, device_id, {"device_keys": device_key} + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"device_keys": device_key} + ) ) # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 @@ -372,7 +414,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "user_signing_key": usersigning_key, "self_signing_key": selfsigning_key, } - yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) + ) # set up another user with a master key. This user will be signed by # the first user @@ -384,76 +428,90 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "usage": ["master"], "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, } - yield self.handler.upload_signing_keys_for_user( - other_user, {"master_key": other_master_key} + yield defer.ensureDeferred( + self.handler.upload_signing_keys_for_user( + other_user, {"master_key": other_master_key} + ) ) # test various signature failures (see below) - ret = yield self.handler.upload_signatures_for_device_keys( - local_user, - { - local_user: { - # fails because the signature is invalid - # should fail with INVALID_SIGNATURE - device_id: { - "user_id": local_user, - "device_id": device_id, - "algorithms": [ - "m.olm.curve25519-aes-sha2", - RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, - ], - "keys": { - "curve25519:xyz": "curve25519+key", - # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA - "ed25519:xyz": device_pubkey, - }, - "signatures": { - local_user: {"ed25519:" + selfsigning_pubkey: "something"} + ret = yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: { + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + device_id: { + "user_id": local_user, + "device_id": device_id, + "algorithms": [ + "m.olm.curve25519-aes-sha2", + RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, + ], + "keys": { + "curve25519:xyz": "curve25519+key", + # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA + "ed25519:xyz": device_pubkey, + }, + "signatures": { + local_user: { + "ed25519:" + selfsigning_pubkey: "something" + } + }, }, - }, - # fails because device is unknown - # should fail with NOT_FOUND - "unknown": { - "user_id": local_user, - "device_id": "unknown", - "signatures": { - local_user: {"ed25519:" + selfsigning_pubkey: "something"} + # fails because device is unknown + # should fail with NOT_FOUND + "unknown": { + "user_id": local_user, + "device_id": "unknown", + "signatures": { + local_user: { + "ed25519:" + selfsigning_pubkey: "something" + } + }, }, - }, - # fails because the signature is invalid - # should fail with INVALID_SIGNATURE - master_pubkey: { - "user_id": local_user, - "usage": ["master"], - "keys": {"ed25519:" + master_pubkey: master_pubkey}, - "signatures": { - local_user: {"ed25519:" + device_pubkey: "something"} + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + master_pubkey: { + "user_id": local_user, + "usage": ["master"], + "keys": {"ed25519:" + master_pubkey: master_pubkey}, + "signatures": { + local_user: {"ed25519:" + device_pubkey: "something"} + }, }, }, - }, - other_user: { - # fails because the device is not the user's master-signing key - # should fail with NOT_FOUND - "unknown": { - "user_id": other_user, - "device_id": "unknown", - "signatures": { - local_user: {"ed25519:" + usersigning_pubkey: "something"} + other_user: { + # fails because the device is not the user's master-signing key + # should fail with NOT_FOUND + "unknown": { + "user_id": other_user, + "device_id": "unknown", + "signatures": { + local_user: { + "ed25519:" + usersigning_pubkey: "something" + } + }, }, - }, - other_master_pubkey: { - # fails because the key doesn't match what the server has - # should fail with UNKNOWN - "user_id": other_user, - "usage": ["master"], - "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, - "something": "random", - "signatures": { - local_user: {"ed25519:" + usersigning_pubkey: "something"} + other_master_pubkey: { + # fails because the key doesn't match what the server has + # should fail with UNKNOWN + "user_id": other_user, + "usage": ["master"], + "keys": { + "ed25519:" + other_master_pubkey: other_master_pubkey + }, + "something": "random", + "signatures": { + local_user: { + "ed25519:" + usersigning_pubkey: "something" + } + }, }, }, }, - }, + ) ) user_failures = ret["failures"][local_user] @@ -478,19 +536,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase): sign.sign_json(device_key, local_user, selfsigning_signing_key) sign.sign_json(master_key, local_user, device_signing_key) sign.sign_json(other_master_key, local_user, usersigning_signing_key) - ret = yield self.handler.upload_signatures_for_device_keys( - local_user, - { - local_user: {device_id: device_key, master_pubkey: master_key}, - other_user: {other_master_pubkey: other_master_key}, - }, + ret = yield defer.ensureDeferred( + self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: {device_id: device_key, master_pubkey: master_key}, + other_user: {other_master_pubkey: other_master_key}, + }, + ) ) self.assertEqual(ret["failures"], {}) # fetch the signed keys/devices and make sure that the signatures are there - ret = yield self.handler.query_devices( - {"device_keys": {local_user: [], other_user: []}}, 0, local_user + ret = yield defer.ensureDeferred( + self.handler.query_devices( + {"device_keys": {local_user: [], other_user: []}}, 0, local_user + ) ) self.assertEqual( diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 822ea42dde..3362050ce0 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -66,7 +66,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.get_version_info(self.local_user) + yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -78,7 +78,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.get_version_info(self.local_user, "bogus_version") + yield defer.ensureDeferred( + self.handler.get_version_info(self.local_user, "bogus_version") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -87,14 +89,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_create_version(self): """Check that we can create and then retrieve versions. """ - res = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + res = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(res, "1") # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) version_etag = res["etag"] self.assertIsInstance(version_etag, str) del res["etag"] @@ -109,7 +116,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # check we can retrieve it as a specific version - res = yield self.handler.get_version_info(self.local_user, "1") + res = yield defer.ensureDeferred( + self.handler.get_version_info(self.local_user, "1") + ) self.assertEqual(res["etag"], version_etag) del res["etag"] self.assertDictEqual( @@ -123,17 +132,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # upload a new one... - res = yield self.handler.create_version( - self.local_user, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }, + res = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) ) self.assertEqual(res, "2") # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) del res["etag"] self.assertDictEqual( res, @@ -149,25 +160,32 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_update_version(self): """Check that we can update versions. """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - res = yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": version, - }, + res = yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version, + }, + ) ) self.assertDictEqual(res, {}) # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) del res["etag"] self.assertDictEqual( res, @@ -185,14 +203,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.update_version( - self.local_user, - "1", - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "1", - }, + yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + "1", + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "1", + }, + ) ) except errors.SynapseError as e: res = e.code @@ -202,23 +222,30 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_update_omitted_version(self): """Check that the update succeeds if the version is missing from the body """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - }, + yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + }, + ) ) # check we can retrieve it as the current version - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) del res["etag"] # etag is opaque, so don't test its contents self.assertDictEqual( res, @@ -234,22 +261,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_update_bad_version(self): """Check that we get a 400 if the version in the body doesn't match """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") res = None try: - yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "incorrect", - }, + yield defer.ensureDeferred( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "incorrect", + }, + ) ) except errors.SynapseError as e: res = e.code @@ -261,7 +295,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.delete_version(self.local_user, "1") + yield defer.ensureDeferred( + self.handler.delete_version(self.local_user, "1") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -272,7 +308,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.delete_version(self.local_user) + yield defer.ensureDeferred(self.handler.delete_version(self.local_user)) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -281,19 +317,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_version(self): """Check that we can create and then delete versions. """ - res = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + res = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(res, "1") # check we can delete it - yield self.handler.delete_version(self.local_user, "1") + yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1")) # check that it's gone res = None try: - yield self.handler.get_version_info(self.local_user, "1") + yield defer.ensureDeferred( + self.handler.get_version_info(self.local_user, "1") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -304,7 +347,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.get_room_keys(self.local_user, "bogus_version") + yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, "bogus_version") + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -313,13 +358,20 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_get_missing_room_keys(self): """Check we get an empty response from an empty backup """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertDictEqual(res, {"rooms": {}}) # TODO: test the locking semantics when uploading room_keys, @@ -331,8 +383,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.upload_room_keys( - self.local_user, "no_version", room_keys + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, "no_version", room_keys) ) except errors.SynapseError as e: res = e.code @@ -343,16 +395,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """Check that we get a 404 on uploading keys when an nonexistent version is specified """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") res = None try: - yield self.handler.upload_room_keys( - self.local_user, "bogus_version", room_keys + yield defer.ensureDeferred( + self.handler.upload_room_keys( + self.local_user, "bogus_version", room_keys + ) ) except errors.SynapseError as e: res = e.code @@ -362,24 +421,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_wrong_version(self): """Check that we get a 403 on uploading keys for an old version """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - version = yield self.handler.create_version( - self.local_user, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) ) self.assertEqual(version, "2") res = None try: - yield self.handler.upload_room_keys(self.local_user, "1", room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, "1", room_keys) + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 403) @@ -388,26 +456,39 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_insert(self): """Check that we can insert and retrieve keys for a session """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertDictEqual(res, room_keys) # check getting room_keys for a given room - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org" + ) ) self.assertDictEqual(res, room_keys) # check getting room_keys for a given session_id - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, room_keys) @@ -415,16 +496,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_merge(self): """Check that we can upload a new room_key for an existing session and have it correctly merged""" - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") - yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) # get the etag to compare to future versions - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) backup_etag = res["etag"] self.assertEqual(res["count"], 1) @@ -434,29 +522,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # test that increasing the message_index doesn't replace the existing session new_room_key["first_message_index"] = 2 new_room_key["session_data"] = "new" - yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, new_room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "SSBBTSBBIEZJU0gK", ) # the etag should be the same since the session did not change - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) self.assertEqual(res["etag"], backup_etag) # test that marking the session as verified however /does/ replace it new_room_key["is_verified"] = True - yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, new_room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # the etag should NOT be equal now, since the key changed - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) self.assertNotEqual(res["etag"], backup_etag) backup_etag = res["etag"] @@ -464,15 +560,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # with a lower forwarding count new_room_key["forwarded_count"] = 2 new_room_key["session_data"] = "other" - yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, new_room_keys) + ) - res = yield self.handler.get_room_keys(self.local_user, version) + res = yield defer.ensureDeferred( + self.handler.get_room_keys(self.local_user, version) + ) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # the etag should be the same since the session did not change - res = yield self.handler.get_version_info(self.local_user) + res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) self.assertEqual(res["etag"], backup_etag) # TODO: check edge cases as well as the common variations here @@ -481,36 +581,59 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_room_keys(self): """Check that we can insert and delete keys for a session """ - version = yield self.handler.create_version( - self.local_user, - {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + version = yield defer.ensureDeferred( + self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) ) self.assertEqual(version, "1") # check for bulk-delete - yield self.handler.upload_room_keys(self.local_user, version, room_keys) - yield self.handler.delete_room_keys(self.local_user, version) - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) + yield defer.ensureDeferred( + self.handler.delete_room_keys(self.local_user, version) + ) + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per room - yield self.handler.upload_room_keys(self.local_user, version, room_keys) - yield self.handler.delete_room_keys( - self.local_user, version, room_id="!abc:matrix.org" + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) + yield defer.ensureDeferred( + self.handler.delete_room_keys( + self.local_user, version, room_id="!abc:matrix.org" + ) ) - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per session - yield self.handler.upload_room_keys(self.local_user, version, room_keys) - yield self.handler.delete_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + yield defer.ensureDeferred( + self.handler.upload_room_keys(self.local_user, version, room_keys) + ) + yield defer.ensureDeferred( + self.handler.delete_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) - res = yield self.handler.get_room_keys( - self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + res = yield defer.ensureDeferred( + self.handler.get_room_keys( + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" + ) ) self.assertDictEqual(res, {"rooms": {}}) -- cgit 1.5.1 From 6b3ac3b8cddda9911f42a08a0dcefc4a3386ff51 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 17 Jul 2020 07:09:25 -0400 Subject: Convert device handler to async/await (#7871) --- changelog.d/7871.misc | 1 + synapse/handlers/device.py | 241 +++++++++++++++++----------------------- synapse/util/distributor.py | 28 ++++- tests/handlers/test_device.py | 13 +-- tests/handlers/test_e2e_keys.py | 10 +- tests/test_federation.py | 35 +++--- 6 files changed, 162 insertions(+), 166 deletions(-) create mode 100644 changelog.d/7871.misc (limited to 'tests/handlers/test_e2e_keys.py') diff --git a/changelog.d/7871.misc b/changelog.d/7871.misc new file mode 100644 index 0000000000..4d398a9f3a --- /dev/null +++ b/changelog.d/7871.misc @@ -0,0 +1 @@ +Convert device handler to async/await. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 31346b56c3..f947aa1627 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,9 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, Optional - -from twisted.internet import defer +from typing import Any, Dict, List, Optional from synapse.api import errors from synapse.api.constants import EventTypes @@ -57,21 +55,20 @@ class DeviceWorkerHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() @trace - @defer.inlineCallbacks - def get_devices_by_user(self, user_id): + async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]: """ Retrieve the given user's devices Args: - user_id (str): + user_id: The user ID to query for devices. Returns: - defer.Deferred: list[dict[str, X]]: info on each device + info on each device """ set_tag("user_id", user_id) - device_map = yield self.store.get_devices_by_user(user_id) + device_map = await self.store.get_devices_by_user(user_id) - ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None) + ips = await self.store.get_last_client_ip_by_device(user_id, device_id=None) devices = list(device_map.values()) for device in devices: @@ -81,24 +78,23 @@ class DeviceWorkerHandler(BaseHandler): return devices @trace - @defer.inlineCallbacks - def get_device(self, user_id, device_id): + async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: """ Retrieve the given device Args: - user_id (str): - device_id (str): + user_id: The user to get the device from + device_id: The device to fetch. Returns: - defer.Deferred: dict[str, X]: info on the device + info on the device Raises: errors.NotFoundError: if the device was not found """ try: - device = yield self.store.get_device(user_id, device_id) + device = await self.store.get_device(user_id, device_id) except errors.StoreError: raise errors.NotFoundError - ips = yield self.store.get_last_client_ip_by_device(user_id, device_id) + ips = await self.store.get_last_client_ip_by_device(user_id, device_id) _update_device_from_client_ips(device, ips) set_tag("device", device) @@ -106,10 +102,9 @@ class DeviceWorkerHandler(BaseHandler): return device - @measure_func("device.get_user_ids_changed") @trace - @defer.inlineCallbacks - def get_user_ids_changed(self, user_id, from_token): + @measure_func("device.get_user_ids_changed") + async def get_user_ids_changed(self, user_id, from_token): """Get list of users that have had the devices updated, or have newly joined a room, that `user_id` may be interested in. @@ -120,13 +115,13 @@ class DeviceWorkerHandler(BaseHandler): set_tag("user_id", user_id) set_tag("from_token", from_token) - now_room_key = yield self.store.get_room_events_max_id() + now_room_key = await self.store.get_room_events_max_id() - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) # First we check if any devices have changed for users that we share # rooms with. - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id ) @@ -135,14 +130,14 @@ class DeviceWorkerHandler(BaseHandler): # Always tell the user about their own devices tracked_users.add(user_id) - changed = yield self.store.get_users_whose_devices_changed( + changed = await self.store.get_users_whose_devices_changed( from_token.device_list_key, tracked_users ) # Then work out if any users have since joined rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) - member_events = yield self.store.get_membership_changes_for_user( + member_events = await self.store.get_membership_changes_for_user( user_id, from_token.room_key, now_room_key ) rooms_changed.update(event.room_id for event in member_events) @@ -152,7 +147,7 @@ class DeviceWorkerHandler(BaseHandler): possibly_changed = set(changed) possibly_left = set() for room_id in rooms_changed: - current_state_ids = yield self.store.get_current_state_ids(room_id) + current_state_ids = await self.store.get_current_state_ids(room_id) # The user may have left the room # TODO: Check if they actually did or if we were just invited. @@ -166,7 +161,7 @@ class DeviceWorkerHandler(BaseHandler): # Fetch the current state at the time. try: - event_ids = yield self.store.get_forward_extremeties_for_room( + event_ids = await self.store.get_forward_extremeties_for_room( room_id, stream_ordering=stream_ordering ) except errors.StoreError: @@ -192,7 +187,7 @@ class DeviceWorkerHandler(BaseHandler): continue # mapping from event_id -> state_dict - prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids) + prev_state_ids = await self.state_store.get_state_ids_for_events(event_ids) # Check if we've joined the room? If so we just blindly add all the users to # the "possibly changed" users. @@ -238,11 +233,10 @@ class DeviceWorkerHandler(BaseHandler): return result - @defer.inlineCallbacks - def on_federation_query_user_devices(self, user_id): - stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) - master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") - self_signing_key = yield self.store.get_e2e_cross_signing_key( + async def on_federation_query_user_devices(self, user_id): + stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id) + master_key = await self.store.get_e2e_cross_signing_key(user_id, "master") + self_signing_key = await self.store.get_e2e_cross_signing_key( user_id, "self_signing" ) @@ -271,8 +265,7 @@ class DeviceHandler(DeviceWorkerHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) - @defer.inlineCallbacks - def check_device_registered( + async def check_device_registered( self, user_id, device_id, initial_device_display_name=None ): """ @@ -290,13 +283,13 @@ class DeviceHandler(DeviceWorkerHandler): str: device id (generated if none was supplied) """ if device_id is not None: - new_device = yield self.store.store_device( + new_device = await self.store.store_device( user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, ) if new_device: - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) return device_id # if the device id is not specified, we'll autogen one, but loop a few @@ -304,33 +297,29 @@ class DeviceHandler(DeviceWorkerHandler): attempts = 0 while attempts < 5: device_id = stringutils.random_string(10).upper() - new_device = yield self.store.store_device( + new_device = await self.store.store_device( user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, ) if new_device: - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) return device_id attempts += 1 raise errors.StoreError(500, "Couldn't generate a device ID.") @trace - @defer.inlineCallbacks - def delete_device(self, user_id, device_id): + async def delete_device(self, user_id: str, device_id: str) -> None: """ Delete the given device Args: - user_id (str): - device_id (str): - - Returns: - defer.Deferred: + user_id: The user to delete the device from. + device_id: The device to delete. """ try: - yield self.store.delete_device(user_id, device_id) + await self.store.delete_device(user_id, device_id) except errors.StoreError as e: if e.code == 404: # no match @@ -342,49 +331,40 @@ class DeviceHandler(DeviceWorkerHandler): else: raise - yield defer.ensureDeferred( - self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id - ) + await self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id ) - yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) + await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) @trace - @defer.inlineCallbacks - def delete_all_devices_for_user(self, user_id, except_device_id=None): + async def delete_all_devices_for_user( + self, user_id: str, except_device_id: Optional[str] = None + ) -> None: """Delete all of the user's devices Args: - user_id (str): - except_device_id (str|None): optional device id which should not - be deleted - - Returns: - defer.Deferred: + user_id: The user to remove all devices from + except_device_id: optional device id which should not be deleted """ - device_map = yield self.store.get_devices_by_user(user_id) + device_map = await self.store.get_devices_by_user(user_id) device_ids = list(device_map) if except_device_id is not None: device_ids = [d for d in device_ids if d != except_device_id] - yield self.delete_devices(user_id, device_ids) + await self.delete_devices(user_id, device_ids) - @defer.inlineCallbacks - def delete_devices(self, user_id, device_ids): + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """ Delete several devices Args: - user_id (str): - device_ids (List[str]): The list of device IDs to delete - - Returns: - defer.Deferred: + user_id: The user to delete devices from. + device_ids: The list of device IDs to delete """ try: - yield self.store.delete_devices(user_id, device_ids) + await self.store.delete_devices(user_id, device_ids) except errors.StoreError as e: if e.code == 404: # no match @@ -397,28 +377,22 @@ class DeviceHandler(DeviceWorkerHandler): # Delete access tokens and e2e keys for each device. Not optimised as it is not # considered as part of a critical path. for device_id in device_ids: - yield defer.ensureDeferred( - self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id - ) + await self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id ) - yield self.store.delete_e2e_keys_by_device( + await self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id ) - yield self.notify_device_update(user_id, device_ids) + await self.notify_device_update(user_id, device_ids) - @defer.inlineCallbacks - def update_device(self, user_id, device_id, content): + async def update_device(self, user_id: str, device_id: str, content: dict) -> None: """ Update the given device Args: - user_id (str): - device_id (str): - content (dict): body of update request - - Returns: - defer.Deferred: + user_id: The user to update devices of. + device_id: The device to update. + content: body of update request """ # Reject a new displayname which is too long. @@ -431,10 +405,10 @@ class DeviceHandler(DeviceWorkerHandler): ) try: - yield self.store.update_device( + await self.store.update_device( user_id, device_id, new_display_name=new_display_name ) - yield self.notify_device_update(user_id, [device_id]) + await self.notify_device_update(user_id, [device_id]) except errors.StoreError as e: if e.code == 404: raise errors.NotFoundError() @@ -443,12 +417,11 @@ class DeviceHandler(DeviceWorkerHandler): @trace @measure_func("notify_device_update") - @defer.inlineCallbacks - def notify_device_update(self, user_id, device_ids): + async def notify_device_update(self, user_id, device_ids): """Notify that a user's device(s) has changed. Pokes the notifier, and remote servers if the user is local. """ - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id ) @@ -459,7 +432,7 @@ class DeviceHandler(DeviceWorkerHandler): set_tag("target_hosts", hosts) - position = yield self.store.add_device_change_to_streams( + position = await self.store.add_device_change_to_streams( user_id, device_ids, list(hosts) ) @@ -468,11 +441,11 @@ class DeviceHandler(DeviceWorkerHandler): "Notifying about update %r/%r, ID: %r", user_id, device_id, position ) - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) # specify the user ID too since the user should always get their own device list # updates, even if they aren't in any rooms. - yield self.notifier.on_new_event( + self.notifier.on_new_event( "device_list_key", position, users=[user_id], rooms=room_ids ) @@ -484,29 +457,29 @@ class DeviceHandler(DeviceWorkerHandler): self.federation_sender.send_device_messages(host) log_kv({"message": "sent device update to host", "host": host}) - @defer.inlineCallbacks - def notify_user_signature_update(self, from_user_id, user_ids): + async def notify_user_signature_update( + self, from_user_id: str, user_ids: List[str] + ) -> None: """Notify a user that they have made new signatures of other users. Args: - from_user_id (str): the user who made the signature - user_ids (list[str]): the users IDs that have new signatures + from_user_id: the user who made the signature + user_ids: the users IDs that have new signatures """ - position = yield self.store.add_user_signature_change_to_streams( + position = await self.store.add_user_signature_change_to_streams( from_user_id, user_ids ) self.notifier.on_new_event("device_list_key", position, users=[from_user_id]) - @defer.inlineCallbacks - def user_left_room(self, user, room_id): + async def user_left_room(self, user, room_id): user_id = user.to_string() - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: # We no longer share rooms with this user, so we'll no longer # receive device updates. Mark this in DB. - yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id) + await self.store.mark_remote_user_device_list_as_unsubscribed(user_id) def _update_device_from_client_ips(device, client_ips): @@ -549,8 +522,7 @@ class DeviceListUpdater(object): ) @trace - @defer.inlineCallbacks - def incoming_device_list_update(self, origin, edu_content): + async def incoming_device_list_update(self, origin, edu_content): """Called on incoming device list update from federation. Responsible for parsing the EDU and adding to pending updates list. """ @@ -583,7 +555,7 @@ class DeviceListUpdater(object): ) return - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: # We don't share any rooms with this user. Ignore update, as we # probably won't get any further updates. @@ -608,14 +580,13 @@ class DeviceListUpdater(object): (device_id, stream_id, prev_ids, edu_content) ) - yield self._handle_device_updates(user_id) + await self._handle_device_updates(user_id) @measure_func("_incoming_device_list_update") - @defer.inlineCallbacks - def _handle_device_updates(self, user_id): + async def _handle_device_updates(self, user_id): "Actually handle pending updates." - with (yield self._remote_edu_linearizer.queue(user_id)): + with (await self._remote_edu_linearizer.queue(user_id)): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: # This can happen since we batch updates @@ -632,7 +603,7 @@ class DeviceListUpdater(object): # Given a list of updates we check if we need to resync. This # happens if we've missed updates. - resync = yield self._need_to_do_resync(user_id, pending_updates) + resync = await self._need_to_do_resync(user_id, pending_updates) if logger.isEnabledFor(logging.INFO): logger.info( @@ -643,16 +614,16 @@ class DeviceListUpdater(object): ) if resync: - yield self.user_device_resync(user_id) + await self.user_device_resync(user_id) else: # Simply update the single device, since we know that is the only # change (because of the single prev_id matching the current cache) for device_id, stream_id, prev_ids, content in pending_updates: - yield self.store.update_remote_device_list_cache_entry( + await self.store.update_remote_device_list_cache_entry( user_id, device_id, content, stream_id ) - yield self.device_handler.notify_device_update( + await self.device_handler.notify_device_update( user_id, [device_id for device_id, _, _, _ in pending_updates] ) @@ -660,14 +631,13 @@ class DeviceListUpdater(object): stream_id for _, stream_id, _, _ in pending_updates ) - @defer.inlineCallbacks - def _need_to_do_resync(self, user_id, updates): + async def _need_to_do_resync(self, user_id, updates): """Given a list of updates for a user figure out if we need to do a full resync, or whether we have enough data that we can just apply the delta. """ seen_updates = self._seen_updates.get(user_id, set()) - extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id) + extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id) logger.debug("Current extremity for %r: %r", user_id, extremity) @@ -692,8 +662,7 @@ class DeviceListUpdater(object): return False @trace - @defer.inlineCallbacks - def _maybe_retry_device_resync(self): + async def _maybe_retry_device_resync(self): """Retry to resync device lists that are out of sync, except if another retry is in progress. """ @@ -705,12 +674,12 @@ class DeviceListUpdater(object): # we don't send too many requests. self._resync_retry_in_progress = True # Get all of the users that need resyncing. - need_resync = yield self.store.get_user_ids_requiring_device_list_resync() + need_resync = await self.store.get_user_ids_requiring_device_list_resync() # Iterate over the set of user IDs. for user_id in need_resync: try: # Try to resync the current user's devices list. - result = yield self.user_device_resync( + result = await self.user_device_resync( user_id=user_id, mark_failed_as_stale=False, ) @@ -734,16 +703,17 @@ class DeviceListUpdater(object): # Allow future calls to retry resyncinc out of sync device lists. self._resync_retry_in_progress = False - @defer.inlineCallbacks - def user_device_resync(self, user_id, mark_failed_as_stale=True): + async def user_device_resync( + self, user_id: str, mark_failed_as_stale: bool = True + ) -> Optional[dict]: """Fetches all devices for a user and updates the device cache with them. Args: - user_id (str): The user's id whose device_list will be updated. - mark_failed_as_stale (bool): Whether to mark the user's device list as stale + user_id: The user's id whose device_list will be updated. + mark_failed_as_stale: Whether to mark the user's device list as stale if the attempt to resync failed. Returns: - Deferred[dict]: a dict with device info as under the "devices" in the result of this + A dict with device info as under the "devices" in the result of this request: https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid """ @@ -752,12 +722,12 @@ class DeviceListUpdater(object): # Fetch all devices for the user. origin = get_domain_from_id(user_id) try: - result = yield self.federation.query_user_devices(origin, user_id) + result = await self.federation.query_user_devices(origin, user_id) except NotRetryingDestination: if mark_failed_as_stale: # Mark the remote user's device list as stale so we know we need to retry # it later. - yield self.store.mark_remote_user_device_cache_as_stale(user_id) + await self.store.mark_remote_user_device_cache_as_stale(user_id) return except (RequestSendFailed, HttpResponseException) as e: @@ -768,7 +738,7 @@ class DeviceListUpdater(object): if mark_failed_as_stale: # Mark the remote user's device list as stale so we know we need to retry # it later. - yield self.store.mark_remote_user_device_cache_as_stale(user_id) + await self.store.mark_remote_user_device_cache_as_stale(user_id) # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list @@ -792,7 +762,7 @@ class DeviceListUpdater(object): if mark_failed_as_stale: # Mark the remote user's device list as stale so we know we need to retry # it later. - yield self.store.mark_remote_user_device_cache_as_stale(user_id) + await self.store.mark_remote_user_device_cache_as_stale(user_id) return log_kv({"result": result}) @@ -833,25 +803,24 @@ class DeviceListUpdater(object): stream_id, ) - yield self.store.update_remote_device_list_cache(user_id, devices, stream_id) + await self.store.update_remote_device_list_cache(user_id, devices, stream_id) device_ids = [device["device_id"] for device in devices] # Handle cross-signing keys. - cross_signing_device_ids = yield self.process_cross_signing_key_update( + cross_signing_device_ids = await self.process_cross_signing_key_update( user_id, master_key, self_signing_key, ) device_ids = device_ids + cross_signing_device_ids - yield self.device_handler.notify_device_update(user_id, device_ids) + await self.device_handler.notify_device_update(user_id, device_ids) # We clobber the seen updates since we've re-synced from a given # point. self._seen_updates[user_id] = {stream_id} - defer.returnValue(result) + return result - @defer.inlineCallbacks - def process_cross_signing_key_update( + async def process_cross_signing_key_update( self, user_id: str, master_key: Optional[Dict[str, Any]], @@ -872,14 +841,14 @@ class DeviceListUpdater(object): device_ids = [] if master_key: - yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + await self.store.set_e2e_cross_signing_key(user_id, "master", master_key) _, verify_key = get_verify_key_from_cross_signing_key(master_key) # verify_key is a VerifyKey from signedjson, which uses # .version to denote the portion of the key ID after the # algorithm and colon, which is the device ID device_ids.append(verify_key.version) if self_signing_key: - yield self.store.set_e2e_cross_signing_key( + await self.store.set_e2e_cross_signing_key( user_id, "self_signing", self_signing_key ) _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key) diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index da20523b70..22a857a306 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -12,10 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect import logging from twisted.internet import defer +from twisted.internet.defer import Deferred, fail, succeed +from twisted.python import failure from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -79,6 +81,28 @@ class Distributor(object): run_as_background_process(name, self.signals[name].fire, *args, **kwargs) +def maybeAwaitableDeferred(f, *args, **kw): + """ + Invoke a function that may or may not return a Deferred or an Awaitable. + + This is a modified version of twisted.internet.defer.maybeDeferred. + """ + try: + result = f(*args, **kw) + except Exception: + return fail(failure.Failure(captureVars=Deferred.debug)) + + if isinstance(result, Deferred): + return result + # Handle the additional case of an awaitable being returned. + elif inspect.isawaitable(result): + return defer.ensureDeferred(result) + elif isinstance(result, failure.Failure): + return fail(result) + else: + return succeed(result) + + class Signal(object): """A Signal is a dispatch point that stores a list of callables as observers of it. @@ -122,7 +146,7 @@ class Signal(object): ), ) - return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) + return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb) deferreds = [run_in_background(do, o) for o in self.observers] diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 62b47f6574..6aa322bf3a 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -142,10 +142,8 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.get_success(self.handler.delete_device(user1, "abc")) # check the device was deleted - res = self.handler.get_device(user1, "abc") - self.pump() - self.assertIsInstance( - self.failureResultOf(res).value, synapse.api.errors.NotFoundError + self.get_failure( + self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError ) # we'd like to check the access token was invalidated, but that's a @@ -180,10 +178,9 @@ class DeviceTestCase(unittest.HomeserverTestCase): def test_update_unknown_device(self): update = {"display_name": "new_display"} - res = self.handler.update_device("user_id", "unknown_device_id", update) - self.pump() - self.assertIsInstance( - self.failureResultOf(res).value, synapse.api.errors.NotFoundError + self.get_failure( + self.handler.update_device("user_id", "unknown_device_id", update), + synapse.api.errors.NotFoundError, ) def _record_users(self): diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index cdd093ffa8..210ddcbb88 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -334,10 +334,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase): res = None try: - yield self.hs.get_device_handler().check_device_registered( - user_id=local_user, - device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", - initial_device_display_name="new display name", + yield defer.ensureDeferred( + self.hs.get_device_handler().check_device_registered( + user_id=local_user, + device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", + initial_device_display_name="new display name", + ) ) except errors.SynapseError as e: res = e.code diff --git a/tests/test_federation.py b/tests/test_federation.py index 89dcc58b99..87a16d7d7a 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -173,7 +173,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. store = self.homeserver.get_datastore() - store.get_rooms_for_user = Mock(return_value=["!someroom:test"]) + store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"])) # Manually inject a fake device list update. We need this update to include at # least one prev_id so that the user's device list will need to be retried. @@ -218,23 +218,26 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register mock device list retrieval on the federation client. federation_client = self.homeserver.get_federation_client() federation_client.query_user_devices = Mock( - return_value={ - "user_id": remote_user_id, - "stream_id": 1, - "devices": [], - "master_key": { + return_value=succeed( + { "user_id": remote_user_id, - "usage": ["master"], - "keys": {"ed25519:" + remote_master_key: remote_master_key}, - }, - "self_signing_key": { - "user_id": remote_user_id, - "usage": ["self_signing"], - "keys": { - "ed25519:" + remote_self_signing_key: remote_self_signing_key + "stream_id": 1, + "devices": [], + "master_key": { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, }, - }, - } + "self_signing_key": { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + + remote_self_signing_key: remote_self_signing_key + }, + }, + } + ) ) # Resync the device list. -- cgit 1.5.1