diff options
Diffstat (limited to 'synapse/handlers')
-rw-r--r-- | synapse/handlers/e2e_keys.py | 64 |
1 files changed, 43 insertions, 21 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 2c7bfd91ed..5bfd700931 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import json import logging from twisted.internet import defer -from synapse.api import errors -import synapse.types +from synapse.api.errors import SynapseError, CodeMessageException +from synapse.types import get_domain_from_id +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred logger = logging.getLogger(__name__) @@ -30,7 +30,6 @@ class E2eKeysHandler(object): self.store = hs.get_datastore() self.federation = hs.get_replication_layer() self.is_mine_id = hs.is_mine_id - self.server_name = hs.hostname # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the @@ -40,7 +39,7 @@ class E2eKeysHandler(object): ) @defer.inlineCallbacks - def query_devices(self, query_body): + def query_devices(self, query_body, timeout): """ Handle a device key query from a client { @@ -63,27 +62,50 @@ class E2eKeysHandler(object): # separate users by domain. # make a map from domain to user_id to device_ids - queries_by_domain = collections.defaultdict(dict) + local_query = {} + remote_queries = {} + for user_id, device_ids in device_keys_query.items(): - user = synapse.types.UserID.from_string(user_id) - queries_by_domain[user.domain][user_id] = device_ids + if self.is_mine_id(user_id): + local_query[user_id] = device_ids + else: + domain = get_domain_from_id(user_id) + remote_queries.setdefault(domain, {})[user_id] = device_ids # do the queries - # TODO: do these in parallel + failures = {} results = {} - for destination, destination_query in queries_by_domain.items(): - if destination == self.server_name: - res = yield self.query_local_devices(destination_query) - else: - res = yield self.federation.query_client_keys( - destination, {"device_keys": destination_query} - ) - res = res["device_keys"] - for user_id, keys in res.items(): - if user_id in destination_query: + if local_query: + local_result = yield self.query_local_devices(local_query) + for user_id, keys in local_result.items(): + if user_id in local_query: results[user_id] = keys - defer.returnValue((200, {"device_keys": results})) + @defer.inlineCallbacks + def do_remote_query(destination): + destination_query = remote_queries[destination] + try: + remote_result = yield self.federation.query_client_keys( + destination, + {"device_keys": destination_query}, + timeout=timeout + ) + for user_id, keys in remote_result["device_keys"].items(): + if user_id in destination_query: + results[user_id] = keys + except CodeMessageException as e: + failures[destination] = { + "status": e.code, "message": e.message + } + + yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(do_remote_query)(destination) + for destination in remote_queries + ])) + + defer.returnValue((200, { + "device_keys": results, "failures": failures, + })) @defer.inlineCallbacks def query_local_devices(self, query): @@ -104,7 +126,7 @@ class E2eKeysHandler(object): if not self.is_mine_id(user_id): logger.warning("Request for keys for non-local user %s", user_id) - raise errors.SynapseError(400, "Not a user here") + raise SynapseError(400, "Not a user here") if not device_ids: local_query.append((user_id, None)) |