summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16643.misc1
-rw-r--r--synapse/handlers/device.py8
-rw-r--r--synapse/storage/databases/main/deviceinbox.py106
-rw-r--r--synapse/util/task_scheduler.py2
4 files changed, 88 insertions, 29 deletions
diff --git a/changelog.d/16643.misc b/changelog.d/16643.misc
new file mode 100644
index 0000000000..cc0cf0901f
--- /dev/null
+++ b/changelog.d/16643.misc
@@ -0,0 +1 @@
+Speed up deleting of device messages when deleting a device.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 93472d0117..1af6d77545 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -396,15 +396,17 @@ class DeviceWorkerHandler:
         up_to_stream_id = task.params["up_to_stream_id"]
 
         # Delete the messages in batches to avoid too much DB load.
+        from_stream_id = None
         while True:
-            res = await self.store.delete_messages_for_device(
+            from_stream_id, _ = await self.store.delete_messages_for_device_between(
                 user_id=user_id,
                 device_id=device_id,
-                up_to_stream_id=up_to_stream_id,
+                from_stream_id=from_stream_id,
+                to_stream_id=up_to_stream_id,
                 limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT,
             )
 
-            if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT:
+            if from_stream_id is None:
                 return TaskStatus.COMPLETE, None, None
 
             await self.clock.sleep(DeviceHandler.DEVICE_MSGS_DELETE_SLEEP_MS / 1000.0)
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 3e7425d4a6..02dddd1da4 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -450,14 +450,12 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         user_id: str,
         device_id: Optional[str],
         up_to_stream_id: int,
-        limit: Optional[int] = None,
     ) -> int:
         """
         Args:
             user_id: The recipient user_id.
             device_id: The recipient device_id.
             up_to_stream_id: Where to delete messages up to.
-            limit: maximum number of messages to delete
 
         Returns:
             The number of messages deleted.
@@ -478,32 +476,22 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 log_kv({"message": "No changes in cache since last check"})
                 return 0
 
-        def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
-            limit_statement = "" if limit is None else f"LIMIT {limit}"
-            sql = f"""
-                DELETE FROM device_inbox WHERE user_id = ? AND device_id = ? AND stream_id <= (
-                  SELECT MAX(stream_id) FROM (
-                    SELECT stream_id FROM device_inbox
-                    WHERE user_id = ? AND device_id = ? AND stream_id <= ?
-                    ORDER BY stream_id
-                    {limit_statement}
-                  ) AS q1
-                )
-                """
-            txn.execute(sql, (user_id, device_id, user_id, device_id, up_to_stream_id))
-            return txn.rowcount
-
-        count = await self.db_pool.runInteraction(
-            "delete_messages_for_device", delete_messages_for_device_txn
-        )
+        from_stream_id = None
+        count = 0
+        while True:
+            from_stream_id, loop_count = await self.delete_messages_for_device_between(
+                user_id,
+                device_id,
+                from_stream_id=from_stream_id,
+                to_stream_id=up_to_stream_id,
+                limit=1000,
+            )
+            count += loop_count
+            if from_stream_id is None:
+                break
 
         log_kv({"message": f"deleted {count} messages for device", "count": count})
 
-        # In this case we don't know if we hit the limit or the delete is complete
-        # so let's not update the cache.
-        if count == limit:
-            return count
-
         # Update the cache, ensuring that we only ever increase the value
         updated_last_deleted_stream_id = self._last_device_delete_cache.get(
             (user_id, device_id), 0
@@ -515,6 +503,74 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         return count
 
     @trace
+    async def delete_messages_for_device_between(
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        from_stream_id: Optional[int],
+        to_stream_id: int,
+        limit: int,
+    ) -> Tuple[Optional[int], int]:
+        """Delete N device messages between the stream IDs, returning the
+        highest stream ID deleted (or None if all messages in the range have
+        been deleted) and the number of messages deleted.
+
+        This is more efficient than `delete_messages_for_device` when calling in
+        a loop to batch delete messages.
+        """
+
+        # Keeping track of a lower bound of stream ID where we've deleted
+        # everything below makes the queries much faster. Otherwise, every time
+        # we scan for rows to delete we'd re-scan across all the rows that have
+        # previously deleted (until the next table VACUUM).
+
+        if from_stream_id is None:
+            # Minimum device stream ID is 1.
+            from_stream_id = 0
+
+        def delete_messages_for_device_between_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[Optional[int], int]:
+            txn.execute(
+                """
+                SELECT MAX(stream_id) FROM (
+                    SELECT stream_id FROM device_inbox
+                    WHERE user_id = ? AND device_id = ?
+                        AND ? < stream_id AND stream_id <= ?
+                    ORDER BY stream_id
+                    LIMIT ?
+                ) AS d
+                """,
+                (user_id, device_id, from_stream_id, to_stream_id, limit),
+            )
+            row = txn.fetchone()
+            if row is None or row[0] is None:
+                return None, 0
+
+            (max_stream_id,) = row
+
+            txn.execute(
+                """
+                DELETE FROM device_inbox
+                WHERE user_id = ? AND device_id = ?
+                AND ? < stream_id AND stream_id <= ?
+                """,
+                (user_id, device_id, from_stream_id, max_stream_id),
+            )
+
+            num_deleted = txn.rowcount
+            if num_deleted < limit:
+                return None, num_deleted
+
+            return max_stream_id, num_deleted
+
+        return await self.db_pool.runInteraction(
+            "delete_messages_for_device_between",
+            delete_messages_for_device_between_txn,
+            db_autocommit=True,  # We don't need to run in a transaction
+        )
+
+    @trace
     async def get_new_device_msgs_for_remote(
         self, destination: str, last_stream_id: int, current_stream_id: int, limit: int
     ) -> Tuple[List[JsonDict], int]:
diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py
index caf13b3474..29c561e555 100644
--- a/synapse/util/task_scheduler.py
+++ b/synapse/util/task_scheduler.py
@@ -193,7 +193,7 @@ class TaskScheduler:
         result: Optional[JsonMapping] = None,
         error: Optional[str] = None,
     ) -> bool:
-        """Update some task associated values. This is exposed publically so it can
+        """Update some task associated values. This is exposed publicly so it can
         be used inside task functions, mainly to update the result and be able to
         resume a task at a specific step after a restart of synapse.