diff options
Diffstat (limited to 'synapse/storage/end_to_end_keys.py')
-rw-r--r-- | synapse/storage/end_to_end_keys.py | 67 |
1 files changed, 39 insertions, 28 deletions
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index ff3ce84eb1..f08a4efbfc 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -66,7 +66,7 @@ class EndToEndKeyStore(SQLBaseStore): @defer.inlineCallbacks def get_e2e_device_keys( self, query_list, include_all_devices=False, - include_deleted_devices=False, + include_deleted_devices=False, req_user_id=None, ): """Fetch a list of device keys. Args: @@ -76,6 +76,7 @@ class EndToEndKeyStore(SQLBaseStore): include_deleted_devices (bool): whether to include null entries for devices which no longer exist (but were in the query_list). This option only takes effect if include_all_devices is true. + req_user_id: The user requesting the device list Returns: Dict mapping from user-id to dict mapping from device_id to dict containing "key_json", "device_display_name". @@ -85,7 +86,7 @@ class EndToEndKeyStore(SQLBaseStore): results = yield self.runInteraction( "get_e2e_device_keys", self._get_e2e_device_keys_txn, - query_list, include_all_devices, include_deleted_devices, + query_list, include_all_devices, include_deleted_devices, req_user_id, ) for user_id, device_keys in iteritems(results): @@ -96,7 +97,7 @@ class EndToEndKeyStore(SQLBaseStore): def _get_e2e_device_keys_txn( self, txn, query_list, include_all_devices=False, - include_deleted_devices=False, + include_deleted_devices=False, req_user_id=None, ): query_clauses = [] query_params = [] @@ -142,6 +143,16 @@ class EndToEndKeyStore(SQLBaseStore): for user_id, device_id in deleted_devices: result.setdefault(user_id, {})[device_id] = None + attestations = self._get_e2e_attestations_txn(txn, req_user_id, query_list) + + for attestation in attestations: + user_id = attestation["user_id"] + device_id = attestation["device_id"] + # FIXME: combine signatures of the same payload? + if user_id in result and device_id in result[user_id]: + result[user_id][device_id].setdefault("attestations", []) \ + .append(attestation) + return result @defer.inlineCallbacks @@ -305,35 +316,35 @@ class EndToEndKeyStore(SQLBaseStore): "add_e2e_attestations", add_e2e_attestations_txn ) - @defer.inlineCallbacks - def get_e2e_attestations(self, from_user_id, query_list): - def get_e2e_attestations_txn(txn): - query_clauses = [] - query_params = [] + def _get_e2e_attestations_txn(self, txn, from_user_id, query_list): + + query_clauses = [] + query_params = [] - for (user_id, device_id) in query_list: - query_clause = "(from_user_id = ? OR from_user_id = ?) AND user_id = ?" + for (user_id, device_id) in query_list: + if from_user_id: + query_clause = "(from_user_id = ? OR from_user_id = ?)" query_params.append(from_user_id) - query_params.append(user_id) - query_params.append(user_id) + else: + query_clause = "(from_user_id = ?)" + query_params.append(user_id) - if device_id is not None: - query_clause += " AND device_id = ?" - query_params.append(device_id) + query_clause += " AND user_id = ?" + query_params.append(user_id) - query_clauses.append(query_clause) + if device_id is not None: + query_clause += " AND device_id = ?" + query_params.append(device_id) - sql = ( - "SELECT attestation " - " FROM e2e_attestations " - " WHERE %s" - ) % ( - " OR ".join("(" + q + ")" for q in query_clauses) - ) + query_clauses.append(query_clause) - txn.execute(sql, query_params) - return [json.loads(row[0]) for row in txn] - results = yield self.runInteraction( - "get_e2e_attestations", get_e2e_attestations_txn + sql = ( + "SELECT attestation " + " FROM e2e_attestations " + " WHERE %s" + ) % ( + " OR ".join("(" + q + ")" for q in query_clauses) ) - defer.returnValue(results) + + txn.execute(sql, query_params) + return [json.loads(row[0]) for row in txn] |