summary refs log tree commit diff
path: root/synapse/storage/databases/main/deviceinbox.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/deviceinbox.py')
-rw-r--r--synapse/storage/databases/main/deviceinbox.py189
1 files changed, 186 insertions, 3 deletions
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 8143168107..264e625bd7 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -19,9 +19,10 @@ from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.replication.tcp.streams import ToDeviceStream
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -488,10 +489,12 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             devices = list(messages_by_device.keys())
             if len(devices) == 1 and devices[0] == "*":
                 # Handle wildcard device_ids.
+                # We exclude hidden devices (such as cross-signing keys) here as they are
+                # not expected to receive to-device messages.
                 devices = self.db_pool.simple_select_onecol_txn(
                     txn,
                     table="devices",
-                    keyvalues={"user_id": user_id},
+                    keyvalues={"user_id": user_id, "hidden": False},
                     retcol="device_id",
                 )
 
@@ -504,10 +507,12 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 if not devices:
                     continue
 
+                # We exclude hidden devices (such as cross-signing keys) here as they are
+                # not expected to receive to-device messages.
                 rows = self.db_pool.simple_select_many_txn(
                     txn,
                     table="devices",
-                    keyvalues={"user_id": user_id},
+                    keyvalues={"user_id": user_id, "hidden": False},
                     column="device_id",
                     iterable=devices,
                     retcols=("device_id",),
@@ -555,6 +560,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
 class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
+    REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox"
+    REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
 
     def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
@@ -570,6 +577,16 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
             self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            self.REMOVE_DELETED_DEVICES,
+            self._remove_deleted_devices_from_device_inbox,
+        )
+
+        self.db_pool.updates.register_background_update_handler(
+            self.REMOVE_HIDDEN_DEVICES,
+            self._remove_hidden_devices_from_device_inbox,
+        )
+
     async def _background_drop_index_device_inbox(self, progress, batch_size):
         def reindex_txn(conn):
             txn = conn.cursor()
@@ -582,6 +599,172 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
 
         return 1
 
+    async def _remove_deleted_devices_from_device_inbox(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """A background update that deletes all device_inboxes for deleted devices.
+
+        This should only need to be run once (when users upgrade to v1.47.0)
+
+        Args:
+            progress: JsonDict used to store progress of this background update
+            batch_size: the maximum number of rows to retrieve in a single select query
+
+        Returns:
+            The number of deleted rows
+        """
+
+        def _remove_deleted_devices_from_device_inbox_txn(
+            txn: LoggingTransaction,
+        ) -> int:
+            """stream_id is not unique
+            we need to use an inclusive `stream_id >= ?` clause,
+            since we might not have deleted all dead device messages for the stream_id
+            returned from the previous query
+
+            Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
+            to avoid problems of deleting a large number of rows all at once
+            due to a single device having lots of device messages.
+            """
+
+            last_stream_id = progress.get("stream_id", 0)
+
+            sql = """
+                SELECT device_id, user_id, stream_id
+                FROM device_inbox
+                WHERE
+                    stream_id >= ?
+                    AND (device_id, user_id) NOT IN (
+                        SELECT device_id, user_id FROM devices
+                    )
+                ORDER BY stream_id
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_stream_id, batch_size))
+            rows = txn.fetchall()
+
+            num_deleted = 0
+            for row in rows:
+                num_deleted += self.db_pool.simple_delete_txn(
+                    txn,
+                    "device_inbox",
+                    {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
+                )
+
+            if rows:
+                # send more than stream_id to progress
+                # otherwise it can happen in large deployments that
+                # no change of status is visible in the log file
+                # it may be that the stream_id does not change in several runs
+                self.db_pool.updates._background_update_progress_txn(
+                    txn,
+                    self.REMOVE_DELETED_DEVICES,
+                    {
+                        "device_id": rows[-1][0],
+                        "user_id": rows[-1][1],
+                        "stream_id": rows[-1][2],
+                    },
+                )
+
+            return num_deleted
+
+        number_deleted = await self.db_pool.runInteraction(
+            "_remove_deleted_devices_from_device_inbox",
+            _remove_deleted_devices_from_device_inbox_txn,
+        )
+
+        # The task is finished when no more lines are deleted.
+        if not number_deleted:
+            await self.db_pool.updates._end_background_update(
+                self.REMOVE_DELETED_DEVICES
+            )
+
+        return number_deleted
+
+    async def _remove_hidden_devices_from_device_inbox(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """A background update that deletes all device_inboxes for hidden devices.
+
+        This should only need to be run once (when users upgrade to v1.47.0)
+
+        Args:
+            progress: JsonDict used to store progress of this background update
+            batch_size: the maximum number of rows to retrieve in a single select query
+
+        Returns:
+            The number of deleted rows
+        """
+
+        def _remove_hidden_devices_from_device_inbox_txn(
+            txn: LoggingTransaction,
+        ) -> int:
+            """stream_id is not unique
+            we need to use an inclusive `stream_id >= ?` clause,
+            since we might not have deleted all hidden device messages for the stream_id
+            returned from the previous query
+
+            Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
+            to avoid problems of deleting a large number of rows all at once
+            due to a single device having lots of device messages.
+            """
+
+            last_stream_id = progress.get("stream_id", 0)
+
+            sql = """
+                SELECT device_id, user_id, stream_id
+                FROM device_inbox
+                WHERE
+                    stream_id >= ?
+                    AND (device_id, user_id) IN (
+                        SELECT device_id, user_id FROM devices WHERE hidden = ?
+                    )
+                ORDER BY stream_id
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_stream_id, True, batch_size))
+            rows = txn.fetchall()
+
+            num_deleted = 0
+            for row in rows:
+                num_deleted += self.db_pool.simple_delete_txn(
+                    txn,
+                    "device_inbox",
+                    {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
+                )
+
+            if rows:
+                # We don't just save the `stream_id` in progress as
+                # otherwise it can happen in large deployments that
+                # no change of status is visible in the log file, as
+                # it may be that the stream_id does not change in several runs
+                self.db_pool.updates._background_update_progress_txn(
+                    txn,
+                    self.REMOVE_HIDDEN_DEVICES,
+                    {
+                        "device_id": rows[-1][0],
+                        "user_id": rows[-1][1],
+                        "stream_id": rows[-1][2],
+                    },
+                )
+
+            return num_deleted
+
+        number_deleted = await self.db_pool.runInteraction(
+            "_remove_hidden_devices_from_device_inbox",
+            _remove_hidden_devices_from_device_inbox_txn,
+        )
+
+        # The task is finished when no more lines are deleted.
+        if not number_deleted:
+            await self.db_pool.updates._end_background_update(
+                self.REMOVE_HIDDEN_DEVICES
+            )
+
+        return number_deleted
+
 
 class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
     pass