diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 52bb5c9c55..84c28c480e 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
@@ -92,7 +92,11 @@ class E2eKeysHandler:
@trace
async def query_devices(
- self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
+ self,
+ query_body: JsonDict,
+ timeout: int,
+ from_user_id: str,
+ from_device_id: Optional[str],
) -> JsonDict:
"""Handle a device key query from a client
@@ -120,9 +124,7 @@ class E2eKeysHandler:
the number of in-flight queries at a time.
"""
async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
- device_keys_query: Dict[str, Iterable[str]] = query_body.get(
- "device_keys", {}
- )
+ device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {})
# separate users by domain.
# make a map from domain to user_id to device_ids
@@ -392,7 +394,7 @@ class E2eKeysHandler:
@trace
async def query_local_devices(
- self, query: Dict[str, Optional[List[str]]]
+ self, query: Mapping[str, Optional[List[str]]]
) -> Dict[str, Dict[str, dict]]:
"""Get E2E device keys for local users
@@ -461,7 +463,7 @@ class E2eKeysHandler:
@trace
async def claim_one_time_keys(
- self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
+ self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
) -> JsonDict:
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
|