diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 11956cc48e..8bedcdbdff 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -293,17 +293,17 @@ class DeviceWorkerStore(SQLBaseStore):
prev_id = stream_id
if device is not None:
- key_json = device.get("key_json", None)
+ key_json = device.key_json
if key_json:
result["keys"] = db_to_json(key_json)
- if "signatures" in device:
- for sig_user_id, sigs in device["signatures"].items():
+ if device.signatures:
+ for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
- device_display_name = device.get("device_display_name", None)
+ device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
else:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 5a7de44b33..449d95f31e 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -17,6 +17,7 @@
import abc
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+import attr
from canonicaljson import encode_canonical_json
from twisted.enterprise.adbapi import Connection
@@ -33,6 +34,21 @@ if TYPE_CHECKING:
from synapse.handlers.e2e_keys import SignatureListItem
+@attr.s
+class DeviceKeyLookupResult:
+ """The type returned by _get_e2e_device_keys_and_signatures_txn"""
+
+ display_name = attr.ib(type=Optional[str])
+
+ # the key data from e2e_device_keys_json. Typically includes fields like
+ # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
+ # key) and "signatures" (a signature of the structure by the ed25519 key)
+ key_json = attr.ib(type=Optional[str])
+
+ # cross-signing sigs
+ signatures = attr.ib(type=Optional[Dict], default=None)
+
+
class EndToEndKeyWorkerStore(SQLBaseStore):
async def get_e2e_device_keys_for_federation_query(
self, user_id: str
@@ -61,17 +77,17 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for device_id, device in user_devices.items():
result = {"device_id": device_id}
- key_json = device.get("key_json", None)
+ key_json = device.key_json
if key_json:
result["keys"] = db_to_json(key_json)
- if "signatures" in device:
- for sig_user_id, sigs in device["signatures"].items():
+ if device.signatures:
+ for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
- device_display_name = device.get("device_display_name", None)
+ device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
@@ -109,13 +125,13 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for user_id, device_keys in results.items():
rv[user_id] = {}
for device_id, device_info in device_keys.items():
- r = db_to_json(device_info.pop("key_json"))
+ r = db_to_json(device_info.key_json)
r["unsigned"] = {}
- display_name = device_info["device_display_name"]
+ display_name = device_info.display_name
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
- if "signatures" in device_info:
- for sig_user_id, sigs in device_info["signatures"].items():
+ if device_info.signatures:
+ for sig_user_id, sigs in device_info.signatures.items():
r.setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
@@ -126,7 +142,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
@trace
def _get_e2e_device_keys_and_signatures_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
- ) -> Dict[str, Dict[str, Optional[Dict]]]:
+ ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
set_tag("include_all_devices", include_all_devices)
set_tag("include_deleted_devices", include_deleted_devices)
@@ -161,7 +177,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
sql = (
"SELECT user_id, device_id, "
- " d.display_name AS device_display_name, "
+ " d.display_name, "
" k.key_json"
" FROM devices d"
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
@@ -172,13 +188,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
txn.execute(sql, query_params)
- rows = self.db_pool.cursor_to_dict(txn)
- result = {}
- for row in rows:
+ result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
+ for (user_id, device_id, display_name, key_json) in txn:
if include_deleted_devices:
- deleted_devices.remove((row["user_id"], row["device_id"]))
- result.setdefault(row["user_id"], {})[row["device_id"]] = row
+ deleted_devices.remove((user_id, device_id))
+ result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
+ display_name, key_json
+ )
if include_deleted_devices:
for user_id, device_id in deleted_devices:
@@ -209,7 +226,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# note that target_device_result will be None for deleted devices.
continue
- target_device_signatures = target_device_result.setdefault("signatures", {})
+ target_device_signatures = target_device_result.signatures
+ if target_device_signatures is None:
+ target_device_signatures = target_device_result.signatures = {}
+
signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {}
)
|