summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/devices.py40
-rw-r--r--synapse/storage/end_to_end_keys.py16
2 files changed, 43 insertions, 13 deletions
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index ec68e39f1e..0c797f9f3e 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -239,6 +239,7 @@ class DeviceStore(SQLBaseStore):
     def update_remote_device_list_cache_entry(self, user_id, device_id, content,
                                               stream_id):
         """Updates a single user's device in the cache.
+           If the content is null, delete the device from the cache.
         """
         return self.runInteraction(
             "update_remote_device_list_cache_entry",
@@ -248,17 +249,32 @@ class DeviceStore(SQLBaseStore):
 
     def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
                                                    content, stream_id):
-        self._simple_upsert_txn(
-            txn,
-            table="device_lists_remote_cache",
-            keyvalues={
-                "user_id": user_id,
-                "device_id": device_id,
-            },
-            values={
-                "content": json.dumps(content),
-            }
-        )
+        if content is None:
+            self._simple_delete_txn(
+                txn,
+                table="device_lists_remote_cache",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+            )
+
+            # Do we need this?
+            txn.call_after(
+                self.device_id_exists_cache.invalidate, (user_id, device_id,)
+            )
+        else:
+            self._simple_upsert_txn(
+                txn,
+                table="device_lists_remote_cache",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+                values={
+                    "content": json.dumps(content),
+                }
+            )
 
         txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
         txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
@@ -366,7 +382,7 @@ class DeviceStore(SQLBaseStore):
             now_stream_id = max(stream_id for stream_id in itervalues(query_map))
 
         devices = self._get_e2e_device_keys_txn(
-            txn, query_map.keys(), include_all_devices=True
+            txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
         )
 
         prev_sent_id_sql = """
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 7ae5c65482..f61553cec8 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -64,12 +64,17 @@ class EndToEndKeyStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def get_e2e_device_keys(self, query_list, include_all_devices=False):
+    def get_e2e_device_keys(
+        self, query_list, include_all_devices=False,
+        include_deleted_devices=False
+    ):
         """Fetch a list of device keys.
         Args:
             query_list(list): List of pairs of user_ids and device_ids.
             include_all_devices (bool): whether to include entries for devices
                 that don't have device keys
+            include_deleted_devices (bool): whether to include null entries for
+                devices which no longer exist (but where in the query_list)
         Returns:
             Dict mapping from user-id to dict mapping from device_id to
             dict containing "key_json", "device_display_name".
@@ -82,10 +87,19 @@ class EndToEndKeyStore(SQLBaseStore):
             query_list, include_all_devices,
         )
 
+        if include_deleted_devices:
+            deleted_devices = set(query_list)
+
         for user_id, device_keys in iteritems(results):
             for device_id, device_info in iteritems(device_keys):
+                if include_deleted_devices:
+                    deleted_devices -= (user_id, device_id)
                 device_info["keys"] = json.loads(device_info.pop("key_json"))
 
+        if include_deleted_devices:
+            for user_id, device_id in deleted_devices:
+                results.setdefault(user_id, {})[device_id] = None
+
         defer.returnValue(results)
 
     def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):