summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/e2e_keys.py84
-rw-r--r--synapse/rest/client/keys.py13
2 files changed, 59 insertions, 38 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 7d4feecaf1..668cec513b 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -35,6 +35,7 @@ from synapse.api.errors import CodeMessageException, Codes, NotFoundError, Synap
 from synapse.handlers.device import DeviceHandler
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
+from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
 from synapse.types import (
     JsonDict,
     JsonMapping,
@@ -89,6 +90,12 @@ class E2eKeysHandler:
                 edu_updater.incoming_signing_key_update,
             )
 
+            self.device_key_uploader = self.upload_device_keys_for_user
+        else:
+            self.device_key_uploader = (
+                ReplicationUploadKeysForUserRestServlet.make_client(hs)
+            )
+
         # doesn't really work as part of the generic query API, because the
         # query request requires an object POST, but we abuse the
         # "query handler" interface.
@@ -796,36 +803,17 @@ class E2eKeysHandler:
             "one_time_keys": A mapping from algorithm to number of keys for that
                 algorithm, including those previously persisted.
         """
-        # This can only be called from the main process.
-        assert isinstance(self.device_handler, DeviceHandler)
-
         time_now = self.clock.time_msec()
 
         # TODO: Validate the JSON to make sure it has the right keys.
         device_keys = keys.get("device_keys", None)
         if device_keys:
-            logger.info(
-                "Updating device_keys for device %r for user %s at %d",
-                device_id,
-                user_id,
-                time_now,
+            await self.device_key_uploader(
+                user_id=user_id,
+                device_id=device_id,
+                keys={"device_keys": device_keys},
             )
-            log_kv(
-                {
-                    "message": "Updating device_keys for user.",
-                    "user_id": user_id,
-                    "device_id": device_id,
-                }
-            )
-            # TODO: Sign the JSON with the server key
-            changed = await self.store.set_e2e_device_keys(
-                user_id, device_id, time_now, device_keys
-            )
-            if changed:
-                # Only notify about device updates *if* the keys actually changed
-                await self.device_handler.notify_device_update(user_id, [device_id])
-        else:
-            log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
+
         one_time_keys = keys.get("one_time_keys", None)
         if one_time_keys:
             log_kv(
@@ -861,6 +849,49 @@ class E2eKeysHandler:
                 {"message": "Did not update fallback_keys", "reason": "no keys given"}
             )
 
+        result = await self.store.count_e2e_one_time_keys(user_id, device_id)
+
+        set_tag("one_time_key_counts", str(result))
+        return {"one_time_key_counts": result}
+
+    @tag_args
+    async def upload_device_keys_for_user(
+        self, user_id: str, device_id: str, keys: JsonDict
+    ) -> None:
+        """
+        Args:
+            user_id: user whose keys are being uploaded.
+            device_id: device whose keys are being uploaded.
+            device_keys: the `device_keys` of an /keys/upload request.
+
+        """
+        # This can only be called from the main process.
+        assert isinstance(self.device_handler, DeviceHandler)
+
+        time_now = self.clock.time_msec()
+
+        device_keys = keys["device_keys"]
+        logger.info(
+            "Updating device_keys for device %r for user %s at %d",
+            device_id,
+            user_id,
+            time_now,
+        )
+        log_kv(
+            {
+                "message": "Updating device_keys for user.",
+                "user_id": user_id,
+                "device_id": device_id,
+            }
+        )
+        # TODO: Sign the JSON with the server key
+        changed = await self.store.set_e2e_device_keys(
+            user_id, device_id, time_now, device_keys
+        )
+        if changed:
+            # Only notify about device updates *if* the keys actually changed
+            await self.device_handler.notify_device_update(user_id, [device_id])
+
         # the device should have been registered already, but it may have been
         # deleted due to a race with a DELETE request. Or we may be using an
         # old access_token without an associated device_id. Either way, we
@@ -868,11 +899,6 @@ class E2eKeysHandler:
         # keys without a corresponding device.
         await self.device_handler.check_device_registered(user_id, device_id)
 
-        result = await self.store.count_e2e_one_time_keys(user_id, device_id)
-
-        set_tag("one_time_key_counts", str(result))
-        return {"one_time_key_counts": result}
-
     async def _upload_one_time_keys_for_user(
         self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
     ) -> None:
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index a0017257ce..306db07b86 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -36,7 +36,6 @@ from synapse.http.servlet import (
 )
 from synapse.http.site import SynapseRequest
 from synapse.logging.opentracing import log_kv, set_tag
-from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
 from synapse.rest.client._base import client_patterns, interactive_auth_handler
 from synapse.types import JsonDict, StreamToken
 from synapse.util.cancellation import cancellable
@@ -105,13 +104,8 @@ class KeyUploadServlet(RestServlet):
         self.auth = hs.get_auth()
         self.e2e_keys_handler = hs.get_e2e_keys_handler()
         self.device_handler = hs.get_device_handler()
-
-        if hs.config.worker.worker_app is None:
-            # if main process
-            self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
-        else:
-            # then a worker
-            self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
+        self._clock = hs.get_clock()
+        self._store = hs.get_datastores().main
 
     async def on_POST(
         self, request: SynapseRequest, device_id: Optional[str]
@@ -151,9 +145,10 @@ class KeyUploadServlet(RestServlet):
                 400, "To upload keys, you must pass device_id when authenticating"
             )
 
-        result = await self.key_uploader(
+        result = await self.e2e_keys_handler.upload_keys_for_user(
             user_id=user_id, device_id=device_id, keys=body
         )
+
         return 200, result