summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-05-26 13:16:08 -0400
committerGitHub <noreply@github.com>2023-05-26 13:16:08 -0400
commit2ad91ec628126753590c1a90c432270d6c8fa8fd (patch)
treed437ff10f3a4c604719146cb4b1e201dec0095d1 /synapse/storage/databases
parentMerge branch 'master' into develop (diff)
downloadsynapse-2ad91ec628126753590c1a90c432270d6c8fa8fd.tar.xz
Set thread_id column to non-null for event_push_{actions,actions_staging,summary} (#15597)
Updates the database schema to require a thread_id (by adding a
constraint that the column is non-null) for event_push_actions,
event_push_actions_staging, and event_push_actions_summary.

For PostgreSQL we add the constraint as NOT VALID, then
VALIDATE the constraint a background job to avoid locking
the table during an upgrade.

Each table is updated as a separate schema delta to avoid
deadlocks between them.

For SQLite we simply rebuild the table & copy the data.
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/event_push_actions.py254
1 files changed, 31 insertions, 223 deletions
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 6fdb1e292e..07bda7d6be 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -289,179 +289,52 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             unique=True,
         )
 
-        self.db_pool.updates.register_background_update_handler(
-            "event_push_backfill_thread_id",
-            self._background_backfill_thread_id,
+        self.db_pool.updates.register_background_validate_constraint(
+            "event_push_actions_staging_thread_id",
+            constraint_name="event_push_actions_staging_thread_id",
+            table="event_push_actions_staging",
         )
-
-        # Indexes which will be used to quickly make the thread_id column non-null.
-        self.db_pool.updates.register_background_index_update(
-            "event_push_actions_thread_id_null",
-            index_name="event_push_actions_thread_id_null",
+        self.db_pool.updates.register_background_validate_constraint(
+            "event_push_actions_thread_id",
+            constraint_name="event_push_actions_thread_id",
             table="event_push_actions",
-            columns=["thread_id"],
-            where_clause="thread_id IS NULL",
         )
-        self.db_pool.updates.register_background_index_update(
-            "event_push_summary_thread_id_null",
-            index_name="event_push_summary_thread_id_null",
+        self.db_pool.updates.register_background_validate_constraint(
+            "event_push_summary_thread_id",
+            constraint_name="event_push_summary_thread_id",
             table="event_push_summary",
-            columns=["thread_id"],
-            where_clause="thread_id IS NULL",
         )
 
-        # Check ASAP (and then later, every 1s) to see if we have finished
-        # background updates the event_push_actions and event_push_summary tables.
-        self._clock.call_later(0.0, self._check_event_push_backfill_thread_id)
-        self._event_push_backfill_thread_id_done = False
-
-    @wrap_as_background_process("check_event_push_backfill_thread_id")
-    async def _check_event_push_backfill_thread_id(self) -> None:
-        """
-        Has thread_id finished backfilling?
-
-        If not, we need to just-in-time update it so the queries work.
-        """
-        done = await self.db_pool.updates.has_completed_background_update(
-            "event_push_backfill_thread_id"
+        self.db_pool.updates.register_background_update_handler(
+            "event_push_drop_null_thread_id_indexes",
+            self._background_drop_null_thread_id_indexes,
         )
 
-        if done:
-            self._event_push_backfill_thread_id_done = True
-        else:
-            # Reschedule to run.
-            self._clock.call_later(15.0, self._check_event_push_backfill_thread_id)
-
-    async def _background_backfill_thread_id(
+    async def _background_drop_null_thread_id_indexes(
         self, progress: JsonDict, batch_size: int
     ) -> int:
         """
-        Fill in the thread_id field for event_push_actions and event_push_summary.
-
-        This is preparatory so that it can be made non-nullable in the future.
-
-        Because all current (null) data is done in an unthreaded manner this
-        simply assumes it is on the "main" timeline. Since event_push_actions
-        are periodically cleared it is not possible to correctly re-calculate
-        the thread_id.
+        Drop the indexes used to find null thread_ids for event_push_actions and
+        event_push_summary.
         """
-        event_push_actions_done = progress.get("event_push_actions_done", False)
 
-        def add_thread_id_txn(
-            txn: LoggingTransaction, start_stream_ordering: int
-        ) -> int:
-            sql = """
-            SELECT stream_ordering
-            FROM event_push_actions
-            WHERE
-                thread_id IS NULL
-                AND stream_ordering > ?
-            ORDER BY stream_ordering
-            LIMIT ?
-            """
-            txn.execute(sql, (start_stream_ordering, batch_size))
-
-            # No more rows to process.
-            rows = txn.fetchall()
-            if not rows:
-                progress["event_push_actions_done"] = True
-                self.db_pool.updates._background_update_progress_txn(
-                    txn, "event_push_backfill_thread_id", progress
-                )
-                return 0
+        def drop_null_thread_id_indexes_txn(txn: LoggingTransaction) -> None:
+            sql = "DROP INDEX IF EXISTS event_push_actions_thread_id_null"
+            logger.debug("[SQL] %s", sql)
+            txn.execute(sql)
 
-            # Update the thread ID for any of those rows.
-            max_stream_ordering = rows[-1][0]
+            sql = "DROP INDEX IF EXISTS event_push_summary_thread_id_null"
+            logger.debug("[SQL] %s", sql)
+            txn.execute(sql)
 
-            sql = """
-            UPDATE event_push_actions
-            SET thread_id = 'main'
-            WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL
-            """
-            txn.execute(
-                sql,
-                (
-                    start_stream_ordering,
-                    max_stream_ordering,
-                ),
-            )
-
-            # Update progress.
-            processed_rows = txn.rowcount
-            progress["max_event_push_actions_stream_ordering"] = max_stream_ordering
-            self.db_pool.updates._background_update_progress_txn(
-                txn, "event_push_backfill_thread_id", progress
-            )
-
-            return processed_rows
-
-        def add_thread_id_summary_txn(txn: LoggingTransaction) -> int:
-            min_user_id = progress.get("max_summary_user_id", "")
-            min_room_id = progress.get("max_summary_room_id", "")
-
-            # Slightly overcomplicated query for getting the Nth user ID / room
-            # ID tuple, or the last if there are less than N remaining.
-            sql = """
-            SELECT user_id, room_id FROM (
-                SELECT user_id, room_id FROM event_push_summary
-                WHERE (user_id, room_id) > (?, ?)
-                    AND thread_id IS NULL
-                ORDER BY user_id, room_id
-                LIMIT ?
-            ) AS e
-            ORDER BY user_id DESC, room_id DESC
-            LIMIT 1
-            """
-
-            txn.execute(sql, (min_user_id, min_room_id, batch_size))
-            row = txn.fetchone()
-            if not row:
-                return 0
-
-            max_user_id, max_room_id = row
-
-            sql = """
-            UPDATE event_push_summary
-            SET thread_id = 'main'
-            WHERE
-                (?, ?) < (user_id, room_id) AND (user_id, room_id) <= (?, ?)
-                AND thread_id IS NULL
-            """
-            txn.execute(sql, (min_user_id, min_room_id, max_user_id, max_room_id))
-            processed_rows = txn.rowcount
-
-            progress["max_summary_user_id"] = max_user_id
-            progress["max_summary_room_id"] = max_room_id
-            self.db_pool.updates._background_update_progress_txn(
-                txn, "event_push_backfill_thread_id", progress
-            )
-
-            return processed_rows
-
-        # First update the event_push_actions table, then the event_push_summary table.
-        #
-        # Note that the event_push_actions_staging table is ignored since it is
-        # assumed that items in that table will only exist for a short period of
-        # time.
-        if not event_push_actions_done:
-            result = await self.db_pool.runInteraction(
-                "event_push_backfill_thread_id",
-                add_thread_id_txn,
-                progress.get("max_event_push_actions_stream_ordering", 0),
-            )
-        else:
-            result = await self.db_pool.runInteraction(
-                "event_push_backfill_thread_id",
-                add_thread_id_summary_txn,
-            )
-
-            # Only done after the event_push_summary table is done.
-            if not result:
-                await self.db_pool.updates._end_background_update(
-                    "event_push_backfill_thread_id"
-                )
-
-        return result
+        await self.db_pool.runInteraction(
+            "drop_null_thread_id_indexes_txn",
+            drop_null_thread_id_indexes_txn,
+        )
+        await self.db_pool.updates._end_background_update(
+            "event_push_drop_null_thread_id_indexes"
+        )
+        return 0
 
     async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]:
         """Get the notification count by room for a user. Only considers notifications,
@@ -711,25 +584,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
         )
 
-        # First ensure that the existing rows have an updated thread_id field.
-        if not self._event_push_backfill_thread_id_done:
-            txn.execute(
-                """
-                UPDATE event_push_summary
-                SET thread_id = ?
-                WHERE room_id = ? AND user_id = ? AND thread_id is NULL
-                """,
-                (MAIN_TIMELINE, room_id, user_id),
-            )
-            txn.execute(
-                """
-                UPDATE event_push_actions
-                SET thread_id = ?
-                WHERE room_id = ? AND user_id = ? AND thread_id is NULL
-                """,
-                (MAIN_TIMELINE, room_id, user_id),
-            )
-
         # First we pull the counts from the summary table.
         #
         # We check that `last_receipt_stream_ordering` matches the stream ordering of the
@@ -1545,25 +1399,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 (room_id, user_id, stream_ordering, *thread_args),
             )
 
-            # First ensure that the existing rows have an updated thread_id field.
-            if not self._event_push_backfill_thread_id_done:
-                txn.execute(
-                    """
-                    UPDATE event_push_summary
-                    SET thread_id = ?
-                    WHERE room_id = ? AND user_id = ? AND thread_id is NULL
-                    """,
-                    (MAIN_TIMELINE, room_id, user_id),
-                )
-                txn.execute(
-                    """
-                    UPDATE event_push_actions
-                    SET thread_id = ?
-                    WHERE room_id = ? AND user_id = ? AND thread_id is NULL
-                    """,
-                    (MAIN_TIMELINE, room_id, user_id),
-                )
-
             # Fetch the notification counts between the stream ordering of the
             # latest receipt and what was previously summarised.
             unread_counts = self._get_notif_unread_count_for_user_room(
@@ -1698,19 +1533,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             rotate_to_stream_ordering: The new maximum event stream ordering to summarise.
         """
 
-        # Ensure that any new actions have an updated thread_id.
-        if not self._event_push_backfill_thread_id_done:
-            txn.execute(
-                """
-                UPDATE event_push_actions
-                SET thread_id = ?
-                WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL
-                """,
-                (MAIN_TIMELINE, old_rotate_stream_ordering, rotate_to_stream_ordering),
-            )
-
-        # XXX Do we need to update summaries here too?
-
         # Calculate the new counts that should be upserted into event_push_summary
         sql = """
             SELECT user_id, room_id, thread_id,
@@ -1773,20 +1595,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         logger.info("Rotating notifications, handling %d rows", len(summaries))
 
-        # Ensure that any updated threads have the proper thread_id.
-        if not self._event_push_backfill_thread_id_done:
-            txn.execute_batch(
-                """
-                UPDATE event_push_summary
-                SET thread_id = ?
-                WHERE room_id = ? AND user_id = ? AND thread_id is NULL
-                """,
-                [
-                    (MAIN_TIMELINE, room_id, user_id)
-                    for user_id, room_id, _ in summaries
-                ],
-            )
-
         self.db_pool.simple_upsert_many_txn(
             txn,
             table="event_push_summary",