diff --git a/changelog.d/17271.misc b/changelog.d/17271.misc
new file mode 100644
index 0000000000..915d717ad7
--- /dev/null
+++ b/changelog.d/17271.misc
@@ -0,0 +1 @@
+Handle OTK uploads off master.
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 7d4feecaf1..668cec513b 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -35,6 +35,7 @@ from synapse.api.errors import CodeMessageException, Codes, NotFoundError, Synap
from synapse.handlers.device import DeviceHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
+from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.types import (
JsonDict,
JsonMapping,
@@ -89,6 +90,12 @@ class E2eKeysHandler:
edu_updater.incoming_signing_key_update,
)
+ self.device_key_uploader = self.upload_device_keys_for_user
+ else:
+ self.device_key_uploader = (
+ ReplicationUploadKeysForUserRestServlet.make_client(hs)
+ )
+
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
@@ -796,36 +803,17 @@ class E2eKeysHandler:
"one_time_keys": A mapping from algorithm to number of keys for that
algorithm, including those previously persisted.
"""
- # This can only be called from the main process.
- assert isinstance(self.device_handler, DeviceHandler)
-
time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys.
device_keys = keys.get("device_keys", None)
if device_keys:
- logger.info(
- "Updating device_keys for device %r for user %s at %d",
- device_id,
- user_id,
- time_now,
+ await self.device_key_uploader(
+ user_id=user_id,
+ device_id=device_id,
+ keys={"device_keys": device_keys},
)
- log_kv(
- {
- "message": "Updating device_keys for user.",
- "user_id": user_id,
- "device_id": device_id,
- }
- )
- # TODO: Sign the JSON with the server key
- changed = await self.store.set_e2e_device_keys(
- user_id, device_id, time_now, device_keys
- )
- if changed:
- # Only notify about device updates *if* the keys actually changed
- await self.device_handler.notify_device_update(user_id, [device_id])
- else:
- log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
+
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
log_kv(
@@ -861,6 +849,49 @@ class E2eKeysHandler:
{"message": "Did not update fallback_keys", "reason": "no keys given"}
)
+ result = await self.store.count_e2e_one_time_keys(user_id, device_id)
+
+ set_tag("one_time_key_counts", str(result))
+ return {"one_time_key_counts": result}
+
+ @tag_args
+ async def upload_device_keys_for_user(
+ self, user_id: str, device_id: str, keys: JsonDict
+ ) -> None:
+ """
+ Args:
+ user_id: user whose keys are being uploaded.
+ device_id: device whose keys are being uploaded.
+ device_keys: the `device_keys` of an /keys/upload request.
+
+ """
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
+
+ time_now = self.clock.time_msec()
+
+ device_keys = keys["device_keys"]
+ logger.info(
+ "Updating device_keys for device %r for user %s at %d",
+ device_id,
+ user_id,
+ time_now,
+ )
+ log_kv(
+ {
+ "message": "Updating device_keys for user.",
+ "user_id": user_id,
+ "device_id": device_id,
+ }
+ )
+ # TODO: Sign the JSON with the server key
+ changed = await self.store.set_e2e_device_keys(
+ user_id, device_id, time_now, device_keys
+ )
+ if changed:
+ # Only notify about device updates *if* the keys actually changed
+ await self.device_handler.notify_device_update(user_id, [device_id])
+
# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
@@ -868,11 +899,6 @@ class E2eKeysHandler:
# keys without a corresponding device.
await self.device_handler.check_device_registered(user_id, device_id)
- result = await self.store.count_e2e_one_time_keys(user_id, device_id)
-
- set_tag("one_time_key_counts", str(result))
- return {"one_time_key_counts": result}
-
async def _upload_one_time_keys_for_user(
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
) -> None:
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index a0017257ce..306db07b86 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -36,7 +36,6 @@ from synapse.http.servlet import (
)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag
-from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.types import JsonDict, StreamToken
from synapse.util.cancellation import cancellable
@@ -105,13 +104,8 @@ class KeyUploadServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler()
-
- if hs.config.worker.worker_app is None:
- # if main process
- self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
- else:
- # then a worker
- self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
+ self._clock = hs.get_clock()
+ self._store = hs.get_datastores().main
async def on_POST(
self, request: SynapseRequest, device_id: Optional[str]
@@ -151,9 +145,10 @@ class KeyUploadServlet(RestServlet):
400, "To upload keys, you must pass device_id when authenticating"
)
- result = await self.key_uploader(
+ result = await self.e2e_keys_handler.upload_keys_for_user(
user_id=user_id, device_id=device_id, keys=body
)
+
return 200, result
|