summary refs log tree commit diff
path: root/synapse/handlers/e2e_keys.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/handlers/e2e_keys.py247
1 files changed, 143 insertions, 104 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 929752150d..9a946a3cfe 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -16,7 +16,7 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
 from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
 from synapse.types import (
+    JsonDict,
     UserID,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
@@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class E2eKeysHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.federation = hs.get_federation_client()
         self.device_handler = hs.get_device_handler()
@@ -57,8 +61,8 @@ class E2eKeysHandler:
 
         self._is_master = hs.config.worker_app is None
         if not self._is_master:
-            self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client(
-                hs
+            self._user_device_resync_client = (
+                ReplicationUserDevicesResyncRestServlet.make_client(hs)
             )
         else:
             # Only register this edu handler on master as it requires writing
@@ -78,8 +82,10 @@ class E2eKeysHandler:
         )
 
     @trace
-    async def query_devices(self, query_body, timeout, from_user_id):
-        """ Handle a device key query from a client
+    async def query_devices(
+        self, query_body: JsonDict, timeout: int, from_user_id: str
+    ) -> JsonDict:
+        """Handle a device key query from a client
 
         {
             "device_keys": {
@@ -98,12 +104,14 @@ class E2eKeysHandler:
         }
 
         Args:
-            from_user_id (str): the user making the query.  This is used when
+            from_user_id: the user making the query.  This is used when
                 adding cross-signing signatures to limit what signatures users
                 can see.
         """
 
-        device_keys_query = query_body.get("device_keys", {})
+        device_keys_query = query_body.get(
+            "device_keys", {}
+        )  # type: Dict[str, Iterable[str]]
 
         # separate users by domain.
         # make a map from domain to user_id to device_ids
@@ -121,7 +129,8 @@ class E2eKeysHandler:
         set_tag("remote_key_query", remote_queries)
 
         # First get local devices.
-        failures = {}
+        # A map of destination -> failure response.
+        failures = {}  # type: Dict[str, JsonDict]
         results = {}
         if local_query:
             local_result = await self.query_local_devices(local_query)
@@ -135,9 +144,10 @@ class E2eKeysHandler:
         )
 
         # Now attempt to get any remote devices from our local cache.
-        remote_queries_not_in_cache = {}
+        # A map of destination -> user ID -> device IDs.
+        remote_queries_not_in_cache = {}  # type: Dict[str, Dict[str, Iterable[str]]]
         if remote_queries:
-            query_list = []
+            query_list = []  # type: List[Tuple[str, Optional[str]]]
             for user_id, device_ids in remote_queries.items():
                 if device_ids:
                     query_list.extend((user_id, device_id) for device_id in device_ids)
@@ -284,15 +294,15 @@ class E2eKeysHandler:
         return ret
 
     async def get_cross_signing_keys_from_cache(
-        self, query, from_user_id
+        self, query: Iterable[str], from_user_id: Optional[str]
     ) -> Dict[str, Dict[str, dict]]:
         """Get cross-signing keys for users from the database
 
         Args:
-            query (Iterable[string]) an iterable of user IDs.  A dict whose keys
+            query: an iterable of user IDs.  A dict whose keys
                 are user IDs satisfies this, so the query format used for
                 query_devices can be used here.
-            from_user_id (str): the user making the query.  This is used when
+            from_user_id: the user making the query.  This is used when
                 adding cross-signing signatures to limit what signatures users
                 can see.
 
@@ -315,14 +325,12 @@ class E2eKeysHandler:
             if "self_signing" in user_info:
                 self_signing_keys[user_id] = user_info["self_signing"]
 
-        if (
-            from_user_id in keys
-            and keys[from_user_id] is not None
-            and "user_signing" in keys[from_user_id]
-        ):
-            # users can see other users' master and self-signing keys, but can
-            # only see their own user-signing keys
-            user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
+        # users can see other users' master and self-signing keys, but can
+        # only see their own user-signing keys
+        if from_user_id:
+            from_user_key = keys.get(from_user_id)
+            if from_user_key and "user_signing" in from_user_key:
+                user_signing_keys[from_user_id] = from_user_key["user_signing"]
 
         return {
             "master_keys": master_keys,
@@ -344,9 +352,9 @@ class E2eKeysHandler:
             A map from user_id -> device_id -> device details
         """
         set_tag("local_query", query)
-        local_query = []
+        local_query = []  # type: List[Tuple[str, Optional[str]]]
 
-        result_dict = {}
+        result_dict = {}  # type: Dict[str, Dict[str, dict]]
         for user_id, device_ids in query.items():
             # we use UserID.from_string to catch invalid user ids
             if not self.is_mine(UserID.from_string(user_id)):
@@ -380,10 +388,13 @@ class E2eKeysHandler:
         log_kv(results)
         return result_dict
 
-    async def on_federation_query_client_keys(self, query_body):
-        """ Handle a device key query from a federated server
-        """
-        device_keys_query = query_body.get("device_keys", {})
+    async def on_federation_query_client_keys(
+        self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
+    ) -> JsonDict:
+        """Handle a device key query from a federated server"""
+        device_keys_query = query_body.get(
+            "device_keys", {}
+        )  # type: Dict[str, Optional[List[str]]]
         res = await self.query_local_devices(device_keys_query)
         ret = {"device_keys": res}
 
@@ -397,31 +408,34 @@ class E2eKeysHandler:
         return ret
 
     @trace
-    async def claim_one_time_keys(self, query, timeout):
-        local_query = []
-        remote_queries = {}
+    async def claim_one_time_keys(
+        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
+    ) -> JsonDict:
+        local_query = []  # type: List[Tuple[str, str, str]]
+        remote_queries = {}  # type: Dict[str, Dict[str, Dict[str, str]]]
 
-        for user_id, device_keys in query.get("one_time_keys", {}).items():
+        for user_id, one_time_keys in query.get("one_time_keys", {}).items():
             # we use UserID.from_string to catch invalid user ids
             if self.is_mine(UserID.from_string(user_id)):
-                for device_id, algorithm in device_keys.items():
+                for device_id, algorithm in one_time_keys.items():
                     local_query.append((user_id, device_id, algorithm))
             else:
                 domain = get_domain_from_id(user_id)
-                remote_queries.setdefault(domain, {})[user_id] = device_keys
+                remote_queries.setdefault(domain, {})[user_id] = one_time_keys
 
         set_tag("local_key_query", local_query)
         set_tag("remote_key_query", remote_queries)
 
         results = await self.store.claim_e2e_one_time_keys(local_query)
 
-        json_result = {}
-        failures = {}
+        # A map of user ID -> device ID -> key ID -> key.
+        json_result = {}  # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
+        failures = {}  # type: Dict[str, JsonDict]
         for user_id, device_keys in results.items():
             for device_id, keys in device_keys.items():
-                for key_id, json_bytes in keys.items():
+                for key_id, json_str in keys.items():
                     json_result.setdefault(user_id, {})[device_id] = {
-                        key_id: json_decoder.decode(json_bytes)
+                        key_id: json_decoder.decode(json_str)
                     }
 
         @trace
@@ -468,7 +482,9 @@ class E2eKeysHandler:
         return {"one_time_keys": json_result, "failures": failures}
 
     @tag_args
-    async def upload_keys_for_user(self, user_id, device_id, keys):
+    async def upload_keys_for_user(
+        self, user_id: str, device_id: str, keys: JsonDict
+    ) -> JsonDict:
 
         time_now = self.clock.time_msec()
 
@@ -543,8 +559,8 @@ class E2eKeysHandler:
         return {"one_time_key_counts": result}
 
     async def _upload_one_time_keys_for_user(
-        self, user_id, device_id, time_now, one_time_keys
-    ):
+        self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
+    ) -> None:
         logger.info(
             "Adding one_time_keys %r for device %r for user %r at %d",
             one_time_keys.keys(),
@@ -585,12 +601,14 @@ class E2eKeysHandler:
         log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
         await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
 
-    async def upload_signing_keys_for_user(self, user_id, keys):
+    async def upload_signing_keys_for_user(
+        self, user_id: str, keys: JsonDict
+    ) -> JsonDict:
         """Upload signing keys for cross-signing
 
         Args:
-            user_id (string): the user uploading the keys
-            keys (dict[string, dict]): the signing keys
+            user_id: the user uploading the keys
+            keys: the signing keys
         """
 
         # if a master key is uploaded, then check it.  Otherwise, load the
@@ -667,16 +685,17 @@ class E2eKeysHandler:
 
         return {}
 
-    async def upload_signatures_for_device_keys(self, user_id, signatures):
+    async def upload_signatures_for_device_keys(
+        self, user_id: str, signatures: JsonDict
+    ) -> JsonDict:
         """Upload device signatures for cross-signing
 
         Args:
-            user_id (string): the user uploading the signatures
-            signatures (dict[string, dict[string, dict]]): map of users to
-                devices to signed keys. This is the submission from the user; an
-                exception will be raised if it is malformed.
+            user_id: the user uploading the signatures
+            signatures: map of users to devices to signed keys. This is the submission
+            from the user; an exception will be raised if it is malformed.
         Returns:
-            dict: response to be sent back to the client.  The response will have
+            The response to be sent back to the client.  The response will have
                 a "failures" key, which will be a dict mapping users to devices
                 to errors for the signatures that failed.
         Raises:
@@ -719,7 +738,9 @@ class E2eKeysHandler:
 
         return {"failures": failures}
 
-    async def _process_self_signatures(self, user_id, signatures):
+    async def _process_self_signatures(
+        self, user_id: str, signatures: JsonDict
+    ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
         """Process uploaded signatures of the user's own keys.
 
         Signatures of the user's own keys from this API come in two forms:
@@ -731,15 +752,14 @@ class E2eKeysHandler:
             signatures (dict[string, dict]): map of devices to signed keys
 
         Returns:
-            (list[SignatureListItem], dict[string, dict[string, dict]]):
-            a list of signatures to store, and a map of users to devices to failure
-            reasons
+            A tuple of a list of signatures to store, and a map of users to
+            devices to failure reasons
 
         Raises:
             SynapseError: if the input is malformed
         """
-        signature_list = []
-        failures = {}
+        signature_list = []  # type: List[SignatureListItem]
+        failures = {}  # type: Dict[str, Dict[str, JsonDict]]
         if not signatures:
             return signature_list, failures
 
@@ -834,19 +854,24 @@ class E2eKeysHandler:
         return signature_list, failures
 
     def _check_master_key_signature(
-        self, user_id, master_key_id, signed_master_key, stored_master_key, devices
-    ):
+        self,
+        user_id: str,
+        master_key_id: str,
+        signed_master_key: JsonDict,
+        stored_master_key: JsonDict,
+        devices: Dict[str, Dict[str, JsonDict]],
+    ) -> List["SignatureListItem"]:
         """Check signatures of a user's master key made by their devices.
 
         Args:
-            user_id (string): the user whose master key is being checked
-            master_key_id (string): the ID of the user's master key
-            signed_master_key (dict): the user's signed master key that was uploaded
-            stored_master_key (dict): our previously-stored copy of the user's master key
-            devices (iterable(dict)): the user's devices
+            user_id: the user whose master key is being checked
+            master_key_id: the ID of the user's master key
+            signed_master_key: the user's signed master key that was uploaded
+            stored_master_key: our previously-stored copy of the user's master key
+            devices: the user's devices
 
         Returns:
-            list[SignatureListItem]: a list of signatures to store
+            A list of signatures to store
 
         Raises:
             SynapseError: if a signature is invalid
@@ -877,25 +902,26 @@ class E2eKeysHandler:
 
         return master_key_signature_list
 
-    async def _process_other_signatures(self, user_id, signatures):
+    async def _process_other_signatures(
+        self, user_id: str, signatures: Dict[str, dict]
+    ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
         """Process uploaded signatures of other users' keys.  These will be the
         target user's master keys, signed by the uploading user's user-signing
         key.
 
         Args:
-            user_id (string): the user uploading the keys
-            signatures (dict[string, dict]): map of users to devices to signed keys
+            user_id: the user uploading the keys
+            signatures: map of users to devices to signed keys
 
         Returns:
-            (list[SignatureListItem], dict[string, dict[string, dict]]):
-            a list of signatures to store, and a map of users to devices to failure
+            A list of signatures to store, and a map of users to devices to failure
             reasons
 
         Raises:
             SynapseError: if the input is malformed
         """
-        signature_list = []
-        failures = {}
+        signature_list = []  # type: List[SignatureListItem]
+        failures = {}  # type: Dict[str, Dict[str, JsonDict]]
         if not signatures:
             return signature_list, failures
 
@@ -983,7 +1009,7 @@ class E2eKeysHandler:
 
     async def _get_e2e_cross_signing_verify_key(
         self, user_id: str, key_type: str, from_user_id: str = None
-    ):
+    ) -> Tuple[JsonDict, str, VerifyKey]:
         """Fetch locally or remotely query for a cross-signing public key.
 
         First, attempt to fetch the cross-signing public key from storage.
@@ -997,8 +1023,7 @@ class E2eKeysHandler:
                 This affects what signatures are fetched.
 
         Returns:
-            dict, str, VerifyKey: the raw key data, the key ID, and the
-                signedjson verify key
+            The raw key data, the key ID, and the signedjson verify key
 
         Raises:
             NotFoundError: if the key is not found
@@ -1039,7 +1064,9 @@ class E2eKeysHandler:
         return key, key_id, verify_key
 
     async def _retrieve_cross_signing_keys_for_remote_user(
-        self, user: UserID, desired_key_type: str,
+        self,
+        user: UserID,
+        desired_key_type: str,
     ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
         """Queries cross-signing keys for a remote user and saves them to the database
 
@@ -1135,16 +1162,18 @@ class E2eKeysHandler:
         return desired_key, desired_key_id, desired_verify_key
 
 
-def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
+def _check_cross_signing_key(
+    key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
+) -> None:
     """Check a cross-signing key uploaded by a user.  Performs some basic sanity
     checking, and ensures that it is signed, if a signature is required.
 
     Args:
-        key (dict): the key data to verify
-        user_id (str): the user whose key is being checked
-        key_type (str): the type of key that the key should be
-        signing_key (VerifyKey): (optional) the signing key that the key should
-            be signed with.  If omitted, signatures will not be checked.
+        key: the key data to verify
+        user_id: the user whose key is being checked
+        key_type: the type of key that the key should be
+        signing_key: the signing key that the key should be signed with.  If
+            omitted, signatures will not be checked.
     """
     if (
         key.get("user_id") != user_id
@@ -1162,16 +1191,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
             )
 
 
-def _check_device_signature(user_id, verify_key, signed_device, stored_device):
+def _check_device_signature(
+    user_id: str,
+    verify_key: VerifyKey,
+    signed_device: JsonDict,
+    stored_device: JsonDict,
+) -> None:
     """Check that a signature on a device or cross-signing key is correct and
     matches the copy of the device/key that we have stored.  Throws an
     exception if an error is detected.
 
     Args:
-        user_id (str): the user ID whose signature is being checked
-        verify_key (VerifyKey): the key to verify the device with
-        signed_device (dict): the uploaded signed device data
-        stored_device (dict): our previously stored copy of the device
+        user_id: the user ID whose signature is being checked
+        verify_key: the key to verify the device with
+        signed_device: the uploaded signed device data
+        stored_device: our previously stored copy of the device
 
     Raises:
         SynapseError: if the signature was invalid or the sent device is not the
@@ -1201,7 +1235,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device):
         raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
 
 
-def _exception_to_failure(e):
+def _exception_to_failure(e: Exception) -> JsonDict:
     if isinstance(e, SynapseError):
         return {"status": e.code, "errcode": e.errcode, "message": str(e)}
 
@@ -1218,7 +1252,7 @@ def _exception_to_failure(e):
     return {"status": 503, "message": str(e)}
 
 
-def _one_time_keys_match(old_key_json, new_key):
+def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
     old_key = json_decoder.decode(old_key_json)
 
     # if either is a string rather than an object, they must match exactly
@@ -1236,19 +1270,18 @@ def _one_time_keys_match(old_key_json, new_key):
 
 @attr.s(slots=True)
 class SignatureListItem:
-    """An item in the signature list as used by upload_signatures_for_device_keys.
-    """
+    """An item in the signature list as used by upload_signatures_for_device_keys."""
 
-    signing_key_id = attr.ib()
-    target_user_id = attr.ib()
-    target_device_id = attr.ib()
-    signature = attr.ib()
+    signing_key_id = attr.ib(type=str)
+    target_user_id = attr.ib(type=str)
+    target_device_id = attr.ib(type=str)
+    signature = attr.ib(type=JsonDict)
 
 
 class SigningKeyEduUpdater:
     """Handles incoming signing key updates from federation and updates the DB"""
 
-    def __init__(self, hs, e2e_keys_handler):
+    def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
         self.store = hs.get_datastore()
         self.federation = hs.get_federation_client()
         self.clock = hs.get_clock()
@@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater:
         self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
 
         # user_id -> list of updates waiting to be handled.
-        self._pending_updates = {}
+        self._pending_updates = {}  # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
 
         # Recently seen stream ids. We don't bother keeping these in the DB,
         # but they're useful to have them about to reduce the number of spurious
@@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater:
             iterable=True,
         )
 
-    async def incoming_signing_key_update(self, origin, edu_content):
+    async def incoming_signing_key_update(
+        self, origin: str, edu_content: JsonDict
+    ) -> None:
         """Called on incoming signing key update from federation. Responsible for
         parsing the EDU and adding to pending updates list.
 
         Args:
-            origin (string): the server that sent the EDU
-            edu_content (dict): the contents of the EDU
+            origin: the server that sent the EDU
+            edu_content: the contents of the EDU
         """
 
         user_id = edu_content.pop("user_id")
@@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater:
 
         await self._handle_signing_key_updates(user_id)
 
-    async def _handle_signing_key_updates(self, user_id):
+    async def _handle_signing_key_updates(self, user_id: str) -> None:
         """Actually handle pending updates.
 
         Args:
-            user_id (string): the user whose updates we are processing
+            user_id: the user whose updates we are processing
         """
 
         device_handler = self.e2e_keys_handler.device_handler
@@ -1315,13 +1350,17 @@ class SigningKeyEduUpdater:
                 # This can happen since we batch updates
                 return
 
-            device_ids = []
+            device_ids = []  # type: List[str]
 
             logger.info("pending updates: %r", pending_updates)
 
             for master_key, self_signing_key in pending_updates:
-                new_device_ids = await device_list_updater.process_cross_signing_key_update(
-                    user_id, master_key, self_signing_key,
+                new_device_ids = (
+                    await device_list_updater.process_cross_signing_key_update(
+                        user_id,
+                        master_key,
+                        self_signing_key,
+                    )
                 )
                 device_ids = device_ids + new_device_ids