From a6895dd576f96d7fd086fb4128d48ac8a3f098c5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 19 Jul 2022 14:14:30 -0400 Subject: Add type annotations to `trace` decorator. (#13328) Functions that are decorated with `trace` are now properly typed and the type hints for them are fixed. --- synapse/storage/databases/main/devices.py | 2 +- synapse/storage/databases/main/end_to_end_keys.py | 47 ++++++++++++++++++++--- 2 files changed, 42 insertions(+), 7 deletions(-) (limited to 'synapse/storage/databases/main') diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index adde5d0978..7a6ed332aa 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -669,7 +669,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore): @trace async def get_user_devices_from_cache( - self, query_list: List[Tuple[str, str]] + self, query_list: List[Tuple[str, Optional[str]]] ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: """Get the devices (and keys if any) for remote users from the cache. diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 9b293475c8..60f622ad71 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -22,11 +22,14 @@ from typing import ( List, Optional, Tuple, + Union, cast, + overload, ) import attr from canonicaljson import encode_canonical_json +from typing_extensions import Literal from synapse.api.constants import DeviceKeyAlgorithms from synapse.appservice import ( @@ -113,7 +116,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker user_devices = devices[user_id] results = [] for device_id, device in user_devices.items(): - result = {"device_id": device_id} + result: JsonDict = {"device_id": device_id} keys = device.keys if keys: @@ -156,6 +159,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker rv[user_id] = {} for device_id, device_info in device_keys.items(): r = device_info.keys + if r is None: + continue + r["unsigned"] = {} display_name = device_info.display_name if display_name is not None: @@ -164,13 +170,42 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker return rv + @overload + async def get_e2e_device_keys_and_signatures( + self, + query_list: Collection[Tuple[str, Optional[str]]], + include_all_devices: Literal[False] = False, + ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]: + ... + + @overload + async def get_e2e_device_keys_and_signatures( + self, + query_list: Collection[Tuple[str, Optional[str]]], + include_all_devices: bool = False, + include_deleted_devices: Literal[False] = False, + ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]: + ... + + @overload + async def get_e2e_device_keys_and_signatures( + self, + query_list: Collection[Tuple[str, Optional[str]]], + include_all_devices: Literal[True], + include_deleted_devices: Literal[True], + ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: + ... + @trace async def get_e2e_device_keys_and_signatures( self, - query_list: List[Tuple[str, Optional[str]]], + query_list: Collection[Tuple[str, Optional[str]]], include_all_devices: bool = False, include_deleted_devices: bool = False, - ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: + ) -> Union[ + Dict[str, Dict[str, DeviceKeyLookupResult]], + Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]], + ]: """Fetch a list of device keys Any cross-signatures made on the keys by the owner of the device are also @@ -1044,7 +1079,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple db_autocommit = False - row = await self.db_pool.runInteraction( + claim_row = await self.db_pool.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_key, user_id, @@ -1052,11 +1087,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker algorithm, db_autocommit=db_autocommit, ) - if row: + if claim_row: device_results = results.setdefault(user_id, {}).setdefault( device_id, {} ) - device_results[row[0]] = row[1] + device_results[claim_row[0]] = claim_row[1] continue # No one-time key available, so see if there's a fallback -- cgit 1.4.1