summary refs log tree commit diff
path: root/synapse/storage/databases/main/event_push_actions.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/event_push_actions.py')
-rw-r--r--synapse/storage/databases/main/event_push_actions.py376
1 files changed, 192 insertions, 184 deletions
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index eabf9c9739..6b8668d2dc 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -98,6 +98,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.stream import StreamWorkerStore
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
@@ -232,6 +233,104 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             replaces_index="event_push_summary_user_rm",
         )
 
+        self.db_pool.updates.register_background_index_update(
+            "event_push_summary_unique_index2",
+            index_name="event_push_summary_unique_index2",
+            table="event_push_summary",
+            columns=["user_id", "room_id", "thread_id"],
+            unique=True,
+        )
+
+        self.db_pool.updates.register_background_update_handler(
+            "event_push_backfill_thread_id",
+            self._background_backfill_thread_id,
+        )
+
+    async def _background_backfill_thread_id(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """
+        Fill in the thread_id field for event_push_actions and event_push_summary.
+
+        This is preparatory so that it can be made non-nullable in the future.
+
+        Because all current (null) data is done in an unthreaded manner this
+        simply assumes it is on the "main" timeline. Since event_push_actions
+        are periodically cleared it is not possible to correctly re-calculate
+        the thread_id.
+        """
+        event_push_actions_done = progress.get("event_push_actions_done", False)
+
+        def add_thread_id_txn(
+            txn: LoggingTransaction, table_name: str, start_stream_ordering: int
+        ) -> int:
+            sql = f"""
+            SELECT stream_ordering
+            FROM {table_name}
+            WHERE
+                thread_id IS NULL
+                AND stream_ordering > ?
+            ORDER BY stream_ordering
+            LIMIT ?
+            """
+            txn.execute(sql, (start_stream_ordering, batch_size))
+
+            # No more rows to process.
+            rows = txn.fetchall()
+            if not rows:
+                progress[f"{table_name}_done"] = True
+                self.db_pool.updates._background_update_progress_txn(
+                    txn, "event_push_backfill_thread_id", progress
+                )
+                return 0
+
+            # Update the thread ID for any of those rows.
+            max_stream_ordering = rows[-1][0]
+
+            sql = f"""
+            UPDATE {table_name}
+            SET thread_id = 'main'
+            WHERE stream_ordering <= ? AND thread_id IS NULL
+            """
+            txn.execute(sql, (max_stream_ordering,))
+
+            # Update progress.
+            processed_rows = txn.rowcount
+            progress[f"max_{table_name}_stream_ordering"] = max_stream_ordering
+            self.db_pool.updates._background_update_progress_txn(
+                txn, "event_push_backfill_thread_id", progress
+            )
+
+            return processed_rows
+
+        # First update the event_push_actions table, then the event_push_summary table.
+        #
+        # Note that the event_push_actions_staging table is ignored since it is
+        # assumed that items in that table will only exist for a short period of
+        # time.
+        if not event_push_actions_done:
+            result = await self.db_pool.runInteraction(
+                "event_push_backfill_thread_id",
+                add_thread_id_txn,
+                "event_push_actions",
+                progress.get("max_event_push_actions_stream_ordering", 0),
+            )
+        else:
+            result = await self.db_pool.runInteraction(
+                "event_push_backfill_thread_id",
+                add_thread_id_txn,
+                "event_push_summary",
+                progress.get("max_event_push_summary_stream_ordering", 0),
+            )
+
+            # Only done after the event_push_summary table is done.
+            if not result:
+                await self.db_pool.updates._end_background_update(
+                    "event_push_backfill_thread_id"
+                )
+
+        return result
+
     @cached(tree=True, max_entries=5000)
     async def get_unread_event_push_actions_by_room_for_user(
         self,
@@ -274,7 +373,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             receipt_types=(
                 ReceiptTypes.READ,
                 ReceiptTypes.READ_PRIVATE,
-                ReceiptTypes.UNSTABLE_READ_PRIVATE,
             ),
         )
 
@@ -459,6 +557,31 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
 
+    def _get_receipts_by_room_txn(
+        self, txn: LoggingTransaction, user_id: str
+    ) -> List[Tuple[str, int]]:
+        receipt_types_clause, args = make_in_list_sql_clause(
+            self.database_engine,
+            "receipt_type",
+            (
+                ReceiptTypes.READ,
+                ReceiptTypes.READ_PRIVATE,
+            ),
+        )
+
+        sql = f"""
+            SELECT room_id, MAX(stream_ordering)
+            FROM receipts_linearized
+            INNER JOIN events USING (room_id, event_id)
+            WHERE {receipt_types_clause}
+            AND user_id = ?
+            GROUP BY room_id
+        """
+
+        args.extend((user_id,))
+        txn.execute(sql, args)
+        return cast(List[Tuple[str, int]], txn.fetchall())
+
     async def get_unread_push_actions_for_user_in_range_for_http(
         self,
         user_id: str,
@@ -482,106 +605,45 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will have between 0~limit entries.
         """
 
-        # find rooms that have a read receipt in them and return the next
-        # push actions
-        def get_after_receipt(
-            txn: LoggingTransaction,
-        ) -> List[Tuple[str, str, int, str, bool]]:
-            # find rooms that have a read receipt in them and return the next
-            # push actions
-
-            receipt_types_clause, args = make_in_list_sql_clause(
-                self.database_engine,
-                "receipt_type",
-                (
-                    ReceiptTypes.READ,
-                    ReceiptTypes.READ_PRIVATE,
-                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
-                ),
-            )
-
-            sql = f"""
-                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
-                    ep.highlight
-                FROM (
-                    SELECT room_id,
-                        MAX(stream_ordering) as stream_ordering
-                    FROM events
-                    INNER JOIN receipts_linearized USING (room_id, event_id)
-                    WHERE {receipt_types_clause} AND user_id = ?
-                    GROUP BY room_id
-                ) AS rl,
-                event_push_actions AS ep
-                WHERE
-                    ep.room_id = rl.room_id
-                    AND ep.stream_ordering > rl.stream_ordering
-                    AND ep.user_id = ?
-                    AND ep.stream_ordering > ?
-                    AND ep.stream_ordering <= ?
-                    AND ep.notif = 1
-                ORDER BY ep.stream_ordering ASC LIMIT ?
-            """
-            args.extend(
-                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
-            )
-            txn.execute(sql, args)
-            return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
-
-        after_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
+        receipts_by_room = dict(
+            await self.db_pool.runInteraction(
+                "get_unread_push_actions_for_user_in_range_http_receipts",
+                self._get_receipts_by_room_txn,
+                user_id=user_id,
+            ),
         )
 
-        # There are rooms with push actions in them but you don't have a read receipt in
-        # them e.g. rooms you've been invited to, so get push actions for rooms which do
-        # not have read receipts in them too.
-        def get_no_receipt(
+        def get_push_actions_txn(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool]]:
-            receipt_types_clause, args = make_in_list_sql_clause(
-                self.database_engine,
-                "receipt_type",
-                (
-                    ReceiptTypes.READ,
-                    ReceiptTypes.READ_PRIVATE,
-                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
-                ),
-            )
-
-            sql = f"""
-                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
-                    ep.highlight
+            sql = """
+                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight
                 FROM event_push_actions AS ep
-                INNER JOIN events AS e USING (room_id, event_id)
                 WHERE
-                    ep.room_id NOT IN (
-                        SELECT room_id FROM receipts_linearized
-                        WHERE {receipt_types_clause} AND user_id = ?
-                        GROUP BY room_id
-                    )
-                    AND ep.user_id = ?
+                    ep.user_id = ?
                     AND ep.stream_ordering > ?
                     AND ep.stream_ordering <= ?
                     AND ep.notif = 1
                 ORDER BY ep.stream_ordering ASC LIMIT ?
             """
-            args.extend(
-                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
-            )
-            txn.execute(sql, args)
+            txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
             return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
 
-        no_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
+        push_actions = await self.db_pool.runInteraction(
+            "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
         )
 
         notifs = [
             HttpPushAction(
-                event_id=row[0],
-                room_id=row[1],
-                stream_ordering=row[2],
-                actions=_deserialize_action(row[3], row[4]),
+                event_id=event_id,
+                room_id=room_id,
+                stream_ordering=stream_ordering,
+                actions=_deserialize_action(actions, highlight),
             )
-            for row in after_read_receipt + no_read_receipt
+            for event_id, room_id, stream_ordering, actions, highlight in push_actions
+            # Only include push actions with a stream ordering after any receipt, or without any
+            # receipt present (invited to but never read rooms).
+            if stream_ordering > receipts_by_room.get(room_id, 0)
         ]
 
         # Now sort it so it's ordered correctly, since currently it will
@@ -617,106 +679,49 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will have between 0~limit entries.
         """
 
-        # find rooms that have a read receipt in them and return the most recent
-        # push actions
-        def get_after_receipt(
-            txn: LoggingTransaction,
-        ) -> List[Tuple[str, str, int, str, bool, int]]:
-            receipt_types_clause, args = make_in_list_sql_clause(
-                self.database_engine,
-                "receipt_type",
-                (
-                    ReceiptTypes.READ,
-                    ReceiptTypes.READ_PRIVATE,
-                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
-                ),
-            )
-
-            sql = f"""
-                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
-                    ep.highlight, e.received_ts
-                FROM (
-                    SELECT room_id,
-                        MAX(stream_ordering) as stream_ordering
-                    FROM events
-                    INNER JOIN receipts_linearized USING (room_id, event_id)
-                    WHERE {receipt_types_clause} AND user_id = ?
-                    GROUP BY room_id
-                ) AS rl,
-                event_push_actions AS ep
-                INNER JOIN events AS e USING (room_id, event_id)
-                WHERE
-                    ep.room_id = rl.room_id
-                    AND ep.stream_ordering > rl.stream_ordering
-                    AND ep.user_id = ?
-                    AND ep.stream_ordering > ?
-                    AND ep.stream_ordering <= ?
-                    AND ep.notif = 1
-                ORDER BY ep.stream_ordering DESC LIMIT ?
-            """
-            args.extend(
-                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
-            )
-            txn.execute(sql, args)
-            return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
-
-        after_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
+        receipts_by_room = dict(
+            await self.db_pool.runInteraction(
+                "get_unread_push_actions_for_user_in_range_email_receipts",
+                self._get_receipts_by_room_txn,
+                user_id=user_id,
+            ),
         )
 
-        # There are rooms with push actions in them but you don't have a read receipt in
-        # them e.g. rooms you've been invited to, so get push actions for rooms which do
-        # not have read receipts in them too.
-        def get_no_receipt(
+        def get_push_actions_txn(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool, int]]:
-            receipt_types_clause, args = make_in_list_sql_clause(
-                self.database_engine,
-                "receipt_type",
-                (
-                    ReceiptTypes.READ,
-                    ReceiptTypes.READ_PRIVATE,
-                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
-                ),
-            )
-
-            sql = f"""
+            sql = """
                 SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
                     ep.highlight, e.received_ts
                 FROM event_push_actions AS ep
                 INNER JOIN events AS e USING (room_id, event_id)
                 WHERE
-                    ep.room_id NOT IN (
-                        SELECT room_id FROM receipts_linearized
-                        WHERE {receipt_types_clause} AND user_id = ?
-                        GROUP BY room_id
-                    )
-                    AND ep.user_id = ?
+                    ep.user_id = ?
                     AND ep.stream_ordering > ?
                     AND ep.stream_ordering <= ?
                     AND ep.notif = 1
                 ORDER BY ep.stream_ordering DESC LIMIT ?
             """
-            args.extend(
-                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
-            )
-            txn.execute(sql, args)
+            txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
             return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
 
-        no_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
+        push_actions = await self.db_pool.runInteraction(
+            "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
         )
 
         # Make a list of dicts from the two sets of results.
         notifs = [
             EmailPushAction(
-                event_id=row[0],
-                room_id=row[1],
-                stream_ordering=row[2],
-                actions=_deserialize_action(row[3], row[4]),
-                received_ts=row[5],
+                event_id=event_id,
+                room_id=room_id,
+                stream_ordering=stream_ordering,
+                actions=_deserialize_action(actions, highlight),
+                received_ts=received_ts,
             )
-            for row in after_read_receipt + no_read_receipt
+            for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions
+            # Only include push actions with a stream ordering after any receipt, or without any
+            # receipt present (invited to but never read rooms).
+            if stream_ordering > receipts_by_room.get(room_id, 0)
         ]
 
         # Now sort it so it's ordered correctly, since currently it will
@@ -764,6 +769,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         event_id: str,
         user_id_actions: Dict[str, Collection[Union[Mapping, str]]],
         count_as_unread: bool,
+        thread_id: str,
     ) -> None:
         """Add the push actions for the event to the push action staging area.
 
@@ -772,6 +778,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             user_id_actions: A mapping of user_id to list of push actions, where
                 an action can either be a string or dict.
             count_as_unread: Whether this event should increment unread counts.
+            thread_id: The thread this event is parent of, if applicable.
         """
         if not user_id_actions:
             return
@@ -780,7 +787,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         # can be used to insert into the `event_push_actions_staging` table.
         def _gen_entry(
             user_id: str, actions: Collection[Union[Mapping, str]]
-        ) -> Tuple[str, str, str, int, int, int]:
+        ) -> Tuple[str, str, str, int, int, int, str]:
             is_highlight = 1 if _action_has_highlight(actions) else 0
             notif = 1 if "notify" in actions else 0
             return (
@@ -790,28 +797,25 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 notif,  # notif column
                 is_highlight,  # highlight column
                 int(count_as_unread),  # unread column
+                thread_id,  # thread_id column
             )
 
-        def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
-            # We don't use simple_insert_many here to avoid the overhead
-            # of generating lists of dicts.
-
-            sql = """
-                INSERT INTO event_push_actions_staging
-                    (event_id, user_id, actions, notif, highlight, unread)
-                VALUES (?, ?, ?, ?, ?, ?)
-            """
-
-            txn.execute_batch(
-                sql,
-                (
-                    _gen_entry(user_id, actions)
-                    for user_id, actions in user_id_actions.items()
-                ),
-            )
-
-        return await self.db_pool.runInteraction(
-            "add_push_actions_to_staging", _add_push_actions_to_staging_txn
+        await self.db_pool.simple_insert_many(
+            "event_push_actions_staging",
+            keys=(
+                "event_id",
+                "user_id",
+                "actions",
+                "notif",
+                "highlight",
+                "unread",
+                "thread_id",
+            ),
+            values=[
+                _gen_entry(user_id, actions)
+                for user_id, actions in user_id_actions.items()
+            ],
+            desc="add_push_actions_to_staging",
         )
 
     async def remove_push_actions_from_staging(self, event_id: str) -> None:
@@ -1087,6 +1091,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             )
 
             # Replace the previous summary with the new counts.
+            #
+            # TODO(threads): Upsert per-thread instead of setting them all to main.
             self.db_pool.simple_upsert_txn(
                 txn,
                 table="event_push_summary",
@@ -1096,6 +1102,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                     "unread_count": unread_count,
                     "stream_ordering": old_rotate_stream_ordering,
                     "last_receipt_stream_ordering": stream_ordering,
+                    "thread_id": "main",
                 },
             )
 
@@ -1244,17 +1251,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         logger.info("Rotating notifications, handling %d rows", len(summaries))
 
+        # TODO(threads): Update on a per-thread basis.
         self.db_pool.simple_upsert_many_txn(
             txn,
             table="event_push_summary",
             key_names=("user_id", "room_id"),
             key_values=[(user_id, room_id) for user_id, room_id in summaries],
-            value_names=("notif_count", "unread_count", "stream_ordering"),
+            value_names=("notif_count", "unread_count", "stream_ordering", "thread_id"),
             value_values=[
                 (
                     summary.notif_count,
                     summary.unread_count,
                     summary.stream_ordering,
+                    "main",
                 )
                 for summary in summaries.values()
             ],
@@ -1361,7 +1370,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             table="event_push_actions",
             columns=["highlight", "stream_ordering"],
             where_clause="highlight=0",
-            psql_only=True,
         )
 
     async def get_push_actions_for_user(