summary refs log tree commit diff
path: root/synapse/storage/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/devices.py')
-rw-r--r--synapse/storage/devices.py74
1 files changed, 47 insertions, 27 deletions
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index bb27fd1f70..cc3cdf2ebc 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -13,21 +13,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import ujson as json
+
+from six import iteritems, itervalues
+
+from canonicaljson import json
 
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
-from ._base import SQLBaseStore, Cache
-from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 
+from ._base import Cache, SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
 
 class DeviceStore(SQLBaseStore):
-    def __init__(self, hs):
-        super(DeviceStore, self).__init__(hs)
+    def __init__(self, db_conn, hs):
+        super(DeviceStore, self).__init__(db_conn, hs)
 
         # Map of (user_id, device_id) -> bool. If there is an entry that implies
         # the device exists.
@@ -245,17 +248,31 @@ class DeviceStore(SQLBaseStore):
 
     def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
                                                    content, stream_id):
-        self._simple_upsert_txn(
-            txn,
-            table="device_lists_remote_cache",
-            keyvalues={
-                "user_id": user_id,
-                "device_id": device_id,
-            },
-            values={
-                "content": json.dumps(content),
-            }
-        )
+        if content.get("deleted"):
+            self._simple_delete_txn(
+                txn,
+                table="device_lists_remote_cache",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+            )
+
+            txn.call_after(
+                self.device_id_exists_cache.invalidate, (user_id, device_id,)
+            )
+        else:
+            self._simple_upsert_txn(
+                txn,
+                table="device_lists_remote_cache",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+                values={
+                    "content": json.dumps(content),
+                }
+            )
 
         txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
         txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
@@ -360,10 +377,10 @@ class DeviceStore(SQLBaseStore):
             return (now_stream_id, [])
 
         if len(query_map) >= 20:
-            now_stream_id = max(stream_id for stream_id in query_map.itervalues())
+            now_stream_id = max(stream_id for stream_id in itervalues(query_map))
 
         devices = self._get_e2e_device_keys_txn(
-            txn, query_map.keys(), include_all_devices=True
+            txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
         )
 
         prev_sent_id_sql = """
@@ -373,13 +390,13 @@ class DeviceStore(SQLBaseStore):
         """
 
         results = []
-        for user_id, user_devices in devices.iteritems():
+        for user_id, user_devices in iteritems(devices):
             # The prev_id for the first row is always the last row before
             # `from_stream_id`
             txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
             rows = txn.fetchall()
             prev_id = rows[0][0]
-            for device_id, device in user_devices.iteritems():
+            for device_id, device in iteritems(user_devices):
                 stream_id = query_map[(user_id, device_id)]
                 result = {
                     "user_id": user_id,
@@ -390,12 +407,15 @@ class DeviceStore(SQLBaseStore):
 
                 prev_id = stream_id
 
-                key_json = device.get("key_json", None)
-                if key_json:
-                    result["keys"] = json.loads(key_json)
-                device_display_name = device.get("device_display_name", None)
-                if device_display_name:
-                    result["device_display_name"] = device_display_name
+                if device is not None:
+                    key_json = device.get("key_json", None)
+                    if key_json:
+                        result["keys"] = json.loads(key_json)
+                    device_display_name = device.get("device_display_name", None)
+                    if device_display_name:
+                        result["device_display_name"] = device_display_name
+                else:
+                    result["deleted"] = True
 
                 results.append(result)
 
@@ -483,7 +503,7 @@ class DeviceStore(SQLBaseStore):
         if devices:
             user_devices = devices[user_id]
             results = []
-            for device_id, device in user_devices.iteritems():
+            for device_id, device in iteritems(user_devices):
                 result = {
                     "device_id": device_id,
                 }