diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 2c7bfd91ed..5bfd700931 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections
import json
import logging
from twisted.internet import defer
-from synapse.api import errors
-import synapse.types
+from synapse.api.errors import SynapseError, CodeMessageException
+from synapse.types import get_domain_from_id
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
logger = logging.getLogger(__name__)
@@ -30,7 +30,6 @@ class E2eKeysHandler(object):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
self.is_mine_id = hs.is_mine_id
- self.server_name = hs.hostname
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
@@ -40,7 +39,7 @@ class E2eKeysHandler(object):
)
@defer.inlineCallbacks
- def query_devices(self, query_body):
+ def query_devices(self, query_body, timeout):
""" Handle a device key query from a client
{
@@ -63,27 +62,50 @@ class E2eKeysHandler(object):
# separate users by domain.
# make a map from domain to user_id to device_ids
- queries_by_domain = collections.defaultdict(dict)
+ local_query = {}
+ remote_queries = {}
+
for user_id, device_ids in device_keys_query.items():
- user = synapse.types.UserID.from_string(user_id)
- queries_by_domain[user.domain][user_id] = device_ids
+ if self.is_mine_id(user_id):
+ local_query[user_id] = device_ids
+ else:
+ domain = get_domain_from_id(user_id)
+ remote_queries.setdefault(domain, {})[user_id] = device_ids
# do the queries
- # TODO: do these in parallel
+ failures = {}
results = {}
- for destination, destination_query in queries_by_domain.items():
- if destination == self.server_name:
- res = yield self.query_local_devices(destination_query)
- else:
- res = yield self.federation.query_client_keys(
- destination, {"device_keys": destination_query}
- )
- res = res["device_keys"]
- for user_id, keys in res.items():
- if user_id in destination_query:
+ if local_query:
+ local_result = yield self.query_local_devices(local_query)
+ for user_id, keys in local_result.items():
+ if user_id in local_query:
results[user_id] = keys
- defer.returnValue((200, {"device_keys": results}))
+ @defer.inlineCallbacks
+ def do_remote_query(destination):
+ destination_query = remote_queries[destination]
+ try:
+ remote_result = yield self.federation.query_client_keys(
+ destination,
+ {"device_keys": destination_query},
+ timeout=timeout
+ )
+ for user_id, keys in remote_result["device_keys"].items():
+ if user_id in destination_query:
+ results[user_id] = keys
+ except CodeMessageException as e:
+ failures[destination] = {
+ "status": e.code, "message": e.message
+ }
+
+ yield preserve_context_over_deferred(defer.gatherResults([
+ preserve_fn(do_remote_query)(destination)
+ for destination in remote_queries
+ ]))
+
+ defer.returnValue((200, {
+ "device_keys": results, "failures": failures,
+ }))
@defer.inlineCallbacks
def query_local_devices(self, query):
@@ -104,7 +126,7 @@ class E2eKeysHandler(object):
if not self.is_mine_id(user_id):
logger.warning("Request for keys for non-local user %s",
user_id)
- raise errors.SynapseError(400, "Not a user here")
+ raise SynapseError(400, "Not a user here")
if not device_ids:
local_query.append((user_id, None))
|