summary refs log tree commit diff
path: root/synapse/storage/end_to_end_keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/end_to_end_keys.py')
-rw-r--r--synapse/storage/end_to_end_keys.py67
1 files changed, 39 insertions, 28 deletions
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index ff3ce84eb1..f08a4efbfc 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -66,7 +66,7 @@ class EndToEndKeyStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_e2e_device_keys(
         self, query_list, include_all_devices=False,
-        include_deleted_devices=False,
+        include_deleted_devices=False, req_user_id=None,
     ):
         """Fetch a list of device keys.
         Args:
@@ -76,6 +76,7 @@ class EndToEndKeyStore(SQLBaseStore):
             include_deleted_devices (bool): whether to include null entries for
                 devices which no longer exist (but were in the query_list).
                 This option only takes effect if include_all_devices is true.
+            req_user_id: The user requesting the device list
         Returns:
             Dict mapping from user-id to dict mapping from device_id to
             dict containing "key_json", "device_display_name".
@@ -85,7 +86,7 @@ class EndToEndKeyStore(SQLBaseStore):
 
         results = yield self.runInteraction(
             "get_e2e_device_keys", self._get_e2e_device_keys_txn,
-            query_list, include_all_devices, include_deleted_devices,
+            query_list, include_all_devices, include_deleted_devices, req_user_id,
         )
 
         for user_id, device_keys in iteritems(results):
@@ -96,7 +97,7 @@ class EndToEndKeyStore(SQLBaseStore):
 
     def _get_e2e_device_keys_txn(
         self, txn, query_list, include_all_devices=False,
-        include_deleted_devices=False,
+        include_deleted_devices=False, req_user_id=None,
     ):
         query_clauses = []
         query_params = []
@@ -142,6 +143,16 @@ class EndToEndKeyStore(SQLBaseStore):
             for user_id, device_id in deleted_devices:
                 result.setdefault(user_id, {})[device_id] = None
 
+        attestations = self._get_e2e_attestations_txn(txn, req_user_id, query_list)
+
+        for attestation in attestations:
+            user_id = attestation["user_id"]
+            device_id = attestation["device_id"]
+            # FIXME: combine signatures of the same payload?
+            if user_id in result and device_id in result[user_id]:
+                result[user_id][device_id].setdefault("attestations", []) \
+                    .append(attestation)
+
         return result
 
     @defer.inlineCallbacks
@@ -305,35 +316,35 @@ class EndToEndKeyStore(SQLBaseStore):
             "add_e2e_attestations", add_e2e_attestations_txn
         )
 
-    @defer.inlineCallbacks
-    def get_e2e_attestations(self, from_user_id, query_list):
-        def get_e2e_attestations_txn(txn):
-            query_clauses = []
-            query_params = []
+    def _get_e2e_attestations_txn(self, txn, from_user_id, query_list):
+
+        query_clauses = []
+        query_params = []
 
-            for (user_id, device_id) in query_list:
-                query_clause = "(from_user_id = ? OR from_user_id = ?) AND user_id = ?"
+        for (user_id, device_id) in query_list:
+            if from_user_id:
+                query_clause = "(from_user_id = ? OR from_user_id = ?)"
                 query_params.append(from_user_id)
-                query_params.append(user_id)
-                query_params.append(user_id)
+            else:
+                query_clause = "(from_user_id = ?)"
+            query_params.append(user_id)
 
-                if device_id is not None:
-                    query_clause += " AND device_id = ?"
-                    query_params.append(device_id)
+            query_clause += " AND user_id = ?"
+            query_params.append(user_id)
 
-                query_clauses.append(query_clause)
+            if device_id is not None:
+                query_clause += " AND device_id = ?"
+                query_params.append(device_id)
 
-            sql = (
-                "SELECT attestation "
-                "  FROM e2e_attestations "
-                " WHERE %s"
-            ) % (
-                " OR ".join("(" + q + ")" for q in query_clauses)
-            )
+            query_clauses.append(query_clause)
 
-            txn.execute(sql, query_params)
-            return [json.loads(row[0]) for row in txn]
-        results = yield self.runInteraction(
-            "get_e2e_attestations", get_e2e_attestations_txn
+        sql = (
+            "SELECT attestation "
+            "  FROM e2e_attestations "
+            " WHERE %s"
+        ) % (
+            " OR ".join("(" + q + ")" for q in query_clauses)
         )
-        defer.returnValue(results)
+
+        txn.execute(sql, query_params)
+        return [json.loads(row[0]) for row in txn]