summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8205.misc1
-rw-r--r--synapse/handlers/e2e_keys.py4
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py20
-rw-r--r--tests/storage/test_end_to_end_keys.py8
4 files changed, 14 insertions, 19 deletions
diff --git a/changelog.d/8205.misc b/changelog.d/8205.misc
new file mode 100644
index 0000000000..fb8fd83278
--- /dev/null
+++ b/changelog.d/8205.misc
@@ -0,0 +1 @@
+ Refactor queries for device keys and cross-signatures.
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d8def45e38..dfd1c78549 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -353,7 +353,7 @@ class E2eKeysHandler(object):
             # make sure that each queried user appears in the result dict
             result_dict[user_id] = {}
 
-        results = await self.store.get_e2e_device_keys(local_query)
+        results = await self.store.get_e2e_device_keys_for_cs_api(local_query)
 
         # Build the result structure
         for user_id, device_keys in results.items():
@@ -734,7 +734,7 @@ class E2eKeysHandler(object):
             # fetch our stored devices.  This is used to 1. verify
             # signatures on the master key, and 2. to compare with what
             # was sent if the device was signed
-            devices = await self.store.get_e2e_device_keys([(user_id, None)])
+            devices = await self.store.get_e2e_device_keys_for_cs_api([(user_id, None)])
 
             if user_id not in devices:
                 raise NotFoundError("No device keys found")
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index af0b85e2c9..50ecddf7fa 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -23,6 +23,7 @@ from twisted.enterprise.adbapi import Connection
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import make_in_list_sql_clause
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.iterutils import batch_iter
@@ -33,17 +34,12 @@ if TYPE_CHECKING:
 
 class EndToEndKeyWorkerStore(SQLBaseStore):
     @trace
-    async def get_e2e_device_keys(
-        self, query_list, include_all_devices=False, include_deleted_devices=False
-    ):
-        """Fetch a list of device keys.
+    async def get_e2e_device_keys_for_cs_api(
+        self, query_list: List[Tuple[str, Optional[str]]]
+    ) -> Dict[str, Dict[str, JsonDict]]:
+        """Fetch a list of device keys, formatted suitably for the C/S API.
         Args:
             query_list(list): List of pairs of user_ids and device_ids.
-            include_all_devices (bool): whether to include entries for devices
-                that don't have device keys
-            include_deleted_devices (bool): whether to include null entries for
-                devices which no longer exist (but were in the query_list).
-                This option only takes effect if include_all_devices is true.
         Returns:
             Dict mapping from user-id to dict mapping from device_id to
             key data.  The key data will be a dict in the same format as the
@@ -54,11 +50,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             return {}
 
         results = await self.db_pool.runInteraction(
-            "get_e2e_device_keys",
-            self._get_e2e_device_keys_txn,
-            query_list,
-            include_all_devices,
-            include_deleted_devices,
+            "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list,
         )
 
         # Build the result structure, un-jsonify the results, and add the
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 261bf5b08b..3fc4bb13b6 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -37,7 +37,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         )
 
         res = yield defer.ensureDeferred(
-            self.store.get_e2e_device_keys((("user", "device"),))
+            self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
         )
         self.assertIn("user", res)
         self.assertIn("device", res["user"])
@@ -76,7 +76,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         )
 
         res = yield defer.ensureDeferred(
-            self.store.get_e2e_device_keys((("user", "device"),))
+            self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
         )
         self.assertIn("user", res)
         self.assertIn("device", res["user"])
@@ -108,7 +108,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         )
 
         res = yield defer.ensureDeferred(
-            self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2")))
+            self.store.get_e2e_device_keys_for_cs_api(
+                (("user1", "device1"), ("user2", "device2"))
+            )
         )
         self.assertIn("user1", res)
         self.assertIn("device1", res["user1"])