summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/devices.py18
-rw-r--r--synapse/storage/end_to_end_keys.py27
2 files changed, 25 insertions, 20 deletions
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 0c797f9f3e..203f50f07d 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -239,7 +239,6 @@ class DeviceStore(SQLBaseStore):
     def update_remote_device_list_cache_entry(self, user_id, device_id, content,
                                               stream_id):
         """Updates a single user's device in the cache.
-           If the content is null, delete the device from the cache.
         """
         return self.runInteraction(
             "update_remote_device_list_cache_entry",
@@ -249,7 +248,7 @@ class DeviceStore(SQLBaseStore):
 
     def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
                                                    content, stream_id):
-        if content is None:
+        if content.get("deleted"):
             self._simple_delete_txn(
                 txn,
                 table="device_lists_remote_cache",
@@ -409,12 +408,15 @@ class DeviceStore(SQLBaseStore):
 
                 prev_id = stream_id
 
-                key_json = device.get("key_json", None)
-                if key_json:
-                    result["keys"] = json.loads(key_json)
-                device_display_name = device.get("device_display_name", None)
-                if device_display_name:
-                    result["device_display_name"] = device_display_name
+                if device is not None:
+                    key_json = device.get("key_json", None)
+                    if key_json:
+                        result["keys"] = json.loads(key_json)
+                    device_display_name = device.get("device_display_name", None)
+                    if device_display_name:
+                        result["device_display_name"] = device_display_name
+                else:
+                    result["deleted"] = True
 
                 results.append(result)
 
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index f61553cec8..6c28719420 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -74,7 +74,7 @@ class EndToEndKeyStore(SQLBaseStore):
             include_all_devices (bool): whether to include entries for devices
                 that don't have device keys
             include_deleted_devices (bool): whether to include null entries for
-                devices which no longer exist (but where in the query_list)
+                devices which no longer exist (but were in the query_list)
         Returns:
             Dict mapping from user-id to dict mapping from device_id to
             dict containing "key_json", "device_display_name".
@@ -84,28 +84,25 @@ class EndToEndKeyStore(SQLBaseStore):
 
         results = yield self.runInteraction(
             "get_e2e_device_keys", self._get_e2e_device_keys_txn,
-            query_list, include_all_devices,
+            query_list, include_all_devices, include_deleted_devices,
         )
 
-        if include_deleted_devices:
-            deleted_devices = set(query_list)
-
         for user_id, device_keys in iteritems(results):
             for device_id, device_info in iteritems(device_keys):
-                if include_deleted_devices:
-                    deleted_devices -= (user_id, device_id)
                 device_info["keys"] = json.loads(device_info.pop("key_json"))
 
-        if include_deleted_devices:
-            for user_id, device_id in deleted_devices:
-                results.setdefault(user_id, {})[device_id] = None
-
         defer.returnValue(results)
 
-    def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
+    def _get_e2e_device_keys_txn(
+        self, txn, query_list, include_all_devices=False,
+        include_deleted_devices=False,
+    ):
         query_clauses = []
         query_params = []
 
+        if include_deleted_devices:
+            deleted_devices = set(query_list)
+
         for (user_id, device_id) in query_list:
             query_clause = "user_id = ?"
             query_params.append(user_id)
@@ -133,8 +130,14 @@ class EndToEndKeyStore(SQLBaseStore):
 
         result = {}
         for row in rows:
+            if include_deleted_devices:
+                deleted_devices.remove((row["user_id"], row["device_id"]))
             result.setdefault(row["user_id"], {})[row["device_id"]] = row
 
+        if include_deleted_devices:
+            for user_id, device_id in deleted_devices:
+                result.setdefault(user_id, {})[device_id] = None
+
         return result
 
     @defer.inlineCallbacks