summary refs log tree commit diff
path: root/synapse/handlers/e2e_keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/e2e_keys.py')
-rw-r--r--synapse/handlers/e2e_keys.py40
1 files changed, 19 insertions, 21 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 3972849d4d..d92370859f 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -115,9 +115,9 @@ class E2eKeysHandler:
                 the number of in-flight queries at a time.
         """
         with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
-            device_keys_query = query_body.get(
+            device_keys_query: Dict[str, Iterable[str]] = query_body.get(
                 "device_keys", {}
-            )  # type: Dict[str, Iterable[str]]
+            )
 
             # separate users by domain.
             # make a map from domain to user_id to device_ids
@@ -136,7 +136,7 @@ class E2eKeysHandler:
 
             # First get local devices.
             # A map of destination -> failure response.
-            failures = {}  # type: Dict[str, JsonDict]
+            failures: Dict[str, JsonDict] = {}
             results = {}
             if local_query:
                 local_result = await self.query_local_devices(local_query)
@@ -151,11 +151,9 @@ class E2eKeysHandler:
 
             # Now attempt to get any remote devices from our local cache.
             # A map of destination -> user ID -> device IDs.
-            remote_queries_not_in_cache = (
-                {}
-            )  # type: Dict[str, Dict[str, Iterable[str]]]
+            remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
             if remote_queries:
-                query_list = []  # type: List[Tuple[str, Optional[str]]]
+                query_list: List[Tuple[str, Optional[str]]] = []
                 for user_id, device_ids in remote_queries.items():
                     if device_ids:
                         query_list.extend(
@@ -362,9 +360,9 @@ class E2eKeysHandler:
             A map from user_id -> device_id -> device details
         """
         set_tag("local_query", query)
-        local_query = []  # type: List[Tuple[str, Optional[str]]]
+        local_query: List[Tuple[str, Optional[str]]] = []
 
-        result_dict = {}  # type: Dict[str, Dict[str, dict]]
+        result_dict: Dict[str, Dict[str, dict]] = {}
         for user_id, device_ids in query.items():
             # we use UserID.from_string to catch invalid user ids
             if not self.is_mine(UserID.from_string(user_id)):
@@ -402,9 +400,9 @@ class E2eKeysHandler:
         self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
     ) -> JsonDict:
         """Handle a device key query from a federated server"""
-        device_keys_query = query_body.get(
+        device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
             "device_keys", {}
-        )  # type: Dict[str, Optional[List[str]]]
+        )
         res = await self.query_local_devices(device_keys_query)
         ret = {"device_keys": res}
 
@@ -421,8 +419,8 @@ class E2eKeysHandler:
     async def claim_one_time_keys(
         self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
     ) -> JsonDict:
-        local_query = []  # type: List[Tuple[str, str, str]]
-        remote_queries = {}  # type: Dict[str, Dict[str, Dict[str, str]]]
+        local_query: List[Tuple[str, str, str]] = []
+        remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
 
         for user_id, one_time_keys in query.get("one_time_keys", {}).items():
             # we use UserID.from_string to catch invalid user ids
@@ -439,8 +437,8 @@ class E2eKeysHandler:
         results = await self.store.claim_e2e_one_time_keys(local_query)
 
         # A map of user ID -> device ID -> key ID -> key.
-        json_result = {}  # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
-        failures = {}  # type: Dict[str, JsonDict]
+        json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+        failures: Dict[str, JsonDict] = {}
         for user_id, device_keys in results.items():
             for device_id, keys in device_keys.items():
                 for key_id, json_str in keys.items():
@@ -768,8 +766,8 @@ class E2eKeysHandler:
         Raises:
             SynapseError: if the input is malformed
         """
-        signature_list = []  # type: List[SignatureListItem]
-        failures = {}  # type: Dict[str, Dict[str, JsonDict]]
+        signature_list: List["SignatureListItem"] = []
+        failures: Dict[str, Dict[str, JsonDict]] = {}
         if not signatures:
             return signature_list, failures
 
@@ -930,8 +928,8 @@ class E2eKeysHandler:
         Raises:
             SynapseError: if the input is malformed
         """
-        signature_list = []  # type: List[SignatureListItem]
-        failures = {}  # type: Dict[str, Dict[str, JsonDict]]
+        signature_list: List["SignatureListItem"] = []
+        failures: Dict[str, Dict[str, JsonDict]] = {}
         if not signatures:
             return signature_list, failures
 
@@ -1300,7 +1298,7 @@ class SigningKeyEduUpdater:
         self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
 
         # user_id -> list of updates waiting to be handled.
-        self._pending_updates = {}  # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
+        self._pending_updates: Dict[str, List[Tuple[JsonDict, JsonDict]]] = {}
 
     async def incoming_signing_key_update(
         self, origin: str, edu_content: JsonDict
@@ -1349,7 +1347,7 @@ class SigningKeyEduUpdater:
                 # This can happen since we batch updates
                 return
 
-            device_ids = []  # type: List[str]
+            device_ids: List[str] = []
 
             logger.info("pending updates: %r", pending_updates)