diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 950fc927b1..bb69089b91 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -117,10 +117,15 @@ class E2eKeysHandler(object):
results = yield self.store.get_e2e_device_keys(local_query)
- # un-jsonify the results
+ # Build the result structure, un-jsonify the results, and add the
+ # "unsigned" section
for user_id, device_keys in results.items():
- for device_id, json_bytes in device_keys.items():
- result_dict[user_id][device_id] = json.loads(json_bytes)
+ for device_id, device_info in device_keys.items():
+ r = json.loads(device_info["key_json"])
+ r["unsigned"] = {
+ "device_display_name": device_info["device_display_name"],
+ }
+ result_dict[user_id][device_id] = r
defer.returnValue(result_dict)
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 62b7790e91..5c8ed3e492 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import collections
import twisted.internet.defer
@@ -38,24 +39,49 @@ class EndToEndKeyStore(SQLBaseStore):
query_list(list): List of pairs of user_ids and device_ids.
Returns:
Dict mapping from user-id to dict mapping from device_id to
- key json byte strings.
+ dict containing "key_json", "device_display_name".
"""
- def _get_e2e_device_keys(txn):
- result = {}
- for user_id, device_id in query_list:
- user_result = result.setdefault(user_id, {})
- keyvalues = {"user_id": user_id}
- if device_id:
- keyvalues["device_id"] = device_id
- rows = self._simple_select_list_txn(
- txn, table="e2e_device_keys_json",
- keyvalues=keyvalues,
- retcols=["device_id", "key_json"]
- )
- for row in rows:
- user_result[row["device_id"]] = row["key_json"]
- return result
- return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys)
+ if not query_list:
+ return {}
+
+ return self.runInteraction(
+ "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list
+ )
+
+ def _get_e2e_device_keys_txn(self, txn, query_list):
+ query_clauses = []
+ query_params = []
+
+ for (user_id, device_id) in query_list:
+ query_clause = "k.user_id = ?"
+ query_params.append(user_id)
+
+ if device_id:
+ query_clause += " AND k.device_id = ?"
+ query_params.append(device_id)
+
+ query_clauses.append(query_clause)
+
+ sql = (
+ "SELECT k.user_id, k.device_id, "
+ " d.display_name AS device_display_name, "
+ " k.key_json"
+ " FROM e2e_device_keys_json k"
+ " LEFT JOIN devices d ON d.user_id = k.user_id"
+ " AND d.device_id = k.device_id"
+ " WHERE %s"
+ ) % (
+ " OR ".join("("+q+")" for q in query_clauses)
+ )
+
+ txn.execute(sql, query_params)
+ rows = self.cursor_to_dict(txn)
+
+ result = collections.defaultdict(dict)
+ for row in rows:
+ result[row["user_id"]][row["device_id"]] = row
+
+ return result
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
def _add_e2e_one_time_keys(txn):
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
new file mode 100644
index 0000000000..0ebc6dafe8
--- /dev/null
+++ b/tests/storage/test_end_to_end_keys.py
@@ -0,0 +1,92 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import synapse.api.errors
+import tests.unittest
+import tests.utils
+
+
+class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
+ def __init__(self, *args, **kwargs):
+ super(EndToEndKeyStoreTestCase, self).__init__(*args, **kwargs)
+ self.store = None # type: synapse.storage.DataStore
+
+ @defer.inlineCallbacks
+ def setUp(self):
+ hs = yield tests.utils.setup_test_homeserver()
+
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def test_key_without_device_name(self):
+ now = 1470174257070
+ json = '{ "key": "value" }'
+
+ yield self.store.set_e2e_device_keys(
+ "user", "device", now, json)
+
+ res = yield self.store.get_e2e_device_keys((("user", "device"),))
+ self.assertIn("user", res)
+ self.assertIn("device", res["user"])
+ dev = res["user"]["device"]
+ self.assertDictContainsSubset({
+ "key_json": json,
+ "device_display_name": None,
+ }, dev)
+
+ @defer.inlineCallbacks
+ def test_get_key_with_device_name(self):
+ now = 1470174257070
+ json = '{ "key": "value" }'
+
+ yield self.store.set_e2e_device_keys(
+ "user", "device", now, json)
+ yield self.store.store_device(
+ "user", "device", "display_name"
+ )
+
+ res = yield self.store.get_e2e_device_keys((("user", "device"),))
+ self.assertIn("user", res)
+ self.assertIn("device", res["user"])
+ dev = res["user"]["device"]
+ self.assertDictContainsSubset({
+ "key_json": json,
+ "device_display_name": "display_name",
+ }, dev)
+
+
+ @defer.inlineCallbacks
+ def test_multiple_devices(self):
+ now = 1470174257070
+
+ yield self.store.set_e2e_device_keys(
+ "user1", "device1", now, 'json11')
+ yield self.store.set_e2e_device_keys(
+ "user1", "device2", now, 'json12')
+ yield self.store.set_e2e_device_keys(
+ "user2", "device1", now, 'json21')
+ yield self.store.set_e2e_device_keys(
+ "user2", "device2", now, 'json22')
+
+ res = yield self.store.get_e2e_device_keys((("user1", "device1"),
+ ("user2", "device2")))
+ self.assertIn("user1", res)
+ self.assertIn("device1", res["user1"])
+ self.assertNotIn("device2", res["user1"])
+ self.assertIn("user2", res)
+ self.assertNotIn("device1", res["user2"])
+ self.assertIn("device2", res["user2"])
|