summary refs log tree commit diff
path: root/synapse/rest/client/v2_alpha/keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v2_alpha/keys.py')
-rw-r--r--synapse/rest/client/v2_alpha/keys.py77
1 files changed, 51 insertions, 26 deletions
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index c5ff16adf3..8f05727652 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -19,11 +19,12 @@ import simplejson as json
 from canonicaljson import encode_canonical_json
 from twisted.internet import defer
 
-import synapse.api.errors
-import synapse.server
-import synapse.types
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import UserID
+from synapse.api.errors import SynapseError, CodeMessageException
+from synapse.http.servlet import (
+    RestServlet, parse_json_object_from_request, parse_integer
+)
+from synapse.types import get_domain_from_id
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 from ._base import client_v2_patterns
 
 logger = logging.getLogger(__name__)
@@ -88,7 +89,7 @@ class KeyUploadServlet(RestServlet):
             device_id = requester.device_id
 
         if device_id is None:
-            raise synapse.api.errors.SynapseError(
+            raise SynapseError(
                 400,
                 "To upload keys, you must pass device_id when authenticating"
             )
@@ -195,18 +196,21 @@ class KeyQueryServlet(RestServlet):
     @defer.inlineCallbacks
     def on_POST(self, request, user_id, device_id):
         yield self.auth.get_user_by_req(request)
+        timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
-        result = yield self.e2e_keys_handler.query_devices(body)
+        result = yield self.e2e_keys_handler.query_devices(body, timeout)
         defer.returnValue(result)
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id, device_id):
         requester = yield self.auth.get_user_by_req(request)
+        timeout = parse_integer(request, "timeout", 10 * 1000)
         auth_user_id = requester.user.to_string()
         user_id = user_id if user_id else auth_user_id
         device_ids = [device_id] if device_id else []
         result = yield self.e2e_keys_handler.query_devices(
-            {"device_keys": {user_id: device_ids}}
+            {"device_keys": {user_id: device_ids}},
+            timeout,
         )
         defer.returnValue(result)
 
@@ -244,39 +248,43 @@ class OneTimeKeyServlet(RestServlet):
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.federation = hs.get_replication_layer()
-        self.is_mine = hs.is_mine
+        self.is_mine_id = hs.is_mine_id
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id, device_id, algorithm):
         yield self.auth.get_user_by_req(request)
+        timeout = parse_integer(request, "timeout", 10 * 1000)
         result = yield self.handle_request(
-            {"one_time_keys": {user_id: {device_id: algorithm}}}
+            {"one_time_keys": {user_id: {device_id: algorithm}}},
+            timeout,
         )
         defer.returnValue(result)
 
     @defer.inlineCallbacks
     def on_POST(self, request, user_id, device_id, algorithm):
         yield self.auth.get_user_by_req(request)
+        timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
-        result = yield self.handle_request(body)
+        result = yield self.handle_request(body, timeout)
         defer.returnValue(result)
 
     @defer.inlineCallbacks
-    def handle_request(self, body):
+    def handle_request(self, body, timeout):
         local_query = []
         remote_queries = {}
+
         for user_id, device_keys in body.get("one_time_keys", {}).items():
-            user = UserID.from_string(user_id)
-            if self.is_mine(user):
+            if self.is_mine_id(user_id):
                 for device_id, algorithm in device_keys.items():
                     local_query.append((user_id, device_id, algorithm))
             else:
-                remote_queries.setdefault(user.domain, {})[user_id] = (
-                    device_keys
-                )
+                domain = get_domain_from_id(user_id)
+                remote_queries.setdefault(domain, {})[user_id] = device_keys
+
         results = yield self.store.claim_e2e_one_time_keys(local_query)
 
         json_result = {}
+        failures = {}
         for user_id, device_keys in results.items():
             for device_id, keys in device_keys.items():
                 for key_id, json_bytes in keys.items():
@@ -284,15 +292,32 @@ class OneTimeKeyServlet(RestServlet):
                         key_id: json.loads(json_bytes)
                     }
 
-        for destination, device_keys in remote_queries.items():
-            remote_result = yield self.federation.claim_client_keys(
-                destination, {"one_time_keys": device_keys}
-            )
-            for user_id, keys in remote_result["one_time_keys"].items():
-                if user_id in device_keys:
-                    json_result[user_id] = keys
-
-        defer.returnValue((200, {"one_time_keys": json_result}))
+        @defer.inlineCallbacks
+        def claim_client_keys(destination):
+            device_keys = remote_queries[destination]
+            try:
+                remote_result = yield self.federation.claim_client_keys(
+                    destination,
+                    {"one_time_keys": device_keys},
+                    timeout=timeout
+                )
+                for user_id, keys in remote_result["one_time_keys"].items():
+                    if user_id in device_keys:
+                        json_result[user_id] = keys
+            except CodeMessageException as e:
+                failures[destination] = {
+                    "status": e.code, "message": e.message
+                }
+
+        yield preserve_context_over_deferred(defer.gatherResults([
+            preserve_fn(claim_client_keys)(destination)
+            for destination in remote_queries
+        ]))
+
+        defer.returnValue((200, {
+            "one_time_keys": json_result,
+            "failures": failures
+        }))
 
 
 def register_servlets(hs, http_server):