From 7f837959ea25ef50b3675c9c2596ef42592dc127 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 7 Aug 2020 13:36:29 -0400 Subject: Convert directory, e2e_room_keys, end_to_end_keys, monthly_active_users database to async (#8042) --- tests/storage/test_end_to_end_keys.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'tests/storage/test_end_to_end_keys.py') diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 398d546280..9f8d30373b 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -34,7 +34,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield self.store.set_e2e_device_keys("user", "device", now, json) - res = yield self.store.get_e2e_device_keys((("user", "device"),)) + res = yield defer.ensureDeferred( + self.store.get_e2e_device_keys((("user", "device"),)) + ) self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] @@ -63,7 +65,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield self.store.set_e2e_device_keys("user", "device", now, json) yield self.store.store_device("user", "device", "display_name") - res = yield self.store.get_e2e_device_keys((("user", "device"),)) + res = yield defer.ensureDeferred( + self.store.get_e2e_device_keys((("user", "device"),)) + ) self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] @@ -85,8 +89,8 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) - res = yield self.store.get_e2e_device_keys( - (("user1", "device1"), ("user2", "device2")) + res = yield defer.ensureDeferred( + self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2"))) ) self.assertIn("user1", res) self.assertIn("device1", res["user1"]) -- cgit 1.5.1 From 5ecc8b58255d7e33ad63a6c931efa6ed5e41ad01 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Aug 2020 10:51:42 -0400 Subject: Convert devices database to async/await. (#8069) --- changelog.d/8069.misc | 1 + synapse/storage/databases/main/devices.py | 333 ++++++++++++++++-------------- tests/handlers/test_typing.py | 2 +- tests/storage/test_devices.py | 44 ++-- tests/storage/test_end_to_end_keys.py | 16 +- 5 files changed, 220 insertions(+), 176 deletions(-) create mode 100644 changelog.d/8069.misc (limited to 'tests/storage/test_end_to_end_keys.py') diff --git a/changelog.d/8069.misc b/changelog.d/8069.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8069.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 7a5f0bab05..2b33060480 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -15,9 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, Optional, Set, Tuple - -from twisted.internet import defer +from typing import Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( @@ -33,14 +31,9 @@ from synapse.storage.database import ( LoggingTransaction, make_tuple_comparison_clause, ) -from synapse.types import Collection, get_verify_key_from_cross_signing_key +from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_encoder -from synapse.util.caches.descriptors import ( - Cache, - cached, - cachedInlineCallbacks, - cachedList, -) +from synapse.util.caches.descriptors import Cache, cached, cachedList from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -54,13 +47,13 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" class DeviceWorkerStore(SQLBaseStore): - def get_device(self, user_id, device_id): + def get_device(self, user_id: str, device_id: str): """Retrieve a device. Only returns devices that are not marked as hidden. Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to retrieve + user_id: The ID of the user which owns the device + device_id: The ID of the device to retrieve Returns: defer.Deferred for a dict containing the device information Raises: @@ -73,19 +66,17 @@ class DeviceWorkerStore(SQLBaseStore): desc="get_device", ) - @defer.inlineCallbacks - def get_devices_by_user(self, user_id): + async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]: """Retrieve all of a user's registered devices. Only returns devices that are not marked as hidden. Args: - user_id (str): + user_id: Returns: - defer.Deferred: resolves to a dict from device_id to a dict - containing "device_id", "user_id" and "display_name" for each - device. + A mapping from device_id to a dict containing "device_id", "user_id" + and "display_name" for each device. """ - devices = yield self.db_pool.simple_select_list( + devices = await self.db_pool.simple_select_list( table="devices", keyvalues={"user_id": user_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -95,19 +86,20 @@ class DeviceWorkerStore(SQLBaseStore): return {d["device_id"]: d for d in devices} @trace - @defer.inlineCallbacks - def get_device_updates_by_remote(self, destination, from_stream_id, limit): + async def get_device_updates_by_remote( + self, destination: str, from_stream_id: int, limit: int + ) -> Tuple[int, List[Tuple[str, dict]]]: """Get a stream of device updates to send to the given remote server. Args: - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive - limit (int): Maximum number of device updates to return + destination: The host the device updates are intended for + from_stream_id: The minimum stream_id to filter updates by, exclusive + limit: Maximum number of device updates to return + Returns: - Deferred[tuple[int, list[tuple[string,dict]]]]: - current stream id (ie, the stream id of the last update included in the - response), and the list of updates, where each update is a pair of EDU - type and EDU contents + A mapping from the current stream id (ie, the stream id of the last + update included in the response), and the list of updates, where + each update is a pair of EDU type and EDU contents. """ now_stream_id = self._device_list_id_gen.get_current_token() @@ -117,7 +109,7 @@ class DeviceWorkerStore(SQLBaseStore): if not has_changed: return now_stream_id, [] - updates = yield self.db_pool.runInteraction( + updates = await self.db_pool.runInteraction( "get_device_updates_by_remote", self._get_device_updates_by_remote_txn, destination, @@ -136,9 +128,7 @@ class DeviceWorkerStore(SQLBaseStore): master_key_by_user = {} self_signing_key_by_user = {} for user in users: - cross_signing_key = yield defer.ensureDeferred( - self.get_e2e_cross_signing_key(user, "master") - ) + cross_signing_key = await self.get_e2e_cross_signing_key(user, "master") if cross_signing_key: key_id, verify_key = get_verify_key_from_cross_signing_key( cross_signing_key @@ -151,8 +141,8 @@ class DeviceWorkerStore(SQLBaseStore): "device_id": verify_key.version, } - cross_signing_key = yield defer.ensureDeferred( - self.get_e2e_cross_signing_key(user, "self_signing") + cross_signing_key = await self.get_e2e_cross_signing_key( + user, "self_signing" ) if cross_signing_key: key_id, verify_key = get_verify_key_from_cross_signing_key( @@ -202,7 +192,7 @@ class DeviceWorkerStore(SQLBaseStore): if update_stream_id > previous_update_stream_id: query_map[key] = (update_stream_id, update_context) - results = yield self._get_device_update_edus_by_remote( + results = await self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) @@ -215,16 +205,21 @@ class DeviceWorkerStore(SQLBaseStore): return now_stream_id, results def _get_device_updates_by_remote_txn( - self, txn, destination, from_stream_id, now_stream_id, limit + self, + txn: LoggingTransaction, + destination: str, + from_stream_id: int, + now_stream_id: int, + limit: int, ): """Return device update information for a given remote destination Args: - txn (LoggingTransaction): The transaction to execute - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive - now_stream_id (int): The maximum stream_id to filter updates by, inclusive - limit (int): Maximum number of device updates to return + txn: The transaction to execute + destination: The host the device updates are intended for + from_stream_id: The minimum stream_id to filter updates by, exclusive + now_stream_id: The maximum stream_id to filter updates by, inclusive + limit: Maximum number of device updates to return Returns: List: List of device updates @@ -240,23 +235,26 @@ class DeviceWorkerStore(SQLBaseStore): return list(txn) - @defer.inlineCallbacks - def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map): + async def _get_device_update_edus_by_remote( + self, + destination: str, + from_stream_id: int, + query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]], + ) -> List[Tuple[str, dict]]: """Returns a list of device update EDUs as well as E2EE keys Args: - destination (str): The host the device updates are intended for - from_stream_id (int): The minimum stream_id to filter updates by, exclusive + destination: The host the device updates are intended for + from_stream_id: The minimum stream_id to filter updates by, exclusive query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping user_id/device_id to update stream_id and the relevant json-encoded opentracing context Returns: - List[Dict]: List of objects representing an device update EDU - + List of objects representing an device update EDU """ devices = ( - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_get_e2e_device_keys_txn", self._get_e2e_device_keys_txn, query_map.keys(), @@ -271,7 +269,7 @@ class DeviceWorkerStore(SQLBaseStore): for user_id, user_devices in devices.items(): # The prev_id for the first row is always the last row before # `from_stream_id` - prev_id = yield self._get_last_device_update_for_remote_user( + prev_id = await self._get_last_device_update_for_remote_user( destination, user_id, from_stream_id ) @@ -315,7 +313,7 @@ class DeviceWorkerStore(SQLBaseStore): return results def _get_last_device_update_for_remote_user( - self, destination, user_id, from_stream_id + self, destination: str, user_id: str, from_stream_id: int ): def f(txn): prev_sent_id_sql = """ @@ -329,7 +327,7 @@ class DeviceWorkerStore(SQLBaseStore): return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f) - def mark_as_sent_devices_by_remote(self, destination, stream_id): + def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int): """Mark that updates have successfully been sent to the destination. """ return self.db_pool.runInteraction( @@ -339,7 +337,9 @@ class DeviceWorkerStore(SQLBaseStore): stream_id, ) - def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): + def _mark_as_sent_devices_by_remote_txn( + self, txn: LoggingTransaction, destination: str, stream_id: int + ) -> None: # We update the device_lists_outbound_last_success with the successfully # poked users. sql = """ @@ -367,17 +367,21 @@ class DeviceWorkerStore(SQLBaseStore): """ txn.execute(sql, (destination, stream_id)) - @defer.inlineCallbacks - def add_user_signature_change_to_streams(self, from_user_id, user_ids): + async def add_user_signature_change_to_streams( + self, from_user_id: str, user_ids: List[str] + ) -> int: """Persist that a user has made new signatures Args: - from_user_id (str): the user who made the signatures - user_ids (list[str]): the users who were signed + from_user_id: the user who made the signatures + user_ids: the users who were signed + + Returns: + THe new stream ID. """ with self._device_list_id_gen.get_next() as stream_id: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, from_user_id, @@ -386,7 +390,13 @@ class DeviceWorkerStore(SQLBaseStore): ) return stream_id - def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id): + def _add_user_signature_change_txn( + self, + txn: LoggingTransaction, + from_user_id: str, + user_ids: List[str], + stream_id: int, + ) -> None: txn.call_after( self._user_signature_stream_cache.entity_has_changed, from_user_id, @@ -402,29 +412,30 @@ class DeviceWorkerStore(SQLBaseStore): }, ) - def get_device_stream_token(self): + def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() @trace - @defer.inlineCallbacks - def get_user_devices_from_cache(self, query_list): + async def get_user_devices_from_cache( + self, query_list: List[Tuple[str, str]] + ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: """Get the devices (and keys if any) for remote users from the cache. Args: - query_list(list): List of (user_id, device_ids), if device_ids is + query_list: List of (user_id, device_ids), if device_ids is falsey then return all device ids for that user. Returns: - (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is - a set of user_ids and results_map is a mapping of - user_id -> device_id -> device_info + A tuple of (user_ids_not_in_cache, results_map), where + user_ids_not_in_cache is a set of user_ids and results_map is a + mapping of user_id -> device_id -> device_info. """ user_ids = {user_id for user_id, _ in query_list} - user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) + user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids)) # We go and check if any of the users need to have their device lists # resynced. If they do then we remove them from the cached list. - users_needing_resync = yield self.get_user_ids_requiring_device_list_resync( + users_needing_resync = await self.get_user_ids_requiring_device_list_resync( user_ids ) user_ids_in_cache = { @@ -438,19 +449,19 @@ class DeviceWorkerStore(SQLBaseStore): continue if device_id: - device = yield self._get_cached_user_device(user_id, device_id) + device = await self._get_cached_user_device(user_id, device_id) results.setdefault(user_id, {})[device_id] = device else: - results[user_id] = yield self.get_cached_devices_for_user(user_id) + results[user_id] = await self.get_cached_devices_for_user(user_id) set_tag("in_cache", results) set_tag("not_in_cache", user_ids_not_in_cache) return user_ids_not_in_cache, results - @cachedInlineCallbacks(num_args=2, tree=True) - def _get_cached_user_device(self, user_id, device_id): - content = yield self.db_pool.simple_select_one_onecol( + @cached(num_args=2, tree=True) + async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict: + content = await self.db_pool.simple_select_one_onecol( table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="content", @@ -458,9 +469,9 @@ class DeviceWorkerStore(SQLBaseStore): ) return db_to_json(content) - @cachedInlineCallbacks() - def get_cached_devices_for_user(self, user_id): - devices = yield self.db_pool.simple_select_list( + @cached() + async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]: + devices = await self.db_pool.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, retcols=("device_id", "content"), @@ -470,11 +481,11 @@ class DeviceWorkerStore(SQLBaseStore): device["device_id"]: db_to_json(device["content"]) for device in devices } - def get_devices_with_keys_by_user(self, user_id): + def get_devices_with_keys_by_user(self, user_id: str): """Get all devices (with any device keys) for a user Returns: - (stream_id, devices) + Deferred which resolves to (stream_id, devices) """ return self.db_pool.runInteraction( "get_devices_with_keys_by_user", @@ -482,7 +493,9 @@ class DeviceWorkerStore(SQLBaseStore): user_id, ) - def _get_devices_with_keys_by_user_txn(self, txn, user_id): + def _get_devices_with_keys_by_user_txn( + self, txn: LoggingTransaction, user_id: str + ) -> Tuple[int, List[JsonDict]]: now_stream_id = self._device_list_id_gen.get_current_token() devices = self._get_e2e_device_keys_txn( @@ -515,17 +528,18 @@ class DeviceWorkerStore(SQLBaseStore): return now_stream_id, [] - def get_users_whose_devices_changed(self, from_key, user_ids): + async def get_users_whose_devices_changed( + self, from_key: str, user_ids: Iterable[str] + ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that are in the given list of user_ids. Args: - from_key (str): The device lists stream token - user_ids (Iterable[str]) + from_key: The device lists stream token + user_ids: The user IDs to query for devices. Returns: - Deferred[set[str]]: The set of user_ids whose devices have changed - since `from_key` + The set of user_ids whose devices have changed since `from_key` """ from_key = int(from_key) @@ -536,7 +550,7 @@ class DeviceWorkerStore(SQLBaseStore): ) if not to_check: - return defer.succeed(set()) + return set() def _get_users_whose_devices_changed_txn(txn): changes = set() @@ -556,18 +570,22 @@ class DeviceWorkerStore(SQLBaseStore): return changes - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn ) - @defer.inlineCallbacks - def get_users_whose_signatures_changed(self, user_id, from_key): + async def get_users_whose_signatures_changed( + self, user_id: str, from_key: str + ) -> Set[str]: """Get the users who have new cross-signing signatures made by `user_id` since `from_key`. Args: - user_id (str): the user who made the signatures - from_key (str): The device lists stream token + user_id: the user who made the signatures + from_key: The device lists stream token + + Returns: + A set of user IDs with updated signatures. """ from_key = int(from_key) if self._user_signature_stream_cache.has_entity_changed(user_id, from_key): @@ -575,7 +593,7 @@ class DeviceWorkerStore(SQLBaseStore): SELECT DISTINCT user_ids FROM user_signature_stream WHERE from_user_id = ? AND stream_id > ? """ - rows = yield self.db_pool.execute( + rows = await self.db_pool.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) return {user for row in rows for user in db_to_json(row[0])} @@ -638,7 +656,7 @@ class DeviceWorkerStore(SQLBaseStore): ) @cached(max_entries=10000) - def get_device_list_last_stream_id_for_remote(self, user_id): + def get_device_list_last_stream_id_for_remote(self, user_id: str): """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ @@ -655,7 +673,7 @@ class DeviceWorkerStore(SQLBaseStore): list_name="user_ids", inlineCallbacks=True, ) - def get_device_list_last_stream_id_for_remotes(self, user_ids): + def get_device_list_last_stream_id_for_remotes(self, user_ids: str): rows = yield self.db_pool.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", @@ -669,8 +687,7 @@ class DeviceWorkerStore(SQLBaseStore): return results - @defer.inlineCallbacks - def get_user_ids_requiring_device_list_resync( + async def get_user_ids_requiring_device_list_resync( self, user_ids: Optional[Collection[str]] = None, ) -> Set[str]: """Given a list of remote users return the list of users that we @@ -681,7 +698,7 @@ class DeviceWorkerStore(SQLBaseStore): The IDs of users whose device lists need resync. """ if user_ids: - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="device_lists_remote_resync", column="user_id", iterable=user_ids, @@ -689,7 +706,7 @@ class DeviceWorkerStore(SQLBaseStore): desc="get_user_ids_requiring_device_list_resync_with_iterable", ) else: - rows = yield self.db_pool.simple_select_list( + rows = await self.db_pool.simple_select_list( table="device_lists_remote_resync", keyvalues=None, retcols=("user_id",), @@ -710,7 +727,7 @@ class DeviceWorkerStore(SQLBaseStore): desc="make_remote_user_device_cache_as_stale", ) - def mark_remote_user_device_list_as_unsubscribed(self, user_id): + def mark_remote_user_device_list_as_unsubscribed(self, user_id: str): """Mark that we no longer track device lists for remote user. """ @@ -779,16 +796,15 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): "drop_device_lists_outbound_last_success_non_unique_idx", ) - @defer.inlineCallbacks - def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): + async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): def f(conn): txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") txn.close() - yield self.db_pool.runWithConnection(f) - yield self.db_pool.updates._end_background_update( + await self.db_pool.runWithConnection(f) + await self.db_pool.updates._end_background_update( DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES ) return 1 @@ -868,18 +884,20 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) - @defer.inlineCallbacks - def store_device(self, user_id, device_id, initial_device_display_name): + async def store_device( + self, user_id: str, device_id: str, initial_device_display_name: str + ) -> bool: """Ensure the given device is known; add it to the store if not Args: - user_id (str): id of user associated with the device - device_id (str): id of device - initial_device_display_name (str): initial displayname of the - device. Ignored if device exists. + user_id: id of user associated with the device + device_id: id of device + initial_device_display_name: initial displayname of the device. + Ignored if device exists. + Returns: - defer.Deferred: boolean whether the device was inserted or an - existing device existed with that ID. + Whether the device was inserted or an existing device existed with that ID. + Raises: StoreError: if the device is already in use """ @@ -888,7 +906,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return False try: - inserted = yield self.db_pool.simple_insert( + inserted = await self.db_pool.simple_insert( "devices", values={ "user_id": user_id, @@ -902,7 +920,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not inserted: # if the device already exists, check if it's a real device, or # if the device ID is reserved by something else - hidden = yield self.db_pool.simple_select_one_onecol( + hidden = await self.db_pool.simple_select_one_onecol( "devices", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="hidden", @@ -927,17 +945,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) raise StoreError(500, "Problem storing device.") - @defer.inlineCallbacks - def delete_device(self, user_id, device_id): + async def delete_device(self, user_id: str, device_id: str) -> None: """Delete a device. Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to delete - Returns: - defer.Deferred + user_id: The ID of the user which owns the device + device_id: The ID of the device to delete """ - yield self.db_pool.simple_delete_one( + await self.db_pool.simple_delete_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, desc="delete_device", @@ -945,17 +960,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.device_id_exists_cache.invalidate((user_id, device_id)) - @defer.inlineCallbacks - def delete_devices(self, user_id, device_ids): + async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: """Deletes several devices. Args: - user_id (str): The ID of the user which owns the devices - device_ids (list): The IDs of the devices to delete - Returns: - defer.Deferred + user_id: The ID of the user which owns the devices + device_ids: The IDs of the devices to delete """ - yield self.db_pool.simple_delete_many( + await self.db_pool.simple_delete_many( table="devices", column="device_id", iterable=device_ids, @@ -965,26 +977,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) - def update_device(self, user_id, device_id, new_display_name=None): + async def update_device( + self, user_id: str, device_id: str, new_display_name: Optional[str] = None + ) -> None: """Update a device. Only updates the device if it is not marked as hidden. Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to update - new_display_name (str|None): new displayname for device; None - to leave unchanged + user_id: The ID of the user which owns the device + device_id: The ID of the device to update + new_display_name: new displayname for device; None to leave unchanged Raises: StoreError: if the device is not found - Returns: - defer.Deferred """ updates = {} if new_display_name is not None: updates["display_name"] = new_display_name if not updates: - return defer.succeed(None) - return self.db_pool.simple_update_one( + return None + await self.db_pool.simple_update_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, updatevalues=updates, @@ -992,7 +1003,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) def update_remote_device_list_cache_entry( - self, user_id, device_id, content, stream_id + self, user_id: str, device_id: str, content: JsonDict, stream_id: int ): """Updates a single device in the cache of a remote user's devicelist. @@ -1000,10 +1011,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device list. Args: - user_id (str): User to update device list for - device_id (str): ID of decivice being updated - content (dict): new data on this device - stream_id (int): the version of the device list + user_id: User to update device list for + device_id: ID of decivice being updated + content: new data on this device + stream_id: the version of the device list Returns: Deferred[None] @@ -1018,8 +1029,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) def _update_remote_device_list_cache_entry_txn( - self, txn, user_id, device_id, content, stream_id - ): + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + content: JsonDict, + stream_id: int, + ) -> None: if content.get("deleted"): self.db_pool.simple_delete_txn( txn, @@ -1055,16 +1071,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): lock=False, ) - def update_remote_device_list_cache(self, user_id, devices, stream_id): + def update_remote_device_list_cache( + self, user_id: str, devices: List[dict], stream_id: int + ): """Replace the entire cache of the remote user's devices. Note: assumes that we are the only thread that can be updating this user's device list. Args: - user_id (str): User to update device list for - devices (list[dict]): list of device objects supplied over federation - stream_id (int): the version of the device list + user_id: User to update device list for + devices: list of device objects supplied over federation + stream_id: the version of the device list Returns: Deferred[None] @@ -1077,7 +1095,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): stream_id, ) - def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): + def _update_remote_device_list_cache_txn( + self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int + ): self.db_pool.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) @@ -1118,8 +1138,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id}, ) - @defer.inlineCallbacks - def add_device_change_to_streams(self, user_id, device_ids, hosts): + async def add_device_change_to_streams( + self, user_id: str, device_ids: Collection[str], hosts: List[str] + ): """Persist that a user's devices have been updated, and which hosts (if any) should be poked. """ @@ -1127,7 +1148,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_device_change_to_stream", self._add_device_change_to_stream_txn, user_id, @@ -1142,7 +1163,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): with self._device_list_id_gen.get_next_mult( len(hosts) * len(device_ids) ) as stream_ids: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_device_outbound_poke_to_stream", self._add_device_outbound_poke_to_stream_txn, user_id, @@ -1187,7 +1208,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) def _add_device_outbound_poke_to_stream_txn( - self, txn, user_id, device_ids, hosts, stream_ids, context, + self, + txn: LoggingTransaction, + user_id: str, + device_ids: Collection[str], + hosts: List[str], + stream_ids: List[str], + context: Dict[str, str], ): for host in hosts: txn.call_after( @@ -1219,7 +1246,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ], ) - def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000): + def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000): """Delete old entries out of the device_lists_outbound_pokes to ensure that we don't fill up due to dead servers. diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 64ddd8243d..64afd581bc 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -116,7 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): retry_timings_res ) - self.datastore.get_device_updates_by_remote.return_value = defer.succeed( + self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable( (0, []) ) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index c2539b353a..87ed8f8cd1 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -34,7 +34,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_store_new_device(self): - yield self.store.store_device("user_id", "device_id", "display_name") + yield defer.ensureDeferred( + self.store.store_device("user_id", "device_id", "display_name") + ) res = yield self.store.get_device("user_id", "device_id") self.assertDictContainsSubset( @@ -48,11 +50,17 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_get_devices_by_user(self): - yield self.store.store_device("user_id", "device1", "display_name 1") - yield self.store.store_device("user_id", "device2", "display_name 2") - yield self.store.store_device("user_id2", "device3", "display_name 3") + yield defer.ensureDeferred( + self.store.store_device("user_id", "device1", "display_name 1") + ) + yield defer.ensureDeferred( + self.store.store_device("user_id", "device2", "display_name 2") + ) + yield defer.ensureDeferred( + self.store.store_device("user_id2", "device3", "display_name 3") + ) - res = yield self.store.get_devices_by_user("user_id") + res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id")) self.assertEqual(2, len(res.keys())) self.assertDictContainsSubset( { @@ -76,13 +84,13 @@ class DeviceStoreTestCase(tests.unittest.TestCase): device_ids = ["device_id1", "device_id2"] # Add two device updates with a single stream_id - yield self.store.add_device_change_to_streams( - "user_id", device_ids, ["somehost"] + yield defer.ensureDeferred( + self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) ) # Get all device updates ever meant for this remote - now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( - "somehost", -1, limit=100 + now_stream_id, device_updates = yield defer.ensureDeferred( + self.store.get_device_updates_by_remote("somehost", -1, limit=100) ) # Check original device_ids are contained within these updates @@ -99,19 +107,23 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_update_device(self): - yield self.store.store_device("user_id", "device_id", "display_name 1") + yield defer.ensureDeferred( + self.store.store_device("user_id", "device_id", "display_name 1") + ) res = yield self.store.get_device("user_id", "device_id") self.assertEqual("display_name 1", res["display_name"]) # do a no-op first - yield self.store.update_device("user_id", "device_id") + yield defer.ensureDeferred(self.store.update_device("user_id", "device_id")) res = yield self.store.get_device("user_id", "device_id") self.assertEqual("display_name 1", res["display_name"]) # do the update - yield self.store.update_device( - "user_id", "device_id", new_display_name="display_name 2" + yield defer.ensureDeferred( + self.store.update_device( + "user_id", "device_id", new_display_name="display_name 2" + ) ) # check it worked @@ -121,7 +133,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_update_unknown_device(self): with self.assertRaises(synapse.api.errors.StoreError) as cm: - yield self.store.update_device( - "user_id", "unknown_device_id", new_display_name="display_name 2" + yield defer.ensureDeferred( + self.store.update_device( + "user_id", "unknown_device_id", new_display_name="display_name 2" + ) ) self.assertEqual(404, cm.exception.code) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 9f8d30373b..d57cdffd8b 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -30,7 +30,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): now = 1470174257070 json = {"key": "value"} - yield self.store.store_device("user", "device", None) + yield defer.ensureDeferred(self.store.store_device("user", "device", None)) yield self.store.set_e2e_device_keys("user", "device", now, json) @@ -47,7 +47,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): now = 1470174257070 json = {"key": "value"} - yield self.store.store_device("user", "device", None) + yield defer.ensureDeferred(self.store.store_device("user", "device", None)) changed = yield self.store.set_e2e_device_keys("user", "device", now, json) self.assertTrue(changed) @@ -63,7 +63,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): json = {"key": "value"} yield self.store.set_e2e_device_keys("user", "device", now, json) - yield self.store.store_device("user", "device", "display_name") + yield defer.ensureDeferred( + self.store.store_device("user", "device", "display_name") + ) res = yield defer.ensureDeferred( self.store.get_e2e_device_keys((("user", "device"),)) @@ -79,10 +81,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): def test_multiple_devices(self): now = 1470174257070 - yield self.store.store_device("user1", "device1", None) - yield self.store.store_device("user1", "device2", None) - yield self.store.store_device("user2", "device1", None) - yield self.store.store_device("user2", "device2", None) + yield defer.ensureDeferred(self.store.store_device("user1", "device1", None)) + yield defer.ensureDeferred(self.store.store_device("user1", "device2", None)) + yield defer.ensureDeferred(self.store.store_device("user2", "device1", None)) + yield defer.ensureDeferred(self.store.store_device("user2", "device2", None)) yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) -- cgit 1.5.1 From e00816ad98a1165b67238f9711cb1b0e7135f25f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Aug 2020 17:24:46 -0400 Subject: Do not yield on awaitables in tests. (#8193) --- changelog.d/8193.misc | 1 + tests/api/test_filtering.py | 36 ++++++++----- tests/crypto/test_keyring.py | 4 +- tests/federation/test_complexity.py | 6 +-- tests/handlers/test_typing.py | 4 +- tests/rest/client/v2_alpha/test_filter.py | 8 ++- tests/storage/test_appservice.py | 24 ++++++--- tests/storage/test_background_update.py | 9 ++-- tests/storage/test_end_to_end_keys.py | 32 +++++++++--- tests/storage/test_event_push_actions.py | 78 +++++++++++++++++++--------- tests/storage/test_main.py | 8 +-- tests/storage/test_registration.py | 44 ++++++++++------ tests/storage/test_user_directory.py | 16 ++++-- tests/test_state.py | 86 +++++++++++++++++-------------- tests/test_visibility.py | 5 +- 15 files changed, 230 insertions(+), 131 deletions(-) create mode 100644 changelog.d/8193.misc (limited to 'tests/storage/test_end_to_end_keys.py') diff --git a/changelog.d/8193.misc b/changelog.d/8193.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8193.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 1fab1d6b69..d2d535d23c 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -369,8 +369,10 @@ class FilteringTestCase(unittest.TestCase): @defer.inlineCallbacks def test_filter_presence_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} - filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + filter_id = yield defer.ensureDeferred( + self.datastore.add_user_filter( + user_localpart=user_localpart, user_filter=user_filter_json + ) ) event = MockEvent(sender="@foo:bar", type="m.profile") events = [event] @@ -388,8 +390,10 @@ class FilteringTestCase(unittest.TestCase): def test_filter_presence_no_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} - filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart + "2", user_filter=user_filter_json + filter_id = yield defer.ensureDeferred( + self.datastore.add_user_filter( + user_localpart=user_localpart + "2", user_filter=user_filter_json + ) ) event = MockEvent( event_id="$asdasd:localhost", @@ -410,8 +414,10 @@ class FilteringTestCase(unittest.TestCase): @defer.inlineCallbacks def test_filter_room_state_match(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + filter_id = yield defer.ensureDeferred( + self.datastore.add_user_filter( + user_localpart=user_localpart, user_filter=user_filter_json + ) ) event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") events = [event] @@ -428,8 +434,10 @@ class FilteringTestCase(unittest.TestCase): @defer.inlineCallbacks def test_filter_room_state_no_match(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + filter_id = yield defer.ensureDeferred( + self.datastore.add_user_filter( + user_localpart=user_localpart, user_filter=user_filter_json + ) ) event = MockEvent( sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar" @@ -465,8 +473,10 @@ class FilteringTestCase(unittest.TestCase): def test_add_filter(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield self.filtering.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + filter_id = yield defer.ensureDeferred( + self.filtering.add_user_filter( + user_localpart=user_localpart, user_filter=user_filter_json + ) ) self.assertEquals(filter_id, 0) @@ -485,8 +495,10 @@ class FilteringTestCase(unittest.TestCase): def test_get_filter(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart, user_filter=user_filter_json + filter_id = yield defer.ensureDeferred( + self.datastore.add_user_filter( + user_localpart=user_localpart, user_filter=user_filter_json + ) ) filter = yield defer.ensureDeferred( diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 0d4b05304b..d264653e74 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -190,7 +190,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): # should fail immediately on an unsigned object d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") - self.failureResultOf(d, SynapseError) + self.get_failure(d, SynapseError) # should succeed on a signed object d = _verify_json_for_server(kr, "server9", json1, 500, "test signed") @@ -221,7 +221,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): # should fail immediately on an unsigned object d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") - self.failureResultOf(d, SynapseError) + self.get_failure(d, SynapseError) # should fail on a signed object with a non-zero minimum_valid_until_ms, # as it tries to refetch the keys and fails. diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 9bd515080c..3d880c499d 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -15,8 +15,6 @@ from mock import Mock -from twisted.internet import defer - from synapse.api.errors import Codes, SynapseError from synapse.rest import admin from synapse.rest.client.v1 import login, room @@ -60,7 +58,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): # Artificially raise the complexity store = self.hs.get_datastore() - store.get_current_state_event_counts = lambda x: defer.succeed(500 * 1.23) + store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23) # Get the room complexity again -- make sure it's our artificial value request, channel = self.make_request( @@ -160,7 +158,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): ) # Artificially raise the complexity - self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed( + self.hs.get_datastore().get_current_state_event_counts = lambda x: make_awaitable( 600 ) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 81c1839637..7bf15c4ba9 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -155,7 +155,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable( ([], 0) ) - self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None + self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: make_awaitable( + None + ) self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable( None ) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index e0e9e94fbf..de00350580 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + from synapse.api.errors import Codes from synapse.rest.client.v2_alpha import filter @@ -73,8 +75,10 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) def test_get_filter(self): - filter_id = self.filtering.add_user_filter( - user_localpart="apple", user_filter=self.EXAMPLE_FILTER + filter_id = defer.ensureDeferred( + self.filtering.add_user_filter( + user_localpart="apple", user_filter=self.EXAMPLE_FILTER + ) ) self.reactor.advance(1) filter_id = filter_id.result diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 17fbde284a..cb808d4de4 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -243,7 +243,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): def test_create_appservice_txn_first(self): service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] - txn = yield self.store.create_appservice_txn(service, events) + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events) + ) self.assertEquals(txn.id, 1) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) @@ -255,7 +257,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._set_last_txn(service.id, 9643) # AS is falling behind yield self._insert_txn(service.id, 9644, events) yield self._insert_txn(service.id, 9645, events) - txn = yield self.store.create_appservice_txn(service, events) + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events) + ) self.assertEquals(txn.id, 9646) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) @@ -265,7 +269,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] yield self._set_last_txn(service.id, 9643) - txn = yield self.store.create_appservice_txn(service, events) + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events) + ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) @@ -286,7 +292,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._insert_txn(self.as_list[2]["id"], 10, events) yield self._insert_txn(self.as_list[3]["id"], 9643, events) - txn = yield self.store.create_appservice_txn(service, events) + txn = yield defer.ensureDeferred( + self.store.create_appservice_txn(service, events) + ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) @@ -298,7 +306,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): txn_id = 1 yield self._insert_txn(service.id, txn_id, events) - yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) + yield defer.ensureDeferred( + self.store.complete_appservice_txn(txn_id=txn_id, service=service) + ) res = yield self.db_pool.runQuery( self.engine.convert_param_style( @@ -324,7 +334,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): txn_id = 5 yield self._set_last_txn(service.id, 4) yield self._insert_txn(service.id, txn_id, events) - yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) + yield defer.ensureDeferred( + self.store.complete_appservice_txn(txn_id=txn_id, service=service) + ) res = yield self.db_pool.runQuery( self.engine.convert_param_style( diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 1a1c59256c..02aae1c13d 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -1,7 +1,5 @@ from mock import Mock -from twisted.internet import defer - from synapse.storage.background_updates import BackgroundUpdater from tests import unittest @@ -38,11 +36,10 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): ) # first step: make a bit of progress - @defer.inlineCallbacks - def update(progress, count): - yield self.clock.sleep((count * duration_ms) / 1000) + async def update(progress, count): + await self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} - yield store.db_pool.runInteraction( + await store.db_pool.runInteraction( "update_progress", self.updates._background_update_progress_txn, "test_update", diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index d57cdffd8b..261bf5b08b 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -32,7 +32,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield defer.ensureDeferred(self.store.store_device("user", "device", None)) - yield self.store.set_e2e_device_keys("user", "device", now, json) + yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user", "device", now, json) + ) res = yield defer.ensureDeferred( self.store.get_e2e_device_keys((("user", "device"),)) @@ -49,12 +51,16 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield defer.ensureDeferred(self.store.store_device("user", "device", None)) - changed = yield self.store.set_e2e_device_keys("user", "device", now, json) + changed = yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user", "device", now, json) + ) self.assertTrue(changed) # If we try to upload the same key then we should be told nothing # changed - changed = yield self.store.set_e2e_device_keys("user", "device", now, json) + changed = yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user", "device", now, json) + ) self.assertFalse(changed) @defer.inlineCallbacks @@ -62,7 +68,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): now = 1470174257070 json = {"key": "value"} - yield self.store.set_e2e_device_keys("user", "device", now, json) + yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user", "device", now, json) + ) yield defer.ensureDeferred( self.store.store_device("user", "device", "display_name") ) @@ -86,10 +94,18 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield defer.ensureDeferred(self.store.store_device("user2", "device1", None)) yield defer.ensureDeferred(self.store.store_device("user2", "device2", None)) - yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) - yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) - yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) - yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) + yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) + ) + yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) + ) + yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) + ) + yield defer.ensureDeferred( + self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) + ) res = yield defer.ensureDeferred( self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2"))) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 0e7427e57a..cdfd2634aa 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -60,8 +60,10 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def _assert_counts(noitf_count, highlight_count): - counts = yield self.store.db_pool.runInteraction( - "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 + counts = yield defer.ensureDeferred( + self.store.db_pool.runInteraction( + "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 + ) ) self.assertEquals( counts, @@ -81,25 +83,31 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): event.event_id, {user_id: action} ) ) - yield self.store.db_pool.runInteraction( - "", - self.persist_events_store._set_push_actions_for_event_and_users_txn, - [(event, None)], - [(event, None)], + yield defer.ensureDeferred( + self.store.db_pool.runInteraction( + "", + self.persist_events_store._set_push_actions_for_event_and_users_txn, + [(event, None)], + [(event, None)], + ) ) def _rotate(stream): - return self.store.db_pool.runInteraction( - "", self.store._rotate_notifs_before_txn, stream + return defer.ensureDeferred( + self.store.db_pool.runInteraction( + "", self.store._rotate_notifs_before_txn, stream + ) ) def _mark_read(stream, depth): - return self.store.db_pool.runInteraction( - "", - self.store._remove_old_push_actions_before_txn, - room_id, - user_id, - stream, + return defer.ensureDeferred( + self.store.db_pool.runInteraction( + "", + self.store._remove_old_push_actions_before_txn, + room_id, + user_id, + stream, + ) ) yield _assert_counts(0, 0) @@ -163,16 +171,24 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) # start with the base case where there are no events in the table - r = yield self.store.find_first_stream_ordering_after_ts(11) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(11) + ) self.assertEqual(r, 0) # now with one event yield add_event(2, 10) - r = yield self.store.find_first_stream_ordering_after_ts(9) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(9) + ) self.assertEqual(r, 2) - r = yield self.store.find_first_stream_ordering_after_ts(10) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(10) + ) self.assertEqual(r, 2) - r = yield self.store.find_first_stream_ordering_after_ts(11) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(11) + ) self.assertEqual(r, 3) # add a bunch of dummy events to the events table @@ -185,25 +201,37 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ): yield add_event(stream_ordering, ts) - r = yield self.store.find_first_stream_ordering_after_ts(110) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(110) + ) self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r) # 4 and 5 are both after 120: we want 4 rather than 5 - r = yield self.store.find_first_stream_ordering_after_ts(120) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(120) + ) self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r) - r = yield self.store.find_first_stream_ordering_after_ts(129) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(129) + ) self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r) # check we can get the last event - r = yield self.store.find_first_stream_ordering_after_ts(140) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(140) + ) self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r) # off the end - r = yield self.store.find_first_stream_ordering_after_ts(160) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(160) + ) self.assertEqual(r, 21) # check we can find an event at ordering zero yield add_event(0, 5) - r = yield self.store.find_first_stream_ordering_after_ts(1) + r = yield defer.ensureDeferred( + self.store.find_first_stream_ordering_after_ts(1) + ) self.assertEqual(r, 0) diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index 954338a592..7e7f1286d9 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -34,14 +34,16 @@ class DataStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_users_paginate(self): - yield self.store.register_user(self.user.to_string(), "pass") + yield defer.ensureDeferred( + self.store.register_user(self.user.to_string(), "pass") + ) yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) yield defer.ensureDeferred( self.store.set_profile_displayname(self.user.localpart, self.displayname) ) - users, total = yield self.store.get_users_paginate( - 0, 10, name="bc", guests=False + users, total = yield defer.ensureDeferred( + self.store.get_users_paginate(0, 10, name="bc", guests=False) ) self.assertEquals(1, total) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 70c55cd650..6b582771fe 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -37,7 +37,7 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_register(self): - yield self.store.register_user(self.user_id, self.pwhash) + yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) self.assertEquals( { @@ -58,14 +58,16 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_add_tokens(self): - yield self.store.register_user(self.user_id, self.pwhash) + yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) yield defer.ensureDeferred( self.store.add_access_token_to_user( self.user_id, self.tokens[1], self.device_id, valid_until_ms=None ) ) - result = yield self.store.get_user_by_access_token(self.tokens[1]) + result = yield defer.ensureDeferred( + self.store.get_user_by_access_token(self.tokens[1]) + ) self.assertDictContainsSubset( {"name": self.user_id, "device_id": self.device_id}, result @@ -76,7 +78,7 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_user_delete_access_tokens(self): # add some tokens - yield self.store.register_user(self.user_id, self.pwhash) + yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) yield defer.ensureDeferred( self.store.add_access_token_to_user( self.user_id, self.tokens[0], device_id=None, valid_until_ms=None @@ -89,22 +91,28 @@ class RegistrationStoreTestCase(unittest.TestCase): ) # now delete some - yield self.store.user_delete_access_tokens( - self.user_id, device_id=self.device_id + yield defer.ensureDeferred( + self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id) ) # check they were deleted - user = yield self.store.get_user_by_access_token(self.tokens[1]) + user = yield defer.ensureDeferred( + self.store.get_user_by_access_token(self.tokens[1]) + ) self.assertIsNone(user, "access token was not deleted by device_id") # check the one not associated with the device was not deleted - user = yield self.store.get_user_by_access_token(self.tokens[0]) + user = yield defer.ensureDeferred( + self.store.get_user_by_access_token(self.tokens[0]) + ) self.assertEqual(self.user_id, user["name"]) # now delete the rest - yield self.store.user_delete_access_tokens(self.user_id) + yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id)) - user = yield self.store.get_user_by_access_token(self.tokens[0]) + user = yield defer.ensureDeferred( + self.store.get_user_by_access_token(self.tokens[0]) + ) self.assertIsNone(user, "access token was not deleted without device_id") @defer.inlineCallbacks @@ -112,16 +120,20 @@ class RegistrationStoreTestCase(unittest.TestCase): TEST_USER = "@test:test" SUPPORT_USER = "@support:test" - res = yield self.store.is_support_user(None) + res = yield defer.ensureDeferred(self.store.is_support_user(None)) self.assertFalse(res) - yield self.store.register_user(user_id=TEST_USER, password_hash=None) - res = yield self.store.is_support_user(TEST_USER) + yield defer.ensureDeferred( + self.store.register_user(user_id=TEST_USER, password_hash=None) + ) + res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER)) self.assertFalse(res) - yield self.store.register_user( - user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT + yield defer.ensureDeferred( + self.store.register_user( + user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT + ) ) - res = yield self.store.is_support_user(SUPPORT_USER) + res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER)) self.assertTrue(res) @defer.inlineCallbacks diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index ecfafe68a9..738e912468 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -31,10 +31,18 @@ class UserDirectoryStoreTestCase(unittest.TestCase): # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. - yield self.store.update_profile_in_user_dir(ALICE, "alice", None) - yield self.store.update_profile_in_user_dir(BOB, "bob", None) - yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None) - yield self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)) + yield defer.ensureDeferred( + self.store.update_profile_in_user_dir(ALICE, "alice", None) + ) + yield defer.ensureDeferred( + self.store.update_profile_in_user_dir(BOB, "bob", None) + ) + yield defer.ensureDeferred( + self.store.update_profile_in_user_dir(BOBBY, "bobby", None) + ) + yield defer.ensureDeferred( + self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)) + ) @defer.inlineCallbacks def test_search_user_dir(self): diff --git a/tests/test_state.py b/tests/test_state.py index b5c3667d2a..56ba0fecf5 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -80,16 +80,16 @@ class StateGroupStore(object): self._next_group = 1 - def get_state_groups_ids(self, room_id, event_ids): + async def get_state_groups_ids(self, room_id, event_ids): groups = {} for event_id in event_ids: group = self._event_to_state_group.get(event_id) if group: groups[group] = self._group_to_state[group] - return defer.succeed(groups) + return groups - def store_state_group( + async def store_state_group( self, event_id, room_id, prev_group, delta_ids, current_state_ids ): state_group = self._next_group @@ -97,19 +97,17 @@ class StateGroupStore(object): self._group_to_state[state_group] = dict(current_state_ids) - return defer.succeed(state_group) + return state_group - def get_events(self, event_ids, **kwargs): - return defer.succeed( - { - e_id: self._event_id_to_event[e_id] - for e_id in event_ids - if e_id in self._event_id_to_event - } - ) + async def get_events(self, event_ids, **kwargs): + return { + e_id: self._event_id_to_event[e_id] + for e_id in event_ids + if e_id in self._event_id_to_event + } - def get_state_group_delta(self, name): - return defer.succeed((None, None)) + async def get_state_group_delta(self, name): + return (None, None) def register_events(self, events): for e in events: @@ -121,8 +119,8 @@ class StateGroupStore(object): def register_event_id_state_group(self, event_id, state_group): self._event_to_state_group[event_id] = state_group - def get_room_version_id(self, room_id): - return defer.succeed(RoomVersions.V1.identifier) + async def get_room_version_id(self, room_id): + return RoomVersions.V1.identifier class DictObj(dict): @@ -476,12 +474,14 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - group_name = yield self.store.store_state_group( - prev_event_id, - event.room_id, - None, - None, - {(e.type, e.state_key): e.event_id for e in old_state}, + group_name = yield defer.ensureDeferred( + self.store.store_state_group( + prev_event_id, + event.room_id, + None, + None, + {(e.type, e.state_key): e.event_id for e in old_state}, + ) ) self.store.register_event_id_state_group(prev_event_id, group_name) @@ -508,12 +508,14 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - group_name = yield self.store.store_state_group( - prev_event_id, - event.room_id, - None, - None, - {(e.type, e.state_key): e.event_id for e in old_state}, + group_name = yield defer.ensureDeferred( + self.store.store_state_group( + prev_event_id, + event.room_id, + None, + None, + {(e.type, e.state_key): e.event_id for e in old_state}, + ) ) self.store.register_event_id_state_group(prev_event_id, group_name) @@ -691,21 +693,25 @@ class StateTestCase(unittest.TestCase): def _get_context( self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 ): - sg1 = yield self.store.store_state_group( - prev_event_id_1, - event.room_id, - None, - None, - {(e.type, e.state_key): e.event_id for e in old_state_1}, + sg1 = yield defer.ensureDeferred( + self.store.store_state_group( + prev_event_id_1, + event.room_id, + None, + None, + {(e.type, e.state_key): e.event_id for e in old_state_1}, + ) ) self.store.register_event_id_state_group(prev_event_id_1, sg1) - sg2 = yield self.store.store_state_group( - prev_event_id_2, - event.room_id, - None, - None, - {(e.type, e.state_key): e.event_id for e in old_state_2}, + sg2 = yield defer.ensureDeferred( + self.store.store_state_group( + prev_event_id_2, + event.room_id, + None, + None, + {(e.type, e.state_key): e.event_id for e in old_state_2}, + ) ) self.store.register_event_id_state_group(prev_event_id_2, sg2) diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 531a9b9118..4a4483ba12 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -37,7 +37,6 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.hs = yield setup_test_homeserver(self.addCleanup) self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() - self.store = self.hs.get_datastore() self.storage = self.hs.get_storage() yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) @@ -99,7 +98,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): events_to_filter.append(evt) # the erasey user gets erased - yield self.hs.get_datastore().mark_user_erased("@erased:local_hs") + yield defer.ensureDeferred( + self.hs.get_datastore().mark_user_erased("@erased:local_hs") + ) # ... and the filtering happens. filtered = yield defer.ensureDeferred( -- cgit 1.5.1 From 45e8f7726f24d98a9d3fa06ea52ae960cc1d8689 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Sat, 29 Aug 2020 00:14:17 +0100 Subject: Rename `get_e2e_device_keys` to better reflect its purpose (#8205) ... and to show that it does something slightly different to `_get_e2e_device_keys_txn`. `include_all_devices` and `include_deleted_devices` were never used (and `include_deleted_devices` was broken, since that would cause `None`s in the result which were not handled in the loop below. Add some typing too. --- changelog.d/8205.misc | 1 + synapse/handlers/e2e_keys.py | 4 ++-- synapse/storage/databases/main/end_to_end_keys.py | 20 ++++++-------------- tests/storage/test_end_to_end_keys.py | 8 +++++--- 4 files changed, 14 insertions(+), 19 deletions(-) create mode 100644 changelog.d/8205.misc (limited to 'tests/storage/test_end_to_end_keys.py') diff --git a/changelog.d/8205.misc b/changelog.d/8205.misc new file mode 100644 index 0000000000..fb8fd83278 --- /dev/null +++ b/changelog.d/8205.misc @@ -0,0 +1 @@ + Refactor queries for device keys and cross-signatures. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d8def45e38..dfd1c78549 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -353,7 +353,7 @@ class E2eKeysHandler(object): # make sure that each queried user appears in the result dict result_dict[user_id] = {} - results = await self.store.get_e2e_device_keys(local_query) + results = await self.store.get_e2e_device_keys_for_cs_api(local_query) # Build the result structure for user_id, device_keys in results.items(): @@ -734,7 +734,7 @@ class E2eKeysHandler(object): # fetch our stored devices. This is used to 1. verify # signatures on the master key, and 2. to compare with what # was sent if the device was signed - devices = await self.store.get_e2e_device_keys([(user_id, None)]) + devices = await self.store.get_e2e_device_keys_for_cs_api([(user_id, None)]) if user_id not in devices: raise NotFoundError("No device keys found") diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index af0b85e2c9..50ecddf7fa 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -23,6 +23,7 @@ from twisted.enterprise.adbapi import Connection from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import make_in_list_sql_clause +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -33,17 +34,12 @@ if TYPE_CHECKING: class EndToEndKeyWorkerStore(SQLBaseStore): @trace - async def get_e2e_device_keys( - self, query_list, include_all_devices=False, include_deleted_devices=False - ): - """Fetch a list of device keys. + async def get_e2e_device_keys_for_cs_api( + self, query_list: List[Tuple[str, Optional[str]]] + ) -> Dict[str, Dict[str, JsonDict]]: + """Fetch a list of device keys, formatted suitably for the C/S API. Args: query_list(list): List of pairs of user_ids and device_ids. - include_all_devices (bool): whether to include entries for devices - that don't have device keys - include_deleted_devices (bool): whether to include null entries for - devices which no longer exist (but were in the query_list). - This option only takes effect if include_all_devices is true. Returns: Dict mapping from user-id to dict mapping from device_id to key data. The key data will be a dict in the same format as the @@ -54,11 +50,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return {} results = await self.db_pool.runInteraction( - "get_e2e_device_keys", - self._get_e2e_device_keys_txn, - query_list, - include_all_devices, - include_deleted_devices, + "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list, ) # Build the result structure, un-jsonify the results, and add the diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 261bf5b08b..3fc4bb13b6 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -37,7 +37,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): ) res = yield defer.ensureDeferred( - self.store.get_e2e_device_keys((("user", "device"),)) + self.store.get_e2e_device_keys_for_cs_api((("user", "device"),)) ) self.assertIn("user", res) self.assertIn("device", res["user"]) @@ -76,7 +76,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): ) res = yield defer.ensureDeferred( - self.store.get_e2e_device_keys((("user", "device"),)) + self.store.get_e2e_device_keys_for_cs_api((("user", "device"),)) ) self.assertIn("user", res) self.assertIn("device", res["user"]) @@ -108,7 +108,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): ) res = yield defer.ensureDeferred( - self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2"))) + self.store.get_e2e_device_keys_for_cs_api( + (("user1", "device1"), ("user2", "device2")) + ) ) self.assertIn("user1", res) self.assertIn("device1", res["user1"]) -- cgit 1.5.1