summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2016-08-02 18:06:31 +0100
committerRichard van der Hoff <richard@matrix.org>2016-08-02 18:12:00 +0100
commit1efee2f52b931ddcd90e87d06c7ea614da2c9cd0 (patch)
treeafd7d79065972cc5bacd47f872cab3ec79e94def /synapse
parentMove e2e query logic into a handler (diff)
downloadsynapse-1efee2f52b931ddcd90e87d06c7ea614da2c9cd0.tar.xz
E2E keys: Make federation query share code with client query
Refactor the e2e query handler to separate out the local query, and then make
the federation handler use it.
Diffstat (limited to '')
-rw-r--r--synapse/federation/federation_server.py20
-rw-r--r--synapse/federation/transport/server.py4
-rw-r--r--synapse/handlers/e2e_keys.py115
3 files changed, 92 insertions, 47 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 85f5e752fe..e637f2a8bd 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -348,27 +348,9 @@ class FederationServer(FederationBase):
             (200, send_content)
         )
 
-    @defer.inlineCallbacks
     @log_function
     def on_query_client_keys(self, origin, content):
-        query = []
-        for user_id, device_ids in content.get("device_keys", {}).items():
-            if not device_ids:
-                query.append((user_id, None))
-            else:
-                for device_id in device_ids:
-                    query.append((user_id, device_id))
-
-        results = yield self.store.get_e2e_device_keys(query)
-
-        json_result = {}
-        for user_id, device_keys in results.items():
-            for device_id, json_bytes in device_keys.items():
-                json_result.setdefault(user_id, {})[device_id] = json.loads(
-                    json_bytes
-                )
-
-        defer.returnValue({"device_keys": json_result})
+        return self.on_query_request("client_keys", content)
 
     @defer.inlineCallbacks
     @log_function
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 26fa88ae84..1a88413d18 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -367,10 +367,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
 class FederationClientKeysQueryServlet(BaseFederationServlet):
     PATH = "/user/keys/query"
 
-    @defer.inlineCallbacks
     def on_POST(self, origin, content, query):
-        response = yield self.handler.on_query_client_keys(origin, content)
-        defer.returnValue((200, response))
+        return self.handler.on_query_client_keys(origin, content)
 
 
 class FederationClientKeysClaimServlet(BaseFederationServlet):
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 73a14cf952..9c7e9494d6 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -13,12 +13,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import collections
 import json
 import logging
 
 from twisted.internet import defer
 
+from synapse.api import errors
 import synapse.types
+
 from ._base import BaseHandler
 
 logger = logging.getLogger(__name__)
@@ -29,39 +32,101 @@ class E2eKeysHandler(BaseHandler):
         super(E2eKeysHandler, self).__init__(hs)
         self.store = hs.get_datastore()
         self.federation = hs.get_replication_layer()
-        self.is_mine = hs.is_mine
+        self.is_mine_id = hs.is_mine_id
+
+        # doesn't really work as part of the generic query API, because the
+        # query request requires an object POST, but we abuse the
+        # "query handler" interface.
+        self.federation.register_query_handler(
+            "client_keys", self.on_federation_query_client_keys
+        )
 
     @defer.inlineCallbacks
     def query_devices(self, query_body):
-        local_query = []
-        remote_queries = {}
-        for user_id, device_ids in query_body.get("device_keys", {}).items():
+        """ Handle a device key query from a client
+
+        {
+            "device_keys": {
+                "<user_id>": ["<device_id>"]
+            }
+        }
+        ->
+        {
+            "device_keys": {
+                "<user_id>": {
+                    "<device_id>": {
+                        ...
+                    }
+                }
+            }
+        }
+        """
+        device_keys_query = query_body.get("device_keys", {})
+
+        # separate users by domain.
+        # make a map from domain to user_id to device_ids
+        queries_by_domain = collections.defaultdict(dict)
+        for user_id, device_ids in device_keys_query.items():
             user = synapse.types.UserID.from_string(user_id)
-            if self.is_mine(user):
-                if not device_ids:
-                    local_query.append((user_id, None))
-                else:
-                    for device_id in device_ids:
-                        local_query.append((user_id, device_id))
+            queries_by_domain[user.domain][user_id] = device_ids
+
+        # do the queries
+        # TODO: do these in parallel
+        results = {}
+        for destination, destination_query in queries_by_domain.items():
+            if destination == self.hs.hostname:
+                res = yield self.query_local_devices(destination_query)
             else:
-                remote_queries.setdefault(user.domain, {})[user_id] = list(
-                    device_ids
+                res = yield self.federation.query_client_keys(
+                    destination, {"device_keys": destination_query}
                 )
+                res = res["device_keys"]
+            for user_id, keys in res.items():
+                if user_id in destination_query:
+                    results[user_id] = keys
+
+        defer.returnValue((200, {"device_keys": results}))
+
+    @defer.inlineCallbacks
+    def query_local_devices(self, query):
+        """Get E2E device keys for local users
+
+        Args:
+            query (dict[string, list[string]|None): map from user_id to a list
+                 of devices to query (None for all devices)
+
+        Returns:
+            defer.Deferred: (resolves to dict[string, dict[string, dict]]):
+                 map from user_id -> device_id -> device details
+        """
+        local_query = []
+
+        for user_id, device_ids in query.items():
+            if not self.is_mine_id(user_id):
+                logger.warning("Request for keys for non-local user %s",
+                               user_id)
+                raise errors.SynapseError(400, "Not a user here")
+
+            if not device_ids:
+                local_query.append((user_id, None))
+            else:
+                for device_id in device_ids:
+                    local_query.append((user_id, device_id))
+
         results = yield self.store.get_e2e_device_keys(local_query)
 
-        json_result = {}
+        # un-jsonify the results
+        json_result = collections.defaultdict(dict)
         for user_id, device_keys in results.items():
             for device_id, json_bytes in device_keys.items():
-                json_result.setdefault(user_id, {})[
-                    device_id] = json.loads(
-                    json_bytes
-                )
+                json_result[user_id][device_id] = json.loads(json_bytes)
 
-        for destination, device_keys in remote_queries.items():
-            remote_result = yield self.federation.query_client_keys(
-                destination, {"device_keys": device_keys}
-            )
-            for user_id, keys in remote_result["device_keys"].items():
-                if user_id in device_keys:
-                    json_result[user_id] = keys
-        defer.returnValue((200, {"device_keys": json_result}))
+        defer.returnValue(json_result)
+
+    @defer.inlineCallbacks
+    def on_federation_query_client_keys(self, query_body):
+        """ Handle a device key query from a federated server
+        """
+        device_keys_query = query_body.get("device_keys", {})
+        res = yield self.query_local_devices(device_keys_query)
+        defer.returnValue({"device_keys": res})