diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 3395c9e41e..cf8a52510d 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -176,7 +176,7 @@ class FederationClient(FederationBase):
)
@log_function
- def query_client_keys(self, destination, content):
+ def query_client_keys(self, destination, content, timeout):
"""Query device keys for a device hosted on a remote server.
Args:
@@ -188,10 +188,12 @@ class FederationClient(FederationBase):
response
"""
sent_queries_counter.inc("client_device_keys")
- return self.transport_layer.query_client_keys(destination, content)
+ return self.transport_layer.query_client_keys(
+ destination, content, timeout
+ )
@log_function
- def claim_client_keys(self, destination, content):
+ def claim_client_keys(self, destination, content, timeout):
"""Claims one-time keys for a device hosted on a remote server.
Args:
@@ -203,7 +205,9 @@ class FederationClient(FederationBase):
response
"""
sent_queries_counter.inc("client_one_time_keys")
- return self.transport_layer.claim_client_keys(destination, content)
+ return self.transport_layer.claim_client_keys(
+ destination, content, timeout
+ )
@defer.inlineCallbacks
@log_function
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 3d088e43cb..2b138526ba 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -298,7 +298,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
- def query_client_keys(self, destination, query_content):
+ def query_client_keys(self, destination, query_content, timeout):
"""Query the device keys for a list of user ids hosted on a remote
server.
@@ -327,12 +327,13 @@ class TransportLayerClient(object):
destination=destination,
path=path,
data=query_content,
+ timeout=timeout,
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
- def claim_client_keys(self, destination, query_content):
+ def claim_client_keys(self, destination, query_content, timeout):
"""Claim one-time keys for a list of devices hosted on a remote server.
Request:
@@ -363,6 +364,7 @@ class TransportLayerClient(object):
destination=destination,
path=path,
data=query_content,
+ timeout=timeout,
)
defer.returnValue(content)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 2c7bfd91ed..5bfd700931 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections
import json
import logging
from twisted.internet import defer
-from synapse.api import errors
-import synapse.types
+from synapse.api.errors import SynapseError, CodeMessageException
+from synapse.types import get_domain_from_id
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
logger = logging.getLogger(__name__)
@@ -30,7 +30,6 @@ class E2eKeysHandler(object):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
self.is_mine_id = hs.is_mine_id
- self.server_name = hs.hostname
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
@@ -40,7 +39,7 @@ class E2eKeysHandler(object):
)
@defer.inlineCallbacks
- def query_devices(self, query_body):
+ def query_devices(self, query_body, timeout):
""" Handle a device key query from a client
{
@@ -63,27 +62,50 @@ class E2eKeysHandler(object):
# separate users by domain.
# make a map from domain to user_id to device_ids
- queries_by_domain = collections.defaultdict(dict)
+ local_query = {}
+ remote_queries = {}
+
for user_id, device_ids in device_keys_query.items():
- user = synapse.types.UserID.from_string(user_id)
- queries_by_domain[user.domain][user_id] = device_ids
+ if self.is_mine_id(user_id):
+ local_query[user_id] = device_ids
+ else:
+ domain = get_domain_from_id(user_id)
+ remote_queries.setdefault(domain, {})[user_id] = device_ids
# do the queries
- # TODO: do these in parallel
+ failures = {}
results = {}
- for destination, destination_query in queries_by_domain.items():
- if destination == self.server_name:
- res = yield self.query_local_devices(destination_query)
- else:
- res = yield self.federation.query_client_keys(
- destination, {"device_keys": destination_query}
- )
- res = res["device_keys"]
- for user_id, keys in res.items():
- if user_id in destination_query:
+ if local_query:
+ local_result = yield self.query_local_devices(local_query)
+ for user_id, keys in local_result.items():
+ if user_id in local_query:
results[user_id] = keys
- defer.returnValue((200, {"device_keys": results}))
+ @defer.inlineCallbacks
+ def do_remote_query(destination):
+ destination_query = remote_queries[destination]
+ try:
+ remote_result = yield self.federation.query_client_keys(
+ destination,
+ {"device_keys": destination_query},
+ timeout=timeout
+ )
+ for user_id, keys in remote_result["device_keys"].items():
+ if user_id in destination_query:
+ results[user_id] = keys
+ except CodeMessageException as e:
+ failures[destination] = {
+ "status": e.code, "message": e.message
+ }
+
+ yield preserve_context_over_deferred(defer.gatherResults([
+ preserve_fn(do_remote_query)(destination)
+ for destination in remote_queries
+ ]))
+
+ defer.returnValue((200, {
+ "device_keys": results, "failures": failures,
+ }))
@defer.inlineCallbacks
def query_local_devices(self, query):
@@ -104,7 +126,7 @@ class E2eKeysHandler(object):
if not self.is_mine_id(user_id):
logger.warning("Request for keys for non-local user %s",
user_id)
- raise errors.SynapseError(400, "Not a user here")
+ raise SynapseError(400, "Not a user here")
if not device_ids:
local_query.append((user_id, None))
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index f93093dd85..d0556ae347 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -246,7 +246,7 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks
def put_json(self, destination, path, data={}, json_data_callback=None,
- long_retries=False):
+ long_retries=False, timeout=None):
""" Sends the specifed json data using PUT
Args:
@@ -259,6 +259,8 @@ class MatrixFederationHttpClient(object):
use as the request body.
long_retries (bool): A boolean that indicates whether we should
retry for a short or long time.
+ timeout(int): How long to try (in ms) the destination for before
+ giving up. None indicates no timeout.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
@@ -285,6 +287,7 @@ class MatrixFederationHttpClient(object):
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries,
+ timeout=timeout,
)
if 200 <= response.code < 300:
@@ -300,7 +303,8 @@ class MatrixFederationHttpClient(object):
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
- def post_json(self, destination, path, data={}, long_retries=True):
+ def post_json(self, destination, path, data={}, long_retries=True,
+ timeout=None):
""" Sends the specifed json data using POST
Args:
@@ -311,6 +315,8 @@ class MatrixFederationHttpClient(object):
the request body. This will be encoded as JSON.
long_retries (bool): A boolean that indicates whether we should
retry for a short or long time.
+ timeout(int): How long to try (in ms) the destination for before
+ giving up. None indicates no timeout.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
@@ -331,6 +337,7 @@ class MatrixFederationHttpClient(object):
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
long_retries=True,
+ timeout=timeout,
)
if 200 <= response.code < 300:
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index c5ff16adf3..8f05727652 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -19,11 +19,12 @@ import simplejson as json
from canonicaljson import encode_canonical_json
from twisted.internet import defer
-import synapse.api.errors
-import synapse.server
-import synapse.types
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import UserID
+from synapse.api.errors import SynapseError, CodeMessageException
+from synapse.http.servlet import (
+ RestServlet, parse_json_object_from_request, parse_integer
+)
+from synapse.types import get_domain_from_id
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
@@ -88,7 +89,7 @@ class KeyUploadServlet(RestServlet):
device_id = requester.device_id
if device_id is None:
- raise synapse.api.errors.SynapseError(
+ raise SynapseError(
400,
"To upload keys, you must pass device_id when authenticating"
)
@@ -195,18 +196,21 @@ class KeyQueryServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id):
yield self.auth.get_user_by_req(request)
+ timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
- result = yield self.e2e_keys_handler.query_devices(body)
+ result = yield self.e2e_keys_handler.query_devices(body, timeout)
defer.returnValue(result)
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id):
requester = yield self.auth.get_user_by_req(request)
+ timeout = parse_integer(request, "timeout", 10 * 1000)
auth_user_id = requester.user.to_string()
user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else []
result = yield self.e2e_keys_handler.query_devices(
- {"device_keys": {user_id: device_ids}}
+ {"device_keys": {user_id: device_ids}},
+ timeout,
)
defer.returnValue(result)
@@ -244,39 +248,43 @@ class OneTimeKeyServlet(RestServlet):
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.federation = hs.get_replication_layer()
- self.is_mine = hs.is_mine
+ self.is_mine_id = hs.is_mine_id
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request)
+ timeout = parse_integer(request, "timeout", 10 * 1000)
result = yield self.handle_request(
- {"one_time_keys": {user_id: {device_id: algorithm}}}
+ {"one_time_keys": {user_id: {device_id: algorithm}}},
+ timeout,
)
defer.returnValue(result)
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request)
+ timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
- result = yield self.handle_request(body)
+ result = yield self.handle_request(body, timeout)
defer.returnValue(result)
@defer.inlineCallbacks
- def handle_request(self, body):
+ def handle_request(self, body, timeout):
local_query = []
remote_queries = {}
+
for user_id, device_keys in body.get("one_time_keys", {}).items():
- user = UserID.from_string(user_id)
- if self.is_mine(user):
+ if self.is_mine_id(user_id):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
- remote_queries.setdefault(user.domain, {})[user_id] = (
- device_keys
- )
+ domain = get_domain_from_id(user_id)
+ remote_queries.setdefault(domain, {})[user_id] = device_keys
+
results = yield self.store.claim_e2e_one_time_keys(local_query)
json_result = {}
+ failures = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
@@ -284,15 +292,32 @@ class OneTimeKeyServlet(RestServlet):
key_id: json.loads(json_bytes)
}
- for destination, device_keys in remote_queries.items():
- remote_result = yield self.federation.claim_client_keys(
- destination, {"one_time_keys": device_keys}
- )
- for user_id, keys in remote_result["one_time_keys"].items():
- if user_id in device_keys:
- json_result[user_id] = keys
-
- defer.returnValue((200, {"one_time_keys": json_result}))
+ @defer.inlineCallbacks
+ def claim_client_keys(destination):
+ device_keys = remote_queries[destination]
+ try:
+ remote_result = yield self.federation.claim_client_keys(
+ destination,
+ {"one_time_keys": device_keys},
+ timeout=timeout
+ )
+ for user_id, keys in remote_result["one_time_keys"].items():
+ if user_id in device_keys:
+ json_result[user_id] = keys
+ except CodeMessageException as e:
+ failures[destination] = {
+ "status": e.code, "message": e.message
+ }
+
+ yield preserve_context_over_deferred(defer.gatherResults([
+ preserve_fn(claim_client_keys)(destination)
+ for destination in remote_queries
+ ]))
+
+ defer.returnValue((200, {
+ "one_time_keys": json_result,
+ "failures": failures
+ }))
def register_servlets(hs, http_server):
|