summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/federation/federation_client.py12
-rw-r--r--synapse/federation/transport/client.py6
-rw-r--r--synapse/handlers/e2e_keys.py64
-rw-r--r--synapse/http/matrixfederationclient.py11
-rw-r--r--synapse/rest/client/v2_alpha/keys.py77
5 files changed, 115 insertions, 55 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 3395c9e41e..cf8a52510d 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -176,7 +176,7 @@ class FederationClient(FederationBase):
         )
 
     @log_function
-    def query_client_keys(self, destination, content):
+    def query_client_keys(self, destination, content, timeout):
         """Query device keys for a device hosted on a remote server.
 
         Args:
@@ -188,10 +188,12 @@ class FederationClient(FederationBase):
             response
         """
         sent_queries_counter.inc("client_device_keys")
-        return self.transport_layer.query_client_keys(destination, content)
+        return self.transport_layer.query_client_keys(
+            destination, content, timeout
+        )
 
     @log_function
-    def claim_client_keys(self, destination, content):
+    def claim_client_keys(self, destination, content, timeout):
         """Claims one-time keys for a device hosted on a remote server.
 
         Args:
@@ -203,7 +205,9 @@ class FederationClient(FederationBase):
             response
         """
         sent_queries_counter.inc("client_one_time_keys")
-        return self.transport_layer.claim_client_keys(destination, content)
+        return self.transport_layer.claim_client_keys(
+            destination, content, timeout
+        )
 
     @defer.inlineCallbacks
     @log_function
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 3d088e43cb..2b138526ba 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -298,7 +298,7 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
-    def query_client_keys(self, destination, query_content):
+    def query_client_keys(self, destination, query_content, timeout):
         """Query the device keys for a list of user ids hosted on a remote
         server.
 
@@ -327,12 +327,13 @@ class TransportLayerClient(object):
             destination=destination,
             path=path,
             data=query_content,
+            timeout=timeout,
         )
         defer.returnValue(content)
 
     @defer.inlineCallbacks
     @log_function
-    def claim_client_keys(self, destination, query_content):
+    def claim_client_keys(self, destination, query_content, timeout):
         """Claim one-time keys for a list of devices hosted on a remote server.
 
         Request:
@@ -363,6 +364,7 @@ class TransportLayerClient(object):
             destination=destination,
             path=path,
             data=query_content,
+            timeout=timeout,
         )
         defer.returnValue(content)
 
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 2c7bfd91ed..5bfd700931 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -13,14 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import collections
 import json
 import logging
 
 from twisted.internet import defer
 
-from synapse.api import errors
-import synapse.types
+from synapse.api.errors import SynapseError, CodeMessageException
+from synapse.types import get_domain_from_id
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 
 logger = logging.getLogger(__name__)
 
@@ -30,7 +30,6 @@ class E2eKeysHandler(object):
         self.store = hs.get_datastore()
         self.federation = hs.get_replication_layer()
         self.is_mine_id = hs.is_mine_id
-        self.server_name = hs.hostname
 
         # doesn't really work as part of the generic query API, because the
         # query request requires an object POST, but we abuse the
@@ -40,7 +39,7 @@ class E2eKeysHandler(object):
         )
 
     @defer.inlineCallbacks
-    def query_devices(self, query_body):
+    def query_devices(self, query_body, timeout):
         """ Handle a device key query from a client
 
         {
@@ -63,27 +62,50 @@ class E2eKeysHandler(object):
 
         # separate users by domain.
         # make a map from domain to user_id to device_ids
-        queries_by_domain = collections.defaultdict(dict)
+        local_query = {}
+        remote_queries = {}
+
         for user_id, device_ids in device_keys_query.items():
-            user = synapse.types.UserID.from_string(user_id)
-            queries_by_domain[user.domain][user_id] = device_ids
+            if self.is_mine_id(user_id):
+                local_query[user_id] = device_ids
+            else:
+                domain = get_domain_from_id(user_id)
+                remote_queries.setdefault(domain, {})[user_id] = device_ids
 
         # do the queries
-        # TODO: do these in parallel
+        failures = {}
         results = {}
-        for destination, destination_query in queries_by_domain.items():
-            if destination == self.server_name:
-                res = yield self.query_local_devices(destination_query)
-            else:
-                res = yield self.federation.query_client_keys(
-                    destination, {"device_keys": destination_query}
-                )
-                res = res["device_keys"]
-            for user_id, keys in res.items():
-                if user_id in destination_query:
+        if local_query:
+            local_result = yield self.query_local_devices(local_query)
+            for user_id, keys in local_result.items():
+                if user_id in local_query:
                     results[user_id] = keys
 
-        defer.returnValue((200, {"device_keys": results}))
+        @defer.inlineCallbacks
+        def do_remote_query(destination):
+            destination_query = remote_queries[destination]
+            try:
+                remote_result = yield self.federation.query_client_keys(
+                    destination,
+                    {"device_keys": destination_query},
+                    timeout=timeout
+                )
+                for user_id, keys in remote_result["device_keys"].items():
+                    if user_id in destination_query:
+                        results[user_id] = keys
+            except CodeMessageException as e:
+                failures[destination] = {
+                    "status": e.code, "message": e.message
+                }
+
+        yield preserve_context_over_deferred(defer.gatherResults([
+            preserve_fn(do_remote_query)(destination)
+            for destination in remote_queries
+        ]))
+
+        defer.returnValue((200, {
+            "device_keys": results, "failures": failures,
+        }))
 
     @defer.inlineCallbacks
     def query_local_devices(self, query):
@@ -104,7 +126,7 @@ class E2eKeysHandler(object):
             if not self.is_mine_id(user_id):
                 logger.warning("Request for keys for non-local user %s",
                                user_id)
-                raise errors.SynapseError(400, "Not a user here")
+                raise SynapseError(400, "Not a user here")
 
             if not device_ids:
                 local_query.append((user_id, None))
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index f93093dd85..d0556ae347 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -246,7 +246,7 @@ class MatrixFederationHttpClient(object):
 
     @defer.inlineCallbacks
     def put_json(self, destination, path, data={}, json_data_callback=None,
-                 long_retries=False):
+                 long_retries=False, timeout=None):
         """ Sends the specifed json data using PUT
 
         Args:
@@ -259,6 +259,8 @@ class MatrixFederationHttpClient(object):
                 use as the request body.
             long_retries (bool): A boolean that indicates whether we should
                 retry for a short or long time.
+            timeout(int): How long to try (in ms) the destination for before
+                giving up. None indicates no timeout.
 
         Returns:
             Deferred: Succeeds when we get a 2xx HTTP response. The result
@@ -285,6 +287,7 @@ class MatrixFederationHttpClient(object):
             body_callback=body_callback,
             headers_dict={"Content-Type": ["application/json"]},
             long_retries=long_retries,
+            timeout=timeout,
         )
 
         if 200 <= response.code < 300:
@@ -300,7 +303,8 @@ class MatrixFederationHttpClient(object):
         defer.returnValue(json.loads(body))
 
     @defer.inlineCallbacks
-    def post_json(self, destination, path, data={}, long_retries=True):
+    def post_json(self, destination, path, data={}, long_retries=True,
+                  timeout=None):
         """ Sends the specifed json data using POST
 
         Args:
@@ -311,6 +315,8 @@ class MatrixFederationHttpClient(object):
                 the request body. This will be encoded as JSON.
             long_retries (bool): A boolean that indicates whether we should
                 retry for a short or long time.
+            timeout(int): How long to try (in ms) the destination for before
+                giving up. None indicates no timeout.
 
         Returns:
             Deferred: Succeeds when we get a 2xx HTTP response. The result
@@ -331,6 +337,7 @@ class MatrixFederationHttpClient(object):
             body_callback=body_callback,
             headers_dict={"Content-Type": ["application/json"]},
             long_retries=True,
+            timeout=timeout,
         )
 
         if 200 <= response.code < 300:
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index c5ff16adf3..8f05727652 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -19,11 +19,12 @@ import simplejson as json
 from canonicaljson import encode_canonical_json
 from twisted.internet import defer
 
-import synapse.api.errors
-import synapse.server
-import synapse.types
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import UserID
+from synapse.api.errors import SynapseError, CodeMessageException
+from synapse.http.servlet import (
+    RestServlet, parse_json_object_from_request, parse_integer
+)
+from synapse.types import get_domain_from_id
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 from ._base import client_v2_patterns
 
 logger = logging.getLogger(__name__)
@@ -88,7 +89,7 @@ class KeyUploadServlet(RestServlet):
             device_id = requester.device_id
 
         if device_id is None:
-            raise synapse.api.errors.SynapseError(
+            raise SynapseError(
                 400,
                 "To upload keys, you must pass device_id when authenticating"
             )
@@ -195,18 +196,21 @@ class KeyQueryServlet(RestServlet):
     @defer.inlineCallbacks
     def on_POST(self, request, user_id, device_id):
         yield self.auth.get_user_by_req(request)
+        timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
-        result = yield self.e2e_keys_handler.query_devices(body)
+        result = yield self.e2e_keys_handler.query_devices(body, timeout)
         defer.returnValue(result)
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id, device_id):
         requester = yield self.auth.get_user_by_req(request)
+        timeout = parse_integer(request, "timeout", 10 * 1000)
         auth_user_id = requester.user.to_string()
         user_id = user_id if user_id else auth_user_id
         device_ids = [device_id] if device_id else []
         result = yield self.e2e_keys_handler.query_devices(
-            {"device_keys": {user_id: device_ids}}
+            {"device_keys": {user_id: device_ids}},
+            timeout,
         )
         defer.returnValue(result)
 
@@ -244,39 +248,43 @@ class OneTimeKeyServlet(RestServlet):
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.federation = hs.get_replication_layer()
-        self.is_mine = hs.is_mine
+        self.is_mine_id = hs.is_mine_id
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id, device_id, algorithm):
         yield self.auth.get_user_by_req(request)
+        timeout = parse_integer(request, "timeout", 10 * 1000)
         result = yield self.handle_request(
-            {"one_time_keys": {user_id: {device_id: algorithm}}}
+            {"one_time_keys": {user_id: {device_id: algorithm}}},
+            timeout,
         )
         defer.returnValue(result)
 
     @defer.inlineCallbacks
     def on_POST(self, request, user_id, device_id, algorithm):
         yield self.auth.get_user_by_req(request)
+        timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
-        result = yield self.handle_request(body)
+        result = yield self.handle_request(body, timeout)
         defer.returnValue(result)
 
     @defer.inlineCallbacks
-    def handle_request(self, body):
+    def handle_request(self, body, timeout):
         local_query = []
         remote_queries = {}
+
         for user_id, device_keys in body.get("one_time_keys", {}).items():
-            user = UserID.from_string(user_id)
-            if self.is_mine(user):
+            if self.is_mine_id(user_id):
                 for device_id, algorithm in device_keys.items():
                     local_query.append((user_id, device_id, algorithm))
             else:
-                remote_queries.setdefault(user.domain, {})[user_id] = (
-                    device_keys
-                )
+                domain = get_domain_from_id(user_id)
+                remote_queries.setdefault(domain, {})[user_id] = device_keys
+
         results = yield self.store.claim_e2e_one_time_keys(local_query)
 
         json_result = {}
+        failures = {}
         for user_id, device_keys in results.items():
             for device_id, keys in device_keys.items():
                 for key_id, json_bytes in keys.items():
@@ -284,15 +292,32 @@ class OneTimeKeyServlet(RestServlet):
                         key_id: json.loads(json_bytes)
                     }
 
-        for destination, device_keys in remote_queries.items():
-            remote_result = yield self.federation.claim_client_keys(
-                destination, {"one_time_keys": device_keys}
-            )
-            for user_id, keys in remote_result["one_time_keys"].items():
-                if user_id in device_keys:
-                    json_result[user_id] = keys
-
-        defer.returnValue((200, {"one_time_keys": json_result}))
+        @defer.inlineCallbacks
+        def claim_client_keys(destination):
+            device_keys = remote_queries[destination]
+            try:
+                remote_result = yield self.federation.claim_client_keys(
+                    destination,
+                    {"one_time_keys": device_keys},
+                    timeout=timeout
+                )
+                for user_id, keys in remote_result["one_time_keys"].items():
+                    if user_id in device_keys:
+                        json_result[user_id] = keys
+            except CodeMessageException as e:
+                failures[destination] = {
+                    "status": e.code, "message": e.message
+                }
+
+        yield preserve_context_over_deferred(defer.gatherResults([
+            preserve_fn(claim_client_keys)(destination)
+            for destination in remote_queries
+        ]))
+
+        defer.returnValue((200, {
+            "one_time_keys": json_result,
+            "failures": failures
+        }))
 
 
 def register_servlets(hs, http_server):