summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-06-15 16:17:14 +0100
committerGitHub <noreply@github.com>2022-06-15 15:17:14 +0000
commit0d1d3e070886694eff1fa862cd203206b1a63372 (patch)
treee75ba818b9fbb9b39b53b649c850799b2634aeee /synapse/storage/databases
parentRename complement-developonly (#13046) (diff)
downloadsynapse-0d1d3e070886694eff1fa862cd203206b1a63372.tar.xz
Speed up `get_unread_event_push_actions_by_room` (#13005)
Fixes #11887 hopefully.

The core change here is that `event_push_summary` now holds a summary of counts up until a much more recent point, meaning that the range of rows we need to count in `event_push_actions` is much smaller.

This needs two major changes:
1. When we get a receipt we need to recalculate `event_push_summary` rather than just delete it
2. The logic for deleting `event_push_actions` is now divorced from calculating `event_push_summary`.

In future it would be good to calculate `event_push_summary` while we persist a new event (it should just be a case of adding one to the relevant rows in `event_push_summary`), as that will further simplify the get counts logic and remove the need for us to periodically update `event_push_summary` in a background job.
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/__init__.py4
-rw-r--r--synapse/storage/databases/main/event_push_actions.py258
-rw-r--r--synapse/storage/databases/main/push_rule.py2
-rw-r--r--synapse/storage/databases/main/receipts.py74
4 files changed, 233 insertions, 105 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 9121badb3a..cb3d1242bb 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -104,13 +104,14 @@ class DataStore(
     PusherStore,
     PushRuleStore,
     ApplicationServiceTransactionStore,
+    EventPushActionsStore,
+    ServerMetricsStore,
     ReceiptsStore,
     EndToEndKeyStore,
     EndToEndRoomKeyStore,
     SearchStore,
     TagsStore,
     AccountDataStore,
-    EventPushActionsStore,
     OpenIdStore,
     ClientIpWorkerStore,
     DeviceStore,
@@ -124,7 +125,6 @@ class DataStore(
     UIAuthStore,
     EventForwardExtremitiesStore,
     CacheInvalidationWorkerStore,
-    ServerMetricsStore,
     LockStore,
     SessionStore,
 ):
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index b019979350..ae705889a5 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
 
 import attr
 
+from synapse.api.constants import ReceiptTypes
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import (
@@ -24,6 +25,8 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
@@ -79,15 +82,15 @@ class UserPushAction(EmailPushAction):
     profile_tag: str
 
 
-@attr.s(slots=True, frozen=True, auto_attribs=True)
+@attr.s(slots=True, auto_attribs=True)
 class NotifCounts:
     """
     The per-user, per-room count of notifications. Used by sync and push.
     """
 
-    notify_count: int
-    unread_count: int
-    highlight_count: int
+    notify_count: int = 0
+    unread_count: int = 0
+    highlight_count: int = 0
 
 
 def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
@@ -119,7 +122,7 @@ def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, st
         return DEFAULT_NOTIF_ACTION
 
 
-class EventPushActionsWorkerStore(SQLBaseStore):
+class EventPushActionsWorkerStore(ReceiptsWorkerStore, EventsWorkerStore, SQLBaseStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -148,12 +151,20 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 self._rotate_notifs, 30 * 60 * 1000
             )
 
-    @cached(num_args=3, tree=True, max_entries=5000)
+        self.db_pool.updates.register_background_index_update(
+            "event_push_summary_unique_index",
+            index_name="event_push_summary_unique_index",
+            table="event_push_summary",
+            columns=["user_id", "room_id"],
+            unique=True,
+            replaces_index="event_push_summary_user_rm",
+        )
+
+    @cached(tree=True, max_entries=5000)
     async def get_unread_event_push_actions_by_room_for_user(
         self,
         room_id: str,
         user_id: str,
-        last_read_event_id: Optional[str],
     ) -> NotifCounts:
         """Get the notification count, the highlight count and the unread message count
         for a given user in a given room after the given read receipt.
@@ -165,8 +176,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         Args:
             room_id: The room to retrieve the counts in.
             user_id: The user to retrieve the counts for.
-            last_read_event_id: The event associated with the latest read receipt for
-                this user in this room. None if no receipt for this user in this room.
 
         Returns
             A dict containing the counts mentioned earlier in this docstring,
@@ -178,7 +187,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             self._get_unread_counts_by_receipt_txn,
             room_id,
             user_id,
-            last_read_event_id,
         )
 
     def _get_unread_counts_by_receipt_txn(
@@ -186,16 +194,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         txn: LoggingTransaction,
         room_id: str,
         user_id: str,
-        last_read_event_id: Optional[str],
     ) -> NotifCounts:
-        stream_ordering = None
+        result = self.get_last_receipt_for_user_txn(
+            txn,
+            user_id,
+            room_id,
+            receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
+        )
 
-        if last_read_event_id is not None:
-            stream_ordering = self.get_stream_id_for_event_txn(  # type: ignore[attr-defined]
-                txn,
-                last_read_event_id,
-                allow_none=True,
-            )
+        stream_ordering = None
+        if result:
+            _, stream_ordering = result
 
         if stream_ordering is None:
             # Either last_read_event_id is None, or it's an event we don't have (e.g.
@@ -218,49 +227,95 @@ class EventPushActionsWorkerStore(SQLBaseStore):
     def _get_unread_counts_by_pos_txn(
         self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
     ) -> NotifCounts:
-        sql = (
-            "SELECT"
-            "   COUNT(CASE WHEN notif = 1 THEN 1 END),"
-            "   COUNT(CASE WHEN highlight = 1 THEN 1 END),"
-            "   COUNT(CASE WHEN unread = 1 THEN 1 END)"
-            " FROM event_push_actions ea"
-            " WHERE user_id = ?"
-            "   AND room_id = ?"
-            "   AND stream_ordering > ?"
-        )
-
-        txn.execute(sql, (user_id, room_id, stream_ordering))
-        row = txn.fetchone()
-
-        (notif_count, highlight_count, unread_count) = (0, 0, 0)
+        """Get the number of unread messages for a user/room that have happened
+        since the given stream ordering.
+        """
 
-        if row:
-            (notif_count, highlight_count, unread_count) = row
+        counts = NotifCounts()
 
+        # First we pull the counts from the summary table
         txn.execute(
             """
-                SELECT notif_count, unread_count FROM event_push_summary
+                SELECT stream_ordering, notif_count, COALESCE(unread_count, 0)
+                FROM event_push_summary
                 WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
             """,
             (room_id, user_id, stream_ordering),
         )
         row = txn.fetchone()
 
+        summary_stream_ordering = 0
         if row:
-            notif_count += row[0]
-
-            if row[1] is not None:
-                # The unread_count column of event_push_summary is NULLable, so we need
-                # to make sure we don't try increasing the unread counts if it's NULL
-                # for this row.
-                unread_count += row[1]
-
-        return NotifCounts(
-            notify_count=notif_count,
-            unread_count=unread_count,
-            highlight_count=highlight_count,
+            summary_stream_ordering = row[0]
+            counts.notify_count += row[1]
+            counts.unread_count += row[2]
+
+        # Next we need to count highlights, which aren't summarized
+        sql = """
+            SELECT COUNT(*) FROM event_push_actions
+            WHERE user_id = ?
+                AND room_id = ?
+                AND stream_ordering > ?
+                AND highlight = 1
+        """
+        txn.execute(sql, (user_id, room_id, stream_ordering))
+        row = txn.fetchone()
+        if row:
+            counts.highlight_count += row[0]
+
+        # Finally we need to count push actions that haven't been summarized
+        # yet.
+        # We only want to pull out push actions that we haven't summarized yet.
+        stream_ordering = max(stream_ordering, summary_stream_ordering)
+        notify_count, unread_count = self._get_notif_unread_count_for_user_room(
+            txn, room_id, user_id, stream_ordering
         )
 
+        counts.notify_count += notify_count
+        counts.unread_count += unread_count
+
+        return counts
+
+    def _get_notif_unread_count_for_user_room(
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        user_id: str,
+        stream_ordering: int,
+        max_stream_ordering: Optional[int] = None,
+    ) -> Tuple[int, int]:
+        """Returns the notify and unread counts from `event_push_actions` for
+        the given user/room in the given range.
+
+        Does not consult `event_push_summary` table, which may include push
+        actions that have been deleted from `event_push_actions` table.
+        """
+
+        clause = ""
+        args = [user_id, room_id, stream_ordering]
+        if max_stream_ordering is not None:
+            clause = "AND ea.stream_ordering <= ?"
+            args.append(max_stream_ordering)
+
+        sql = f"""
+            SELECT
+               COUNT(CASE WHEN notif = 1 THEN 1 END),
+               COUNT(CASE WHEN unread = 1 THEN 1 END)
+             FROM event_push_actions ea
+             WHERE user_id = ?
+               AND room_id = ?
+               AND ea.stream_ordering > ?
+               {clause}
+        """
+
+        txn.execute(sql, args)
+        row = txn.fetchone()
+
+        if row:
+            return cast(Tuple[int, int], row)
+
+        return 0, 0
+
     async def get_push_action_users_in_range(
         self, min_stream_ordering: int, max_stream_ordering: int
     ) -> List[str]:
@@ -754,6 +809,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 if caught_up:
                     break
                 await self.hs.get_clock().sleep(self._rotate_delay)
+
+            await self._remove_old_push_actions_that_have_rotated()
         finally:
             self._doing_notif_rotation = False
 
@@ -782,20 +839,16 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         stream_row = txn.fetchone()
         if stream_row:
             (offset_stream_ordering,) = stream_row
-            assert self.stream_ordering_day_ago is not None
-            rotate_to_stream_ordering = min(
-                self.stream_ordering_day_ago, offset_stream_ordering
-            )
-            caught_up = offset_stream_ordering >= self.stream_ordering_day_ago
+            rotate_to_stream_ordering = offset_stream_ordering
+            caught_up = False
         else:
-            rotate_to_stream_ordering = self.stream_ordering_day_ago
+            rotate_to_stream_ordering = self._stream_id_gen.get_current_token()
             caught_up = True
 
         logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering)
 
         self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering)
 
-        # We have caught up iff we were limited by `stream_ordering_day_ago`
         return caught_up
 
     def _rotate_notifs_before_txn(
@@ -819,7 +872,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                     max(stream_ordering) as stream_ordering
                 FROM event_push_actions
                 WHERE ? <= stream_ordering AND stream_ordering < ?
-                    AND highlight = 0
                     AND %s = 1
                 GROUP BY user_id, room_id
             ) AS upd
@@ -915,18 +967,72 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
 
         txn.execute(
-            "DELETE FROM event_push_actions"
-            " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0",
-            (old_rotate_stream_ordering, rotate_to_stream_ordering),
+            "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?",
+            (rotate_to_stream_ordering,),
         )
 
-        logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
+    async def _remove_old_push_actions_that_have_rotated(
+        self,
+    ) -> None:
+        """Clear out old push actions that have been summarized."""
 
-        txn.execute(
-            "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?",
-            (rotate_to_stream_ordering,),
+        # We want to clear out anything that older than a day that *has* already
+        # been rotated.
+        rotated_upto_stream_ordering = await self.db_pool.simple_select_one_onecol(
+            table="event_push_summary_stream_ordering",
+            keyvalues={},
+            retcol="stream_ordering",
         )
 
+        max_stream_ordering_to_delete = min(
+            rotated_upto_stream_ordering, self.stream_ordering_day_ago
+        )
+
+        def remove_old_push_actions_that_have_rotated_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
+            # We don't want to clear out too much at a time, so we bound our
+            # deletes.
+            batch_size = 10000
+
+            txn.execute(
+                """
+                SELECT stream_ordering FROM event_push_actions
+                WHERE stream_ordering < ? AND highlight = 0
+                ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
+            """,
+                (
+                    max_stream_ordering_to_delete,
+                    batch_size,
+                ),
+            )
+            stream_row = txn.fetchone()
+
+            if stream_row:
+                (stream_ordering,) = stream_row
+            else:
+                stream_ordering = max_stream_ordering_to_delete
+
+            txn.execute(
+                """
+                DELETE FROM event_push_actions
+                WHERE stream_ordering < ? AND highlight = 0
+                """,
+                (stream_ordering,),
+            )
+
+            logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
+
+            return txn.rowcount < batch_size
+
+        while True:
+            done = await self.db_pool.runInteraction(
+                "_remove_old_push_actions_that_have_rotated",
+                remove_old_push_actions_that_have_rotated_txn,
+            )
+            if done:
+                break
+
     def _remove_old_push_actions_before_txn(
         self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
     ) -> None:
@@ -965,12 +1071,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
         )
 
-        txn.execute(
-            """
-            DELETE FROM event_push_summary
-            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
-        """,
-            (room_id, user_id, stream_ordering),
+        old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="event_push_summary_stream_ordering",
+            keyvalues={},
+            retcol="stream_ordering",
+        )
+
+        notif_count, unread_count = self._get_notif_unread_count_for_user_room(
+            txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
+        )
+
+        self.db_pool.simple_upsert_txn(
+            txn,
+            table="event_push_summary",
+            keyvalues={"room_id": room_id, "user_id": user_id},
+            values={
+                "notif_count": notif_count,
+                "unread_count": unread_count,
+                "stream_ordering": old_rotate_stream_ordering,
+            },
         )
 
 
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index d5aefe02b6..86649c1e6c 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -110,9 +110,9 @@ def _load_rules(
 # the abstract methods being implemented.
 class PushRulesWorkerStore(
     ApplicationServiceWorkerStore,
-    ReceiptsWorkerStore,
     PusherWorkerStore,
     RoomMemberWorkerStore,
+    ReceiptsWorkerStore,
     EventsWorkerStore,
     SQLBaseStore,
     metaclass=abc.ABCMeta,
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index b6106affa6..bec6d60577 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -118,7 +118,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         return self._receipts_id_gen.get_current_token()
 
     async def get_last_receipt_event_id_for_user(
-        self, user_id: str, room_id: str, receipt_types: Iterable[str]
+        self, user_id: str, room_id: str, receipt_types: Collection[str]
     ) -> Optional[str]:
         """
         Fetch the event ID for the latest receipt in a room with one of the given receipt types.
@@ -126,58 +126,63 @@ class ReceiptsWorkerStore(SQLBaseStore):
         Args:
             user_id: The user to fetch receipts for.
             room_id: The room ID to fetch the receipt for.
-            receipt_type: The receipt types to fetch. Earlier receipt types
-                are given priority if multiple receipts point to the same event.
+            receipt_type: The receipt types to fetch.
 
         Returns:
             The latest receipt, if one exists.
         """
-        latest_event_id: Optional[str] = None
-        latest_stream_ordering = 0
-        for receipt_type in receipt_types:
-            result = await self._get_last_receipt_event_id_for_user(
-                user_id, room_id, receipt_type
-            )
-            if result is None:
-                continue
-            event_id, stream_ordering = result
-
-            if latest_event_id is None or latest_stream_ordering < stream_ordering:
-                latest_event_id = event_id
-                latest_stream_ordering = stream_ordering
+        result = await self.db_pool.runInteraction(
+            "get_last_receipt_event_id_for_user",
+            self.get_last_receipt_for_user_txn,
+            user_id,
+            room_id,
+            receipt_types,
+        )
+        if not result:
+            return None
 
-        return latest_event_id
+        event_id, _ = result
+        return event_id
 
-    @cached()
-    async def _get_last_receipt_event_id_for_user(
-        self, user_id: str, room_id: str, receipt_type: str
+    def get_last_receipt_for_user_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        room_id: str,
+        receipt_types: Collection[str],
     ) -> Optional[Tuple[str, int]]:
         """
-        Fetch the event ID and stream ordering for the latest receipt.
+        Fetch the event ID and stream_ordering for the latest receipt in a room
+        with one of the given receipt types.
 
         Args:
             user_id: The user to fetch receipts for.
             room_id: The room ID to fetch the receipt for.
-            receipt_type: The receipt type to fetch.
+            receipt_type: The receipt types to fetch.
 
         Returns:
-            The event ID and stream ordering of the latest receipt, if one exists;
-            otherwise `None`.
+            The latest receipt, if one exists.
         """
-        sql = """
+
+        clause, args = make_in_list_sql_clause(
+            self.database_engine, "receipt_type", receipt_types
+        )
+
+        sql = f"""
             SELECT event_id, stream_ordering
             FROM receipts_linearized
             INNER JOIN events USING (room_id, event_id)
-            WHERE user_id = ?
+            WHERE {clause}
+            AND user_id = ?
             AND room_id = ?
-            AND receipt_type = ?
+            ORDER BY stream_ordering DESC
+            LIMIT 1
         """
 
-        def f(txn: LoggingTransaction) -> Optional[Tuple[str, int]]:
-            txn.execute(sql, (user_id, room_id, receipt_type))
-            return cast(Optional[Tuple[str, int]], txn.fetchone())
+        args.extend((user_id, room_id))
+        txn.execute(sql, args)
 
-        return await self.db_pool.runInteraction("get_own_receipt_for_user", f)
+        return cast(Optional[Tuple[str, int]], txn.fetchone())
 
     async def get_receipts_for_user(
         self, user_id: str, receipt_types: Iterable[str]
@@ -577,8 +582,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
     ) -> None:
         self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type))
         self._get_linearized_receipts_for_room.invalidate((room_id,))
-        self._get_last_receipt_event_id_for_user.invalidate(
-            (user_id, room_id, receipt_type)
+
+        # We use this method to invalidate so that we don't end up with circular
+        # dependencies between the receipts and push action stores.
+        self._attempt_to_invalidate_cache(
+            "get_unread_event_push_actions_by_room_for_user", (room_id,)
         )
 
     def process_replication_rows(