diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/handlers/e2e_keys.py | 21 | ||||
-rw-r--r-- | synapse/storage/end_to_end_keys.py | 60 |
2 files changed, 58 insertions, 23 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 1312cdf5ab..2c7bfd91ed 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -99,6 +99,7 @@ class E2eKeysHandler(object): """ local_query = [] + result_dict = {} for user_id, device_ids in query.items(): if not self.is_mine_id(user_id): logger.warning("Request for keys for non-local user %s", @@ -111,15 +112,23 @@ class E2eKeysHandler(object): for device_id in device_ids: local_query.append((user_id, device_id)) + # make sure that each queried user appears in the result dict + result_dict[user_id] = {} + results = yield self.store.get_e2e_device_keys(local_query) - # un-jsonify the results - json_result = collections.defaultdict(dict) + # Build the result structure, un-jsonify the results, and add the + # "unsigned" section for user_id, device_keys in results.items(): - for device_id, json_bytes in device_keys.items(): - json_result[user_id][device_id] = json.loads(json_bytes) - - defer.returnValue(json_result) + for device_id, device_info in device_keys.items(): + r = json.loads(device_info["key_json"]) + r["unsigned"] = {} + display_name = device_info["device_display_name"] + if display_name is not None: + r["unsigned"]["device_display_name"] = display_name + result_dict[user_id][device_id] = r + + defer.returnValue(result_dict) @defer.inlineCallbacks def on_federation_query_client_keys(self, query_body): diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 62b7790e91..385d607056 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections import twisted.internet.defer @@ -38,24 +39,49 @@ class EndToEndKeyStore(SQLBaseStore): query_list(list): List of pairs of user_ids and device_ids. Returns: Dict mapping from user-id to dict mapping from device_id to - key json byte strings. + dict containing "key_json", "device_display_name". """ - def _get_e2e_device_keys(txn): - result = {} - for user_id, device_id in query_list: - user_result = result.setdefault(user_id, {}) - keyvalues = {"user_id": user_id} - if device_id: - keyvalues["device_id"] = device_id - rows = self._simple_select_list_txn( - txn, table="e2e_device_keys_json", - keyvalues=keyvalues, - retcols=["device_id", "key_json"] - ) - for row in rows: - user_result[row["device_id"]] = row["key_json"] - return result - return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys) + if not query_list: + return {} + + return self.runInteraction( + "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list + ) + + def _get_e2e_device_keys_txn(self, txn, query_list): + query_clauses = [] + query_params = [] + + for (user_id, device_id) in query_list: + query_clause = "k.user_id = ?" + query_params.append(user_id) + + if device_id: + query_clause += " AND k.device_id = ?" + query_params.append(device_id) + + query_clauses.append(query_clause) + + sql = ( + "SELECT k.user_id, k.device_id, " + " d.display_name AS device_display_name, " + " k.key_json" + " FROM e2e_device_keys_json k" + " LEFT JOIN devices d ON d.user_id = k.user_id" + " AND d.device_id = k.device_id" + " WHERE %s" + ) % ( + " OR ".join("(" + q + ")" for q in query_clauses) + ) + + txn.execute(sql, query_params) + rows = self.cursor_to_dict(txn) + + result = collections.defaultdict(dict) + for row in rows: + result[row["user_id"]][row["device_id"]] = row + + return result def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): def _add_e2e_one_time_keys(txn): |