diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d0fb2fc7dc..98b8ac9390 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -29,6 +29,7 @@ from synapse.api.errors import CodeMessageException, Codes, NotFoundError, Synap
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
+from synapse.rest.client.keys import KeyUploadBody
from synapse.types import (
JsonDict,
UserID,
@@ -491,14 +492,12 @@ class E2eKeysHandler:
@tag_args
async def upload_keys_for_user(
- self, user_id: str, device_id: str, keys: JsonDict
+ self, user_id: str, device_id: str, keys: KeyUploadBody
) -> JsonDict:
time_now = self.clock.time_msec()
- # TODO: Validate the JSON to make sure it has the right keys.
- device_keys = keys.get("device_keys", None)
- if device_keys:
+ if keys.device_keys is not None:
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id,
@@ -514,15 +513,14 @@ class E2eKeysHandler:
)
# TODO: Sign the JSON with the server key
changed = await self.store.set_e2e_device_keys(
- user_id, device_id, time_now, device_keys
+ user_id, device_id, time_now, keys.device_keys
)
if changed:
# Only notify about device updates *if* the keys actually changed
await self.device_handler.notify_device_update(user_id, [device_id])
else:
log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
- one_time_keys = keys.get("one_time_keys", None)
- if one_time_keys:
+ if keys.one_time_keys is not None:
log_kv(
{
"message": "Updating one_time_keys for device.",
@@ -531,13 +529,13 @@ class E2eKeysHandler:
}
)
await self._upload_one_time_keys_for_user(
- user_id, device_id, time_now, one_time_keys
+ user_id, device_id, time_now, keys.one_time_keys
)
else:
log_kv(
{"message": "Did not update one_time_keys", "reason": "no keys given"}
)
- fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None)
+ fallback_keys = keys.org_matrix_msc2732_fallback_keys
if fallback_keys and isinstance(fallback_keys, dict):
log_kv(
{
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 7281b2ee29..5ba4fabbbc 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -15,7 +15,10 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, Dict, Sequence, TypeVar, Type
+
+import attr
+from attr.validators import instance_of, deep_iterable, deep_mapping, optional
from synapse.api.errors import InvalidAPICallError, SynapseError
from synapse.http.server import HttpServer
@@ -37,6 +40,53 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+Signatures = Dict[str, Dict[str, str]]
+
+
+C = TypeVar("C")
+
+
+def check_str_or_jsondict(instance: Any, attribute: str, input: Any) -> None:
+ if isinstance(input, str):
+ pass
+ elif isinstance(input, dict):
+ for key, val in input.items():
+ if not isinstance(key, str):
+ raise TypeError(f"{attribute} dictionary has non-string key {key}")
+ else:
+ raise TypeError(f"{attribute} should be a string, or a map from strings to any type.")
+
+
+@attr.s(frozen=True, slots=True)
+class DeviceKeys:
+ user_id: str = attr.ib(validator=instance_of(str))
+ device_id: str = attr.ib(validator=instance_of(str))
+ algorithms: Sequence[str] = attr.ib(validator=deep_iterable(instance_of(str)))
+ keys: Dict[str, str] = attr.ib(validator=deep_mapping(instance_of(str), instance_of(str)))
+ signatures: Signatures = attr.ib(validator=deep_mapping(
+ instance_of(str), deep_mapping(instance_of(str), instance_of(str))
+ ))
+
+
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class KeyUploadBody:
+ device_keys: Optional[DeviceKeys] = attr.ib(converter=attr.converters.optional(DeviceKeys))
+ one_time_keys: Optional[Dict[str, Union[str, JsonDict]]] = attr.ib(
+ validator=optional(deep_mapping(
+ instance_of(str), check_str_or_jsondict
+ ))
+ )
+ org_matrix_msc2732_fallback_keys: Optional[Dict[str, Union[str, JsonDict]]]
+
+ @classmethod
+ def load(cls: Type[C], src: JsonDict) -> C:
+ return cls(
+ device_keys=src.get("device_keys"),
+ one_time_keys=src.get("one_time_keys"),
+ org_matrix_msc2732_fallback_keys=src.get("org.matrix.msc2732.fallback.keys"),
+ )
+
+
class KeyUploadServlet(RestServlet):
"""
POST /keys/upload HTTP/1.1
@@ -77,7 +127,7 @@ class KeyUploadServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
- body = parse_json_object_from_request(request)
+ body = KeyUploadBody.load(parse_json_object_from_request(request))
if device_id is not None:
# Providing the device_id should only be done for setting keys
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index a95ac34f09..20e019e923 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -23,6 +23,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.api.constants import DeviceKeyAlgorithms
from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.rest.client.keys import DeviceKeys
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.storage.engines import PostgresEngine
@@ -902,7 +903,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
async def set_e2e_device_keys(
- self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
+ self, user_id: str, device_id: str, time_now: int, device_keys: DeviceKeys
) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
@@ -924,7 +925,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
# In py3 we need old_key_json to match new_key_json type. The DB
# returns unicode while encode_canonical_json returns bytes.
- new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+ new_key_json = encode_canonical_json(attr.asdict(device_keys)).decode("utf-8")
if old_key_json == new_key_json:
log_kv({"Message": "Device key already stored."})
|