summary refs log tree commit diff
path: root/synapse/storage/end_to_end_keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/end_to_end_keys.py')
-rw-r--r--synapse/storage/end_to_end_keys.py16
1 files changed, 15 insertions, 1 deletions
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 7ae5c65482..f61553cec8 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -64,12 +64,17 @@ class EndToEndKeyStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def get_e2e_device_keys(self, query_list, include_all_devices=False):
+    def get_e2e_device_keys(
+        self, query_list, include_all_devices=False,
+        include_deleted_devices=False
+    ):
         """Fetch a list of device keys.
         Args:
             query_list(list): List of pairs of user_ids and device_ids.
             include_all_devices (bool): whether to include entries for devices
                 that don't have device keys
+            include_deleted_devices (bool): whether to include null entries for
+                devices which no longer exist (but where in the query_list)
         Returns:
             Dict mapping from user-id to dict mapping from device_id to
             dict containing "key_json", "device_display_name".
@@ -82,10 +87,19 @@ class EndToEndKeyStore(SQLBaseStore):
             query_list, include_all_devices,
         )
 
+        if include_deleted_devices:
+            deleted_devices = set(query_list)
+
         for user_id, device_keys in iteritems(results):
             for device_id, device_info in iteritems(device_keys):
+                if include_deleted_devices:
+                    deleted_devices -= (user_id, device_id)
                 device_info["keys"] = json.loads(device_info.pop("key_json"))
 
+        if include_deleted_devices:
+            for user_id, device_id in deleted_devices:
+                results.setdefault(user_id, {})[device_id] = None
+
         defer.returnValue(results)
 
     def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):