summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8059.feature1
-rw-r--r--synapse/handlers/sync.py33
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py92
-rw-r--r--synapse/push/push_tools.py2
-rw-r--r--synapse/rest/client/v2_alpha/sync.py1
-rw-r--r--synapse/storage/databases/main/event_push_actions.py224
-rw-r--r--synapse/storage/databases/main/events.py4
-rw-r--r--synapse/storage/databases/main/schema/delta/58/15unread_count.sql26
-rw-r--r--synapse/storage/databases/main/stream.py21
-rw-r--r--tests/replication/slave/storage/test_events.py10
-rw-r--r--tests/rest/client/v2_alpha/test_sync.py157
-rw-r--r--tests/storage/test_event_push_actions.py8
12 files changed, 457 insertions, 122 deletions
diff --git a/changelog.d/8059.feature b/changelog.d/8059.feature
new file mode 100644
index 0000000000..feb02be234
--- /dev/null
+++ b/changelog.d/8059.feature
@@ -0,0 +1 @@
+Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654).
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 9a86eb01c9..43fc40fc2f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -95,7 +95,12 @@ class TimelineBatch:
     __bool__ = __nonzero__  # python3
 
 
-@attr.s(slots=True, frozen=True)
+# We can't freeze this class, because we need to update it after it's instantiated to
+# update its unread count. This is because we calculate the unread count for a room only
+# if there are updates for it, which we check after the instance has been created.
+# This should not be a big deal because we update the notification counts afterwards as
+# well anyway.
+@attr.s(slots=True)
 class JoinedSyncResult:
     room_id = attr.ib(type=str)
     timeline = attr.ib(type=TimelineBatch)
@@ -104,6 +109,7 @@ class JoinedSyncResult:
     account_data = attr.ib(type=List[JsonDict])
     unread_notifications = attr.ib(type=JsonDict)
     summary = attr.ib(type=Optional[JsonDict])
+    unread_count = attr.ib(type=int)
 
     def __nonzero__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
@@ -931,7 +937,7 @@ class SyncHandler(object):
 
     async def unread_notifs_for_room_id(
         self, room_id: str, sync_config: SyncConfig
-    ) -> Optional[Dict[str, str]]:
+    ) -> Dict[str, int]:
         with Measure(self.clock, "unread_notifs_for_room_id"):
             last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
                 user_id=sync_config.user.to_string(),
@@ -939,15 +945,10 @@ class SyncHandler(object):
                 receipt_type="m.read",
             )
 
-            if last_unread_event_id:
-                notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
-                    room_id, sync_config.user.to_string(), last_unread_event_id
-                )
-                return notifs
-
-        # There is no new information in this period, so your notification
-        # count is whatever it was last time.
-        return None
+            notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
+                room_id, sync_config.user.to_string(), last_unread_event_id
+            )
+            return notifs
 
     async def generate_sync_result(
         self,
@@ -1886,7 +1887,7 @@ class SyncHandler(object):
             )
 
         if room_builder.rtype == "joined":
-            unread_notifications = {}  # type: Dict[str, str]
+            unread_notifications = {}  # type: Dict[str, int]
             room_sync = JoinedSyncResult(
                 room_id=room_id,
                 timeline=batch,
@@ -1895,14 +1896,16 @@ class SyncHandler(object):
                 account_data=account_data_events,
                 unread_notifications=unread_notifications,
                 summary=summary,
+                unread_count=0,
             )
 
             if room_sync or always_include:
                 notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
 
-                if notifs is not None:
-                    unread_notifications["notification_count"] = notifs["notify_count"]
-                    unread_notifications["highlight_count"] = notifs["highlight_count"]
+                unread_notifications["notification_count"] = notifs["notify_count"]
+                unread_notifications["highlight_count"] = notifs["highlight_count"]
+
+                room_sync.unread_count = notifs["unread_count"]
 
                 sync_result_builder.joined.append(room_sync)
 
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index e7fcee0e87..e7fa02b78b 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -19,8 +19,10 @@ from collections import namedtuple
 
 from prometheus_client import Counter
 
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RelationTypes
 from synapse.event_auth import get_user_power_level
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
 from synapse.state import POWER_KEY
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import register_cache
@@ -51,6 +53,48 @@ push_rules_delta_state_cache_metric = register_cache(
 )
 
 
+STATE_EVENT_TYPES_TO_MARK_UNREAD = {
+    EventTypes.Topic,
+    EventTypes.Name,
+    EventTypes.RoomAvatar,
+    EventTypes.Tombstone,
+}
+
+
+def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
+    # Exclude rejected and soft-failed events.
+    if context.rejected or event.internal_metadata.is_soft_failed():
+        return False
+
+    # Exclude notices.
+    if (
+        not event.is_state()
+        and event.type == EventTypes.Message
+        and event.content.get("msgtype") == "m.notice"
+    ):
+        return False
+
+    # Exclude edits.
+    relates_to = event.content.get("m.relates_to", {})
+    if relates_to.get("rel_type") == RelationTypes.REPLACE:
+        return False
+
+    # Mark events that have a non-empty string body as unread.
+    body = event.content.get("body")
+    if isinstance(body, str) and body:
+        return True
+
+    # Mark some state events as unread.
+    if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
+        return True
+
+    # Mark encrypted events as unread.
+    if not event.is_state() and event.type == EventTypes.Encrypted:
+        return True
+
+    return False
+
+
 class BulkPushRuleEvaluator(object):
     """Calculates the outcome of push rules for an event for all users in the
     room at once.
@@ -133,9 +177,12 @@ class BulkPushRuleEvaluator(object):
         return pl_event.content if pl_event else {}, sender_level
 
     async def action_for_event_by_user(self, event, context) -> None:
-        """Given an event and context, evaluate the push rules and insert the
-        results into the event_push_actions_staging table.
+        """Given an event and context, evaluate the push rules, check if the message
+        should increment the unread count, and insert the results into the
+        event_push_actions_staging table.
         """
+        count_as_unread = _should_count_as_unread(event, context)
+
         rules_by_user = await self._get_rules_for_event(event, context)
         actions_by_user = {}
 
@@ -172,6 +219,8 @@ class BulkPushRuleEvaluator(object):
                 if event.type == EventTypes.Member and event.state_key == uid:
                     display_name = event.content.get("displayname", None)
 
+            actions_by_user[uid] = []
+
             for rule in rules:
                 if "enabled" in rule and not rule["enabled"]:
                     continue
@@ -189,7 +238,9 @@ class BulkPushRuleEvaluator(object):
         # Mark in the DB staging area the push actions for users who should be
         # notified for this event. (This will then get handled when we persist
         # the event)
-        await self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
+        await self.store.add_push_actions_to_staging(
+            event.event_id, actions_by_user, count_as_unread,
+        )
 
 
 def _condition_checker(evaluator, conditions, uid, display_name, cache):
@@ -369,8 +420,8 @@ class RulesForRoom(object):
         Args:
             ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
                 updated with any new rules.
-            member_event_ids (list): List of event ids for membership events that
-                have happened since the last time we filled rules_by_user
+            member_event_ids (dict): Dict of user id to event id for membership events
+                that have happened since the last time we filled rules_by_user
             state_group: The state group we are currently computing push rules
                 for. Used when updating the cache.
         """
@@ -390,34 +441,19 @@ class RulesForRoom(object):
         if logger.isEnabledFor(logging.DEBUG):
             logger.debug("Found members %r: %r", self.room_id, members.values())
 
-        interested_in_user_ids = {
+        user_ids = {
             user_id
             for user_id, membership in members.values()
             if membership == Membership.JOIN
         }
 
-        logger.debug("Joined: %r", interested_in_user_ids)
-
-        if_users_with_pushers = await self.store.get_if_users_have_pushers(
-            interested_in_user_ids, on_invalidate=self.invalidate_all_cb
-        )
-
-        user_ids = {
-            uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
-        }
-
-        logger.debug("With pushers: %r", user_ids)
-
-        users_with_receipts = await self.store.get_users_with_read_receipts_in_room(
-            self.room_id, on_invalidate=self.invalidate_all_cb
-        )
-
-        logger.debug("With receipts: %r", users_with_receipts)
+        logger.debug("Joined: %r", user_ids)
 
-        # any users with pushers must be ours: they have pushers
-        for uid in users_with_receipts:
-            if uid in interested_in_user_ids:
-                user_ids.add(uid)
+        # Previously we only considered users with pushers or read receipts in that
+        # room. We can't do this anymore because we use push actions to calculate unread
+        # counts, which don't rely on the user having pushers or sent a read receipt into
+        # the room. Therefore we just need to filter for local users here.
+        user_ids = list(filter(self.is_mine_id, user_ids))
 
         rules_by_user = await self.store.bulk_get_push_rules(
             user_ids, on_invalidate=self.invalidate_all_cb
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index d0145666bf..f7a25571f3 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -36,7 +36,7 @@ async def get_badge_count(store, user_id):
             )
             # return one badge count per conversation, as count per
             # message is so noisy as to be almost useless
-            badge += 1 if notifs["notify_count"] else 0
+            badge += 1 if notifs["unread_count"] else 0
     return badge
 
 
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 96488b131a..a0b00135e1 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -425,6 +425,7 @@ class SyncRestServlet(RestServlet):
             result["ephemeral"] = {"events": ephemeral_events}
             result["unread_notifications"] = room.unread_notifications
             result["summary"] = room.summary
+            result["org.matrix.msc2654.unread_count"] = room.unread_count
 
         return result
 
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 384c39f4d7..001d06378d 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -15,7 +15,9 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Union
+from typing import Dict, List, Optional, Tuple, Union
+
+import attr
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
@@ -88,8 +90,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
     @cached(num_args=3, tree=True, max_entries=5000)
     async def get_unread_event_push_actions_by_room_for_user(
-        self, room_id, user_id, last_read_event_id
-    ):
+        self, room_id: str, user_id: str, last_read_event_id: Optional[str],
+    ) -> Dict[str, int]:
+        """Get the notification count, the highlight count and the unread message count
+        for a given user in a given room after the given read receipt.
+
+        Note that this function assumes the user to be a current member of the room,
+        since it's either called by the sync handler to handle joined room entries, or by
+        the HTTP pusher to calculate the badge of unread joined rooms.
+
+        Args:
+            room_id: The room to retrieve the counts in.
+            user_id: The user to retrieve the counts for.
+            last_read_event_id: The event associated with the latest read receipt for
+                this user in this room. None if no receipt for this user in this room.
+
+        Returns
+            A dict containing the counts mentioned earlier in this docstring,
+            respectively under the keys "notify_count", "highlight_count" and
+            "unread_count".
+        """
         return await self.db_pool.runInteraction(
             "get_unread_event_push_actions_by_room",
             self._get_unread_counts_by_receipt_txn,
@@ -99,69 +119,71 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
 
     def _get_unread_counts_by_receipt_txn(
-        self, txn, room_id, user_id, last_read_event_id
+        self, txn, room_id, user_id, last_read_event_id,
     ):
-        sql = (
-            "SELECT stream_ordering"
-            " FROM events"
-            " WHERE room_id = ? AND event_id = ?"
-        )
-        txn.execute(sql, (room_id, last_read_event_id))
-        results = txn.fetchall()
-        if len(results) == 0:
-            return {"notify_count": 0, "highlight_count": 0}
+        stream_ordering = None
+
+        if last_read_event_id is not None:
+            stream_ordering = self.get_stream_id_for_event_txn(
+                txn, last_read_event_id, allow_none=True,
+            )
+
+        if stream_ordering is None:
+            # Either last_read_event_id is None, or it's an event we don't have (e.g.
+            # because it's been purged), in which case retrieve the stream ordering for
+            # the latest membership event from this user in this room (which we assume is
+            # a join).
+            event_id = self.db_pool.simple_select_one_onecol_txn(
+                txn=txn,
+                table="local_current_membership",
+                keyvalues={"room_id": room_id, "user_id": user_id},
+                retcol="event_id",
+            )
 
-        stream_ordering = results[0][0]
+            stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
 
         return self._get_unread_counts_by_pos_txn(
             txn, room_id, user_id, stream_ordering
         )
 
     def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
-
-        # First get number of notifications.
-        # We don't need to put a notif=1 clause as all rows always have
-        # notif=1
         sql = (
-            "SELECT count(*)"
+            "SELECT"
+            "   COUNT(CASE WHEN notif = 1 THEN 1 END),"
+            "   COUNT(CASE WHEN highlight = 1 THEN 1 END),"
+            "   COUNT(CASE WHEN unread = 1 THEN 1 END)"
             " FROM event_push_actions ea"
-            " WHERE"
-            " user_id = ?"
-            " AND room_id = ?"
-            " AND stream_ordering > ?"
+            " WHERE user_id = ?"
+            "   AND room_id = ?"
+            "   AND stream_ordering > ?"
         )
 
         txn.execute(sql, (user_id, room_id, stream_ordering))
         row = txn.fetchone()
-        notify_count = row[0] if row else 0
+
+        (notif_count, highlight_count, unread_count) = (0, 0, 0)
+
+        if row:
+            (notif_count, highlight_count, unread_count) = row
 
         txn.execute(
             """
-            SELECT notif_count FROM event_push_summary
-            WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
-        """,
+                SELECT notif_count, unread_count FROM event_push_summary
+                WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
+            """,
             (room_id, user_id, stream_ordering),
         )
-        rows = txn.fetchall()
-        if rows:
-            notify_count += rows[0][0]
-
-        # Now get the number of highlights
-        sql = (
-            "SELECT count(*)"
-            " FROM event_push_actions ea"
-            " WHERE"
-            " highlight = 1"
-            " AND user_id = ?"
-            " AND room_id = ?"
-            " AND stream_ordering > ?"
-        )
-
-        txn.execute(sql, (user_id, room_id, stream_ordering))
         row = txn.fetchone()
-        highlight_count = row[0] if row else 0
 
-        return {"notify_count": notify_count, "highlight_count": highlight_count}
+        if row:
+            notif_count += row[0]
+            unread_count += row[1]
+
+        return {
+            "notify_count": notif_count,
+            "unread_count": unread_count,
+            "highlight_count": highlight_count,
+        }
 
     async def get_push_action_users_in_range(
         self, min_stream_ordering, max_stream_ordering
@@ -222,6 +244,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering ASC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -250,6 +273,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering ASC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -324,6 +348,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering DESC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -352,6 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering DESC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -402,7 +428,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         def _get_if_maybe_push_in_range_for_user_txn(txn):
             sql = """
                 SELECT 1 FROM event_push_actions
-                WHERE user_id = ? AND stream_ordering > ?
+                WHERE user_id = ? AND stream_ordering > ? AND notif = 1
                 LIMIT 1
             """
 
@@ -415,7 +441,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
 
     async def add_push_actions_to_staging(
-        self, event_id: str, user_id_actions: Dict[str, List[Union[dict, str]]]
+        self,
+        event_id: str,
+        user_id_actions: Dict[str, List[Union[dict, str]]],
+        count_as_unread: bool,
     ) -> None:
         """Add the push actions for the event to the push action staging area.
 
@@ -423,21 +452,23 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             event_id
             user_id_actions: A mapping of user_id to list of push actions, where
                 an action can either be a string or dict.
+            count_as_unread: Whether this event should increment unread counts.
         """
-
         if not user_id_actions:
             return
 
         # This is a helper function for generating the necessary tuple that
-        # can be used to inert into the `event_push_actions_staging` table.
+        # can be used to insert into the `event_push_actions_staging` table.
         def _gen_entry(user_id, actions):
             is_highlight = 1 if _action_has_highlight(actions) else 0
+            notif = 1 if "notify" in actions else 0
             return (
                 event_id,  # event_id column
                 user_id,  # user_id column
                 _serialize_action(actions, is_highlight),  # actions column
-                1,  # notif column
+                notif,  # notif column
                 is_highlight,  # highlight column
+                int(count_as_unread),  # unread column
             )
 
         def _add_push_actions_to_staging_txn(txn):
@@ -446,8 +477,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
             sql = """
                 INSERT INTO event_push_actions_staging
-                    (event_id, user_id, actions, notif, highlight)
-                VALUES (?, ?, ?, ?, ?)
+                    (event_id, user_id, actions, notif, highlight, unread)
+                VALUES (?, ?, ?, ?, ?, ?)
             """
 
             txn.executemany(
@@ -811,24 +842,63 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         # Calculate the new counts that should be upserted into event_push_summary
         sql = """
             SELECT user_id, room_id,
-                coalesce(old.notif_count, 0) + upd.notif_count,
+                coalesce(old.%s, 0) + upd.cnt,
                 upd.stream_ordering,
                 old.user_id
             FROM (
-                SELECT user_id, room_id, count(*) as notif_count,
+                SELECT user_id, room_id, count(*) as cnt,
                     max(stream_ordering) as stream_ordering
                 FROM event_push_actions
                 WHERE ? <= stream_ordering AND stream_ordering < ?
                     AND highlight = 0
+                    AND %s = 1
                 GROUP BY user_id, room_id
             ) AS upd
             LEFT JOIN event_push_summary AS old USING (user_id, room_id)
         """
 
-        txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
-        rows = txn.fetchall()
+        # First get the count of unread messages.
+        txn.execute(
+            sql % ("unread_count", "unread"),
+            (old_rotate_stream_ordering, rotate_to_stream_ordering),
+        )
+
+        # We need to merge results from the two requests (the one that retrieves the
+        # unread count and the one that retrieves the notifications count) into a single
+        # object because we might not have the same amount of rows in each of them. To do
+        # this, we use a dict indexed on the user ID and room ID to make it easier to
+        # populate.
+        summaries = {}  # type: Dict[Tuple[str, str], _EventPushSummary]
+        for row in txn:
+            summaries[(row[0], row[1])] = _EventPushSummary(
+                unread_count=row[2],
+                stream_ordering=row[3],
+                old_user_id=row[4],
+                notif_count=0,
+            )
+
+        # Then get the count of notifications.
+        txn.execute(
+            sql % ("notif_count", "notif"),
+            (old_rotate_stream_ordering, rotate_to_stream_ordering),
+        )
+
+        for row in txn:
+            if (row[0], row[1]) in summaries:
+                summaries[(row[0], row[1])].notif_count = row[2]
+            else:
+                # Because the rules on notifying are different than the rules on marking
+                # a message unread, we might end up with messages that notify but aren't
+                # marked unread, so we might not have a summary for this (user, room)
+                # tuple to complete.
+                summaries[(row[0], row[1])] = _EventPushSummary(
+                    unread_count=0,
+                    stream_ordering=row[3],
+                    old_user_id=row[4],
+                    notif_count=row[2],
+                )
 
-        logger.info("Rotating notifications, handling %d rows", len(rows))
+        logger.info("Rotating notifications, handling %d rows", len(summaries))
 
         # If the `old.user_id` above is NULL then we know there isn't already an
         # entry in the table, so we simply insert it. Otherwise we update the
@@ -838,22 +908,34 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             table="event_push_summary",
             values=[
                 {
-                    "user_id": row[0],
-                    "room_id": row[1],
-                    "notif_count": row[2],
-                    "stream_ordering": row[3],
+                    "user_id": user_id,
+                    "room_id": room_id,
+                    "notif_count": summary.notif_count,
+                    "unread_count": summary.unread_count,
+                    "stream_ordering": summary.stream_ordering,
                 }
-                for row in rows
-                if row[4] is None
+                for ((user_id, room_id), summary) in summaries.items()
+                if summary.old_user_id is None
             ],
         )
 
         txn.executemany(
             """
-                UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
+                UPDATE event_push_summary
+                SET notif_count = ?, unread_count = ?, stream_ordering = ?
                 WHERE user_id = ? AND room_id = ?
             """,
-            ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
+            (
+                (
+                    summary.notif_count,
+                    summary.unread_count,
+                    summary.stream_ordering,
+                    user_id,
+                    room_id,
+                )
+                for ((user_id, room_id), summary) in summaries.items()
+                if summary.old_user_id is not None
+            ),
         )
 
         txn.execute(
@@ -879,3 +961,15 @@ def _action_has_highlight(actions):
             pass
 
     return False
+
+
+@attr.s
+class _EventPushSummary:
+    """Summary of pending event push actions for a given user in a given room.
+    Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
+    """
+
+    unread_count = attr.ib(type=int)
+    stream_ordering = attr.ib(type=int)
+    old_user_id = attr.ib(type=str)
+    notif_count = attr.ib(type=int)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 46b11e705b..b94fe7ac17 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1298,9 +1298,9 @@ class PersistEventsStore:
         sql = """
             INSERT INTO event_push_actions (
                 room_id, event_id, user_id, actions, stream_ordering,
-                topological_ordering, notif, highlight
+                topological_ordering, notif, highlight, unread
             )
-            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
             FROM event_push_actions_staging
             WHERE event_id = ?
         """
diff --git a/synapse/storage/databases/main/schema/delta/58/15unread_count.sql b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
new file mode 100644
index 0000000000..b451e8663a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
@@ -0,0 +1,26 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We're hijacking the push actions to store unread messages and unread counts (specified
+-- in MSC2654) because doing otherwise would result in either performance issues or
+-- reimplementing a consequent bit of the push actions.
+
+-- Add columns to event_push_actions and event_push_actions_staging to track unread
+-- messages and calculate unread counts.
+ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
+ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
+
+-- Add column to event_push_summary
+ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT NOT NULL DEFAULT 0;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 24f44a7e36..83c1ddf95a 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -47,7 +47,11 @@ from synapse.api.filtering import Filter
 from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingTransaction,
+    make_in_list_sql_clause,
+)
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
 from synapse.types import RoomStreamToken
@@ -593,8 +597,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             A stream ID.
         """
-        return await self.db_pool.simple_select_one_onecol(
-            table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
+        return await self.db_pool.runInteraction(
+            "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
+        )
+
+    def get_stream_id_for_event_txn(
+        self, txn: LoggingTransaction, event_id: str, allow_none=False,
+    ) -> int:
+        return self.db_pool.simple_select_one_onecol_txn(
+            txn=txn,
+            table="events",
+            keyvalues={"event_id": event_id},
+            retcol="stream_ordering",
+            allow_none=allow_none,
         )
 
     async def get_stream_token_for_event(self, event_id: str) -> str:
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 0b5204654c..561258a356 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -160,7 +160,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
-            {"highlight_count": 0, "notify_count": 0},
+            {"highlight_count": 0, "unread_count": 0, "notify_count": 0},
         )
 
         self.persist(
@@ -173,7 +173,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
-            {"highlight_count": 0, "notify_count": 1},
+            {"highlight_count": 0, "unread_count": 0, "notify_count": 1},
         )
 
         self.persist(
@@ -188,7 +188,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
-            {"highlight_count": 1, "notify_count": 2},
+            {"highlight_count": 1, "unread_count": 0, "notify_count": 2},
         )
 
     def test_get_rooms_for_user_with_stream_ordering(self):
@@ -368,7 +368,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
         self.get_success(
             self.master_store.add_push_actions_to_staging(
-                event.event_id, {user_id: actions for user_id, actions in push_actions}
+                event.event_id,
+                {user_id: actions for user_id, actions in push_actions},
+                False,
             )
         )
         return event, context
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index fa3a3ec1bd..a31e44c97e 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -16,9 +16,9 @@
 import json
 
 import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client.v2_alpha import read_marker, sync
 
 from tests import unittest
 from tests.server import TimedOutException
@@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase):
             "GET", sync_url % (access_token, next_batch)
         )
         self.assertRaises(TimedOutException, self.render, request)
+
+
+class UnreadMessagesTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        read_marker.register_servlets,
+        room.register_servlets,
+        sync.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.url = "/sync?since=%s"
+        self.next_batch = "s0"
+
+        # Register the first user (used to check the unread counts).
+        self.user_id = self.register_user("kermit", "monkey")
+        self.tok = self.login("kermit", "monkey")
+
+        # Create the room we'll check unread counts for.
+        self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+        # Register the second user (used to send events to the room).
+        self.user2 = self.register_user("kermit2", "monkey")
+        self.tok2 = self.login("kermit2", "monkey")
+
+        # Change the power levels of the room so that the second user can send state
+        # events.
+        self.helper.send_state(
+            self.room_id,
+            EventTypes.PowerLevels,
+            {
+                "users": {self.user_id: 100, self.user2: 100},
+                "users_default": 0,
+                "events": {
+                    "m.room.name": 50,
+                    "m.room.power_levels": 100,
+                    "m.room.history_visibility": 100,
+                    "m.room.canonical_alias": 50,
+                    "m.room.avatar": 50,
+                    "m.room.tombstone": 100,
+                    "m.room.server_acl": 100,
+                    "m.room.encryption": 100,
+                },
+                "events_default": 0,
+                "state_default": 50,
+                "ban": 50,
+                "kick": 50,
+                "redact": 50,
+                "invite": 0,
+            },
+            tok=self.tok,
+        )
+
+    def test_unread_counts(self):
+        """Tests that /sync returns the right value for the unread count (MSC2654)."""
+
+        # Check that our own messages don't increase the unread count.
+        self.helper.send(self.room_id, "hello", tok=self.tok)
+        self._check_unread_count(0)
+
+        # Join the new user and check that this doesn't increase the unread count.
+        self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
+        self._check_unread_count(0)
+
+        # Check that the new user sending a message increases our unread count.
+        res = self.helper.send(self.room_id, "hello", tok=self.tok2)
+        self._check_unread_count(1)
+
+        # Send a read receipt to tell the server we've read the latest event.
+        body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
+        request, channel = self.make_request(
+            "POST",
+            "/rooms/%s/read_markers" % self.room_id,
+            body,
+            access_token=self.tok,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that the unread counter is back to 0.
+        self._check_unread_count(0)
+
+        # Check that room name changes increase the unread counter.
+        self.helper.send_state(
+            self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
+        )
+        self._check_unread_count(1)
+
+        # Check that room topic changes increase the unread counter.
+        self.helper.send_state(
+            self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
+        )
+        self._check_unread_count(2)
+
+        # Check that encrypted messages increase the unread counter.
+        self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2)
+        self._check_unread_count(3)
+
+        # Check that custom events with a body increase the unread counter.
+        self.helper.send_event(
+            self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
+        )
+        self._check_unread_count(4)
+
+        # Check that edits don't increase the unread counter.
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "body": "hello",
+                "msgtype": "m.text",
+                "m.relates_to": {"rel_type": RelationTypes.REPLACE},
+            },
+            tok=self.tok2,
+        )
+        self._check_unread_count(4)
+
+        # Check that notices don't increase the unread counter.
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={"body": "hello", "msgtype": "m.notice"},
+            tok=self.tok2,
+        )
+        self._check_unread_count(4)
+
+        # Check that tombstone events changes increase the unread counter.
+        self.helper.send_state(
+            self.room_id,
+            EventTypes.Tombstone,
+            {"replacement_room": "!someroom:test"},
+            tok=self.tok2,
+        )
+        self._check_unread_count(5)
+
+    def _check_unread_count(self, expected_count: True):
+        """Syncs and compares the unread count with the expected value."""
+
+        request, channel = self.make_request(
+            "GET", self.url % self.next_batch, access_token=self.tok,
+        )
+        self.render(request)
+
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        room_entry = channel.json_body["rooms"]["join"][self.room_id]
+        self.assertEqual(
+            room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
+        )
+
+        # Store the next batch for the next request.
+        self.next_batch = channel.json_body["next_batch"]
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index cdfd2634aa..c0595963dd 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -67,7 +67,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             )
             self.assertEquals(
                 counts,
-                {"notify_count": noitf_count, "highlight_count": highlight_count},
+                {
+                    "notify_count": noitf_count,
+                    "unread_count": 0,  # Unread counts are tested in the sync tests.
+                    "highlight_count": highlight_count,
+                },
             )
 
         @defer.inlineCallbacks
@@ -80,7 +84,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
 
             yield defer.ensureDeferred(
                 self.store.add_push_actions_to_staging(
-                    event.event_id, {user_id: action}
+                    event.event_id, {user_id: action}, False,
                 )
             )
             yield defer.ensureDeferred(