summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2020-09-02 17:19:37 +0100
committerGitHub <noreply@github.com>2020-09-02 17:19:37 +0100
commit5a1dd297c3ce105a7f516d9d9fe87b94b9d356c8 (patch)
treee7be5283e17b93e1e9de477b5cf08a8d87bfcbb3 /synapse
parentRefactor `_get_e2e_device_keys_for_federation_query_txn` (#8225) (diff)
downloadsynapse-5a1dd297c3ce105a7f516d9d9fe87b94b9d356c8.tar.xz
Re-implement unread counts (again) (#8059)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/sync.py33
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py92
-rw-r--r--synapse/push/push_tools.py2
-rw-r--r--synapse/rest/client/v2_alpha/sync.py1
-rw-r--r--synapse/storage/databases/main/event_push_actions.py224
-rw-r--r--synapse/storage/databases/main/events.py4
-rw-r--r--synapse/storage/databases/main/schema/delta/58/15unread_count.sql26
-rw-r--r--synapse/storage/databases/main/stream.py21
8 files changed, 289 insertions, 114 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 9a86eb01c9..43fc40fc2f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -95,7 +95,12 @@ class TimelineBatch:
     __bool__ = __nonzero__  # python3
 
 
-@attr.s(slots=True, frozen=True)
+# We can't freeze this class, because we need to update it after it's instantiated to
+# update its unread count. This is because we calculate the unread count for a room only
+# if there are updates for it, which we check after the instance has been created.
+# This should not be a big deal because we update the notification counts afterwards as
+# well anyway.
+@attr.s(slots=True)
 class JoinedSyncResult:
     room_id = attr.ib(type=str)
     timeline = attr.ib(type=TimelineBatch)
@@ -104,6 +109,7 @@ class JoinedSyncResult:
     account_data = attr.ib(type=List[JsonDict])
     unread_notifications = attr.ib(type=JsonDict)
     summary = attr.ib(type=Optional[JsonDict])
+    unread_count = attr.ib(type=int)
 
     def __nonzero__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
@@ -931,7 +937,7 @@ class SyncHandler(object):
 
     async def unread_notifs_for_room_id(
         self, room_id: str, sync_config: SyncConfig
-    ) -> Optional[Dict[str, str]]:
+    ) -> Dict[str, int]:
         with Measure(self.clock, "unread_notifs_for_room_id"):
             last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
                 user_id=sync_config.user.to_string(),
@@ -939,15 +945,10 @@ class SyncHandler(object):
                 receipt_type="m.read",
             )
 
-            if last_unread_event_id:
-                notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
-                    room_id, sync_config.user.to_string(), last_unread_event_id
-                )
-                return notifs
-
-        # There is no new information in this period, so your notification
-        # count is whatever it was last time.
-        return None
+            notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
+                room_id, sync_config.user.to_string(), last_unread_event_id
+            )
+            return notifs
 
     async def generate_sync_result(
         self,
@@ -1886,7 +1887,7 @@ class SyncHandler(object):
             )
 
         if room_builder.rtype == "joined":
-            unread_notifications = {}  # type: Dict[str, str]
+            unread_notifications = {}  # type: Dict[str, int]
             room_sync = JoinedSyncResult(
                 room_id=room_id,
                 timeline=batch,
@@ -1895,14 +1896,16 @@ class SyncHandler(object):
                 account_data=account_data_events,
                 unread_notifications=unread_notifications,
                 summary=summary,
+                unread_count=0,
             )
 
             if room_sync or always_include:
                 notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
 
-                if notifs is not None:
-                    unread_notifications["notification_count"] = notifs["notify_count"]
-                    unread_notifications["highlight_count"] = notifs["highlight_count"]
+                unread_notifications["notification_count"] = notifs["notify_count"]
+                unread_notifications["highlight_count"] = notifs["highlight_count"]
+
+                room_sync.unread_count = notifs["unread_count"]
 
                 sync_result_builder.joined.append(room_sync)
 
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index e7fcee0e87..e7fa02b78b 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -19,8 +19,10 @@ from collections import namedtuple
 
 from prometheus_client import Counter
 
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RelationTypes
 from synapse.event_auth import get_user_power_level
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
 from synapse.state import POWER_KEY
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import register_cache
@@ -51,6 +53,48 @@ push_rules_delta_state_cache_metric = register_cache(
 )
 
 
+STATE_EVENT_TYPES_TO_MARK_UNREAD = {
+    EventTypes.Topic,
+    EventTypes.Name,
+    EventTypes.RoomAvatar,
+    EventTypes.Tombstone,
+}
+
+
+def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
+    # Exclude rejected and soft-failed events.
+    if context.rejected or event.internal_metadata.is_soft_failed():
+        return False
+
+    # Exclude notices.
+    if (
+        not event.is_state()
+        and event.type == EventTypes.Message
+        and event.content.get("msgtype") == "m.notice"
+    ):
+        return False
+
+    # Exclude edits.
+    relates_to = event.content.get("m.relates_to", {})
+    if relates_to.get("rel_type") == RelationTypes.REPLACE:
+        return False
+
+    # Mark events that have a non-empty string body as unread.
+    body = event.content.get("body")
+    if isinstance(body, str) and body:
+        return True
+
+    # Mark some state events as unread.
+    if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
+        return True
+
+    # Mark encrypted events as unread.
+    if not event.is_state() and event.type == EventTypes.Encrypted:
+        return True
+
+    return False
+
+
 class BulkPushRuleEvaluator(object):
     """Calculates the outcome of push rules for an event for all users in the
     room at once.
@@ -133,9 +177,12 @@ class BulkPushRuleEvaluator(object):
         return pl_event.content if pl_event else {}, sender_level
 
     async def action_for_event_by_user(self, event, context) -> None:
-        """Given an event and context, evaluate the push rules and insert the
-        results into the event_push_actions_staging table.
+        """Given an event and context, evaluate the push rules, check if the message
+        should increment the unread count, and insert the results into the
+        event_push_actions_staging table.
         """
+        count_as_unread = _should_count_as_unread(event, context)
+
         rules_by_user = await self._get_rules_for_event(event, context)
         actions_by_user = {}
 
@@ -172,6 +219,8 @@ class BulkPushRuleEvaluator(object):
                 if event.type == EventTypes.Member and event.state_key == uid:
                     display_name = event.content.get("displayname", None)
 
+            actions_by_user[uid] = []
+
             for rule in rules:
                 if "enabled" in rule and not rule["enabled"]:
                     continue
@@ -189,7 +238,9 @@ class BulkPushRuleEvaluator(object):
         # Mark in the DB staging area the push actions for users who should be
         # notified for this event. (This will then get handled when we persist
         # the event)
-        await self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
+        await self.store.add_push_actions_to_staging(
+            event.event_id, actions_by_user, count_as_unread,
+        )
 
 
 def _condition_checker(evaluator, conditions, uid, display_name, cache):
@@ -369,8 +420,8 @@ class RulesForRoom(object):
         Args:
             ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
                 updated with any new rules.
-            member_event_ids (list): List of event ids for membership events that
-                have happened since the last time we filled rules_by_user
+            member_event_ids (dict): Dict of user id to event id for membership events
+                that have happened since the last time we filled rules_by_user
             state_group: The state group we are currently computing push rules
                 for. Used when updating the cache.
         """
@@ -390,34 +441,19 @@ class RulesForRoom(object):
         if logger.isEnabledFor(logging.DEBUG):
             logger.debug("Found members %r: %r", self.room_id, members.values())
 
-        interested_in_user_ids = {
+        user_ids = {
             user_id
             for user_id, membership in members.values()
             if membership == Membership.JOIN
         }
 
-        logger.debug("Joined: %r", interested_in_user_ids)
-
-        if_users_with_pushers = await self.store.get_if_users_have_pushers(
-            interested_in_user_ids, on_invalidate=self.invalidate_all_cb
-        )
-
-        user_ids = {
-            uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
-        }
-
-        logger.debug("With pushers: %r", user_ids)
-
-        users_with_receipts = await self.store.get_users_with_read_receipts_in_room(
-            self.room_id, on_invalidate=self.invalidate_all_cb
-        )
-
-        logger.debug("With receipts: %r", users_with_receipts)
+        logger.debug("Joined: %r", user_ids)
 
-        # any users with pushers must be ours: they have pushers
-        for uid in users_with_receipts:
-            if uid in interested_in_user_ids:
-                user_ids.add(uid)
+        # Previously we only considered users with pushers or read receipts in that
+        # room. We can't do this anymore because we use push actions to calculate unread
+        # counts, which don't rely on the user having pushers or sent a read receipt into
+        # the room. Therefore we just need to filter for local users here.
+        user_ids = list(filter(self.is_mine_id, user_ids))
 
         rules_by_user = await self.store.bulk_get_push_rules(
             user_ids, on_invalidate=self.invalidate_all_cb
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index d0145666bf..f7a25571f3 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -36,7 +36,7 @@ async def get_badge_count(store, user_id):
             )
             # return one badge count per conversation, as count per
             # message is so noisy as to be almost useless
-            badge += 1 if notifs["notify_count"] else 0
+            badge += 1 if notifs["unread_count"] else 0
     return badge
 
 
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 96488b131a..a0b00135e1 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -425,6 +425,7 @@ class SyncRestServlet(RestServlet):
             result["ephemeral"] = {"events": ephemeral_events}
             result["unread_notifications"] = room.unread_notifications
             result["summary"] = room.summary
+            result["org.matrix.msc2654.unread_count"] = room.unread_count
 
         return result
 
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 384c39f4d7..001d06378d 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -15,7 +15,9 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Union
+from typing import Dict, List, Optional, Tuple, Union
+
+import attr
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
@@ -88,8 +90,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
     @cached(num_args=3, tree=True, max_entries=5000)
     async def get_unread_event_push_actions_by_room_for_user(
-        self, room_id, user_id, last_read_event_id
-    ):
+        self, room_id: str, user_id: str, last_read_event_id: Optional[str],
+    ) -> Dict[str, int]:
+        """Get the notification count, the highlight count and the unread message count
+        for a given user in a given room after the given read receipt.
+
+        Note that this function assumes the user to be a current member of the room,
+        since it's either called by the sync handler to handle joined room entries, or by
+        the HTTP pusher to calculate the badge of unread joined rooms.
+
+        Args:
+            room_id: The room to retrieve the counts in.
+            user_id: The user to retrieve the counts for.
+            last_read_event_id: The event associated with the latest read receipt for
+                this user in this room. None if no receipt for this user in this room.
+
+        Returns
+            A dict containing the counts mentioned earlier in this docstring,
+            respectively under the keys "notify_count", "highlight_count" and
+            "unread_count".
+        """
         return await self.db_pool.runInteraction(
             "get_unread_event_push_actions_by_room",
             self._get_unread_counts_by_receipt_txn,
@@ -99,69 +119,71 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
 
     def _get_unread_counts_by_receipt_txn(
-        self, txn, room_id, user_id, last_read_event_id
+        self, txn, room_id, user_id, last_read_event_id,
     ):
-        sql = (
-            "SELECT stream_ordering"
-            " FROM events"
-            " WHERE room_id = ? AND event_id = ?"
-        )
-        txn.execute(sql, (room_id, last_read_event_id))
-        results = txn.fetchall()
-        if len(results) == 0:
-            return {"notify_count": 0, "highlight_count": 0}
+        stream_ordering = None
+
+        if last_read_event_id is not None:
+            stream_ordering = self.get_stream_id_for_event_txn(
+                txn, last_read_event_id, allow_none=True,
+            )
+
+        if stream_ordering is None:
+            # Either last_read_event_id is None, or it's an event we don't have (e.g.
+            # because it's been purged), in which case retrieve the stream ordering for
+            # the latest membership event from this user in this room (which we assume is
+            # a join).
+            event_id = self.db_pool.simple_select_one_onecol_txn(
+                txn=txn,
+                table="local_current_membership",
+                keyvalues={"room_id": room_id, "user_id": user_id},
+                retcol="event_id",
+            )
 
-        stream_ordering = results[0][0]
+            stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
 
         return self._get_unread_counts_by_pos_txn(
             txn, room_id, user_id, stream_ordering
         )
 
     def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
-
-        # First get number of notifications.
-        # We don't need to put a notif=1 clause as all rows always have
-        # notif=1
         sql = (
-            "SELECT count(*)"
+            "SELECT"
+            "   COUNT(CASE WHEN notif = 1 THEN 1 END),"
+            "   COUNT(CASE WHEN highlight = 1 THEN 1 END),"
+            "   COUNT(CASE WHEN unread = 1 THEN 1 END)"
             " FROM event_push_actions ea"
-            " WHERE"
-            " user_id = ?"
-            " AND room_id = ?"
-            " AND stream_ordering > ?"
+            " WHERE user_id = ?"
+            "   AND room_id = ?"
+            "   AND stream_ordering > ?"
         )
 
         txn.execute(sql, (user_id, room_id, stream_ordering))
         row = txn.fetchone()
-        notify_count = row[0] if row else 0
+
+        (notif_count, highlight_count, unread_count) = (0, 0, 0)
+
+        if row:
+            (notif_count, highlight_count, unread_count) = row
 
         txn.execute(
             """
-            SELECT notif_count FROM event_push_summary
-            WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
-        """,
+                SELECT notif_count, unread_count FROM event_push_summary
+                WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
+            """,
             (room_id, user_id, stream_ordering),
         )
-        rows = txn.fetchall()
-        if rows:
-            notify_count += rows[0][0]
-
-        # Now get the number of highlights
-        sql = (
-            "SELECT count(*)"
-            " FROM event_push_actions ea"
-            " WHERE"
-            " highlight = 1"
-            " AND user_id = ?"
-            " AND room_id = ?"
-            " AND stream_ordering > ?"
-        )
-
-        txn.execute(sql, (user_id, room_id, stream_ordering))
         row = txn.fetchone()
-        highlight_count = row[0] if row else 0
 
-        return {"notify_count": notify_count, "highlight_count": highlight_count}
+        if row:
+            notif_count += row[0]
+            unread_count += row[1]
+
+        return {
+            "notify_count": notif_count,
+            "unread_count": unread_count,
+            "highlight_count": highlight_count,
+        }
 
     async def get_push_action_users_in_range(
         self, min_stream_ordering, max_stream_ordering
@@ -222,6 +244,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering ASC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -250,6 +273,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering ASC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -324,6 +348,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering DESC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -352,6 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering DESC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -402,7 +428,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         def _get_if_maybe_push_in_range_for_user_txn(txn):
             sql = """
                 SELECT 1 FROM event_push_actions
-                WHERE user_id = ? AND stream_ordering > ?
+                WHERE user_id = ? AND stream_ordering > ? AND notif = 1
                 LIMIT 1
             """
 
@@ -415,7 +441,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
 
     async def add_push_actions_to_staging(
-        self, event_id: str, user_id_actions: Dict[str, List[Union[dict, str]]]
+        self,
+        event_id: str,
+        user_id_actions: Dict[str, List[Union[dict, str]]],
+        count_as_unread: bool,
     ) -> None:
         """Add the push actions for the event to the push action staging area.
 
@@ -423,21 +452,23 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             event_id
             user_id_actions: A mapping of user_id to list of push actions, where
                 an action can either be a string or dict.
+            count_as_unread: Whether this event should increment unread counts.
         """
-
         if not user_id_actions:
             return
 
         # This is a helper function for generating the necessary tuple that
-        # can be used to inert into the `event_push_actions_staging` table.
+        # can be used to insert into the `event_push_actions_staging` table.
         def _gen_entry(user_id, actions):
             is_highlight = 1 if _action_has_highlight(actions) else 0
+            notif = 1 if "notify" in actions else 0
             return (
                 event_id,  # event_id column
                 user_id,  # user_id column
                 _serialize_action(actions, is_highlight),  # actions column
-                1,  # notif column
+                notif,  # notif column
                 is_highlight,  # highlight column
+                int(count_as_unread),  # unread column
             )
 
         def _add_push_actions_to_staging_txn(txn):
@@ -446,8 +477,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
             sql = """
                 INSERT INTO event_push_actions_staging
-                    (event_id, user_id, actions, notif, highlight)
-                VALUES (?, ?, ?, ?, ?)
+                    (event_id, user_id, actions, notif, highlight, unread)
+                VALUES (?, ?, ?, ?, ?, ?)
             """
 
             txn.executemany(
@@ -811,24 +842,63 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         # Calculate the new counts that should be upserted into event_push_summary
         sql = """
             SELECT user_id, room_id,
-                coalesce(old.notif_count, 0) + upd.notif_count,
+                coalesce(old.%s, 0) + upd.cnt,
                 upd.stream_ordering,
                 old.user_id
             FROM (
-                SELECT user_id, room_id, count(*) as notif_count,
+                SELECT user_id, room_id, count(*) as cnt,
                     max(stream_ordering) as stream_ordering
                 FROM event_push_actions
                 WHERE ? <= stream_ordering AND stream_ordering < ?
                     AND highlight = 0
+                    AND %s = 1
                 GROUP BY user_id, room_id
             ) AS upd
             LEFT JOIN event_push_summary AS old USING (user_id, room_id)
         """
 
-        txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
-        rows = txn.fetchall()
+        # First get the count of unread messages.
+        txn.execute(
+            sql % ("unread_count", "unread"),
+            (old_rotate_stream_ordering, rotate_to_stream_ordering),
+        )
+
+        # We need to merge results from the two requests (the one that retrieves the
+        # unread count and the one that retrieves the notifications count) into a single
+        # object because we might not have the same amount of rows in each of them. To do
+        # this, we use a dict indexed on the user ID and room ID to make it easier to
+        # populate.
+        summaries = {}  # type: Dict[Tuple[str, str], _EventPushSummary]
+        for row in txn:
+            summaries[(row[0], row[1])] = _EventPushSummary(
+                unread_count=row[2],
+                stream_ordering=row[3],
+                old_user_id=row[4],
+                notif_count=0,
+            )
+
+        # Then get the count of notifications.
+        txn.execute(
+            sql % ("notif_count", "notif"),
+            (old_rotate_stream_ordering, rotate_to_stream_ordering),
+        )
+
+        for row in txn:
+            if (row[0], row[1]) in summaries:
+                summaries[(row[0], row[1])].notif_count = row[2]
+            else:
+                # Because the rules on notifying are different than the rules on marking
+                # a message unread, we might end up with messages that notify but aren't
+                # marked unread, so we might not have a summary for this (user, room)
+                # tuple to complete.
+                summaries[(row[0], row[1])] = _EventPushSummary(
+                    unread_count=0,
+                    stream_ordering=row[3],
+                    old_user_id=row[4],
+                    notif_count=row[2],
+                )
 
-        logger.info("Rotating notifications, handling %d rows", len(rows))
+        logger.info("Rotating notifications, handling %d rows", len(summaries))
 
         # If the `old.user_id` above is NULL then we know there isn't already an
         # entry in the table, so we simply insert it. Otherwise we update the
@@ -838,22 +908,34 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             table="event_push_summary",
             values=[
                 {
-                    "user_id": row[0],
-                    "room_id": row[1],
-                    "notif_count": row[2],
-                    "stream_ordering": row[3],
+                    "user_id": user_id,
+                    "room_id": room_id,
+                    "notif_count": summary.notif_count,
+                    "unread_count": summary.unread_count,
+                    "stream_ordering": summary.stream_ordering,
                 }
-                for row in rows
-                if row[4] is None
+                for ((user_id, room_id), summary) in summaries.items()
+                if summary.old_user_id is None
             ],
         )
 
         txn.executemany(
             """
-                UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
+                UPDATE event_push_summary
+                SET notif_count = ?, unread_count = ?, stream_ordering = ?
                 WHERE user_id = ? AND room_id = ?
             """,
-            ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
+            (
+                (
+                    summary.notif_count,
+                    summary.unread_count,
+                    summary.stream_ordering,
+                    user_id,
+                    room_id,
+                )
+                for ((user_id, room_id), summary) in summaries.items()
+                if summary.old_user_id is not None
+            ),
         )
 
         txn.execute(
@@ -879,3 +961,15 @@ def _action_has_highlight(actions):
             pass
 
     return False
+
+
+@attr.s
+class _EventPushSummary:
+    """Summary of pending event push actions for a given user in a given room.
+    Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
+    """
+
+    unread_count = attr.ib(type=int)
+    stream_ordering = attr.ib(type=int)
+    old_user_id = attr.ib(type=str)
+    notif_count = attr.ib(type=int)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 46b11e705b..b94fe7ac17 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1298,9 +1298,9 @@ class PersistEventsStore:
         sql = """
             INSERT INTO event_push_actions (
                 room_id, event_id, user_id, actions, stream_ordering,
-                topological_ordering, notif, highlight
+                topological_ordering, notif, highlight, unread
             )
-            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
             FROM event_push_actions_staging
             WHERE event_id = ?
         """
diff --git a/synapse/storage/databases/main/schema/delta/58/15unread_count.sql b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
new file mode 100644
index 0000000000..b451e8663a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
@@ -0,0 +1,26 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We're hijacking the push actions to store unread messages and unread counts (specified
+-- in MSC2654) because doing otherwise would result in either performance issues or
+-- reimplementing a consequent bit of the push actions.
+
+-- Add columns to event_push_actions and event_push_actions_staging to track unread
+-- messages and calculate unread counts.
+ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
+ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
+
+-- Add column to event_push_summary
+ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT NOT NULL DEFAULT 0;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 24f44a7e36..83c1ddf95a 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -47,7 +47,11 @@ from synapse.api.filtering import Filter
 from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingTransaction,
+    make_in_list_sql_clause,
+)
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
 from synapse.types import RoomStreamToken
@@ -593,8 +597,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             A stream ID.
         """
-        return await self.db_pool.simple_select_one_onecol(
-            table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
+        return await self.db_pool.runInteraction(
+            "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
+        )
+
+    def get_stream_id_for_event_txn(
+        self, txn: LoggingTransaction, event_id: str, allow_none=False,
+    ) -> int:
+        return self.db_pool.simple_select_one_onecol_txn(
+            txn=txn,
+            table="events",
+            keyvalues={"event_id": event_id},
+            retcol="stream_ordering",
+            allow_none=allow_none,
         )
 
     async def get_stream_token_for_event(self, event_id: str) -> str: