diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index ec68e39f1e..c0943ecf91 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -21,6 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import StoreError
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from ._base import Cache, SQLBaseStore
@@ -248,17 +249,31 @@ class DeviceStore(SQLBaseStore):
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
content, stream_id):
- self._simple_upsert_txn(
- txn,
- table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
- values={
- "content": json.dumps(content),
- }
- )
+ if content.get("deleted"):
+ self._simple_delete_txn(
+ txn,
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ )
+
+ txn.call_after(
+ self.device_id_exists_cache.invalidate, (user_id, device_id,)
+ )
+ else:
+ self._simple_upsert_txn(
+ txn,
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ values={
+ "content": json.dumps(content),
+ }
+ )
txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
@@ -366,7 +381,7 @@ class DeviceStore(SQLBaseStore):
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn(
- txn, query_map.keys(), include_all_devices=True
+ txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
)
prev_sent_id_sql = """
@@ -393,12 +408,15 @@ class DeviceStore(SQLBaseStore):
prev_id = stream_id
- key_json = device.get("key_json", None)
- if key_json:
- result["keys"] = json.loads(key_json)
- device_display_name = device.get("device_display_name", None)
- if device_display_name:
- result["device_display_name"] = device_display_name
+ if device is not None:
+ key_json = device.get("key_json", None)
+ if key_json:
+ result["keys"] = json.loads(key_json)
+ device_display_name = device.get("device_display_name", None)
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+ else:
+ result["deleted"] = True
results.append(result)
@@ -694,6 +712,9 @@ class DeviceStore(SQLBaseStore):
logger.info("Pruned %d device list outbound pokes", txn.rowcount)
- return self.runInteraction(
- "_prune_old_outbound_device_pokes", _prune_txn
+ return run_as_background_process(
+ "prune_old_outbound_device_pokes",
+ self.runInteraction,
+ "_prune_old_outbound_device_pokes",
+ _prune_txn,
)
|