diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 929752150d..8f3a6b35a4 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -16,7 +16,7 @@
# limitations under the License.
import logging
-from typing import Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
@@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import (
+ JsonDict,
UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
@@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class E2eKeysHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler()
@@ -78,7 +82,9 @@ class E2eKeysHandler:
)
@trace
- async def query_devices(self, query_body, timeout, from_user_id):
+ async def query_devices(
+ self, query_body: JsonDict, timeout: int, from_user_id: str
+ ) -> JsonDict:
""" Handle a device key query from a client
{
@@ -98,12 +104,14 @@ class E2eKeysHandler:
}
Args:
- from_user_id (str): the user making the query. This is used when
+ from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users
can see.
"""
- device_keys_query = query_body.get("device_keys", {})
+ device_keys_query = query_body.get(
+ "device_keys", {}
+ ) # type: Dict[str, Iterable[str]]
# separate users by domain.
# make a map from domain to user_id to device_ids
@@ -121,7 +129,8 @@ class E2eKeysHandler:
set_tag("remote_key_query", remote_queries)
# First get local devices.
- failures = {}
+ # A map of destination -> failure response.
+ failures = {} # type: Dict[str, JsonDict]
results = {}
if local_query:
local_result = await self.query_local_devices(local_query)
@@ -135,9 +144,10 @@ class E2eKeysHandler:
)
# Now attempt to get any remote devices from our local cache.
- remote_queries_not_in_cache = {}
+ # A map of destination -> user ID -> device IDs.
+ remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
if remote_queries:
- query_list = []
+ query_list = [] # type: List[Tuple[str, Optional[str]]]
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
@@ -284,15 +294,15 @@ class E2eKeysHandler:
return ret
async def get_cross_signing_keys_from_cache(
- self, query, from_user_id
+ self, query: Iterable[str], from_user_id: Optional[str]
) -> Dict[str, Dict[str, dict]]:
"""Get cross-signing keys for users from the database
Args:
- query (Iterable[string]) an iterable of user IDs. A dict whose keys
+ query: an iterable of user IDs. A dict whose keys
are user IDs satisfies this, so the query format used for
query_devices can be used here.
- from_user_id (str): the user making the query. This is used when
+ from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users
can see.
@@ -315,14 +325,12 @@ class E2eKeysHandler:
if "self_signing" in user_info:
self_signing_keys[user_id] = user_info["self_signing"]
- if (
- from_user_id in keys
- and keys[from_user_id] is not None
- and "user_signing" in keys[from_user_id]
- ):
- # users can see other users' master and self-signing keys, but can
- # only see their own user-signing keys
- user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
+ # users can see other users' master and self-signing keys, but can
+ # only see their own user-signing keys
+ if from_user_id:
+ from_user_key = keys.get(from_user_id)
+ if from_user_key and "user_signing" in from_user_key:
+ user_signing_keys[from_user_id] = from_user_key["user_signing"]
return {
"master_keys": master_keys,
@@ -344,9 +352,9 @@ class E2eKeysHandler:
A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
- local_query = []
+ local_query = [] # type: List[Tuple[str, Optional[str]]]
- result_dict = {}
+ result_dict = {} # type: Dict[str, Dict[str, dict]]
for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
@@ -380,10 +388,14 @@ class E2eKeysHandler:
log_kv(results)
return result_dict
- async def on_federation_query_client_keys(self, query_body):
+ async def on_federation_query_client_keys(
+ self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
+ ) -> JsonDict:
""" Handle a device key query from a federated server
"""
- device_keys_query = query_body.get("device_keys", {})
+ device_keys_query = query_body.get(
+ "device_keys", {}
+ ) # type: Dict[str, Optional[List[str]]]
res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res}
@@ -397,31 +409,34 @@ class E2eKeysHandler:
return ret
@trace
- async def claim_one_time_keys(self, query, timeout):
- local_query = []
- remote_queries = {}
+ async def claim_one_time_keys(
+ self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
+ ) -> JsonDict:
+ local_query = [] # type: List[Tuple[str, str, str]]
+ remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]]
- for user_id, device_keys in query.get("one_time_keys", {}).items():
+ for user_id, one_time_keys in query.get("one_time_keys", {}).items():
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
- for device_id, algorithm in device_keys.items():
+ for device_id, algorithm in one_time_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
domain = get_domain_from_id(user_id)
- remote_queries.setdefault(domain, {})[user_id] = device_keys
+ remote_queries.setdefault(domain, {})[user_id] = one_time_keys
set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
results = await self.store.claim_e2e_one_time_keys(local_query)
- json_result = {}
- failures = {}
+ # A map of user ID -> device ID -> key ID -> key.
+ json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
+ failures = {} # type: Dict[str, JsonDict]
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
- for key_id, json_bytes in keys.items():
+ for key_id, json_str in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
- key_id: json_decoder.decode(json_bytes)
+ key_id: json_decoder.decode(json_str)
}
@trace
@@ -468,7 +483,9 @@ class E2eKeysHandler:
return {"one_time_keys": json_result, "failures": failures}
@tag_args
- async def upload_keys_for_user(self, user_id, device_id, keys):
+ async def upload_keys_for_user(
+ self, user_id: str, device_id: str, keys: JsonDict
+ ) -> JsonDict:
time_now = self.clock.time_msec()
@@ -543,8 +560,8 @@ class E2eKeysHandler:
return {"one_time_key_counts": result}
async def _upload_one_time_keys_for_user(
- self, user_id, device_id, time_now, one_time_keys
- ):
+ self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
+ ) -> None:
logger.info(
"Adding one_time_keys %r for device %r for user %r at %d",
one_time_keys.keys(),
@@ -585,12 +602,14 @@ class E2eKeysHandler:
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
- async def upload_signing_keys_for_user(self, user_id, keys):
+ async def upload_signing_keys_for_user(
+ self, user_id: str, keys: JsonDict
+ ) -> JsonDict:
"""Upload signing keys for cross-signing
Args:
- user_id (string): the user uploading the keys
- keys (dict[string, dict]): the signing keys
+ user_id: the user uploading the keys
+ keys: the signing keys
"""
# if a master key is uploaded, then check it. Otherwise, load the
@@ -667,16 +686,17 @@ class E2eKeysHandler:
return {}
- async def upload_signatures_for_device_keys(self, user_id, signatures):
+ async def upload_signatures_for_device_keys(
+ self, user_id: str, signatures: JsonDict
+ ) -> JsonDict:
"""Upload device signatures for cross-signing
Args:
- user_id (string): the user uploading the signatures
- signatures (dict[string, dict[string, dict]]): map of users to
- devices to signed keys. This is the submission from the user; an
- exception will be raised if it is malformed.
+ user_id: the user uploading the signatures
+ signatures: map of users to devices to signed keys. This is the submission
+ from the user; an exception will be raised if it is malformed.
Returns:
- dict: response to be sent back to the client. The response will have
+ The response to be sent back to the client. The response will have
a "failures" key, which will be a dict mapping users to devices
to errors for the signatures that failed.
Raises:
@@ -719,7 +739,9 @@ class E2eKeysHandler:
return {"failures": failures}
- async def _process_self_signatures(self, user_id, signatures):
+ async def _process_self_signatures(
+ self, user_id: str, signatures: JsonDict
+ ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
"""Process uploaded signatures of the user's own keys.
Signatures of the user's own keys from this API come in two forms:
@@ -731,15 +753,14 @@ class E2eKeysHandler:
signatures (dict[string, dict]): map of devices to signed keys
Returns:
- (list[SignatureListItem], dict[string, dict[string, dict]]):
- a list of signatures to store, and a map of users to devices to failure
- reasons
+ A tuple of a list of signatures to store, and a map of users to
+ devices to failure reasons
Raises:
SynapseError: if the input is malformed
"""
- signature_list = []
- failures = {}
+ signature_list = [] # type: List[SignatureListItem]
+ failures = {} # type: Dict[str, Dict[str, JsonDict]]
if not signatures:
return signature_list, failures
@@ -834,19 +855,24 @@ class E2eKeysHandler:
return signature_list, failures
def _check_master_key_signature(
- self, user_id, master_key_id, signed_master_key, stored_master_key, devices
- ):
+ self,
+ user_id: str,
+ master_key_id: str,
+ signed_master_key: JsonDict,
+ stored_master_key: JsonDict,
+ devices: Dict[str, Dict[str, JsonDict]],
+ ) -> List["SignatureListItem"]:
"""Check signatures of a user's master key made by their devices.
Args:
- user_id (string): the user whose master key is being checked
- master_key_id (string): the ID of the user's master key
- signed_master_key (dict): the user's signed master key that was uploaded
- stored_master_key (dict): our previously-stored copy of the user's master key
- devices (iterable(dict)): the user's devices
+ user_id: the user whose master key is being checked
+ master_key_id: the ID of the user's master key
+ signed_master_key: the user's signed master key that was uploaded
+ stored_master_key: our previously-stored copy of the user's master key
+ devices: the user's devices
Returns:
- list[SignatureListItem]: a list of signatures to store
+ A list of signatures to store
Raises:
SynapseError: if a signature is invalid
@@ -877,25 +903,26 @@ class E2eKeysHandler:
return master_key_signature_list
- async def _process_other_signatures(self, user_id, signatures):
+ async def _process_other_signatures(
+ self, user_id: str, signatures: Dict[str, dict]
+ ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
"""Process uploaded signatures of other users' keys. These will be the
target user's master keys, signed by the uploading user's user-signing
key.
Args:
- user_id (string): the user uploading the keys
- signatures (dict[string, dict]): map of users to devices to signed keys
+ user_id: the user uploading the keys
+ signatures: map of users to devices to signed keys
Returns:
- (list[SignatureListItem], dict[string, dict[string, dict]]):
- a list of signatures to store, and a map of users to devices to failure
+ A list of signatures to store, and a map of users to devices to failure
reasons
Raises:
SynapseError: if the input is malformed
"""
- signature_list = []
- failures = {}
+ signature_list = [] # type: List[SignatureListItem]
+ failures = {} # type: Dict[str, Dict[str, JsonDict]]
if not signatures:
return signature_list, failures
@@ -983,7 +1010,7 @@ class E2eKeysHandler:
async def _get_e2e_cross_signing_verify_key(
self, user_id: str, key_type: str, from_user_id: str = None
- ):
+ ) -> Tuple[JsonDict, str, VerifyKey]:
"""Fetch locally or remotely query for a cross-signing public key.
First, attempt to fetch the cross-signing public key from storage.
@@ -997,8 +1024,7 @@ class E2eKeysHandler:
This affects what signatures are fetched.
Returns:
- dict, str, VerifyKey: the raw key data, the key ID, and the
- signedjson verify key
+ The raw key data, the key ID, and the signedjson verify key
Raises:
NotFoundError: if the key is not found
@@ -1135,16 +1161,18 @@ class E2eKeysHandler:
return desired_key, desired_key_id, desired_verify_key
-def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
+def _check_cross_signing_key(
+ key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
+) -> None:
"""Check a cross-signing key uploaded by a user. Performs some basic sanity
checking, and ensures that it is signed, if a signature is required.
Args:
- key (dict): the key data to verify
- user_id (str): the user whose key is being checked
- key_type (str): the type of key that the key should be
- signing_key (VerifyKey): (optional) the signing key that the key should
- be signed with. If omitted, signatures will not be checked.
+ key: the key data to verify
+ user_id: the user whose key is being checked
+ key_type: the type of key that the key should be
+ signing_key: the signing key that the key should be signed with. If
+ omitted, signatures will not be checked.
"""
if (
key.get("user_id") != user_id
@@ -1162,16 +1190,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
)
-def _check_device_signature(user_id, verify_key, signed_device, stored_device):
+def _check_device_signature(
+ user_id: str,
+ verify_key: VerifyKey,
+ signed_device: JsonDict,
+ stored_device: JsonDict,
+) -> None:
"""Check that a signature on a device or cross-signing key is correct and
matches the copy of the device/key that we have stored. Throws an
exception if an error is detected.
Args:
- user_id (str): the user ID whose signature is being checked
- verify_key (VerifyKey): the key to verify the device with
- signed_device (dict): the uploaded signed device data
- stored_device (dict): our previously stored copy of the device
+ user_id: the user ID whose signature is being checked
+ verify_key: the key to verify the device with
+ signed_device: the uploaded signed device data
+ stored_device: our previously stored copy of the device
Raises:
SynapseError: if the signature was invalid or the sent device is not the
@@ -1201,7 +1234,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device):
raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
-def _exception_to_failure(e):
+def _exception_to_failure(e: Exception) -> JsonDict:
if isinstance(e, SynapseError):
return {"status": e.code, "errcode": e.errcode, "message": str(e)}
@@ -1218,7 +1251,7 @@ def _exception_to_failure(e):
return {"status": 503, "message": str(e)}
-def _one_time_keys_match(old_key_json, new_key):
+def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
old_key = json_decoder.decode(old_key_json)
# if either is a string rather than an object, they must match exactly
@@ -1239,16 +1272,16 @@ class SignatureListItem:
"""An item in the signature list as used by upload_signatures_for_device_keys.
"""
- signing_key_id = attr.ib()
- target_user_id = attr.ib()
- target_device_id = attr.ib()
- signature = attr.ib()
+ signing_key_id = attr.ib(type=str)
+ target_user_id = attr.ib(type=str)
+ target_device_id = attr.ib(type=str)
+ signature = attr.ib(type=JsonDict)
class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB"""
- def __init__(self, hs, e2e_keys_handler):
+ def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
@@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
# user_id -> list of updates waiting to be handled.
- self._pending_updates = {}
+ self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
@@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater:
iterable=True,
)
- async def incoming_signing_key_update(self, origin, edu_content):
+ async def incoming_signing_key_update(
+ self, origin: str, edu_content: JsonDict
+ ) -> None:
"""Called on incoming signing key update from federation. Responsible for
parsing the EDU and adding to pending updates list.
Args:
- origin (string): the server that sent the EDU
- edu_content (dict): the contents of the EDU
+ origin: the server that sent the EDU
+ edu_content: the contents of the EDU
"""
user_id = edu_content.pop("user_id")
@@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater:
await self._handle_signing_key_updates(user_id)
- async def _handle_signing_key_updates(self, user_id):
+ async def _handle_signing_key_updates(self, user_id: str) -> None:
"""Actually handle pending updates.
Args:
- user_id (string): the user whose updates we are processing
+ user_id: the user whose updates we are processing
"""
device_handler = self.e2e_keys_handler.device_handler
@@ -1315,7 +1350,7 @@ class SigningKeyEduUpdater:
# This can happen since we batch updates
return
- device_ids = []
+ device_ids = [] # type: List[str]
logger.info("pending updates: %r", pending_updates)
|