summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-06-15 16:17:14 +0100
committerGitHub <noreply@github.com>2022-06-15 15:17:14 +0000
commit0d1d3e070886694eff1fa862cd203206b1a63372 (patch)
treee75ba818b9fbb9b39b53b649c850799b2634aeee
parentRename complement-developonly (#13046) (diff)
downloadsynapse-0d1d3e070886694eff1fa862cd203206b1a63372.tar.xz
Speed up `get_unread_event_push_actions_by_room` (#13005)
Fixes #11887 hopefully.

The core change here is that `event_push_summary` now holds a summary of counts up until a much more recent point, meaning that the range of rows we need to count in `event_push_actions` is much smaller.

This needs two major changes:
1. When we get a receipt we need to recalculate `event_push_summary` rather than just delete it
2. The logic for deleting `event_push_actions` is now divorced from calculating `event_push_summary`.

In future it would be good to calculate `event_push_summary` while we persist a new event (it should just be a case of adding one to the relevant rows in `event_push_summary`), as that will further simplify the get counts logic and remove the need for us to periodically update `event_push_summary` in a background job.
-rw-r--r--changelog.d/13005.misc1
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py4
-rw-r--r--synapse/handlers/sync.py10
-rw-r--r--synapse/push/push_tools.py33
-rw-r--r--synapse/storage/database.py1
-rw-r--r--synapse/storage/databases/main/__init__.py4
-rw-r--r--synapse/storage/databases/main/event_push_actions.py258
-rw-r--r--synapse/storage/databases/main/push_rule.py2
-rw-r--r--synapse/storage/databases/main/receipts.py74
-rw-r--r--synapse/storage/schema/main/delta/40/event_push_summary.sql7
-rw-r--r--synapse/storage/schema/main/delta/71/02event_push_summary_unique.sql18
-rw-r--r--tests/push/test_http.py16
-rw-r--r--tests/replication/slave/storage/test_events.py23
-rw-r--r--tests/storage/test_event_push_actions.py24
14 files changed, 323 insertions, 152 deletions
diff --git a/changelog.d/13005.misc b/changelog.d/13005.misc
new file mode 100644
index 0000000000..3bb51962e7
--- /dev/null
+++ b/changelog.d/13005.misc
@@ -0,0 +1 @@
+Reduce DB usage of `/sync` when a large number of unread messages have recently been sent in a room.
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 9586086c03..9c06c837dc 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -58,6 +58,9 @@ from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateSt
 from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
 from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
 from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackgroundStore
+from synapse.storage.databases.main.event_push_actions import (
+    EventPushActionsWorkerStore,
+)
 from synapse.storage.databases.main.events_bg_updates import (
     EventsBackgroundUpdatesStore,
 )
@@ -199,6 +202,7 @@ R = TypeVar("R")
 
 
 class Store(
+    EventPushActionsWorkerStore,
     ClientIpBackgroundUpdateStore,
     DeviceInboxBackgroundUpdateStore,
     DeviceBackgroundUpdateStore,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index af19c513be..6ad053f678 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tup
 import attr
 from prometheus_client import Counter
 
-from synapse.api.constants import EventTypes, Membership, ReceiptTypes
+from synapse.api.constants import EventTypes, Membership
 from synapse.api.filtering import FilterCollection
 from synapse.api.presence import UserPresenceState
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -1054,14 +1054,10 @@ class SyncHandler:
         self, room_id: str, sync_config: SyncConfig
     ) -> NotifCounts:
         with Measure(self.clock, "unread_notifs_for_room_id"):
-            last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
-                user_id=sync_config.user.to_string(),
-                room_id=room_id,
-                receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
-            )
 
             return await self.store.get_unread_event_push_actions_by_room_for_user(
-                room_id, sync_config.user.to_string(), last_unread_event_id
+                room_id,
+                sync_config.user.to_string(),
             )
 
     async def generate_sync_result(
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 8397229ccb..6661887d9f 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -13,7 +13,6 @@
 # limitations under the License.
 from typing import Dict
 
-from synapse.api.constants import ReceiptTypes
 from synapse.events import EventBase
 from synapse.push.presentable_names import calculate_room_name, name_from_member_event
 from synapse.storage.controllers import StorageControllers
@@ -24,30 +23,24 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
     invites = await store.get_invited_rooms_for_local_user(user_id)
     joins = await store.get_rooms_for_user(user_id)
 
-    my_receipts_by_room = await store.get_receipts_for_user(
-        user_id, (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE)
-    )
-
     badge = len(invites)
 
     for room_id in joins:
-        if room_id in my_receipts_by_room:
-            last_unread_event_id = my_receipts_by_room[room_id]
-
-            notifs = await (
-                store.get_unread_event_push_actions_by_room_for_user(
-                    room_id, user_id, last_unread_event_id
-                )
+        notifs = await (
+            store.get_unread_event_push_actions_by_room_for_user(
+                room_id,
+                user_id,
             )
-            if notifs.notify_count == 0:
-                continue
+        )
+        if notifs.notify_count == 0:
+            continue
 
-            if group_by_room:
-                # return one badge count per conversation
-                badge += 1
-            else:
-                # increment the badge count by the number of unread messages in the room
-                badge += notifs.notify_count
+        if group_by_room:
+            # return one badge count per conversation
+            badge += 1
+        else:
+            # increment the badge count by the number of unread messages in the room
+            badge += notifs.notify_count
     return badge
 
 
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a78d68a9d7..e8c63cf567 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -92,6 +92,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
     "event_search": "event_search_event_id_idx",
     "local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx",
     "remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx",
+    "event_push_summary": "event_push_summary_unique_index",
 }
 
 
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 9121badb3a..cb3d1242bb 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -104,13 +104,14 @@ class DataStore(
     PusherStore,
     PushRuleStore,
     ApplicationServiceTransactionStore,
+    EventPushActionsStore,
+    ServerMetricsStore,
     ReceiptsStore,
     EndToEndKeyStore,
     EndToEndRoomKeyStore,
     SearchStore,
     TagsStore,
     AccountDataStore,
-    EventPushActionsStore,
     OpenIdStore,
     ClientIpWorkerStore,
     DeviceStore,
@@ -124,7 +125,6 @@ class DataStore(
     UIAuthStore,
     EventForwardExtremitiesStore,
     CacheInvalidationWorkerStore,
-    ServerMetricsStore,
     LockStore,
     SessionStore,
 ):
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index b019979350..ae705889a5 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
 
 import attr
 
+from synapse.api.constants import ReceiptTypes
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import (
@@ -24,6 +25,8 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
@@ -79,15 +82,15 @@ class UserPushAction(EmailPushAction):
     profile_tag: str
 
 
-@attr.s(slots=True, frozen=True, auto_attribs=True)
+@attr.s(slots=True, auto_attribs=True)
 class NotifCounts:
     """
     The per-user, per-room count of notifications. Used by sync and push.
     """
 
-    notify_count: int
-    unread_count: int
-    highlight_count: int
+    notify_count: int = 0
+    unread_count: int = 0
+    highlight_count: int = 0
 
 
 def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
@@ -119,7 +122,7 @@ def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, st
         return DEFAULT_NOTIF_ACTION
 
 
-class EventPushActionsWorkerStore(SQLBaseStore):
+class EventPushActionsWorkerStore(ReceiptsWorkerStore, EventsWorkerStore, SQLBaseStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -148,12 +151,20 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 self._rotate_notifs, 30 * 60 * 1000
             )
 
-    @cached(num_args=3, tree=True, max_entries=5000)
+        self.db_pool.updates.register_background_index_update(
+            "event_push_summary_unique_index",
+            index_name="event_push_summary_unique_index",
+            table="event_push_summary",
+            columns=["user_id", "room_id"],
+            unique=True,
+            replaces_index="event_push_summary_user_rm",
+        )
+
+    @cached(tree=True, max_entries=5000)
     async def get_unread_event_push_actions_by_room_for_user(
         self,
         room_id: str,
         user_id: str,
-        last_read_event_id: Optional[str],
     ) -> NotifCounts:
         """Get the notification count, the highlight count and the unread message count
         for a given user in a given room after the given read receipt.
@@ -165,8 +176,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         Args:
             room_id: The room to retrieve the counts in.
             user_id: The user to retrieve the counts for.
-            last_read_event_id: The event associated with the latest read receipt for
-                this user in this room. None if no receipt for this user in this room.
 
         Returns
             A dict containing the counts mentioned earlier in this docstring,
@@ -178,7 +187,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             self._get_unread_counts_by_receipt_txn,
             room_id,
             user_id,
-            last_read_event_id,
         )
 
     def _get_unread_counts_by_receipt_txn(
@@ -186,16 +194,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         txn: LoggingTransaction,
         room_id: str,
         user_id: str,
-        last_read_event_id: Optional[str],
     ) -> NotifCounts:
-        stream_ordering = None
+        result = self.get_last_receipt_for_user_txn(
+            txn,
+            user_id,
+            room_id,
+            receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
+        )
 
-        if last_read_event_id is not None:
-            stream_ordering = self.get_stream_id_for_event_txn(  # type: ignore[attr-defined]
-                txn,
-                last_read_event_id,
-                allow_none=True,
-            )
+        stream_ordering = None
+        if result:
+            _, stream_ordering = result
 
         if stream_ordering is None:
             # Either last_read_event_id is None, or it's an event we don't have (e.g.
@@ -218,49 +227,95 @@ class EventPushActionsWorkerStore(SQLBaseStore):
     def _get_unread_counts_by_pos_txn(
         self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
     ) -> NotifCounts:
-        sql = (
-            "SELECT"
-            "   COUNT(CASE WHEN notif = 1 THEN 1 END),"
-            "   COUNT(CASE WHEN highlight = 1 THEN 1 END),"
-            "   COUNT(CASE WHEN unread = 1 THEN 1 END)"
-            " FROM event_push_actions ea"
-            " WHERE user_id = ?"
-            "   AND room_id = ?"
-            "   AND stream_ordering > ?"
-        )
-
-        txn.execute(sql, (user_id, room_id, stream_ordering))
-        row = txn.fetchone()
-
-        (notif_count, highlight_count, unread_count) = (0, 0, 0)
+        """Get the number of unread messages for a user/room that have happened
+        since the given stream ordering.
+        """
 
-        if row:
-            (notif_count, highlight_count, unread_count) = row
+        counts = NotifCounts()
 
+        # First we pull the counts from the summary table
         txn.execute(
             """
-                SELECT notif_count, unread_count FROM event_push_summary
+                SELECT stream_ordering, notif_count, COALESCE(unread_count, 0)
+                FROM event_push_summary
                 WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
             """,
             (room_id, user_id, stream_ordering),
         )
         row = txn.fetchone()
 
+        summary_stream_ordering = 0
         if row:
-            notif_count += row[0]
-
-            if row[1] is not None:
-                # The unread_count column of event_push_summary is NULLable, so we need
-                # to make sure we don't try increasing the unread counts if it's NULL
-                # for this row.
-                unread_count += row[1]
-
-        return NotifCounts(
-            notify_count=notif_count,
-            unread_count=unread_count,
-            highlight_count=highlight_count,
+            summary_stream_ordering = row[0]
+            counts.notify_count += row[1]
+            counts.unread_count += row[2]
+
+        # Next we need to count highlights, which aren't summarized
+        sql = """
+            SELECT COUNT(*) FROM event_push_actions
+            WHERE user_id = ?
+                AND room_id = ?
+                AND stream_ordering > ?
+                AND highlight = 1
+        """
+        txn.execute(sql, (user_id, room_id, stream_ordering))
+        row = txn.fetchone()
+        if row:
+            counts.highlight_count += row[0]
+
+        # Finally we need to count push actions that haven't been summarized
+        # yet.
+        # We only want to pull out push actions that we haven't summarized yet.
+        stream_ordering = max(stream_ordering, summary_stream_ordering)
+        notify_count, unread_count = self._get_notif_unread_count_for_user_room(
+            txn, room_id, user_id, stream_ordering
         )
 
+        counts.notify_count += notify_count
+        counts.unread_count += unread_count
+
+        return counts
+
+    def _get_notif_unread_count_for_user_room(
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        user_id: str,
+        stream_ordering: int,
+        max_stream_ordering: Optional[int] = None,
+    ) -> Tuple[int, int]:
+        """Returns the notify and unread counts from `event_push_actions` for
+        the given user/room in the given range.
+
+        Does not consult `event_push_summary` table, which may include push
+        actions that have been deleted from `event_push_actions` table.
+        """
+
+        clause = ""
+        args = [user_id, room_id, stream_ordering]
+        if max_stream_ordering is not None:
+            clause = "AND ea.stream_ordering <= ?"
+            args.append(max_stream_ordering)
+
+        sql = f"""
+            SELECT
+               COUNT(CASE WHEN notif = 1 THEN 1 END),
+               COUNT(CASE WHEN unread = 1 THEN 1 END)
+             FROM event_push_actions ea
+             WHERE user_id = ?
+               AND room_id = ?
+               AND ea.stream_ordering > ?
+               {clause}
+        """
+
+        txn.execute(sql, args)
+        row = txn.fetchone()
+
+        if row:
+            return cast(Tuple[int, int], row)
+
+        return 0, 0
+
     async def get_push_action_users_in_range(
         self, min_stream_ordering: int, max_stream_ordering: int
     ) -> List[str]:
@@ -754,6 +809,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 if caught_up:
                     break
                 await self.hs.get_clock().sleep(self._rotate_delay)
+
+            await self._remove_old_push_actions_that_have_rotated()
         finally:
             self._doing_notif_rotation = False
 
@@ -782,20 +839,16 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         stream_row = txn.fetchone()
         if stream_row:
             (offset_stream_ordering,) = stream_row
-            assert self.stream_ordering_day_ago is not None
-            rotate_to_stream_ordering = min(
-                self.stream_ordering_day_ago, offset_stream_ordering
-            )
-            caught_up = offset_stream_ordering >= self.stream_ordering_day_ago
+            rotate_to_stream_ordering = offset_stream_ordering
+            caught_up = False
         else:
-            rotate_to_stream_ordering = self.stream_ordering_day_ago
+            rotate_to_stream_ordering = self._stream_id_gen.get_current_token()
             caught_up = True
 
         logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering)
 
         self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering)
 
-        # We have caught up iff we were limited by `stream_ordering_day_ago`
         return caught_up
 
     def _rotate_notifs_before_txn(
@@ -819,7 +872,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                     max(stream_ordering) as stream_ordering
                 FROM event_push_actions
                 WHERE ? <= stream_ordering AND stream_ordering < ?
-                    AND highlight = 0
                     AND %s = 1
                 GROUP BY user_id, room_id
             ) AS upd
@@ -915,18 +967,72 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
 
         txn.execute(
-            "DELETE FROM event_push_actions"
-            " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0",
-            (old_rotate_stream_ordering, rotate_to_stream_ordering),
+            "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?",
+            (rotate_to_stream_ordering,),
         )
 
-        logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
+    async def _remove_old_push_actions_that_have_rotated(
+        self,
+    ) -> None:
+        """Clear out old push actions that have been summarized."""
 
-        txn.execute(
-            "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?",
-            (rotate_to_stream_ordering,),
+        # We want to clear out anything that older than a day that *has* already
+        # been rotated.
+        rotated_upto_stream_ordering = await self.db_pool.simple_select_one_onecol(
+            table="event_push_summary_stream_ordering",
+            keyvalues={},
+            retcol="stream_ordering",
         )
 
+        max_stream_ordering_to_delete = min(
+            rotated_upto_stream_ordering, self.stream_ordering_day_ago
+        )
+
+        def remove_old_push_actions_that_have_rotated_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
+            # We don't want to clear out too much at a time, so we bound our
+            # deletes.
+            batch_size = 10000
+
+            txn.execute(
+                """
+                SELECT stream_ordering FROM event_push_actions
+                WHERE stream_ordering < ? AND highlight = 0
+                ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
+            """,
+                (
+                    max_stream_ordering_to_delete,
+                    batch_size,
+                ),
+            )
+            stream_row = txn.fetchone()
+
+            if stream_row:
+                (stream_ordering,) = stream_row
+            else:
+                stream_ordering = max_stream_ordering_to_delete
+
+            txn.execute(
+                """
+                DELETE FROM event_push_actions
+                WHERE stream_ordering < ? AND highlight = 0
+                """,
+                (stream_ordering,),
+            )
+
+            logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
+
+            return txn.rowcount < batch_size
+
+        while True:
+            done = await self.db_pool.runInteraction(
+                "_remove_old_push_actions_that_have_rotated",
+                remove_old_push_actions_that_have_rotated_txn,
+            )
+            if done:
+                break
+
     def _remove_old_push_actions_before_txn(
         self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
     ) -> None:
@@ -965,12 +1071,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
         )
 
-        txn.execute(
-            """
-            DELETE FROM event_push_summary
-            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
-        """,
-            (room_id, user_id, stream_ordering),
+        old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="event_push_summary_stream_ordering",
+            keyvalues={},
+            retcol="stream_ordering",
+        )
+
+        notif_count, unread_count = self._get_notif_unread_count_for_user_room(
+            txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
+        )
+
+        self.db_pool.simple_upsert_txn(
+            txn,
+            table="event_push_summary",
+            keyvalues={"room_id": room_id, "user_id": user_id},
+            values={
+                "notif_count": notif_count,
+                "unread_count": unread_count,
+                "stream_ordering": old_rotate_stream_ordering,
+            },
         )
 
 
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index d5aefe02b6..86649c1e6c 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -110,9 +110,9 @@ def _load_rules(
 # the abstract methods being implemented.
 class PushRulesWorkerStore(
     ApplicationServiceWorkerStore,
-    ReceiptsWorkerStore,
     PusherWorkerStore,
     RoomMemberWorkerStore,
+    ReceiptsWorkerStore,
     EventsWorkerStore,
     SQLBaseStore,
     metaclass=abc.ABCMeta,
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index b6106affa6..bec6d60577 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -118,7 +118,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         return self._receipts_id_gen.get_current_token()
 
     async def get_last_receipt_event_id_for_user(
-        self, user_id: str, room_id: str, receipt_types: Iterable[str]
+        self, user_id: str, room_id: str, receipt_types: Collection[str]
     ) -> Optional[str]:
         """
         Fetch the event ID for the latest receipt in a room with one of the given receipt types.
@@ -126,58 +126,63 @@ class ReceiptsWorkerStore(SQLBaseStore):
         Args:
             user_id: The user to fetch receipts for.
             room_id: The room ID to fetch the receipt for.
-            receipt_type: The receipt types to fetch. Earlier receipt types
-                are given priority if multiple receipts point to the same event.
+            receipt_type: The receipt types to fetch.
 
         Returns:
             The latest receipt, if one exists.
         """
-        latest_event_id: Optional[str] = None
-        latest_stream_ordering = 0
-        for receipt_type in receipt_types:
-            result = await self._get_last_receipt_event_id_for_user(
-                user_id, room_id, receipt_type
-            )
-            if result is None:
-                continue
-            event_id, stream_ordering = result
-
-            if latest_event_id is None or latest_stream_ordering < stream_ordering:
-                latest_event_id = event_id
-                latest_stream_ordering = stream_ordering
+        result = await self.db_pool.runInteraction(
+            "get_last_receipt_event_id_for_user",
+            self.get_last_receipt_for_user_txn,
+            user_id,
+            room_id,
+            receipt_types,
+        )
+        if not result:
+            return None
 
-        return latest_event_id
+        event_id, _ = result
+        return event_id
 
-    @cached()
-    async def _get_last_receipt_event_id_for_user(
-        self, user_id: str, room_id: str, receipt_type: str
+    def get_last_receipt_for_user_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        room_id: str,
+        receipt_types: Collection[str],
     ) -> Optional[Tuple[str, int]]:
         """
-        Fetch the event ID and stream ordering for the latest receipt.
+        Fetch the event ID and stream_ordering for the latest receipt in a room
+        with one of the given receipt types.
 
         Args:
             user_id: The user to fetch receipts for.
             room_id: The room ID to fetch the receipt for.
-            receipt_type: The receipt type to fetch.
+            receipt_type: The receipt types to fetch.
 
         Returns:
-            The event ID and stream ordering of the latest receipt, if one exists;
-            otherwise `None`.
+            The latest receipt, if one exists.
         """
-        sql = """
+
+        clause, args = make_in_list_sql_clause(
+            self.database_engine, "receipt_type", receipt_types
+        )
+
+        sql = f"""
             SELECT event_id, stream_ordering
             FROM receipts_linearized
             INNER JOIN events USING (room_id, event_id)
-            WHERE user_id = ?
+            WHERE {clause}
+            AND user_id = ?
             AND room_id = ?
-            AND receipt_type = ?
+            ORDER BY stream_ordering DESC
+            LIMIT 1
         """
 
-        def f(txn: LoggingTransaction) -> Optional[Tuple[str, int]]:
-            txn.execute(sql, (user_id, room_id, receipt_type))
-            return cast(Optional[Tuple[str, int]], txn.fetchone())
+        args.extend((user_id, room_id))
+        txn.execute(sql, args)
 
-        return await self.db_pool.runInteraction("get_own_receipt_for_user", f)
+        return cast(Optional[Tuple[str, int]], txn.fetchone())
 
     async def get_receipts_for_user(
         self, user_id: str, receipt_types: Iterable[str]
@@ -577,8 +582,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
     ) -> None:
         self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type))
         self._get_linearized_receipts_for_room.invalidate((room_id,))
-        self._get_last_receipt_event_id_for_user.invalidate(
-            (user_id, room_id, receipt_type)
+
+        # We use this method to invalidate so that we don't end up with circular
+        # dependencies between the receipts and push action stores.
+        self._attempt_to_invalidate_cache(
+            "get_unread_event_push_actions_by_room_for_user", (room_id,)
         )
 
     def process_replication_rows(
diff --git a/synapse/storage/schema/main/delta/40/event_push_summary.sql b/synapse/storage/schema/main/delta/40/event_push_summary.sql
index 3918f0b794..499bf60178 100644
--- a/synapse/storage/schema/main/delta/40/event_push_summary.sql
+++ b/synapse/storage/schema/main/delta/40/event_push_summary.sql
@@ -13,9 +13,10 @@
  * limitations under the License.
  */
 
--- Aggregate of old notification counts that have been deleted out of the
--- main event_push_actions table. This count does not include those that were
--- highlights, as they remain in the event_push_actions table.
+-- Aggregate of notification counts up to `stream_ordering`, including those
+-- that may have been deleted out of the main event_push_actions table. This
+-- count does not include those that were highlights, as they remain in the
+-- event_push_actions table.
 CREATE TABLE event_push_summary (
     user_id TEXT NOT NULL,
     room_id TEXT NOT NULL,
diff --git a/synapse/storage/schema/main/delta/71/02event_push_summary_unique.sql b/synapse/storage/schema/main/delta/71/02event_push_summary_unique.sql
new file mode 100644
index 0000000000..9cdcea21ae
--- /dev/null
+++ b/synapse/storage/schema/main/delta/71/02event_push_summary_unique.sql
@@ -0,0 +1,18 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a unique index to `event_push_summary`
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (7002, 'event_push_summary_unique_index', '{}');
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index ba158f5d93..d9c68cdd2d 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -577,7 +577,7 @@ class HTTPPusherTests(HomeserverTestCase):
         # Carry out our option-value specific test
         #
         # This push should still only contain an unread count of 1 (for 1 unread room)
-        self._check_push_attempt(6, 1)
+        self._check_push_attempt(7, 1)
 
     @override_config({"push": {"group_unread_count_by_room": False}})
     def test_push_unread_count_message_count(self) -> None:
@@ -591,7 +591,7 @@ class HTTPPusherTests(HomeserverTestCase):
         #
         # We're counting every unread message, so there should now be 3 since the
         # last read receipt
-        self._check_push_attempt(6, 3)
+        self._check_push_attempt(7, 3)
 
     def _test_push_unread_count(self) -> None:
         """
@@ -641,18 +641,18 @@ class HTTPPusherTests(HomeserverTestCase):
         response = self.helper.send(
             room_id, body="Hello there!", tok=other_access_token
         )
-        # To get an unread count, the user who is getting notified has to have a read
-        # position in the room. We'll set the read position to this event in a moment
+
         first_message_event_id = response["event_id"]
 
         expected_push_attempts = 1
-        self._check_push_attempt(expected_push_attempts, 0)
+        self._check_push_attempt(expected_push_attempts, 1)
 
         self._send_read_request(access_token, first_message_event_id, room_id)
 
-        # Unread count has not changed. Therefore, ensure that read request does not
-        # trigger a push notification.
-        self.assertEqual(len(self.push_attempts), 1)
+        # Unread count has changed. Therefore, ensure that read request triggers
+        # a push notification.
+        expected_push_attempts += 1
+        self.assertEqual(len(self.push_attempts), expected_push_attempts)
 
         # Send another message
         response2 = self.helper.send(
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 6d3d4afe52..531a0db2d0 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -15,7 +15,9 @@ import logging
 from typing import Iterable, Optional
 
 from canonicaljson import encode_canonical_json
+from parameterized import parameterized
 
+from synapse.api.constants import ReceiptTypes
 from synapse.api.room_versions import RoomVersions
 from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
 from synapse.handlers.room import RoomEventSource
@@ -156,17 +158,26 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             ],
         )
 
-    def test_push_actions_for_user(self):
+    @parameterized.expand([(True,), (False,)])
+    def test_push_actions_for_user(self, send_receipt: bool):
         self.persist(type="m.room.create", key="", creator=USER_ID)
-        self.persist(type="m.room.join", key=USER_ID, membership="join")
+        self.persist(type="m.room.member", key=USER_ID, membership="join")
         self.persist(
-            type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
+            type="m.room.member", sender=USER_ID, key=USER_ID_2, membership="join"
         )
         event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello")
         self.replicate()
+
+        if send_receipt:
+            self.get_success(
+                self.master_store.insert_receipt(
+                    ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], {}
+                )
+            )
+
         self.check(
             "get_unread_event_push_actions_by_room_for_user",
-            [ROOM_ID, USER_ID_2, event1.event_id],
+            [ROOM_ID, USER_ID_2],
             NotifCounts(highlight_count=0, unread_count=0, notify_count=0),
         )
 
@@ -179,7 +190,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         self.replicate()
         self.check(
             "get_unread_event_push_actions_by_room_for_user",
-            [ROOM_ID, USER_ID_2, event1.event_id],
+            [ROOM_ID, USER_ID_2],
             NotifCounts(highlight_count=0, unread_count=0, notify_count=1),
         )
 
@@ -194,7 +205,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         self.replicate()
         self.check(
             "get_unread_event_push_actions_by_room_for_user",
-            [ROOM_ID, USER_ID_2, event1.event_id],
+            [ROOM_ID, USER_ID_2],
             NotifCounts(highlight_count=1, unread_count=0, notify_count=2),
         )
 
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 0f9add4841..4273524c4c 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -51,10 +51,16 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
         room_id = "!foo:example.com"
         user_id = "@user1235:example.com"
 
+        last_read_stream_ordering = [0]
+
         def _assert_counts(noitf_count, highlight_count):
             counts = self.get_success(
                 self.store.db_pool.runInteraction(
-                    "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
+                    "",
+                    self.store._get_unread_counts_by_pos_txn,
+                    room_id,
+                    user_id,
+                    last_read_stream_ordering[0],
                 )
             )
             self.assertEqual(
@@ -98,6 +104,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
             )
 
         def _mark_read(stream, depth):
+            last_read_stream_ordering[0] = stream
             self.get_success(
                 self.store.db_pool.runInteraction(
                     "",
@@ -144,8 +151,19 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
         _assert_counts(1, 1)
         _rotate(9)
         _assert_counts(1, 1)
-        _rotate(10)
-        _assert_counts(1, 1)
+
+        # Check that adding another notification and rotating after highlight
+        # works.
+        _inject_actions(10, PlAIN_NOTIF)
+        _rotate(11)
+        _assert_counts(2, 1)
+
+        # Check that sending read receipts at different points results in the
+        # right counts.
+        _mark_read(8, 8)
+        _assert_counts(1, 0)
+        _mark_read(10, 10)
+        _assert_counts(0, 0)
 
     def test_find_first_stream_ordering_after_ts(self):
         def add_event(so, ts):