diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/storage/databases/main/end_to_end_keys.py | 109 |
1 files changed, 65 insertions, 44 deletions
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index cc0b15ae07..09af033233 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import make_in_list_sql_clause +from synapse.storage.types import Cursor from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -45,8 +46,9 @@ class DeviceKeyLookupResult: # key) and "signatures" (a signature of the structure by the ed25519 key) key_json = attr.ib(type=Optional[str]) - # cross-signing sigs - signatures = attr.ib(type=Optional[Dict], default=None) + # cross-signing sigs on this device. + # dict from (signing user_id)->(signing device_id)->sig + signatures = attr.ib(type=Optional[Dict[str, Dict[str, str]]], factory=dict) class EndToEndKeyWorkerStore(SQLBaseStore): @@ -133,7 +135,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore): include_all_devices: bool = False, include_deleted_devices: bool = False, ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: - """Fetch a list of device keys, together with their cross-signatures. + """Fetch a list of device keys + + Any cross-signatures made on the keys by the owner of the device are also + included. Args: query_list: List of pairs of user_ids and device_ids. Device id can be None @@ -154,22 +159,51 @@ class EndToEndKeyWorkerStore(SQLBaseStore): result = await self.db_pool.runInteraction( "get_e2e_device_keys", - self._get_e2e_device_keys_and_signatures_txn, + self._get_e2e_device_keys_txn, query_list, include_all_devices, include_deleted_devices, ) + # get the (user_id, device_id) tuples to look up cross-signatures for + signature_query = ( + (user_id, device_id) + for user_id, dev in result.items() + for device_id, d in dev.items() + if d is not None + ) + + for batch in batch_iter(signature_query, 50): + cross_sigs_result = await self.db_pool.runInteraction( + "get_e2e_cross_signing_signatures", + self._get_e2e_cross_signing_signatures_for_devices_txn, + batch, + ) + + # add each cross-signing signature to the correct device in the result dict. + for (user_id, key_id, device_id, signature) in cross_sigs_result: + target_device_result = result[user_id][device_id] + target_device_signatures = target_device_result.signatures + + signing_user_signatures = target_device_signatures.setdefault( + user_id, {} + ) + signing_user_signatures[key_id] = signature + log_kv(result) return result - def _get_e2e_device_keys_and_signatures_txn( + def _get_e2e_device_keys_txn( self, txn, query_list, include_all_devices=False, include_deleted_devices=False ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: + """Get information on devices from the database + + The results include the device's keys and self-signatures, but *not* any + cross-signing signatures which have been added subsequently (for which, see + get_e2e_device_keys_and_signatures) + """ query_clauses = [] query_params = [] - signature_query_clauses = [] - signature_query_params = [] if include_all_devices is False: include_deleted_devices = False @@ -180,20 +214,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for (user_id, device_id) in query_list: query_clause = "user_id = ?" query_params.append(user_id) - signature_query_clause = "target_user_id = ?" - signature_query_params.append(user_id) if device_id is not None: query_clause += " AND device_id = ?" query_params.append(device_id) - signature_query_clause += " AND target_device_id = ?" - signature_query_params.append(device_id) - - signature_query_clause += " AND user_id = ?" - signature_query_params.append(user_id) query_clauses.append(query_clause) - signature_query_clauses.append(signature_query_clause) sql = ( "SELECT user_id, device_id, " @@ -221,41 +247,36 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for user_id, device_id in deleted_devices: result.setdefault(user_id, {})[device_id] = None - # get signatures on the device - signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % ( - " OR ".join("(" + q + ")" for q in signature_query_clauses) - ) + return result - txn.execute(signature_sql, signature_query_params) - rows = self.db_pool.cursor_to_dict(txn) - - # add each cross-signing signature to the correct device in the result dict. - for row in rows: - signing_user_id = row["user_id"] - signing_key_id = row["key_id"] - target_user_id = row["target_user_id"] - target_device_id = row["target_device_id"] - signature = row["signature"] - - target_user_result = result.get(target_user_id) - if not target_user_result: - continue + def _get_e2e_cross_signing_signatures_for_devices_txn( + self, txn: Cursor, device_query: Iterable[Tuple[str, str]] + ) -> List[Tuple[str, str, str, str]]: + """Get cross-signing signatures for a given list of devices - target_device_result = target_user_result.get(target_device_id) - if not target_device_result: - # note that target_device_result will be None for deleted devices. - continue + Returns signatures made by the owners of the devices. - target_device_signatures = target_device_result.signatures - if target_device_signatures is None: - target_device_signatures = target_device_result.signatures = {} + Returns: a list of results; each entry in the list is a tuple of + (user_id, key_id, target_device_id, signature). + """ + signature_query_clauses = [] + signature_query_params = [] - signing_user_signatures = target_device_signatures.setdefault( - signing_user_id, {} + for (user_id, device_id) in device_query: + signature_query_clauses.append( + "target_user_id = ? AND target_device_id = ? AND user_id = ?" ) - signing_user_signatures[signing_key_id] = signature + signature_query_params.extend([user_id, device_id, user_id]) - return result + signature_sql = """ + SELECT user_id, key_id, target_device_id, signature + FROM e2e_cross_signing_signatures WHERE %s + """ % ( + " OR ".join("(" + q + ")" for q in signature_query_clauses) + ) + + txn.execute(signature_sql, signature_query_params) + return txn.fetchall() async def get_e2e_one_time_keys( self, user_id: str, device_id: str, key_ids: List[str] |