summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/federation/federation_server.py20
-rw-r--r--synapse/federation/transport/server.py4
-rw-r--r--synapse/handlers/e2e_keys.py115
3 files changed, 92 insertions, 47 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 85f5e752fe..e637f2a8bd 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -348,27 +348,9 @@ class FederationServer(FederationBase):
             (200, send_content)
         )
 
-    @defer.inlineCallbacks
     @log_function
     def on_query_client_keys(self, origin, content):
-        query = []
-        for user_id, device_ids in content.get("device_keys", {}).items():
-            if not device_ids:
-                query.append((user_id, None))
-            else:
-                for device_id in device_ids:
-                    query.append((user_id, device_id))
-
-        results = yield self.store.get_e2e_device_keys(query)
-
-        json_result = {}
-        for user_id, device_keys in results.items():
-            for device_id, json_bytes in device_keys.items():
-                json_result.setdefault(user_id, {})[device_id] = json.loads(
-                    json_bytes
-                )
-
-        defer.returnValue({"device_keys": json_result})
+        return self.on_query_request("client_keys", content)
 
     @defer.inlineCallbacks
     @log_function
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 26fa88ae84..1a88413d18 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -367,10 +367,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
 class FederationClientKeysQueryServlet(BaseFederationServlet):
     PATH = "/user/keys/query"
 
-    @defer.inlineCallbacks
     def on_POST(self, origin, content, query):
-        response = yield self.handler.on_query_client_keys(origin, content)
-        defer.returnValue((200, response))
+        return self.handler.on_query_client_keys(origin, content)
 
 
 class FederationClientKeysClaimServlet(BaseFederationServlet):
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 73a14cf952..9c7e9494d6 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -13,12 +13,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import collections
 import json
 import logging
 
 from twisted.internet import defer
 
+from synapse.api import errors
 import synapse.types
+
 from ._base import BaseHandler
 
 logger = logging.getLogger(__name__)
@@ -29,39 +32,101 @@ class E2eKeysHandler(BaseHandler):
         super(E2eKeysHandler, self).__init__(hs)
         self.store = hs.get_datastore()
         self.federation = hs.get_replication_layer()
-        self.is_mine = hs.is_mine
+        self.is_mine_id = hs.is_mine_id
+
+        # doesn't really work as part of the generic query API, because the
+        # query request requires an object POST, but we abuse the
+        # "query handler" interface.
+        self.federation.register_query_handler(
+            "client_keys", self.on_federation_query_client_keys
+        )
 
     @defer.inlineCallbacks
     def query_devices(self, query_body):
-        local_query = []
-        remote_queries = {}
-        for user_id, device_ids in query_body.get("device_keys", {}).items():
+        """ Handle a device key query from a client
+
+        {
+            "device_keys": {
+                "<user_id>": ["<device_id>"]
+            }
+        }
+        ->
+        {
+            "device_keys": {
+                "<user_id>": {
+                    "<device_id>": {
+                        ...
+                    }
+                }
+            }
+        }
+        """
+        device_keys_query = query_body.get("device_keys", {})
+
+        # separate users by domain.
+        # make a map from domain to user_id to device_ids
+        queries_by_domain = collections.defaultdict(dict)
+        for user_id, device_ids in device_keys_query.items():
             user = synapse.types.UserID.from_string(user_id)
-            if self.is_mine(user):
-                if not device_ids:
-                    local_query.append((user_id, None))
-                else:
-                    for device_id in device_ids:
-                        local_query.append((user_id, device_id))
+            queries_by_domain[user.domain][user_id] = device_ids
+
+        # do the queries
+        # TODO: do these in parallel
+        results = {}
+        for destination, destination_query in queries_by_domain.items():
+            if destination == self.hs.hostname:
+                res = yield self.query_local_devices(destination_query)
             else:
-                remote_queries.setdefault(user.domain, {})[user_id] = list(
-                    device_ids
+                res = yield self.federation.query_client_keys(
+                    destination, {"device_keys": destination_query}
                 )
+                res = res["device_keys"]
+            for user_id, keys in res.items():
+                if user_id in destination_query:
+                    results[user_id] = keys
+
+        defer.returnValue((200, {"device_keys": results}))
+
+    @defer.inlineCallbacks
+    def query_local_devices(self, query):
+        """Get E2E device keys for local users
+
+        Args:
+            query (dict[string, list[string]|None): map from user_id to a list
+                 of devices to query (None for all devices)
+
+        Returns:
+            defer.Deferred: (resolves to dict[string, dict[string, dict]]):
+                 map from user_id -> device_id -> device details
+        """
+        local_query = []
+
+        for user_id, device_ids in query.items():
+            if not self.is_mine_id(user_id):
+                logger.warning("Request for keys for non-local user %s",
+                               user_id)
+                raise errors.SynapseError(400, "Not a user here")
+
+            if not device_ids:
+                local_query.append((user_id, None))
+            else:
+                for device_id in device_ids:
+                    local_query.append((user_id, device_id))
+
         results = yield self.store.get_e2e_device_keys(local_query)
 
-        json_result = {}
+        # un-jsonify the results
+        json_result = collections.defaultdict(dict)
         for user_id, device_keys in results.items():
             for device_id, json_bytes in device_keys.items():
-                json_result.setdefault(user_id, {})[
-                    device_id] = json.loads(
-                    json_bytes
-                )
+                json_result[user_id][device_id] = json.loads(json_bytes)
 
-        for destination, device_keys in remote_queries.items():
-            remote_result = yield self.federation.query_client_keys(
-                destination, {"device_keys": device_keys}
-            )
-            for user_id, keys in remote_result["device_keys"].items():
-                if user_id in device_keys:
-                    json_result[user_id] = keys
-        defer.returnValue((200, {"device_keys": json_result}))
+        defer.returnValue(json_result)
+
+    @defer.inlineCallbacks
+    def on_federation_query_client_keys(self, query_body):
+        """ Handle a device key query from a federated server
+        """
+        device_keys_query = query_body.get("device_keys", {})
+        res = yield self.query_local_devices(device_keys_query)
+        defer.returnValue({"device_keys": res})