summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/devices.py8
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py52
2 files changed, 40 insertions, 20 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 11956cc48e..8bedcdbdff 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -293,17 +293,17 @@ class DeviceWorkerStore(SQLBaseStore):
                 prev_id = stream_id
 
                 if device is not None:
-                    key_json = device.get("key_json", None)
+                    key_json = device.key_json
                     if key_json:
                         result["keys"] = db_to_json(key_json)
 
-                        if "signatures" in device:
-                            for sig_user_id, sigs in device["signatures"].items():
+                        if device.signatures:
+                            for sig_user_id, sigs in device.signatures.items():
                                 result["keys"].setdefault("signatures", {}).setdefault(
                                     sig_user_id, {}
                                 ).update(sigs)
 
-                    device_display_name = device.get("device_display_name", None)
+                    device_display_name = device.display_name
                     if device_display_name:
                         result["device_display_name"] = device_display_name
                 else:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 5a7de44b33..449d95f31e 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -17,6 +17,7 @@
 import abc
 from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
 
+import attr
 from canonicaljson import encode_canonical_json
 
 from twisted.enterprise.adbapi import Connection
@@ -33,6 +34,21 @@ if TYPE_CHECKING:
     from synapse.handlers.e2e_keys import SignatureListItem
 
 
+@attr.s
+class DeviceKeyLookupResult:
+    """The type returned by _get_e2e_device_keys_and_signatures_txn"""
+
+    display_name = attr.ib(type=Optional[str])
+
+    # the key data from e2e_device_keys_json. Typically includes fields like
+    # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
+    # key) and "signatures" (a signature of the structure by the ed25519 key)
+    key_json = attr.ib(type=Optional[str])
+
+    # cross-signing sigs
+    signatures = attr.ib(type=Optional[Dict], default=None)
+
+
 class EndToEndKeyWorkerStore(SQLBaseStore):
     async def get_e2e_device_keys_for_federation_query(
         self, user_id: str
@@ -61,17 +77,17 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             for device_id, device in user_devices.items():
                 result = {"device_id": device_id}
 
-                key_json = device.get("key_json", None)
+                key_json = device.key_json
                 if key_json:
                     result["keys"] = db_to_json(key_json)
 
-                    if "signatures" in device:
-                        for sig_user_id, sigs in device["signatures"].items():
+                    if device.signatures:
+                        for sig_user_id, sigs in device.signatures.items():
                             result["keys"].setdefault("signatures", {}).setdefault(
                                 sig_user_id, {}
                             ).update(sigs)
 
-                device_display_name = device.get("device_display_name", None)
+                device_display_name = device.display_name
                 if device_display_name:
                     result["device_display_name"] = device_display_name
 
@@ -109,13 +125,13 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         for user_id, device_keys in results.items():
             rv[user_id] = {}
             for device_id, device_info in device_keys.items():
-                r = db_to_json(device_info.pop("key_json"))
+                r = db_to_json(device_info.key_json)
                 r["unsigned"] = {}
-                display_name = device_info["device_display_name"]
+                display_name = device_info.display_name
                 if display_name is not None:
                     r["unsigned"]["device_display_name"] = display_name
-                if "signatures" in device_info:
-                    for sig_user_id, sigs in device_info["signatures"].items():
+                if device_info.signatures:
+                    for sig_user_id, sigs in device_info.signatures.items():
                         r.setdefault("signatures", {}).setdefault(
                             sig_user_id, {}
                         ).update(sigs)
@@ -126,7 +142,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
     @trace
     def _get_e2e_device_keys_and_signatures_txn(
         self, txn, query_list, include_all_devices=False, include_deleted_devices=False
-    ) -> Dict[str, Dict[str, Optional[Dict]]]:
+    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
         set_tag("include_all_devices", include_all_devices)
         set_tag("include_deleted_devices", include_deleted_devices)
 
@@ -161,7 +177,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
 
         sql = (
             "SELECT user_id, device_id, "
-            "    d.display_name AS device_display_name, "
+            "    d.display_name, "
             "    k.key_json"
             " FROM devices d"
             "    %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
@@ -172,13 +188,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         )
 
         txn.execute(sql, query_params)
-        rows = self.db_pool.cursor_to_dict(txn)
 
-        result = {}
-        for row in rows:
+        result = {}  # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
+        for (user_id, device_id, display_name, key_json) in txn:
             if include_deleted_devices:
-                deleted_devices.remove((row["user_id"], row["device_id"]))
-            result.setdefault(row["user_id"], {})[row["device_id"]] = row
+                deleted_devices.remove((user_id, device_id))
+            result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
+                display_name, key_json
+            )
 
         if include_deleted_devices:
             for user_id, device_id in deleted_devices:
@@ -209,7 +226,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                 # note that target_device_result will be None for deleted devices.
                 continue
 
-            target_device_signatures = target_device_result.setdefault("signatures", {})
+            target_device_signatures = target_device_result.signatures
+            if target_device_signatures is None:
+                target_device_signatures = target_device_result.signatures = {}
+
             signing_user_signatures = target_device_signatures.setdefault(
                 signing_user_id, {}
             )