summary refs log tree commit diff
path: root/synapse/handlers/e2e_keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/e2e_keys.py')
-rw-r--r--synapse/handlers/e2e_keys.py32
1 files changed, 17 insertions, 15 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 52bb5c9c55..c938339ddd 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -92,7 +92,11 @@ class E2eKeysHandler:
 
     @trace
     async def query_devices(
-        self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
+        self,
+        query_body: JsonDict,
+        timeout: int,
+        from_user_id: str,
+        from_device_id: Optional[str],
     ) -> JsonDict:
         """Handle a device key query from a client
 
@@ -120,9 +124,7 @@ class E2eKeysHandler:
                 the number of in-flight queries at a time.
         """
         async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
-            device_keys_query: Dict[str, Iterable[str]] = query_body.get(
-                "device_keys", {}
-            )
+            device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {})
 
             # separate users by domain.
             # make a map from domain to user_id to device_ids
@@ -136,8 +138,8 @@ class E2eKeysHandler:
                 else:
                     remote_queries[user_id] = device_ids
 
-            set_tag("local_key_query", local_query)
-            set_tag("remote_key_query", remote_queries)
+            set_tag("local_key_query", str(local_query))
+            set_tag("remote_key_query", str(remote_queries))
 
             # First get local devices.
             # A map of destination -> failure response.
@@ -341,7 +343,7 @@ class E2eKeysHandler:
             failure = _exception_to_failure(e)
             failures[destination] = failure
             set_tag("error", True)
-            set_tag("reason", failure)
+            set_tag("reason", str(failure))
 
         return
 
@@ -392,7 +394,7 @@ class E2eKeysHandler:
 
     @trace
     async def query_local_devices(
-        self, query: Dict[str, Optional[List[str]]]
+        self, query: Mapping[str, Optional[List[str]]]
     ) -> Dict[str, Dict[str, dict]]:
         """Get E2E device keys for local users
 
@@ -403,7 +405,7 @@ class E2eKeysHandler:
         Returns:
             A map from user_id -> device_id -> device details
         """
-        set_tag("local_query", query)
+        set_tag("local_query", str(query))
         local_query: List[Tuple[str, Optional[str]]] = []
 
         result_dict: Dict[str, Dict[str, dict]] = {}
@@ -461,7 +463,7 @@ class E2eKeysHandler:
 
     @trace
     async def claim_one_time_keys(
-        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
+        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
     ) -> JsonDict:
         local_query: List[Tuple[str, str, str]] = []
         remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
@@ -475,8 +477,8 @@ class E2eKeysHandler:
                 domain = get_domain_from_id(user_id)
                 remote_queries.setdefault(domain, {})[user_id] = one_time_keys
 
-        set_tag("local_key_query", local_query)
-        set_tag("remote_key_query", remote_queries)
+        set_tag("local_key_query", str(local_query))
+        set_tag("remote_key_query", str(remote_queries))
 
         results = await self.store.claim_e2e_one_time_keys(local_query)
 
@@ -506,7 +508,7 @@ class E2eKeysHandler:
                 failure = _exception_to_failure(e)
                 failures[destination] = failure
                 set_tag("error", True)
-                set_tag("reason", failure)
+                set_tag("reason", str(failure))
 
         await make_deferred_yieldable(
             defer.gatherResults(
@@ -609,7 +611,7 @@ class E2eKeysHandler:
 
         result = await self.store.count_e2e_one_time_keys(user_id, device_id)
 
-        set_tag("one_time_key_counts", result)
+        set_tag("one_time_key_counts", str(result))
         return {"one_time_key_counts": result}
 
     async def _upload_one_time_keys_for_user(