summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorMark Haines <mark.haines@matrix.org>2016-09-12 18:17:09 +0100
committerMark Haines <mark.haines@matrix.org>2016-09-12 18:17:09 +0100
commit949c2c54352f5a1fe2d8de39c4ddebc1f1e13aac (patch)
tree268c6c1da8d430c55569e7fb3ad617620126a92e /synapse
parentMerge pull request #1104 from matrix-org/markjh/direct_to_device_federation_sync (diff)
downloadsynapse-949c2c54352f5a1fe2d8de39c4ddebc1f1e13aac.tar.xz
Add a timeout parameter for end2end key queries.
Add a timeout parameter for controlling how long synapse will wait
for responses from remote servers. For servers that fail include how
they failed to make it easier to debug.

Fetch keys from different servers in parallel rather than in series.

Set the default timeout to 10s.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/federation/federation_client.py12
-rw-r--r--synapse/federation/transport/client.py6
-rw-r--r--synapse/handlers/e2e_keys.py64
-rw-r--r--synapse/http/matrixfederationclient.py11
-rw-r--r--synapse/rest/client/v2_alpha/keys.py77
5 files changed, 115 insertions, 55 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 3395c9e41e..cf8a52510d 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -176,7 +176,7 @@ class FederationClient(FederationBase):
         )
 
     @log_function
-    def query_client_keys(self, destination, content):
+    def query_client_keys(self, destination, content, timeout):
         """Query device keys for a device hosted on a remote server.
 
         Args:
@@ -188,10 +188,12 @@ class FederationClient(FederationBase):
             response
         """
         sent_queries_counter.inc("client_device_keys")
-        return self.transport_layer.query_client_keys(destination, content)
+        return self.transport_layer.query_client_keys(
+            destination, content, timeout
+        )
 
     @log_function
-    def claim_client_keys(self, destination, content):
+    def claim_client_keys(self, destination, content, timeout):
         """Claims one-time keys for a device hosted on a remote server.
 
         Args:
@@ -203,7 +205,9 @@ class FederationClient(FederationBase):
             response
         """
         sent_queries_counter.inc("client_one_time_keys")
-        return self.transport_layer.claim_client_keys(destination, content)
+        return self.transport_layer.claim_client_keys(
+            destination, content, timeout
+        )
 
     @defer.inlineCallbacks
     @log_function
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 3d088e43cb..2b138526ba 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -298,7 +298,7 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
-    def query_client_keys(self, destination, query_content):
+    def query_client_keys(self, destination, query_content, timeout):
         """Query the device keys for a list of user ids hosted on a remote
         server.
 
@@ -327,12 +327,13 @@ class TransportLayerClient(object):
             destination=destination,
             path=path,
             data=query_content,
+            timeout=timeout,
         )
         defer.returnValue(content)
 
     @defer.inlineCallbacks
     @log_function
-    def claim_client_keys(self, destination, query_content):
+    def claim_client_keys(self, destination, query_content, timeout):
         """Claim one-time keys for a list of devices hosted on a remote server.
 
         Request:
@@ -363,6 +364,7 @@ class TransportLayerClient(object):
             destination=destination,
             path=path,
             data=query_content,
+            timeout=timeout,
         )
         defer.returnValue(content)
 
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 2c7bfd91ed..5bfd700931 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -13,14 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import collections
 import json
 import logging
 
 from twisted.internet import defer
 
-from synapse.api import errors
-import synapse.types
+from synapse.api.errors import SynapseError, CodeMessageException
+from synapse.types import get_domain_from_id
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 
 logger = logging.getLogger(__name__)
 
@@ -30,7 +30,6 @@ class E2eKeysHandler(object):
         self.store = hs.get_datastore()
         self.federation = hs.get_replication_layer()
         self.is_mine_id = hs.is_mine_id
-        self.server_name = hs.hostname
 
         # doesn't really work as part of the generic query API, because the
         # query request requires an object POST, but we abuse the
@@ -40,7 +39,7 @@ class E2eKeysHandler(object):
         )
 
     @defer.inlineCallbacks
-    def query_devices(self, query_body):
+    def query_devices(self, query_body, timeout):
         """ Handle a device key query from a client
 
         {
@@ -63,27 +62,50 @@ class E2eKeysHandler(object):
 
         # separate users by domain.
         # make a map from domain to user_id to device_ids
-        queries_by_domain = collections.defaultdict(dict)
+        local_query = {}
+        remote_queries = {}
+
         for user_id, device_ids in device_keys_query.items():
-            user = synapse.types.UserID.from_string(user_id)
-            queries_by_domain[user.domain][user_id] = device_ids
+            if self.is_mine_id(user_id):
+                local_query[user_id] = device_ids
+            else:
+                domain = get_domain_from_id(user_id)
+                remote_queries.setdefault(domain, {})[user_id] = device_ids
 
         # do the queries
-        # TODO: do these in parallel
+        failures = {}
         results = {}
-        for destination, destination_query in queries_by_domain.items():
-            if destination == self.server_name:
-                res = yield self.query_local_devices(destination_query)
-            else:
-                res = yield self.federation.query_client_keys(
-                    destination, {"device_keys": destination_query}
-                )
-                res = res["device_keys"]
-            for user_id, keys in res.items():
-                if user_id in destination_query:
+        if local_query:
+            local_result = yield self.query_local_devices(local_query)
+            for user_id, keys in local_result.items():
+                if user_id in local_query:
                     results[user_id] = keys
 
-        defer.returnValue((200, {"device_keys": results}))
+        @defer.inlineCallbacks
+        def do_remote_query(destination):
+            destination_query = remote_queries[destination]
+            try:
+                remote_result = yield self.federation.query_client_keys(
+                    destination,
+                    {"device_keys": destination_query},
+                    timeout=timeout
+                )
+                for user_id, keys in remote_result["device_keys"].items():
+                    if user_id in destination_query:
+                        results[user_id] = keys
+            except CodeMessageException as e:
+                failures[destination] = {
+                    "status": e.code, "message": e.message
+                }
+
+        yield preserve_context_over_deferred(defer.gatherResults([
+            preserve_fn(do_remote_query)(destination)
+            for destination in remote_queries
+        ]))
+
+        defer.returnValue((200, {
+            "device_keys": results, "failures": failures,
+        }))
 
     @defer.inlineCallbacks
     def query_local_devices(self, query):
@@ -104,7 +126,7 @@ class E2eKeysHandler(object):
             if not self.is_mine_id(user_id):
                 logger.warning("Request for keys for non-local user %s",
                                user_id)
-                raise errors.SynapseError(400, "Not a user here")
+                raise SynapseError(400, "Not a user here")
 
             if not device_ids:
                 local_query.append((user_id, None))
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index f93093dd85..d0556ae347 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -246,7 +246,7 @@ class MatrixFederationHttpClient(object):
 
     @defer.inlineCallbacks
     def put_json(self, destination, path, data={}, json_data_callback=None,
-                 long_retries=False):
+                 long_retries=False, timeout=None):
         """ Sends the specifed json data using PUT
 
         Args:
@@ -259,6 +259,8 @@ class MatrixFederationHttpClient(object):
                 use as the request body.
             long_retries (bool): A boolean that indicates whether we should
                 retry for a short or long time.
+            timeout(int): How long to try (in ms) the destination for before
+                giving up. None indicates no timeout.
 
         Returns:
             Deferred: Succeeds when we get a 2xx HTTP response. The result
@@ -285,6 +287,7 @@ class MatrixFederationHttpClient(object):
             body_callback=body_callback,
             headers_dict={"Content-Type": ["application/json"]},
             long_retries=long_retries,
+            timeout=timeout,
         )
 
         if 200 <= response.code < 300:
@@ -300,7 +303,8 @@ class MatrixFederationHttpClient(object):
         defer.returnValue(json.loads(body))
 
     @defer.inlineCallbacks
-    def post_json(self, destination, path, data={}, long_retries=True):
+    def post_json(self, destination, path, data={}, long_retries=True,
+                  timeout=None):
         """ Sends the specifed json data using POST
 
         Args:
@@ -311,6 +315,8 @@ class MatrixFederationHttpClient(object):
                 the request body. This will be encoded as JSON.
             long_retries (bool): A boolean that indicates whether we should
                 retry for a short or long time.
+            timeout(int): How long to try (in ms) the destination for before
+                giving up. None indicates no timeout.
 
         Returns:
             Deferred: Succeeds when we get a 2xx HTTP response. The result
@@ -331,6 +337,7 @@ class MatrixFederationHttpClient(object):
             body_callback=body_callback,
             headers_dict={"Content-Type": ["application/json"]},
             long_retries=True,
+            timeout=timeout,
         )
 
         if 200 <= response.code < 300:
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index c5ff16adf3..8f05727652 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -19,11 +19,12 @@ import simplejson as json
 from canonicaljson import encode_canonical_json
 from twisted.internet import defer
 
-import synapse.api.errors
-import synapse.server
-import synapse.types
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import UserID
+from synapse.api.errors import SynapseError, CodeMessageException
+from synapse.http.servlet import (
+    RestServlet, parse_json_object_from_request, parse_integer
+)
+from synapse.types import get_domain_from_id
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 from ._base import client_v2_patterns
 
 logger = logging.getLogger(__name__)
@@ -88,7 +89,7 @@ class KeyUploadServlet(RestServlet):
             device_id = requester.device_id
 
         if device_id is None:
-            raise synapse.api.errors.SynapseError(
+            raise SynapseError(
                 400,
                 "To upload keys, you must pass device_id when authenticating"
             )
@@ -195,18 +196,21 @@ class KeyQueryServlet(RestServlet):
     @defer.inlineCallbacks
     def on_POST(self, request, user_id, device_id):
         yield self.auth.get_user_by_req(request)
+        timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
-        result = yield self.e2e_keys_handler.query_devices(body)
+        result = yield self.e2e_keys_handler.query_devices(body, timeout)
         defer.returnValue(result)
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id, device_id):
         requester = yield self.auth.get_user_by_req(request)
+        timeout = parse_integer(request, "timeout", 10 * 1000)
         auth_user_id = requester.user.to_string()
         user_id = user_id if user_id else auth_user_id
         device_ids = [device_id] if device_id else []
         result = yield self.e2e_keys_handler.query_devices(
-            {"device_keys": {user_id: device_ids}}
+            {"device_keys": {user_id: device_ids}},
+            timeout,
         )
         defer.returnValue(result)
 
@@ -244,39 +248,43 @@ class OneTimeKeyServlet(RestServlet):
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.federation = hs.get_replication_layer()
-        self.is_mine = hs.is_mine
+        self.is_mine_id = hs.is_mine_id
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id, device_id, algorithm):
         yield self.auth.get_user_by_req(request)
+        timeout = parse_integer(request, "timeout", 10 * 1000)
         result = yield self.handle_request(
-            {"one_time_keys": {user_id: {device_id: algorithm}}}
+            {"one_time_keys": {user_id: {device_id: algorithm}}},
+            timeout,
         )
         defer.returnValue(result)
 
     @defer.inlineCallbacks
     def on_POST(self, request, user_id, device_id, algorithm):
         yield self.auth.get_user_by_req(request)
+        timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
-        result = yield self.handle_request(body)
+        result = yield self.handle_request(body, timeout)
         defer.returnValue(result)
 
     @defer.inlineCallbacks
-    def handle_request(self, body):
+    def handle_request(self, body, timeout):
         local_query = []
         remote_queries = {}
+
         for user_id, device_keys in body.get("one_time_keys", {}).items():
-            user = UserID.from_string(user_id)
-            if self.is_mine(user):
+            if self.is_mine_id(user_id):
                 for device_id, algorithm in device_keys.items():
                     local_query.append((user_id, device_id, algorithm))
             else:
-                remote_queries.setdefault(user.domain, {})[user_id] = (
-                    device_keys
-                )
+                domain = get_domain_from_id(user_id)
+                remote_queries.setdefault(domain, {})[user_id] = device_keys
+
         results = yield self.store.claim_e2e_one_time_keys(local_query)
 
         json_result = {}
+        failures = {}
         for user_id, device_keys in results.items():
             for device_id, keys in device_keys.items():
                 for key_id, json_bytes in keys.items():
@@ -284,15 +292,32 @@ class OneTimeKeyServlet(RestServlet):
                         key_id: json.loads(json_bytes)
                     }
 
-        for destination, device_keys in remote_queries.items():
-            remote_result = yield self.federation.claim_client_keys(
-                destination, {"one_time_keys": device_keys}
-            )
-            for user_id, keys in remote_result["one_time_keys"].items():
-                if user_id in device_keys:
-                    json_result[user_id] = keys
-
-        defer.returnValue((200, {"one_time_keys": json_result}))
+        @defer.inlineCallbacks
+        def claim_client_keys(destination):
+            device_keys = remote_queries[destination]
+            try:
+                remote_result = yield self.federation.claim_client_keys(
+                    destination,
+                    {"one_time_keys": device_keys},
+                    timeout=timeout
+                )
+                for user_id, keys in remote_result["one_time_keys"].items():
+                    if user_id in device_keys:
+                        json_result[user_id] = keys
+            except CodeMessageException as e:
+                failures[destination] = {
+                    "status": e.code, "message": e.message
+                }
+
+        yield preserve_context_over_deferred(defer.gatherResults([
+            preserve_fn(claim_client_keys)(destination)
+            for destination in remote_queries
+        ]))
+
+        defer.returnValue((200, {
+            "one_time_keys": json_result,
+            "failures": failures
+        }))
 
 
 def register_servlets(hs, http_server):