summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/devices.py4
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py52
2 files changed, 39 insertions, 17 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 8bedcdbdff..f8fe948122 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -255,9 +255,7 @@ class DeviceWorkerStore(SQLBaseStore):
             List of objects representing an device update EDU
         """
         devices = (
-            await self.db_pool.runInteraction(
-                "get_e2e_device_keys_and_signatures_txn",
-                self._get_e2e_device_keys_and_signatures_txn,
+            await self.get_e2e_device_keys_and_signatures(
                 query_map.keys(),
                 include_all_devices=True,
                 include_deleted_devices=True,
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4059701cfd..cc0b15ae07 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
 
 @attr.s
 class DeviceKeyLookupResult:
-    """The type returned by _get_e2e_device_keys_and_signatures_txn"""
+    """The type returned by get_e2e_device_keys_and_signatures"""
 
     display_name = attr.ib(type=Optional[str])
 
@@ -60,11 +60,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         """
         now_stream_id = self.get_device_stream_token()
 
-        devices = await self.db_pool.runInteraction(
-            "get_e2e_device_keys_and_signatures_txn",
-            self._get_e2e_device_keys_and_signatures_txn,
-            [(user_id, None)],
-        )
+        devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
 
         if devices:
             user_devices = devices[user_id]
@@ -108,11 +104,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         if not query_list:
             return {}
 
-        results = await self.db_pool.runInteraction(
-            "get_e2e_device_keys_and_signatures_txn",
-            self._get_e2e_device_keys_and_signatures_txn,
-            query_list,
-        )
+        results = await self.get_e2e_device_keys_and_signatures(query_list)
 
         # Build the result structure, un-jsonify the results, and add the
         # "unsigned" section
@@ -135,12 +127,45 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         return rv
 
     @trace
-    def _get_e2e_device_keys_and_signatures_txn(
-        self, txn, query_list, include_all_devices=False, include_deleted_devices=False
+    async def get_e2e_device_keys_and_signatures(
+        self,
+        query_list: List[Tuple[str, Optional[str]]],
+        include_all_devices: bool = False,
+        include_deleted_devices: bool = False,
     ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+        """Fetch a list of device keys, together with their cross-signatures.
+
+        Args:
+            query_list: List of pairs of user_ids and device_ids. Device id can be None
+                to indicate "all devices for this user"
+
+            include_all_devices: whether to return devices without device keys
+
+            include_deleted_devices: whether to include null entries for
+                devices which no longer exist (but were in the query_list).
+                This option only takes effect if include_all_devices is true.
+
+        Returns:
+            Dict mapping from user-id to dict mapping from device_id to
+            key data.
+        """
         set_tag("include_all_devices", include_all_devices)
         set_tag("include_deleted_devices", include_deleted_devices)
 
+        result = await self.db_pool.runInteraction(
+            "get_e2e_device_keys",
+            self._get_e2e_device_keys_and_signatures_txn,
+            query_list,
+            include_all_devices,
+            include_deleted_devices,
+        )
+
+        log_kv(result)
+        return result
+
+    def _get_e2e_device_keys_and_signatures_txn(
+        self, txn, query_list, include_all_devices=False, include_deleted_devices=False
+    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
         query_clauses = []
         query_params = []
         signature_query_clauses = []
@@ -230,7 +255,6 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             )
             signing_user_signatures[signing_key_id] = signature
 
-        log_kv(result)
         return result
 
     async def get_e2e_one_time_keys(