diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index cc0b15ae07..09af033233 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
+from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@@ -45,8 +46,9 @@ class DeviceKeyLookupResult:
# key) and "signatures" (a signature of the structure by the ed25519 key)
key_json = attr.ib(type=Optional[str])
- # cross-signing sigs
- signatures = attr.ib(type=Optional[Dict], default=None)
+ # cross-signing sigs on this device.
+ # dict from (signing user_id)->(signing device_id)->sig
+ signatures = attr.ib(type=Optional[Dict[str, Dict[str, str]]], factory=dict)
class EndToEndKeyWorkerStore(SQLBaseStore):
@@ -133,7 +135,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
include_all_devices: bool = False,
include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
- """Fetch a list of device keys, together with their cross-signatures.
+ """Fetch a list of device keys
+
+ Any cross-signatures made on the keys by the owner of the device are also
+ included.
Args:
query_list: List of pairs of user_ids and device_ids. Device id can be None
@@ -154,22 +159,51 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result = await self.db_pool.runInteraction(
"get_e2e_device_keys",
- self._get_e2e_device_keys_and_signatures_txn,
+ self._get_e2e_device_keys_txn,
query_list,
include_all_devices,
include_deleted_devices,
)
+ # get the (user_id, device_id) tuples to look up cross-signatures for
+ signature_query = (
+ (user_id, device_id)
+ for user_id, dev in result.items()
+ for device_id, d in dev.items()
+ if d is not None
+ )
+
+ for batch in batch_iter(signature_query, 50):
+ cross_sigs_result = await self.db_pool.runInteraction(
+ "get_e2e_cross_signing_signatures",
+ self._get_e2e_cross_signing_signatures_for_devices_txn,
+ batch,
+ )
+
+ # add each cross-signing signature to the correct device in the result dict.
+ for (user_id, key_id, device_id, signature) in cross_sigs_result:
+ target_device_result = result[user_id][device_id]
+ target_device_signatures = target_device_result.signatures
+
+ signing_user_signatures = target_device_signatures.setdefault(
+ user_id, {}
+ )
+ signing_user_signatures[key_id] = signature
+
log_kv(result)
return result
- def _get_e2e_device_keys_and_signatures_txn(
+ def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+ """Get information on devices from the database
+
+ The results include the device's keys and self-signatures, but *not* any
+ cross-signing signatures which have been added subsequently (for which, see
+ get_e2e_device_keys_and_signatures)
+ """
query_clauses = []
query_params = []
- signature_query_clauses = []
- signature_query_params = []
if include_all_devices is False:
include_deleted_devices = False
@@ -180,20 +214,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for (user_id, device_id) in query_list:
query_clause = "user_id = ?"
query_params.append(user_id)
- signature_query_clause = "target_user_id = ?"
- signature_query_params.append(user_id)
if device_id is not None:
query_clause += " AND device_id = ?"
query_params.append(device_id)
- signature_query_clause += " AND target_device_id = ?"
- signature_query_params.append(device_id)
-
- signature_query_clause += " AND user_id = ?"
- signature_query_params.append(user_id)
query_clauses.append(query_clause)
- signature_query_clauses.append(signature_query_clause)
sql = (
"SELECT user_id, device_id, "
@@ -221,41 +247,36 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for user_id, device_id in deleted_devices:
result.setdefault(user_id, {})[device_id] = None
- # get signatures on the device
- signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % (
- " OR ".join("(" + q + ")" for q in signature_query_clauses)
- )
+ return result
- txn.execute(signature_sql, signature_query_params)
- rows = self.db_pool.cursor_to_dict(txn)
-
- # add each cross-signing signature to the correct device in the result dict.
- for row in rows:
- signing_user_id = row["user_id"]
- signing_key_id = row["key_id"]
- target_user_id = row["target_user_id"]
- target_device_id = row["target_device_id"]
- signature = row["signature"]
-
- target_user_result = result.get(target_user_id)
- if not target_user_result:
- continue
+ def _get_e2e_cross_signing_signatures_for_devices_txn(
+ self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
+ ) -> List[Tuple[str, str, str, str]]:
+ """Get cross-signing signatures for a given list of devices
- target_device_result = target_user_result.get(target_device_id)
- if not target_device_result:
- # note that target_device_result will be None for deleted devices.
- continue
+ Returns signatures made by the owners of the devices.
- target_device_signatures = target_device_result.signatures
- if target_device_signatures is None:
- target_device_signatures = target_device_result.signatures = {}
+ Returns: a list of results; each entry in the list is a tuple of
+ (user_id, key_id, target_device_id, signature).
+ """
+ signature_query_clauses = []
+ signature_query_params = []
- signing_user_signatures = target_device_signatures.setdefault(
- signing_user_id, {}
+ for (user_id, device_id) in device_query:
+ signature_query_clauses.append(
+ "target_user_id = ? AND target_device_id = ? AND user_id = ?"
)
- signing_user_signatures[signing_key_id] = signature
+ signature_query_params.extend([user_id, device_id, user_id])
- return result
+ signature_sql = """
+ SELECT user_id, key_id, target_device_id, signature
+ FROM e2e_cross_signing_signatures WHERE %s
+ """ % (
+ " OR ".join("(" + q + ")" for q in signature_query_clauses)
+ )
+
+ txn.execute(signature_sql, signature_query_params)
+ return txn.fetchall()
async def get_e2e_one_time_keys(
self, user_id: str, device_id: str, key_ids: List[str]
|