summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/event_push_actions.py240
1 files changed, 240 insertions, 0 deletions
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 6afc51320a..eeccf5db24 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -100,6 +100,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.stream import StreamWorkerStore
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
@@ -288,6 +289,180 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             unique=True,
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            "event_push_backfill_thread_id",
+            self._background_backfill_thread_id,
+        )
+
+        # Indexes which will be used to quickly make the thread_id column non-null.
+        self.db_pool.updates.register_background_index_update(
+            "event_push_actions_thread_id_null",
+            index_name="event_push_actions_thread_id_null",
+            table="event_push_actions",
+            columns=["thread_id"],
+            where_clause="thread_id IS NULL",
+        )
+        self.db_pool.updates.register_background_index_update(
+            "event_push_summary_thread_id_null",
+            index_name="event_push_summary_thread_id_null",
+            table="event_push_summary",
+            columns=["thread_id"],
+            where_clause="thread_id IS NULL",
+        )
+
+        # Check ASAP (and then later, every 1s) to see if we have finished
+        # background updates the event_push_actions and event_push_summary tables.
+        self._clock.call_later(0.0, self._check_event_push_backfill_thread_id)
+        self._event_push_backfill_thread_id_done = False
+
+    @wrap_as_background_process("check_event_push_backfill_thread_id")
+    async def _check_event_push_backfill_thread_id(self) -> None:
+        """
+        Has thread_id finished backfilling?
+
+        If not, we need to just-in-time update it so the queries work.
+        """
+        done = await self.db_pool.updates.has_completed_background_update(
+            "event_push_backfill_thread_id"
+        )
+
+        if done:
+            self._event_push_backfill_thread_id_done = True
+        else:
+            # Reschedule to run.
+            self._clock.call_later(15.0, self._check_event_push_backfill_thread_id)
+
+    async def _background_backfill_thread_id(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """
+        Fill in the thread_id field for event_push_actions and event_push_summary.
+
+        This is preparatory so that it can be made non-nullable in the future.
+
+        Because all current (null) data is done in an unthreaded manner this
+        simply assumes it is on the "main" timeline. Since event_push_actions
+        are periodically cleared it is not possible to correctly re-calculate
+        the thread_id.
+        """
+        event_push_actions_done = progress.get("event_push_actions_done", False)
+
+        def add_thread_id_txn(
+            txn: LoggingTransaction, start_stream_ordering: int
+        ) -> int:
+            sql = """
+            SELECT stream_ordering
+            FROM event_push_actions
+            WHERE
+                thread_id IS NULL
+                AND stream_ordering > ?
+            ORDER BY stream_ordering
+            LIMIT ?
+            """
+            txn.execute(sql, (start_stream_ordering, batch_size))
+
+            # No more rows to process.
+            rows = txn.fetchall()
+            if not rows:
+                progress["event_push_actions_done"] = True
+                self.db_pool.updates._background_update_progress_txn(
+                    txn, "event_push_backfill_thread_id", progress
+                )
+                return 0
+
+            # Update the thread ID for any of those rows.
+            max_stream_ordering = rows[-1][0]
+
+            sql = """
+            UPDATE event_push_actions
+            SET thread_id = 'main'
+            WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL
+            """
+            txn.execute(
+                sql,
+                (
+                    start_stream_ordering,
+                    max_stream_ordering,
+                ),
+            )
+
+            # Update progress.
+            processed_rows = txn.rowcount
+            progress["max_event_push_actions_stream_ordering"] = max_stream_ordering
+            self.db_pool.updates._background_update_progress_txn(
+                txn, "event_push_backfill_thread_id", progress
+            )
+
+            return processed_rows
+
+        def add_thread_id_summary_txn(txn: LoggingTransaction) -> int:
+            min_user_id = progress.get("max_summary_user_id", "")
+            min_room_id = progress.get("max_summary_room_id", "")
+
+            # Slightly overcomplicated query for getting the Nth user ID / room
+            # ID tuple, or the last if there are less than N remaining.
+            sql = """
+            SELECT user_id, room_id FROM (
+                SELECT user_id, room_id FROM event_push_summary
+                WHERE (user_id, room_id) > (?, ?)
+                    AND thread_id IS NULL
+                ORDER BY user_id, room_id
+                LIMIT ?
+            ) AS e
+            ORDER BY user_id DESC, room_id DESC
+            LIMIT 1
+            """
+
+            txn.execute(sql, (min_user_id, min_room_id, batch_size))
+            row = txn.fetchone()
+            if not row:
+                return 0
+
+            max_user_id, max_room_id = row
+
+            sql = """
+            UPDATE event_push_summary
+            SET thread_id = 'main'
+            WHERE
+                (?, ?) < (user_id, room_id) AND (user_id, room_id) <= (?, ?)
+                AND thread_id IS NULL
+            """
+            txn.execute(sql, (min_user_id, min_room_id, max_user_id, max_room_id))
+            processed_rows = txn.rowcount
+
+            progress["max_summary_user_id"] = max_user_id
+            progress["max_summary_room_id"] = max_room_id
+            self.db_pool.updates._background_update_progress_txn(
+                txn, "event_push_backfill_thread_id", progress
+            )
+
+            return processed_rows
+
+        # First update the event_push_actions table, then the event_push_summary table.
+        #
+        # Note that the event_push_actions_staging table is ignored since it is
+        # assumed that items in that table will only exist for a short period of
+        # time.
+        if not event_push_actions_done:
+            result = await self.db_pool.runInteraction(
+                "event_push_backfill_thread_id",
+                add_thread_id_txn,
+                progress.get("max_event_push_actions_stream_ordering", 0),
+            )
+        else:
+            result = await self.db_pool.runInteraction(
+                "event_push_backfill_thread_id",
+                add_thread_id_summary_txn,
+            )
+
+            # Only done after the event_push_summary table is done.
+            if not result:
+                await self.db_pool.updates._end_background_update(
+                    "event_push_backfill_thread_id"
+                )
+
+        return result
+
     async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]:
         """Get the notification count by room for a user. Only considers notifications,
         not highlight or unread counts, and threads are currently aggregated under their room.
@@ -536,6 +711,25 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
         )
 
+        # First ensure that the existing rows have an updated thread_id field.
+        if not self._event_push_backfill_thread_id_done:
+            txn.execute(
+                """
+                UPDATE event_push_summary
+                SET thread_id = ?
+                WHERE room_id = ? AND user_id = ? AND thread_id is NULL
+                """,
+                (MAIN_TIMELINE, room_id, user_id),
+            )
+            txn.execute(
+                """
+                UPDATE event_push_actions
+                SET thread_id = ?
+                WHERE room_id = ? AND user_id = ? AND thread_id is NULL
+                """,
+                (MAIN_TIMELINE, room_id, user_id),
+            )
+
         # First we pull the counts from the summary table.
         #
         # We check that `last_receipt_stream_ordering` matches the stream ordering of the
@@ -1351,6 +1545,25 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 (room_id, user_id, stream_ordering, *thread_args),
             )
 
+            # First ensure that the existing rows have an updated thread_id field.
+            if not self._event_push_backfill_thread_id_done:
+                txn.execute(
+                    """
+                    UPDATE event_push_summary
+                    SET thread_id = ?
+                    WHERE room_id = ? AND user_id = ? AND thread_id is NULL
+                    """,
+                    (MAIN_TIMELINE, room_id, user_id),
+                )
+                txn.execute(
+                    """
+                    UPDATE event_push_actions
+                    SET thread_id = ?
+                    WHERE room_id = ? AND user_id = ? AND thread_id is NULL
+                    """,
+                    (MAIN_TIMELINE, room_id, user_id),
+                )
+
             # Fetch the notification counts between the stream ordering of the
             # latest receipt and what was previously summarised.
             unread_counts = self._get_notif_unread_count_for_user_room(
@@ -1485,6 +1698,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             rotate_to_stream_ordering: The new maximum event stream ordering to summarise.
         """
 
+        # Ensure that any new actions have an updated thread_id.
+        if not self._event_push_backfill_thread_id_done:
+            txn.execute(
+                """
+                UPDATE event_push_actions
+                SET thread_id = ?
+                WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL
+                """,
+                (MAIN_TIMELINE, old_rotate_stream_ordering, rotate_to_stream_ordering),
+            )
+
+        # XXX Do we need to update summaries here too?
+
         # Calculate the new counts that should be upserted into event_push_summary
         sql = """
             SELECT user_id, room_id, thread_id,
@@ -1547,6 +1773,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         logger.info("Rotating notifications, handling %d rows", len(summaries))
 
+        # Ensure that any updated threads have the proper thread_id.
+        if not self._event_push_backfill_thread_id_done:
+            txn.execute_batch(
+                """
+                UPDATE event_push_summary
+                SET thread_id = ?
+                WHERE room_id = ? AND user_id = ? AND thread_id is NULL
+                """,
+                [
+                    (MAIN_TIMELINE, room_id, user_id)
+                    for user_id, room_id, _ in summaries
+                ],
+            )
+
         self.db_pool.simple_upsert_many_txn(
             txn,
             table="event_push_summary",