summary refs log tree commit diff
path: root/synapse/handlers/e2e_keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/e2e_keys.py')
-rw-r--r--synapse/handlers/e2e_keys.py64
1 files changed, 43 insertions, 21 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 2c7bfd91ed..5bfd700931 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -13,14 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import collections
 import json
 import logging
 
 from twisted.internet import defer
 
-from synapse.api import errors
-import synapse.types
+from synapse.api.errors import SynapseError, CodeMessageException
+from synapse.types import get_domain_from_id
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 
 logger = logging.getLogger(__name__)
 
@@ -30,7 +30,6 @@ class E2eKeysHandler(object):
         self.store = hs.get_datastore()
         self.federation = hs.get_replication_layer()
         self.is_mine_id = hs.is_mine_id
-        self.server_name = hs.hostname
 
         # doesn't really work as part of the generic query API, because the
         # query request requires an object POST, but we abuse the
@@ -40,7 +39,7 @@ class E2eKeysHandler(object):
         )
 
     @defer.inlineCallbacks
-    def query_devices(self, query_body):
+    def query_devices(self, query_body, timeout):
         """ Handle a device key query from a client
 
         {
@@ -63,27 +62,50 @@ class E2eKeysHandler(object):
 
         # separate users by domain.
         # make a map from domain to user_id to device_ids
-        queries_by_domain = collections.defaultdict(dict)
+        local_query = {}
+        remote_queries = {}
+
         for user_id, device_ids in device_keys_query.items():
-            user = synapse.types.UserID.from_string(user_id)
-            queries_by_domain[user.domain][user_id] = device_ids
+            if self.is_mine_id(user_id):
+                local_query[user_id] = device_ids
+            else:
+                domain = get_domain_from_id(user_id)
+                remote_queries.setdefault(domain, {})[user_id] = device_ids
 
         # do the queries
-        # TODO: do these in parallel
+        failures = {}
         results = {}
-        for destination, destination_query in queries_by_domain.items():
-            if destination == self.server_name:
-                res = yield self.query_local_devices(destination_query)
-            else:
-                res = yield self.federation.query_client_keys(
-                    destination, {"device_keys": destination_query}
-                )
-                res = res["device_keys"]
-            for user_id, keys in res.items():
-                if user_id in destination_query:
+        if local_query:
+            local_result = yield self.query_local_devices(local_query)
+            for user_id, keys in local_result.items():
+                if user_id in local_query:
                     results[user_id] = keys
 
-        defer.returnValue((200, {"device_keys": results}))
+        @defer.inlineCallbacks
+        def do_remote_query(destination):
+            destination_query = remote_queries[destination]
+            try:
+                remote_result = yield self.federation.query_client_keys(
+                    destination,
+                    {"device_keys": destination_query},
+                    timeout=timeout
+                )
+                for user_id, keys in remote_result["device_keys"].items():
+                    if user_id in destination_query:
+                        results[user_id] = keys
+            except CodeMessageException as e:
+                failures[destination] = {
+                    "status": e.code, "message": e.message
+                }
+
+        yield preserve_context_over_deferred(defer.gatherResults([
+            preserve_fn(do_remote_query)(destination)
+            for destination in remote_queries
+        ]))
+
+        defer.returnValue((200, {
+            "device_keys": results, "failures": failures,
+        }))
 
     @defer.inlineCallbacks
     def query_local_devices(self, query):
@@ -104,7 +126,7 @@ class E2eKeysHandler(object):
             if not self.is_mine_id(user_id):
                 logger.warning("Request for keys for non-local user %s",
                                user_id)
-                raise errors.SynapseError(400, "Not a user here")
+                raise SynapseError(400, "Not a user here")
 
             if not device_ids:
                 local_query.append((user_id, None))