summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/device.py2
-rw-r--r--synapse/storage/databases/main/deviceinbox.py15
2 files changed, 10 insertions, 7 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 544bc7c13d..3ce96ef3cb 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -592,6 +592,8 @@ class DeviceHandler(DeviceWorkerHandler):
                 )
 
             # Delete device messages asynchronously and in batches using the task scheduler
+            # We specify an upper stream id to avoid deleting non delivered messages
+            # if an user re-uses a device ID.
             await self._task_scheduler.schedule_task(
                 DELETE_DEVICE_MSGS_TASK_NAME,
                 resource_id=device_id,
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 1faa6f04b2..3e7425d4a6 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -478,18 +478,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 log_kv({"message": "No changes in cache since last check"})
                 return 0
 
-        ROW_ID_NAME = self.database_engine.row_id_name
-
         def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
             limit_statement = "" if limit is None else f"LIMIT {limit}"
             sql = f"""
-                DELETE FROM device_inbox WHERE {ROW_ID_NAME} IN (
-                  SELECT {ROW_ID_NAME} FROM device_inbox
-                  WHERE user_id = ? AND device_id = ? AND stream_id <= ?
-                  {limit_statement}
+                DELETE FROM device_inbox WHERE user_id = ? AND device_id = ? AND stream_id <= (
+                  SELECT MAX(stream_id) FROM (
+                    SELECT stream_id FROM device_inbox
+                    WHERE user_id = ? AND device_id = ? AND stream_id <= ?
+                    ORDER BY stream_id
+                    {limit_statement}
+                  ) AS q1
                 )
                 """
-            txn.execute(sql, (user_id, device_id, up_to_stream_id))
+            txn.execute(sql, (user_id, device_id, user_id, device_id, up_to_stream_id))
             return txn.rowcount
 
         count = await self.db_pool.runInteraction(