summary refs log tree commit diff
path: root/synapse/storage/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/devices.py')
-rw-r--r--synapse/storage/devices.py77
1 files changed, 50 insertions, 27 deletions
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index d149d8392e..d10ff9e4b9 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -13,15 +13,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import simplejson as json
+
+from six import iteritems, itervalues
+
+from canonicaljson import json
 
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
-from ._base import SQLBaseStore, Cache
-from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 
-from six import itervalues, iteritems
+from ._base import Cache, SQLBaseStore, db_to_json
 
 logger = logging.getLogger(__name__)
 
@@ -246,17 +249,31 @@ class DeviceStore(SQLBaseStore):
 
     def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
                                                    content, stream_id):
-        self._simple_upsert_txn(
-            txn,
-            table="device_lists_remote_cache",
-            keyvalues={
-                "user_id": user_id,
-                "device_id": device_id,
-            },
-            values={
-                "content": json.dumps(content),
-            }
-        )
+        if content.get("deleted"):
+            self._simple_delete_txn(
+                txn,
+                table="device_lists_remote_cache",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+            )
+
+            txn.call_after(
+                self.device_id_exists_cache.invalidate, (user_id, device_id,)
+            )
+        else:
+            self._simple_upsert_txn(
+                txn,
+                table="device_lists_remote_cache",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+                values={
+                    "content": json.dumps(content),
+                }
+            )
 
         txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
         txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
@@ -364,7 +381,7 @@ class DeviceStore(SQLBaseStore):
             now_stream_id = max(stream_id for stream_id in itervalues(query_map))
 
         devices = self._get_e2e_device_keys_txn(
-            txn, query_map.keys(), include_all_devices=True
+            txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
         )
 
         prev_sent_id_sql = """
@@ -391,12 +408,15 @@ class DeviceStore(SQLBaseStore):
 
                 prev_id = stream_id
 
-                key_json = device.get("key_json", None)
-                if key_json:
-                    result["keys"] = json.loads(key_json)
-                device_display_name = device.get("device_display_name", None)
-                if device_display_name:
-                    result["device_display_name"] = device_display_name
+                if device is not None:
+                    key_json = device.get("key_json", None)
+                    if key_json:
+                        result["keys"] = db_to_json(key_json)
+                    device_display_name = device.get("device_display_name", None)
+                    if device_display_name:
+                        result["device_display_name"] = device_display_name
+                else:
+                    result["deleted"] = True
 
                 results.append(result)
 
@@ -446,7 +466,7 @@ class DeviceStore(SQLBaseStore):
             retcol="content",
             desc="_get_cached_user_device",
         )
-        defer.returnValue(json.loads(content))
+        defer.returnValue(db_to_json(content))
 
     @cachedInlineCallbacks()
     def _get_cached_devices_for_user(self, user_id):
@@ -459,7 +479,7 @@ class DeviceStore(SQLBaseStore):
             desc="_get_cached_devices_for_user",
         )
         defer.returnValue({
-            device["device_id"]: json.loads(device["content"])
+            device["device_id"]: db_to_json(device["content"])
             for device in devices
         })
 
@@ -491,7 +511,7 @@ class DeviceStore(SQLBaseStore):
 
                 key_json = device.get("key_json", None)
                 if key_json:
-                    result["keys"] = json.loads(key_json)
+                    result["keys"] = db_to_json(key_json)
                 device_display_name = device.get("device_display_name", None)
                 if device_display_name:
                     result["device_display_name"] = device_display_name
@@ -692,6 +712,9 @@ class DeviceStore(SQLBaseStore):
 
             logger.info("Pruned %d device list outbound pokes", txn.rowcount)
 
-        return self.runInteraction(
-            "_prune_old_outbound_device_pokes", _prune_txn
+        return run_as_background_process(
+            "prune_old_outbound_device_pokes",
+            self.runInteraction,
+            "_prune_old_outbound_device_pokes",
+            _prune_txn,
         )