diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index e3f454896a..ee038c7192 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -26,10 +26,11 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
+from synapse.logging.opentracing import log_kv, set_tag
+from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
+from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.types import JsonDict, StreamToken
-
-from ._base import client_patterns, interactive_auth_handler
+from synapse.util.cancellation import cancellable
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -43,24 +44,48 @@ class KeyUploadServlet(RestServlet):
Content-Type: application/json
{
- "device_keys": {
- "user_id": "<user_id>",
- "device_id": "<device_id>",
- "valid_until_ts": <millisecond_timestamp>,
- "algorithms": [
- "m.olm.curve25519-aes-sha2",
- ]
- "keys": {
- "<algorithm>:<device_id>": "<key_base64>",
+ "device_keys": {
+ "user_id": "<user_id>",
+ "device_id": "<device_id>",
+ "valid_until_ts": <millisecond_timestamp>,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ ]
+ "keys": {
+ "<algorithm>:<device_id>": "<key_base64>",
+ },
+ "signatures:" {
+ "<user_id>" {
+ "<algorithm>:<device_id>": "<signature_base64>"
+ }
+ }
},
- "signatures:" {
- "<user_id>" {
- "<algorithm>:<device_id>": "<signature_base64>"
- } } },
- "one_time_keys": {
- "<algorithm>:<key_id>": "<key_base64>"
- },
+ "fallback_keys": {
+ "<algorithm>:<device_id>": "<key_base64>",
+ "signed_<algorithm>:<device_id>": {
+ "fallback": true,
+ "key": "<key_base64>",
+ "signatures": {
+ "<user_id>": {
+ "<algorithm>:<device_id>": "<key_base64>"
+ }
+ }
+ }
+ }
+ "one_time_keys": {
+ "<algorithm>:<key_id>": "<key_base64>"
+ },
+ }
+
+ response, e.g.:
+
+ {
+ "one_time_key_counts": {
+ "curve25519": 10,
+ "signed_curve25519": 20
+ }
}
+
"""
PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
@@ -71,7 +96,13 @@ class KeyUploadServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler()
- @trace_with_opname("upload_keys")
+ if hs.config.worker.worker_app is None:
+ # if main process
+ self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
+ else:
+ # then a worker
+ self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
+
async def on_POST(
self, request: SynapseRequest, device_id: Optional[str]
) -> Tuple[int, JsonDict]:
@@ -110,8 +141,8 @@ class KeyUploadServlet(RestServlet):
400, "To upload keys, you must pass device_id when authenticating"
)
- result = await self.e2e_keys_handler.upload_keys_for_user(
- user_id, device_id, body
+ result = await self.key_uploader(
+ user_id=user_id, device_id=device_id, keys=body
)
return 200, result
@@ -157,6 +188,7 @@ class KeyQueryServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
+ @cancellable
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
@@ -200,6 +232,7 @@ class KeyChangesServlet(RestServlet):
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastores().main
+ @cancellable
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|