summary refs log tree commit diff
diff options
context:
space:
mode:
authorHubert Chathi <hubert@uhoreg.ca>2018-11-19 11:15:44 -0500
committerHubert Chathi <hubert@uhoreg.ca>2018-11-19 11:15:44 -0500
commit0d0ee82a6b575ad0c61f64c9d3d30146b11c4c7f (patch)
tree5c9a6410140474e0e4a6256057eed366dcbde8e0
parentwork in Python 3 (diff)
downloadsynapse-0d0ee82a6b575ad0c61f64c9d3d30146b11c4c7f.tar.xz
put attestations in the right place
-rw-r--r--synapse/handlers/e2e_keys.py54
-rw-r--r--synapse/storage/end_to_end_keys.py67
2 files changed, 52 insertions, 69 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d81887f5b7..7ac93aef33 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -62,18 +62,7 @@ class E2eKeysHandler(object):
                         ...
                     }
                 }
-            },
-            "attestations": [
-               "user_id": "<user_id>",
-               "device_id": "<device_id>",
-               "keys": {
-                  "ed25519": "<key_base64>"
-               },
-               "state": "<verified or revoked>",
-               "signatures": {
-                  "<algorithm>:<device_id>": "<signature_base64>"
-               }
-            ]
+            }
         }
         """
         device_keys_query = query_body.get("device_keys", {})
@@ -93,13 +82,11 @@ class E2eKeysHandler(object):
         # First get local devices.
         failures = {}
         results = {}
-        attestations = []
         if local_query:
-            local_result = yield self.query_local_devices(local_query)
+            local_result = yield self.query_local_devices(local_query, req_user_id)
             for user_id, keys in local_result.items():
                 if user_id in local_query:
                     results[user_id] = keys
-            attestations = yield self.query_local_attestations(req_user_id, local_query)
 
         # Now attempt to get any remote devices from our local cache.
         remote_queries_not_in_cache = {}
@@ -121,11 +108,14 @@ class E2eKeysHandler(object):
                 for device_id, device in iteritems(devices):
                     keys = device.get("keys", None)
                     device_display_name = device.get("device_display_name", None)
+                    attestations = device.get("attestations", None)
                     if keys:
                         result = dict(keys)
                         unsigned = result.setdefault("unsigned", {})
                         if device_display_name:
                             unsigned["device_display_name"] = device_display_name
+                        if attestations:
+                            unsigned["attestations"] = attestations
                         user_devices[device_id] = result
 
             for user_id in user_ids_not_in_cache:
@@ -158,16 +148,16 @@ class E2eKeysHandler(object):
 
         defer.returnValue({
             "device_keys": results, "failures": failures,
-            "attestations": attestations
         })
 
     @defer.inlineCallbacks
-    def query_local_devices(self, query):
+    def query_local_devices(self, query, req_user_id=None):
         """Get E2E device keys for local users
 
         Args:
             query (dict[string, list[string]|None): map from user_id to a list
                  of devices to query (None for all devices)
+            req_user_id: the user requesting the devices
 
         Returns:
             defer.Deferred: (resolves to dict[string, dict[string, dict]]):
@@ -192,7 +182,9 @@ class E2eKeysHandler(object):
             # make sure that each queried user appears in the result dict
             result_dict[user_id] = {}
 
-        results = yield self.store.get_e2e_device_keys(local_query)
+        results = yield self.store.get_e2e_device_keys(
+            local_query, req_user_id=req_user_id,
+        )
 
         # Build the result structure, un-jsonify the results, and add the
         # "unsigned" section
@@ -201,36 +193,16 @@ class E2eKeysHandler(object):
                 r = dict(device_info["keys"])
                 r["unsigned"] = {}
                 display_name = device_info["device_display_name"]
+                attestations = device_info.get("attestations", None)
                 if display_name is not None:
                     r["unsigned"]["device_display_name"] = display_name
+                if attestations is not None:
+                    r["unsigned"]["attestations"] = attestations
                 result_dict[user_id][device_id] = r
 
         defer.returnValue(result_dict)
 
     @defer.inlineCallbacks
-    def query_local_attestations(self, req_user_id, query):
-        local_query = []
-
-        for user_id, device_ids in query.items():
-            # we use UserID.from_string to catch invalid user ids
-            if not self.is_mine(UserID.from_string(user_id)):
-                logger.warning("Request for keys for non-local user %s",
-                               user_id)
-                raise SynapseError(400, "Not a user here")
-
-            if not device_ids:
-                local_query.append((user_id, None))
-            else:
-                for device_id in device_ids:
-                    local_query.append((user_id, device_id))
-
-        results = yield self.store.get_e2e_attestations(req_user_id, local_query)
-
-        # FIXME: combine signatures of the same payload
-
-        defer.returnValue(results)
-
-    @defer.inlineCallbacks
     def on_federation_query_client_keys(self, query_body):
         """ Handle a device key query from a federated server
         """
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index ff3ce84eb1..f08a4efbfc 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -66,7 +66,7 @@ class EndToEndKeyStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_e2e_device_keys(
         self, query_list, include_all_devices=False,
-        include_deleted_devices=False,
+        include_deleted_devices=False, req_user_id=None,
     ):
         """Fetch a list of device keys.
         Args:
@@ -76,6 +76,7 @@ class EndToEndKeyStore(SQLBaseStore):
             include_deleted_devices (bool): whether to include null entries for
                 devices which no longer exist (but were in the query_list).
                 This option only takes effect if include_all_devices is true.
+            req_user_id: The user requesting the device list
         Returns:
             Dict mapping from user-id to dict mapping from device_id to
             dict containing "key_json", "device_display_name".
@@ -85,7 +86,7 @@ class EndToEndKeyStore(SQLBaseStore):
 
         results = yield self.runInteraction(
             "get_e2e_device_keys", self._get_e2e_device_keys_txn,
-            query_list, include_all_devices, include_deleted_devices,
+            query_list, include_all_devices, include_deleted_devices, req_user_id,
         )
 
         for user_id, device_keys in iteritems(results):
@@ -96,7 +97,7 @@ class EndToEndKeyStore(SQLBaseStore):
 
     def _get_e2e_device_keys_txn(
         self, txn, query_list, include_all_devices=False,
-        include_deleted_devices=False,
+        include_deleted_devices=False, req_user_id=None,
     ):
         query_clauses = []
         query_params = []
@@ -142,6 +143,16 @@ class EndToEndKeyStore(SQLBaseStore):
             for user_id, device_id in deleted_devices:
                 result.setdefault(user_id, {})[device_id] = None
 
+        attestations = self._get_e2e_attestations_txn(txn, req_user_id, query_list)
+
+        for attestation in attestations:
+            user_id = attestation["user_id"]
+            device_id = attestation["device_id"]
+            # FIXME: combine signatures of the same payload?
+            if user_id in result and device_id in result[user_id]:
+                result[user_id][device_id].setdefault("attestations", []) \
+                    .append(attestation)
+
         return result
 
     @defer.inlineCallbacks
@@ -305,35 +316,35 @@ class EndToEndKeyStore(SQLBaseStore):
             "add_e2e_attestations", add_e2e_attestations_txn
         )
 
-    @defer.inlineCallbacks
-    def get_e2e_attestations(self, from_user_id, query_list):
-        def get_e2e_attestations_txn(txn):
-            query_clauses = []
-            query_params = []
+    def _get_e2e_attestations_txn(self, txn, from_user_id, query_list):
+
+        query_clauses = []
+        query_params = []
 
-            for (user_id, device_id) in query_list:
-                query_clause = "(from_user_id = ? OR from_user_id = ?) AND user_id = ?"
+        for (user_id, device_id) in query_list:
+            if from_user_id:
+                query_clause = "(from_user_id = ? OR from_user_id = ?)"
                 query_params.append(from_user_id)
-                query_params.append(user_id)
-                query_params.append(user_id)
+            else:
+                query_clause = "(from_user_id = ?)"
+            query_params.append(user_id)
 
-                if device_id is not None:
-                    query_clause += " AND device_id = ?"
-                    query_params.append(device_id)
+            query_clause += " AND user_id = ?"
+            query_params.append(user_id)
 
-                query_clauses.append(query_clause)
+            if device_id is not None:
+                query_clause += " AND device_id = ?"
+                query_params.append(device_id)
 
-            sql = (
-                "SELECT attestation "
-                "  FROM e2e_attestations "
-                " WHERE %s"
-            ) % (
-                " OR ".join("(" + q + ")" for q in query_clauses)
-            )
+            query_clauses.append(query_clause)
 
-            txn.execute(sql, query_params)
-            return [json.loads(row[0]) for row in txn]
-        results = yield self.runInteraction(
-            "get_e2e_attestations", get_e2e_attestations_txn
+        sql = (
+            "SELECT attestation "
+            "  FROM e2e_attestations "
+            " WHERE %s"
+        ) % (
+            " OR ".join("(" + q + ")" for q in query_clauses)
         )
-        defer.returnValue(results)
+
+        txn.execute(sql, query_params)
+        return [json.loads(row[0]) for row in txn]