diff options
author | Richard van der Hoff <1389908+richvdh@users.noreply.github.com> | 2020-09-02 11:47:26 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-09-02 11:47:26 +0100 |
commit | abeab964d5a0cc41fa421cf9d89dc12b7a796391 (patch) | |
tree | 02433188ace90fdb6fab8ac9552c3d9eb5f16064 /synapse | |
parent | Fix errors when updating the user directory with invalid data (#8223) (diff) | |
download | synapse-abeab964d5a0cc41fa421cf9d89dc12b7a796391.tar.xz |
Make _get_e2e_device_keys_and_signatures_txn return an attrs (#8224)
this makes it a bit clearer what's going on.
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/storage/databases/main/devices.py | 8 | ||||
-rw-r--r-- | synapse/storage/databases/main/end_to_end_keys.py | 52 |
2 files changed, 40 insertions, 20 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 11956cc48e..8bedcdbdff 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -293,17 +293,17 @@ class DeviceWorkerStore(SQLBaseStore): prev_id = stream_id if device is not None: - key_json = device.get("key_json", None) + key_json = device.key_json if key_json: result["keys"] = db_to_json(key_json) - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): + if device.signatures: + for sig_user_id, sigs in device.signatures.items(): result["keys"].setdefault("signatures", {}).setdefault( sig_user_id, {} ).update(sigs) - device_display_name = device.get("device_display_name", None) + device_display_name = device.display_name if device_display_name: result["device_display_name"] = device_display_name else: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 5a7de44b33..449d95f31e 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -17,6 +17,7 @@ import abc from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple +import attr from canonicaljson import encode_canonical_json from twisted.enterprise.adbapi import Connection @@ -33,6 +34,21 @@ if TYPE_CHECKING: from synapse.handlers.e2e_keys import SignatureListItem +@attr.s +class DeviceKeyLookupResult: + """The type returned by _get_e2e_device_keys_and_signatures_txn""" + + display_name = attr.ib(type=Optional[str]) + + # the key data from e2e_device_keys_json. Typically includes fields like + # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing + # key) and "signatures" (a signature of the structure by the ed25519 key) + key_json = attr.ib(type=Optional[str]) + + # cross-signing sigs + signatures = attr.ib(type=Optional[Dict], default=None) + + class EndToEndKeyWorkerStore(SQLBaseStore): async def get_e2e_device_keys_for_federation_query( self, user_id: str @@ -61,17 +77,17 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for device_id, device in user_devices.items(): result = {"device_id": device_id} - key_json = device.get("key_json", None) + key_json = device.key_json if key_json: result["keys"] = db_to_json(key_json) - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): + if device.signatures: + for sig_user_id, sigs in device.signatures.items(): result["keys"].setdefault("signatures", {}).setdefault( sig_user_id, {} ).update(sigs) - device_display_name = device.get("device_display_name", None) + device_display_name = device.display_name if device_display_name: result["device_display_name"] = device_display_name @@ -109,13 +125,13 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for user_id, device_keys in results.items(): rv[user_id] = {} for device_id, device_info in device_keys.items(): - r = db_to_json(device_info.pop("key_json")) + r = db_to_json(device_info.key_json) r["unsigned"] = {} - display_name = device_info["device_display_name"] + display_name = device_info.display_name if display_name is not None: r["unsigned"]["device_display_name"] = display_name - if "signatures" in device_info: - for sig_user_id, sigs in device_info["signatures"].items(): + if device_info.signatures: + for sig_user_id, sigs in device_info.signatures.items(): r.setdefault("signatures", {}).setdefault( sig_user_id, {} ).update(sigs) @@ -126,7 +142,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): @trace def _get_e2e_device_keys_and_signatures_txn( self, txn, query_list, include_all_devices=False, include_deleted_devices=False - ) -> Dict[str, Dict[str, Optional[Dict]]]: + ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: set_tag("include_all_devices", include_all_devices) set_tag("include_deleted_devices", include_deleted_devices) @@ -161,7 +177,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): sql = ( "SELECT user_id, device_id, " - " d.display_name AS device_display_name, " + " d.display_name, " " k.key_json" " FROM devices d" " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" @@ -172,13 +188,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore): ) txn.execute(sql, query_params) - rows = self.db_pool.cursor_to_dict(txn) - result = {} - for row in rows: + result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] + for (user_id, device_id, display_name, key_json) in txn: if include_deleted_devices: - deleted_devices.remove((row["user_id"], row["device_id"])) - result.setdefault(row["user_id"], {})[row["device_id"]] = row + deleted_devices.remove((user_id, device_id)) + result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult( + display_name, key_json + ) if include_deleted_devices: for user_id, device_id in deleted_devices: @@ -209,7 +226,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore): # note that target_device_result will be None for deleted devices. continue - target_device_signatures = target_device_result.setdefault("signatures", {}) + target_device_signatures = target_device_result.signatures + if target_device_signatures is None: + target_device_signatures = target_device_result.signatures = {} + signing_user_signatures = target_device_signatures.setdefault( signing_user_id, {} ) |