summary refs log tree commit diff
path: root/synapse/rest/client/v2_alpha/keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v2_alpha/keys.py')
-rw-r--r--synapse/rest/client/v2_alpha/keys.py47
1 files changed, 35 insertions, 12 deletions
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 56364af337..0bf32a089b 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -19,6 +19,9 @@ import simplejson as json
 from canonicaljson import encode_canonical_json
 from twisted.internet import defer
 
+import synapse.api.errors
+import synapse.server
+import synapse.types
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.types import UserID
 from ._base import client_v2_patterns
@@ -28,7 +31,7 @@ logger = logging.getLogger(__name__)
 
 class KeyUploadServlet(RestServlet):
     """
-    POST /keys/upload/<device_id> HTTP/1.1
+    POST /keys/upload HTTP/1.1
     Content-Type: application/json
 
     {
@@ -51,23 +54,51 @@ class KeyUploadServlet(RestServlet):
       },
     }
     """
-    PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)", releases=())
+    PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
+                                  releases=(), v2_alpha=False)
 
     def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer): server
+        """
         super(KeyUploadServlet, self).__init__()
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
         self.auth = hs.get_auth()
+        self.device_handler = hs.get_device_handler()
 
     @defer.inlineCallbacks
     def on_POST(self, request, device_id):
         requester = yield self.auth.get_user_by_req(request)
+
         user_id = requester.user.to_string()
-        # TODO: Check that the device_id matches that in the authentication
-        # or derive the device_id from the authentication instead.
 
         body = parse_json_object_from_request(request)
 
+        if device_id is not None:
+            # passing the device_id here is deprecated; however, we allow it
+            # for now for compatibility with older clients. But if a device_id
+            # was given here and in the auth, they must match.
+
+            if (requester.device_id is not None and
+                    device_id != requester.device_id):
+                raise synapse.api.errors.SynapseError(
+                    400, "Can only upload keys for current device"
+                )
+
+            self.device_handler.check_device_registered(
+                user_id, device_id, "unknown device"
+            )
+        else:
+            device_id = requester.device_id
+
+        if device_id is None:
+            raise synapse.api.errors.SynapseError(
+                400,
+                "To upload keys, you must pass device_id when authenticating"
+            )
+
         time_now = self.clock.time_msec()
 
         # TODO: Validate the JSON to make sure it has the right keys.
@@ -103,14 +134,6 @@ class KeyUploadServlet(RestServlet):
         result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
         defer.returnValue((200, {"one_time_key_counts": result}))
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, device_id):
-        requester = yield self.auth.get_user_by_req(request)
-        user_id = requester.user.to_string()
-
-        result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
-        defer.returnValue((200, {"one_time_key_counts": result}))
-
 
 class KeyQueryServlet(RestServlet):
     """