summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/e2e_keys.py21
-rw-r--r--synapse/storage/end_to_end_keys.py60
2 files changed, 58 insertions, 23 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 1312cdf5ab..2c7bfd91ed 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -99,6 +99,7 @@ class E2eKeysHandler(object):
         """
         local_query = []
 
+        result_dict = {}
         for user_id, device_ids in query.items():
             if not self.is_mine_id(user_id):
                 logger.warning("Request for keys for non-local user %s",
@@ -111,15 +112,23 @@ class E2eKeysHandler(object):
                 for device_id in device_ids:
                     local_query.append((user_id, device_id))
 
+            # make sure that each queried user appears in the result dict
+            result_dict[user_id] = {}
+
         results = yield self.store.get_e2e_device_keys(local_query)
 
-        # un-jsonify the results
-        json_result = collections.defaultdict(dict)
+        # Build the result structure, un-jsonify the results, and add the
+        # "unsigned" section
         for user_id, device_keys in results.items():
-            for device_id, json_bytes in device_keys.items():
-                json_result[user_id][device_id] = json.loads(json_bytes)
-
-        defer.returnValue(json_result)
+            for device_id, device_info in device_keys.items():
+                r = json.loads(device_info["key_json"])
+                r["unsigned"] = {}
+                display_name = device_info["device_display_name"]
+                if display_name is not None:
+                    r["unsigned"]["device_display_name"] = display_name
+                result_dict[user_id][device_id] = r
+
+        defer.returnValue(result_dict)
 
     @defer.inlineCallbacks
     def on_federation_query_client_keys(self, query_body):
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 62b7790e91..385d607056 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import collections
 
 import twisted.internet.defer
 
@@ -38,24 +39,49 @@ class EndToEndKeyStore(SQLBaseStore):
             query_list(list): List of pairs of user_ids and device_ids.
         Returns:
             Dict mapping from user-id to dict mapping from device_id to
-            key json byte strings.
+            dict containing "key_json", "device_display_name".
         """
-        def _get_e2e_device_keys(txn):
-            result = {}
-            for user_id, device_id in query_list:
-                user_result = result.setdefault(user_id, {})
-                keyvalues = {"user_id": user_id}
-                if device_id:
-                    keyvalues["device_id"] = device_id
-                rows = self._simple_select_list_txn(
-                    txn, table="e2e_device_keys_json",
-                    keyvalues=keyvalues,
-                    retcols=["device_id", "key_json"]
-                )
-                for row in rows:
-                    user_result[row["device_id"]] = row["key_json"]
-            return result
-        return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys)
+        if not query_list:
+            return {}
+
+        return self.runInteraction(
+            "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list
+        )
+
+    def _get_e2e_device_keys_txn(self, txn, query_list):
+        query_clauses = []
+        query_params = []
+
+        for (user_id, device_id) in query_list:
+            query_clause = "k.user_id = ?"
+            query_params.append(user_id)
+
+            if device_id:
+                query_clause += " AND k.device_id = ?"
+                query_params.append(device_id)
+
+            query_clauses.append(query_clause)
+
+        sql = (
+            "SELECT k.user_id, k.device_id, "
+            "    d.display_name AS device_display_name, "
+            "    k.key_json"
+            " FROM e2e_device_keys_json k"
+            "    LEFT JOIN devices d ON d.user_id = k.user_id"
+            "      AND d.device_id = k.device_id"
+            " WHERE %s"
+        ) % (
+            " OR ".join("(" + q + ")" for q in query_clauses)
+        )
+
+        txn.execute(sql, query_params)
+        rows = self.cursor_to_dict(txn)
+
+        result = collections.defaultdict(dict)
+        for row in rows:
+            result[row["user_id"]][row["device_id"]] = row
+
+        return result
 
     def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
         def _add_e2e_one_time_keys(txn):