diff --git a/changelog.d/8205.misc b/changelog.d/8205.misc
new file mode 100644
index 0000000000..fb8fd83278
--- /dev/null
+++ b/changelog.d/8205.misc
@@ -0,0 +1 @@
+ Refactor queries for device keys and cross-signatures.
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d8def45e38..dfd1c78549 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -353,7 +353,7 @@ class E2eKeysHandler(object):
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
- results = await self.store.get_e2e_device_keys(local_query)
+ results = await self.store.get_e2e_device_keys_for_cs_api(local_query)
# Build the result structure
for user_id, device_keys in results.items():
@@ -734,7 +734,7 @@ class E2eKeysHandler(object):
# fetch our stored devices. This is used to 1. verify
# signatures on the master key, and 2. to compare with what
# was sent if the device was signed
- devices = await self.store.get_e2e_device_keys([(user_id, None)])
+ devices = await self.store.get_e2e_device_keys_for_cs_api([(user_id, None)])
if user_id not in devices:
raise NotFoundError("No device keys found")
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index af0b85e2c9..50ecddf7fa 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -23,6 +23,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -33,17 +34,12 @@ if TYPE_CHECKING:
class EndToEndKeyWorkerStore(SQLBaseStore):
@trace
- async def get_e2e_device_keys(
- self, query_list, include_all_devices=False, include_deleted_devices=False
- ):
- """Fetch a list of device keys.
+ async def get_e2e_device_keys_for_cs_api(
+ self, query_list: List[Tuple[str, Optional[str]]]
+ ) -> Dict[str, Dict[str, JsonDict]]:
+ """Fetch a list of device keys, formatted suitably for the C/S API.
Args:
query_list(list): List of pairs of user_ids and device_ids.
- include_all_devices (bool): whether to include entries for devices
- that don't have device keys
- include_deleted_devices (bool): whether to include null entries for
- devices which no longer exist (but were in the query_list).
- This option only takes effect if include_all_devices is true.
Returns:
Dict mapping from user-id to dict mapping from device_id to
key data. The key data will be a dict in the same format as the
@@ -54,11 +50,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return {}
results = await self.db_pool.runInteraction(
- "get_e2e_device_keys",
- self._get_e2e_device_keys_txn,
- query_list,
- include_all_devices,
- include_deleted_devices,
+ "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list,
)
# Build the result structure, un-jsonify the results, and add the
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 261bf5b08b..3fc4bb13b6 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -37,7 +37,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
)
res = yield defer.ensureDeferred(
- self.store.get_e2e_device_keys((("user", "device"),))
+ self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
self.assertIn("device", res["user"])
@@ -76,7 +76,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
)
res = yield defer.ensureDeferred(
- self.store.get_e2e_device_keys((("user", "device"),))
+ self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
self.assertIn("device", res["user"])
@@ -108,7 +108,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
)
res = yield defer.ensureDeferred(
- self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2")))
+ self.store.get_e2e_device_keys_for_cs_api(
+ (("user1", "device1"), ("user2", "device2"))
+ )
)
self.assertIn("user1", res)
self.assertIn("device1", res["user1"])
|