summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16240.misc1
-rw-r--r--synapse/handlers/device.py48
-rw-r--r--synapse/handlers/presence.py4
-rw-r--r--synapse/handlers/sync.py16
-rw-r--r--synapse/storage/databases/main/deviceinbox.py26
-rw-r--r--synapse/storage/databases/main/devices.py8
-rw-r--r--synapse/storage/databases/main/receipts.py6
-rw-r--r--synapse/storage/engines/_base.py6
-rw-r--r--synapse/storage/engines/postgres.py4
-rw-r--r--synapse/storage/engines/sqlite.py4
-rw-r--r--synapse/storage/schema/main/delta/48/group_unique_indexes.py4
-rw-r--r--synapse/util/task_scheduler.py17
-rw-r--r--tests/handlers/test_device.py47
13 files changed, 154 insertions, 37 deletions
diff --git a/changelog.d/16240.misc b/changelog.d/16240.misc
new file mode 100644
index 0000000000..4f266c1fb0
--- /dev/null
+++ b/changelog.d/16240.misc
@@ -0,0 +1 @@
+Delete device messages asynchronously and in staged batches using the task scheduler.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 763f56dfc1..9e52af5f13 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -43,9 +43,12 @@ from synapse.metrics.background_process_metrics import (
 )
 from synapse.types import (
     JsonDict,
+    JsonMapping,
+    ScheduledTask,
     StrCollection,
     StreamKeyType,
     StreamToken,
+    TaskStatus,
     UserID,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
@@ -62,6 +65,7 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+DELETE_DEVICE_MSGS_TASK_NAME = "delete_device_messages"
 MAX_DEVICE_DISPLAY_NAME_LEN = 100
 DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
 
@@ -78,6 +82,7 @@ class DeviceWorkerHandler:
         self._appservice_handler = hs.get_application_service_handler()
         self._state_storage = hs.get_storage_controllers().state
         self._auth_handler = hs.get_auth_handler()
+        self._event_sources = hs.get_event_sources()
         self.server_name = hs.hostname
         self._msc3852_enabled = hs.config.experimental.msc3852_enabled
         self._query_appservices_for_keys = (
@@ -386,6 +391,7 @@ class DeviceHandler(DeviceWorkerHandler):
         self._account_data_handler = hs.get_account_data_handler()
         self._storage_controllers = hs.get_storage_controllers()
         self.db_pool = hs.get_datastores().main.db_pool
+        self._task_scheduler = hs.get_task_scheduler()
 
         self.device_list_updater = DeviceListUpdater(hs, self)
 
@@ -419,6 +425,10 @@ class DeviceHandler(DeviceWorkerHandler):
                 self._delete_stale_devices,
             )
 
+        self._task_scheduler.register_action(
+            self._delete_device_messages, DELETE_DEVICE_MSGS_TASK_NAME
+        )
+
     def _check_device_name_length(self, name: Optional[str]) -> None:
         """
         Checks whether a device name is longer than the maximum allowed length.
@@ -530,6 +540,7 @@ class DeviceHandler(DeviceWorkerHandler):
             user_id: The user to delete devices from.
             device_ids: The list of device IDs to delete
         """
+        to_device_stream_id = self._event_sources.get_current_token().to_device_key
 
         try:
             await self.store.delete_devices(user_id, device_ids)
@@ -559,12 +570,49 @@ class DeviceHandler(DeviceWorkerHandler):
                     f"org.matrix.msc3890.local_notification_settings.{device_id}",
                 )
 
+            # Delete device messages asynchronously and in batches using the task scheduler
+            await self._task_scheduler.schedule_task(
+                DELETE_DEVICE_MSGS_TASK_NAME,
+                resource_id=device_id,
+                params={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "up_to_stream_id": to_device_stream_id,
+                },
+            )
+
         # Pushers are deleted after `delete_access_tokens_for_user` is called so that
         # modules using `on_logged_out` hook can use them if needed.
         await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids)
 
         await self.notify_device_update(user_id, device_ids)
 
+    DEVICE_MSGS_DELETE_BATCH_LIMIT = 100
+
+    async def _delete_device_messages(
+        self,
+        task: ScheduledTask,
+    ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+        """Scheduler task to delete device messages in batch of `DEVICE_MSGS_DELETE_BATCH_LIMIT`."""
+        assert task.params is not None
+        user_id = task.params["user_id"]
+        device_id = task.params["device_id"]
+        up_to_stream_id = task.params["up_to_stream_id"]
+
+        res = await self.store.delete_messages_for_device(
+            user_id=user_id,
+            device_id=device_id,
+            up_to_stream_id=up_to_stream_id,
+            limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT,
+        )
+
+        if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT:
+            return TaskStatus.COMPLETE, None, None
+        else:
+            # There is probably still device messages to be deleted, let's keep the task active and it will be run
+            # again in a subsequent scheduler loop run (probably the next one, if not too many tasks are running).
+            return TaskStatus.ACTIVE, None, None
+
     async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
         """Update the given device
 
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index a4b05b72e7..375c7d0901 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -183,6 +183,7 @@ class BasePresenceHandler(abc.ABC):
     writer"""
 
     def __init__(self, hs: "HomeServer"):
+        self.hs = hs
         self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
         self._storage_controllers = hs.get_storage_controllers()
@@ -473,8 +474,6 @@ class _NullContextManager(ContextManager[None]):
 class WorkerPresenceHandler(BasePresenceHandler):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
-        self.hs = hs
-
         self._presence_writer_instance = hs.config.worker.writers.presence[0]
 
         # Route presence EDUs to the right worker
@@ -738,7 +737,6 @@ class WorkerPresenceHandler(BasePresenceHandler):
 class PresenceHandler(BasePresenceHandler):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
-        self.hs = hs
         self.wheel_timer: WheelTimer[str] = WheelTimer()
         self.notifier = hs.get_notifier()
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 60a9f341b5..0ccd7d250c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -40,6 +40,7 @@ from synapse.api.filtering import FilterCollection
 from synapse.api.presence import UserPresenceState
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
+from synapse.handlers.device import DELETE_DEVICE_MSGS_TASK_NAME
 from synapse.handlers.relations import BundledAggregations
 from synapse.logging import issue9533_logger
 from synapse.logging.context import current_context
@@ -268,6 +269,7 @@ class SyncHandler:
         self._storage_controllers = hs.get_storage_controllers()
         self._state_storage_controller = self._storage_controllers.state
         self._device_handler = hs.get_device_handler()
+        self._task_scheduler = hs.get_task_scheduler()
 
         self.should_calculate_push_rules = hs.config.push.enable_push
 
@@ -360,11 +362,19 @@ class SyncHandler:
         # (since we now know that the device has received them)
         if since_token is not None:
             since_stream_id = since_token.to_device_key
-            deleted = await self.store.delete_messages_for_device(
-                sync_config.user.to_string(), sync_config.device_id, since_stream_id
+            # Delete device messages asynchronously and in batches using the task scheduler
+            await self._task_scheduler.schedule_task(
+                DELETE_DEVICE_MSGS_TASK_NAME,
+                resource_id=sync_config.device_id,
+                params={
+                    "user_id": sync_config.user.to_string(),
+                    "device_id": sync_config.device_id,
+                    "up_to_stream_id": since_stream_id,
+                },
             )
             logger.debug(
-                "Deleted %d to-device messages up to %d", deleted, since_stream_id
+                "Deletion of to-device messages up to %d scheduled",
+                since_stream_id,
             )
 
         if timeout == 0 or since_token is None or full_state:
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 271cdf923c..744e98c6d0 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -445,13 +445,18 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
     @trace
     async def delete_messages_for_device(
-        self, user_id: str, device_id: Optional[str], up_to_stream_id: int
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        up_to_stream_id: int,
+        limit: int,
     ) -> int:
         """
         Args:
             user_id: The recipient user_id.
             device_id: The recipient device_id.
             up_to_stream_id: Where to delete messages up to.
+            limit: maximum number of messages to delete
 
         Returns:
             The number of messages deleted.
@@ -472,12 +477,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 log_kv({"message": "No changes in cache since last check"})
                 return 0
 
+        ROW_ID_NAME = self.database_engine.row_id_name
+
         def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
-            sql = (
-                "DELETE FROM device_inbox"
-                " WHERE user_id = ? AND device_id = ?"
-                " AND stream_id <= ?"
-            )
+            sql = f"""
+                DELETE FROM device_inbox WHERE {ROW_ID_NAME} IN (
+                  SELECT {ROW_ID_NAME} FROM device_inbox
+                  WHERE user_id = ? AND device_id = ? AND stream_id <= ?
+                  LIMIT {limit}
+                )
+                """
             txn.execute(sql, (user_id, device_id, up_to_stream_id))
             return txn.rowcount
 
@@ -487,6 +496,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
         log_kv({"message": f"deleted {count} messages for device", "count": count})
 
+        # In this case we don't know if we hit the limit or the delete is complete
+        # so let's not update the cache.
+        if count == limit:
+            return count
+
         # Update the cache, ensuring that we only ever increase the value
         updated_last_deleted_stream_id = self._last_device_delete_cache.get(
             (user_id, device_id), 0
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index fa69a4a298..7208fc8b33 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1768,14 +1768,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
             self.db_pool.simple_delete_many_txn(
                 txn,
-                table="device_inbox",
-                column="device_id",
-                values=device_ids,
-                keyvalues={"user_id": user_id},
-            )
-
-            self.db_pool.simple_delete_many_txn(
-                txn,
                 table="device_auth_providers",
                 column="device_id",
                 values=device_ids,
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 5ee5c7ad9f..e4d10ff250 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -939,11 +939,7 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
         receipts."""
 
         def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None:
-            if isinstance(self.database_engine, PostgresEngine):
-                ROW_ID_NAME = "ctid"
-            else:
-                ROW_ID_NAME = "rowid"
-
+            ROW_ID_NAME = self.database_engine.row_id_name
             # Identify any duplicate receipts arising from
             # https://github.com/matrix-org/synapse/issues/14406.
             # The following query takes less than a minute on matrix.org.
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 0b5b3bf03e..b1a2418cbd 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -100,6 +100,12 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM
         """Gets a string giving the server version. For example: '3.22.0'"""
         ...
 
+    @property
+    @abc.abstractmethod
+    def row_id_name(self) -> str:
+        """Gets the literal name representing a row id for this engine."""
+        ...
+
     @abc.abstractmethod
     def in_transaction(self, conn: ConnectionType) -> bool:
         """Whether the connection is currently in a transaction."""
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 05a72dc554..6309363217 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -211,6 +211,10 @@ class PostgresEngine(
         else:
             return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
 
+    @property
+    def row_id_name(self) -> str:
+        return "ctid"
+
     def in_transaction(self, conn: psycopg2.extensions.connection) -> bool:
         return conn.status != psycopg2.extensions.STATUS_READY
 
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index ca8c59297c..802069e1e1 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -123,6 +123,10 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]):
         """Gets a string giving the server version. For example: '3.22.0'."""
         return "%i.%i.%i" % sqlite3.sqlite_version_info
 
+    @property
+    def row_id_name(self) -> str:
+        return "rowid"
+
     def in_transaction(self, conn: sqlite3.Connection) -> bool:
         return conn.in_transaction
 
diff --git a/synapse/storage/schema/main/delta/48/group_unique_indexes.py b/synapse/storage/schema/main/delta/48/group_unique_indexes.py
index ad2da4c8af..622686d28f 100644
--- a/synapse/storage/schema/main/delta/48/group_unique_indexes.py
+++ b/synapse/storage/schema/main/delta/48/group_unique_indexes.py
@@ -14,7 +14,7 @@
 
 
 from synapse.storage.database import LoggingTransaction
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine
 from synapse.storage.prepare_database import get_statements
 
 FIX_INDEXES = """
@@ -37,7 +37,7 @@ CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
 
 
 def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None:
-    rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid"
+    rowid = database_engine.row_id_name
 
     # remove duplicates from group_users & group_invites tables
     cur.execute(
diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py
index 9e89aeb748..9b2581e51a 100644
--- a/synapse/util/task_scheduler.py
+++ b/synapse/util/task_scheduler.py
@@ -77,6 +77,7 @@ class TaskScheduler:
     LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000  # 24hrs
 
     def __init__(self, hs: "HomeServer"):
+        self._hs = hs
         self._store = hs.get_datastores().main
         self._clock = hs.get_clock()
         self._running_tasks: Set[str] = set()
@@ -97,8 +98,6 @@ class TaskScheduler:
                 "handle_scheduled_tasks",
                 self._handle_scheduled_tasks,
             )
-        else:
-            self.replication_client = hs.get_replication_command_handler()
 
     def register_action(
         self,
@@ -133,7 +132,7 @@ class TaskScheduler:
         params: Optional[JsonMapping] = None,
     ) -> str:
         """Schedule a new potentially resumable task. A function matching the specified
-        `action` should have been previously registered with `register_action`.
+        `action` should have be registered with `register_action` before the task is run.
 
         Args:
             action: the name of a previously registered action
@@ -149,11 +148,6 @@ class TaskScheduler:
         Returns:
             The id of the scheduled task
         """
-        if action not in self._actions:
-            raise Exception(
-                f"No function associated with action {action} of the scheduled task"
-            )
-
         status = TaskStatus.SCHEDULED
         if timestamp is None or timestamp < self._clock.time_msec():
             timestamp = self._clock.time_msec()
@@ -175,7 +169,7 @@ class TaskScheduler:
             if self._run_background_tasks:
                 await self._launch_task(task)
             else:
-                self.replication_client.send_new_active_task(task.id)
+                self._hs.get_replication_command_handler().send_new_active_task(task.id)
 
         return task.id
 
@@ -315,7 +309,10 @@ class TaskScheduler:
         """
         assert self._run_background_tasks
 
-        assert task.action in self._actions
+        if task.action not in self._actions:
+            raise Exception(
+                f"No function associated with action {task.action} of the scheduled task {task.id}"
+            )
         function = self._actions[task.action]
 
         async def wrapper() -> None:
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 55a4f95ef3..9659a4a355 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -30,6 +30,7 @@ from synapse.server import HomeServer
 from synapse.storage.databases.main.appservice import _make_exclusive_regex
 from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
+from synapse.util.task_scheduler import TaskScheduler
 
 from tests import unittest
 from tests.unittest import override_config
@@ -49,6 +50,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         assert isinstance(handler, DeviceHandler)
         self.handler = handler
         self.store = hs.get_datastores().main
+        self.device_message_handler = hs.get_device_message_handler()
         return hs
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -211,6 +213,51 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         )
         self.assertIsNone(res)
 
+    def test_delete_device_and_big_device_inbox(self) -> None:
+        """Check that deleting a big device inbox is staged and batched asynchronously."""
+        DEVICE_ID = "abc"
+        sender = "@sender:" + self.hs.hostname
+        receiver = "@receiver:" + self.hs.hostname
+        self._record_user(sender, DEVICE_ID, DEVICE_ID)
+        self._record_user(receiver, DEVICE_ID, DEVICE_ID)
+
+        # queue a bunch of messages in the inbox
+        requester = create_requester(sender, device_id=DEVICE_ID)
+        for i in range(0, DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT + 10):
+            self.get_success(
+                self.device_message_handler.send_device_message(
+                    requester, "message_type", {receiver: {"*": {"val": i}}}
+                )
+            )
+
+        # delete the device
+        self.get_success(self.handler.delete_devices(receiver, [DEVICE_ID]))
+
+        # messages should be deleted up to DEVICE_MSGS_DELETE_BATCH_LIMIT straight away
+        res = self.get_success(
+            self.store.db_pool.simple_select_list(
+                table="device_inbox",
+                keyvalues={"user_id": receiver},
+                retcols=("user_id", "device_id", "stream_id"),
+                desc="get_device_id_from_device_inbox",
+            )
+        )
+        self.assertEqual(10, len(res))
+
+        # wait for the task scheduler to do a second delete pass
+        self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)
+
+        # remaining messages should now be deleted
+        res = self.get_success(
+            self.store.db_pool.simple_select_list(
+                table="device_inbox",
+                keyvalues={"user_id": receiver},
+                retcols=("user_id", "device_id", "stream_id"),
+                desc="get_device_id_from_device_inbox",
+            )
+        )
+        self.assertEqual(0, len(res))
+
     def test_update_device(self) -> None:
         self._record_users()