summary refs log tree commit diff
path: root/synapse/storage/end_to_end_keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/end_to_end_keys.py')
-rw-r--r--synapse/storage/end_to_end_keys.py27
1 files changed, 15 insertions, 12 deletions
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index f61553cec8..6c28719420 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -74,7 +74,7 @@ class EndToEndKeyStore(SQLBaseStore):
             include_all_devices (bool): whether to include entries for devices
                 that don't have device keys
             include_deleted_devices (bool): whether to include null entries for
-                devices which no longer exist (but where in the query_list)
+                devices which no longer exist (but were in the query_list)
         Returns:
             Dict mapping from user-id to dict mapping from device_id to
             dict containing "key_json", "device_display_name".
@@ -84,28 +84,25 @@ class EndToEndKeyStore(SQLBaseStore):
 
         results = yield self.runInteraction(
             "get_e2e_device_keys", self._get_e2e_device_keys_txn,
-            query_list, include_all_devices,
+            query_list, include_all_devices, include_deleted_devices,
         )
 
-        if include_deleted_devices:
-            deleted_devices = set(query_list)
-
         for user_id, device_keys in iteritems(results):
             for device_id, device_info in iteritems(device_keys):
-                if include_deleted_devices:
-                    deleted_devices -= (user_id, device_id)
                 device_info["keys"] = json.loads(device_info.pop("key_json"))
 
-        if include_deleted_devices:
-            for user_id, device_id in deleted_devices:
-                results.setdefault(user_id, {})[device_id] = None
-
         defer.returnValue(results)
 
-    def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
+    def _get_e2e_device_keys_txn(
+        self, txn, query_list, include_all_devices=False,
+        include_deleted_devices=False,
+    ):
         query_clauses = []
         query_params = []
 
+        if include_deleted_devices:
+            deleted_devices = set(query_list)
+
         for (user_id, device_id) in query_list:
             query_clause = "user_id = ?"
             query_params.append(user_id)
@@ -133,8 +130,14 @@ class EndToEndKeyStore(SQLBaseStore):
 
         result = {}
         for row in rows:
+            if include_deleted_devices:
+                deleted_devices.remove((row["user_id"], row["device_id"]))
             result.setdefault(row["user_id"], {})[row["device_id"]] = row
 
+        if include_deleted_devices:
+            for user_id, device_id in deleted_devices:
+                result.setdefault(user_id, {})[device_id] = None
+
         return result
 
     @defer.inlineCallbacks