diff --git a/changelog.d/8059.feature b/changelog.d/8059.feature
new file mode 100644
index 0000000000..feb02be234
--- /dev/null
+++ b/changelog.d/8059.feature
@@ -0,0 +1 @@
+Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654).
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 9a86eb01c9..43fc40fc2f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -95,7 +95,12 @@ class TimelineBatch:
__bool__ = __nonzero__ # python3
-@attr.s(slots=True, frozen=True)
+# We can't freeze this class, because we need to update it after it's instantiated to
+# update its unread count. This is because we calculate the unread count for a room only
+# if there are updates for it, which we check after the instance has been created.
+# This should not be a big deal because we update the notification counts afterwards as
+# well anyway.
+@attr.s(slots=True)
class JoinedSyncResult:
room_id = attr.ib(type=str)
timeline = attr.ib(type=TimelineBatch)
@@ -104,6 +109,7 @@ class JoinedSyncResult:
account_data = attr.ib(type=List[JsonDict])
unread_notifications = attr.ib(type=JsonDict)
summary = attr.ib(type=Optional[JsonDict])
+ unread_count = attr.ib(type=int)
def __nonzero__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -931,7 +937,7 @@ class SyncHandler(object):
async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig
- ) -> Optional[Dict[str, str]]:
+ ) -> Dict[str, int]:
with Measure(self.clock, "unread_notifs_for_room_id"):
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
@@ -939,15 +945,10 @@ class SyncHandler(object):
receipt_type="m.read",
)
- if last_unread_event_id:
- notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
- room_id, sync_config.user.to_string(), last_unread_event_id
- )
- return notifs
-
- # There is no new information in this period, so your notification
- # count is whatever it was last time.
- return None
+ notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
+ room_id, sync_config.user.to_string(), last_unread_event_id
+ )
+ return notifs
async def generate_sync_result(
self,
@@ -1886,7 +1887,7 @@ class SyncHandler(object):
)
if room_builder.rtype == "joined":
- unread_notifications = {} # type: Dict[str, str]
+ unread_notifications = {} # type: Dict[str, int]
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,
@@ -1895,14 +1896,16 @@ class SyncHandler(object):
account_data=account_data_events,
unread_notifications=unread_notifications,
summary=summary,
+ unread_count=0,
)
if room_sync or always_include:
notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
- if notifs is not None:
- unread_notifications["notification_count"] = notifs["notify_count"]
- unread_notifications["highlight_count"] = notifs["highlight_count"]
+ unread_notifications["notification_count"] = notifs["notify_count"]
+ unread_notifications["highlight_count"] = notifs["highlight_count"]
+
+ room_sync.unread_count = notifs["unread_count"]
sync_result_builder.joined.append(room_sync)
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index e7fcee0e87..e7fa02b78b 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -19,8 +19,10 @@ from collections import namedtuple
from prometheus_client import Counter
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RelationTypes
from synapse.event_auth import get_user_power_level
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import register_cache
@@ -51,6 +53,48 @@ push_rules_delta_state_cache_metric = register_cache(
)
+STATE_EVENT_TYPES_TO_MARK_UNREAD = {
+ EventTypes.Topic,
+ EventTypes.Name,
+ EventTypes.RoomAvatar,
+ EventTypes.Tombstone,
+}
+
+
+def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
+ # Exclude rejected and soft-failed events.
+ if context.rejected or event.internal_metadata.is_soft_failed():
+ return False
+
+ # Exclude notices.
+ if (
+ not event.is_state()
+ and event.type == EventTypes.Message
+ and event.content.get("msgtype") == "m.notice"
+ ):
+ return False
+
+ # Exclude edits.
+ relates_to = event.content.get("m.relates_to", {})
+ if relates_to.get("rel_type") == RelationTypes.REPLACE:
+ return False
+
+ # Mark events that have a non-empty string body as unread.
+ body = event.content.get("body")
+ if isinstance(body, str) and body:
+ return True
+
+ # Mark some state events as unread.
+ if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
+ return True
+
+ # Mark encrypted events as unread.
+ if not event.is_state() and event.type == EventTypes.Encrypted:
+ return True
+
+ return False
+
+
class BulkPushRuleEvaluator(object):
"""Calculates the outcome of push rules for an event for all users in the
room at once.
@@ -133,9 +177,12 @@ class BulkPushRuleEvaluator(object):
return pl_event.content if pl_event else {}, sender_level
async def action_for_event_by_user(self, event, context) -> None:
- """Given an event and context, evaluate the push rules and insert the
- results into the event_push_actions_staging table.
+ """Given an event and context, evaluate the push rules, check if the message
+ should increment the unread count, and insert the results into the
+ event_push_actions_staging table.
"""
+ count_as_unread = _should_count_as_unread(event, context)
+
rules_by_user = await self._get_rules_for_event(event, context)
actions_by_user = {}
@@ -172,6 +219,8 @@ class BulkPushRuleEvaluator(object):
if event.type == EventTypes.Member and event.state_key == uid:
display_name = event.content.get("displayname", None)
+ actions_by_user[uid] = []
+
for rule in rules:
if "enabled" in rule and not rule["enabled"]:
continue
@@ -189,7 +238,9 @@ class BulkPushRuleEvaluator(object):
# Mark in the DB staging area the push actions for users who should be
# notified for this event. (This will then get handled when we persist
# the event)
- await self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
+ await self.store.add_push_actions_to_staging(
+ event.event_id, actions_by_user, count_as_unread,
+ )
def _condition_checker(evaluator, conditions, uid, display_name, cache):
@@ -369,8 +420,8 @@ class RulesForRoom(object):
Args:
ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
updated with any new rules.
- member_event_ids (list): List of event ids for membership events that
- have happened since the last time we filled rules_by_user
+ member_event_ids (dict): Dict of user id to event id for membership events
+ that have happened since the last time we filled rules_by_user
state_group: The state group we are currently computing push rules
for. Used when updating the cache.
"""
@@ -390,34 +441,19 @@ class RulesForRoom(object):
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
- interested_in_user_ids = {
+ user_ids = {
user_id
for user_id, membership in members.values()
if membership == Membership.JOIN
}
- logger.debug("Joined: %r", interested_in_user_ids)
-
- if_users_with_pushers = await self.store.get_if_users_have_pushers(
- interested_in_user_ids, on_invalidate=self.invalidate_all_cb
- )
-
- user_ids = {
- uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
- }
-
- logger.debug("With pushers: %r", user_ids)
-
- users_with_receipts = await self.store.get_users_with_read_receipts_in_room(
- self.room_id, on_invalidate=self.invalidate_all_cb
- )
-
- logger.debug("With receipts: %r", users_with_receipts)
+ logger.debug("Joined: %r", user_ids)
- # any users with pushers must be ours: they have pushers
- for uid in users_with_receipts:
- if uid in interested_in_user_ids:
- user_ids.add(uid)
+ # Previously we only considered users with pushers or read receipts in that
+ # room. We can't do this anymore because we use push actions to calculate unread
+ # counts, which don't rely on the user having pushers or sent a read receipt into
+ # the room. Therefore we just need to filter for local users here.
+ user_ids = list(filter(self.is_mine_id, user_ids))
rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index d0145666bf..f7a25571f3 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -36,7 +36,7 @@ async def get_badge_count(store, user_id):
)
# return one badge count per conversation, as count per
# message is so noisy as to be almost useless
- badge += 1 if notifs["notify_count"] else 0
+ badge += 1 if notifs["unread_count"] else 0
return badge
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 96488b131a..a0b00135e1 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -425,6 +425,7 @@ class SyncRestServlet(RestServlet):
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary
+ result["org.matrix.msc2654.unread_count"] = room.unread_count
return result
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 384c39f4d7..001d06378d 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -15,7 +15,9 @@
# limitations under the License.
import logging
-from typing import Dict, List, Union
+from typing import Dict, List, Optional, Tuple, Union
+
+import attr
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
@@ -88,8 +90,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
@cached(num_args=3, tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
- self, room_id, user_id, last_read_event_id
- ):
+ self, room_id: str, user_id: str, last_read_event_id: Optional[str],
+ ) -> Dict[str, int]:
+ """Get the notification count, the highlight count and the unread message count
+ for a given user in a given room after the given read receipt.
+
+ Note that this function assumes the user to be a current member of the room,
+ since it's either called by the sync handler to handle joined room entries, or by
+ the HTTP pusher to calculate the badge of unread joined rooms.
+
+ Args:
+ room_id: The room to retrieve the counts in.
+ user_id: The user to retrieve the counts for.
+ last_read_event_id: The event associated with the latest read receipt for
+ this user in this room. None if no receipt for this user in this room.
+
+ Returns
+ A dict containing the counts mentioned earlier in this docstring,
+ respectively under the keys "notify_count", "highlight_count" and
+ "unread_count".
+ """
return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
@@ -99,69 +119,71 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
def _get_unread_counts_by_receipt_txn(
- self, txn, room_id, user_id, last_read_event_id
+ self, txn, room_id, user_id, last_read_event_id,
):
- sql = (
- "SELECT stream_ordering"
- " FROM events"
- " WHERE room_id = ? AND event_id = ?"
- )
- txn.execute(sql, (room_id, last_read_event_id))
- results = txn.fetchall()
- if len(results) == 0:
- return {"notify_count": 0, "highlight_count": 0}
+ stream_ordering = None
+
+ if last_read_event_id is not None:
+ stream_ordering = self.get_stream_id_for_event_txn(
+ txn, last_read_event_id, allow_none=True,
+ )
+
+ if stream_ordering is None:
+ # Either last_read_event_id is None, or it's an event we don't have (e.g.
+ # because it's been purged), in which case retrieve the stream ordering for
+ # the latest membership event from this user in this room (which we assume is
+ # a join).
+ event_id = self.db_pool.simple_select_one_onecol_txn(
+ txn=txn,
+ table="local_current_membership",
+ keyvalues={"room_id": room_id, "user_id": user_id},
+ retcol="event_id",
+ )
- stream_ordering = results[0][0]
+ stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
return self._get_unread_counts_by_pos_txn(
txn, room_id, user_id, stream_ordering
)
def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
-
- # First get number of notifications.
- # We don't need to put a notif=1 clause as all rows always have
- # notif=1
sql = (
- "SELECT count(*)"
+ "SELECT"
+ " COUNT(CASE WHEN notif = 1 THEN 1 END),"
+ " COUNT(CASE WHEN highlight = 1 THEN 1 END),"
+ " COUNT(CASE WHEN unread = 1 THEN 1 END)"
" FROM event_push_actions ea"
- " WHERE"
- " user_id = ?"
- " AND room_id = ?"
- " AND stream_ordering > ?"
+ " WHERE user_id = ?"
+ " AND room_id = ?"
+ " AND stream_ordering > ?"
)
txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
- notify_count = row[0] if row else 0
+
+ (notif_count, highlight_count, unread_count) = (0, 0, 0)
+
+ if row:
+ (notif_count, highlight_count, unread_count) = row
txn.execute(
"""
- SELECT notif_count FROM event_push_summary
- WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
- """,
+ SELECT notif_count, unread_count FROM event_push_summary
+ WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
+ """,
(room_id, user_id, stream_ordering),
)
- rows = txn.fetchall()
- if rows:
- notify_count += rows[0][0]
-
- # Now get the number of highlights
- sql = (
- "SELECT count(*)"
- " FROM event_push_actions ea"
- " WHERE"
- " highlight = 1"
- " AND user_id = ?"
- " AND room_id = ?"
- " AND stream_ordering > ?"
- )
-
- txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
- highlight_count = row[0] if row else 0
- return {"notify_count": notify_count, "highlight_count": highlight_count}
+ if row:
+ notif_count += row[0]
+ unread_count += row[1]
+
+ return {
+ "notify_count": notif_count,
+ "unread_count": unread_count,
+ "highlight_count": highlight_count,
+ }
async def get_push_action_users_in_range(
self, min_stream_ordering, max_stream_ordering
@@ -222,6 +244,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -250,6 +273,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -324,6 +348,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -352,6 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -402,7 +428,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _get_if_maybe_push_in_range_for_user_txn(txn):
sql = """
SELECT 1 FROM event_push_actions
- WHERE user_id = ? AND stream_ordering > ?
+ WHERE user_id = ? AND stream_ordering > ? AND notif = 1
LIMIT 1
"""
@@ -415,7 +441,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
async def add_push_actions_to_staging(
- self, event_id: str, user_id_actions: Dict[str, List[Union[dict, str]]]
+ self,
+ event_id: str,
+ user_id_actions: Dict[str, List[Union[dict, str]]],
+ count_as_unread: bool,
) -> None:
"""Add the push actions for the event to the push action staging area.
@@ -423,21 +452,23 @@ class EventPushActionsWorkerStore(SQLBaseStore):
event_id
user_id_actions: A mapping of user_id to list of push actions, where
an action can either be a string or dict.
+ count_as_unread: Whether this event should increment unread counts.
"""
-
if not user_id_actions:
return
# This is a helper function for generating the necessary tuple that
- # can be used to inert into the `event_push_actions_staging` table.
+ # can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(user_id, actions):
is_highlight = 1 if _action_has_highlight(actions) else 0
+ notif = 1 if "notify" in actions else 0
return (
event_id, # event_id column
user_id, # user_id column
_serialize_action(actions, is_highlight), # actions column
- 1, # notif column
+ notif, # notif column
is_highlight, # highlight column
+ int(count_as_unread), # unread column
)
def _add_push_actions_to_staging_txn(txn):
@@ -446,8 +477,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
sql = """
INSERT INTO event_push_actions_staging
- (event_id, user_id, actions, notif, highlight)
- VALUES (?, ?, ?, ?, ?)
+ (event_id, user_id, actions, notif, highlight, unread)
+ VALUES (?, ?, ?, ?, ?, ?)
"""
txn.executemany(
@@ -811,24 +842,63 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# Calculate the new counts that should be upserted into event_push_summary
sql = """
SELECT user_id, room_id,
- coalesce(old.notif_count, 0) + upd.notif_count,
+ coalesce(old.%s, 0) + upd.cnt,
upd.stream_ordering,
old.user_id
FROM (
- SELECT user_id, room_id, count(*) as notif_count,
+ SELECT user_id, room_id, count(*) as cnt,
max(stream_ordering) as stream_ordering
FROM event_push_actions
WHERE ? <= stream_ordering AND stream_ordering < ?
AND highlight = 0
+ AND %s = 1
GROUP BY user_id, room_id
) AS upd
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
"""
- txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
- rows = txn.fetchall()
+ # First get the count of unread messages.
+ txn.execute(
+ sql % ("unread_count", "unread"),
+ (old_rotate_stream_ordering, rotate_to_stream_ordering),
+ )
+
+ # We need to merge results from the two requests (the one that retrieves the
+ # unread count and the one that retrieves the notifications count) into a single
+ # object because we might not have the same amount of rows in each of them. To do
+ # this, we use a dict indexed on the user ID and room ID to make it easier to
+ # populate.
+ summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary]
+ for row in txn:
+ summaries[(row[0], row[1])] = _EventPushSummary(
+ unread_count=row[2],
+ stream_ordering=row[3],
+ old_user_id=row[4],
+ notif_count=0,
+ )
+
+ # Then get the count of notifications.
+ txn.execute(
+ sql % ("notif_count", "notif"),
+ (old_rotate_stream_ordering, rotate_to_stream_ordering),
+ )
+
+ for row in txn:
+ if (row[0], row[1]) in summaries:
+ summaries[(row[0], row[1])].notif_count = row[2]
+ else:
+ # Because the rules on notifying are different than the rules on marking
+ # a message unread, we might end up with messages that notify but aren't
+ # marked unread, so we might not have a summary for this (user, room)
+ # tuple to complete.
+ summaries[(row[0], row[1])] = _EventPushSummary(
+ unread_count=0,
+ stream_ordering=row[3],
+ old_user_id=row[4],
+ notif_count=row[2],
+ )
- logger.info("Rotating notifications, handling %d rows", len(rows))
+ logger.info("Rotating notifications, handling %d rows", len(summaries))
# If the `old.user_id` above is NULL then we know there isn't already an
# entry in the table, so we simply insert it. Otherwise we update the
@@ -838,22 +908,34 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
table="event_push_summary",
values=[
{
- "user_id": row[0],
- "room_id": row[1],
- "notif_count": row[2],
- "stream_ordering": row[3],
+ "user_id": user_id,
+ "room_id": room_id,
+ "notif_count": summary.notif_count,
+ "unread_count": summary.unread_count,
+ "stream_ordering": summary.stream_ordering,
}
- for row in rows
- if row[4] is None
+ for ((user_id, room_id), summary) in summaries.items()
+ if summary.old_user_id is None
],
)
txn.executemany(
"""
- UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
+ UPDATE event_push_summary
+ SET notif_count = ?, unread_count = ?, stream_ordering = ?
WHERE user_id = ? AND room_id = ?
""",
- ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
+ (
+ (
+ summary.notif_count,
+ summary.unread_count,
+ summary.stream_ordering,
+ user_id,
+ room_id,
+ )
+ for ((user_id, room_id), summary) in summaries.items()
+ if summary.old_user_id is not None
+ ),
)
txn.execute(
@@ -879,3 +961,15 @@ def _action_has_highlight(actions):
pass
return False
+
+
+@attr.s
+class _EventPushSummary:
+ """Summary of pending event push actions for a given user in a given room.
+ Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
+ """
+
+ unread_count = attr.ib(type=int)
+ stream_ordering = attr.ib(type=int)
+ old_user_id = attr.ib(type=str)
+ notif_count = attr.ib(type=int)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 46b11e705b..b94fe7ac17 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1298,9 +1298,9 @@ class PersistEventsStore:
sql = """
INSERT INTO event_push_actions (
room_id, event_id, user_id, actions, stream_ordering,
- topological_ordering, notif, highlight
+ topological_ordering, notif, highlight, unread
)
- SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+ SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
FROM event_push_actions_staging
WHERE event_id = ?
"""
diff --git a/synapse/storage/databases/main/schema/delta/58/15unread_count.sql b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
new file mode 100644
index 0000000000..b451e8663a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
@@ -0,0 +1,26 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We're hijacking the push actions to store unread messages and unread counts (specified
+-- in MSC2654) because doing otherwise would result in either performance issues or
+-- reimplementing a consequent bit of the push actions.
+
+-- Add columns to event_push_actions and event_push_actions_staging to track unread
+-- messages and calculate unread counts.
+ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
+ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
+
+-- Add column to event_push_summary
+ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT NOT NULL DEFAULT 0;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 24f44a7e36..83c1ddf95a 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -47,7 +47,11 @@ from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.types import RoomStreamToken
@@ -593,8 +597,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A stream ID.
"""
- return await self.db_pool.simple_select_one_onecol(
- table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
+ return await self.db_pool.runInteraction(
+ "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
+ )
+
+ def get_stream_id_for_event_txn(
+ self, txn: LoggingTransaction, event_id: str, allow_none=False,
+ ) -> int:
+ return self.db_pool.simple_select_one_onecol_txn(
+ txn=txn,
+ table="events",
+ keyvalues={"event_id": event_id},
+ retcol="stream_ordering",
+ allow_none=allow_none,
)
async def get_stream_token_for_event(self, event_id: str) -> str:
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 0b5204654c..561258a356 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -160,7 +160,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 0, "notify_count": 0},
+ {"highlight_count": 0, "unread_count": 0, "notify_count": 0},
)
self.persist(
@@ -173,7 +173,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 0, "notify_count": 1},
+ {"highlight_count": 0, "unread_count": 0, "notify_count": 1},
)
self.persist(
@@ -188,7 +188,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 1, "notify_count": 2},
+ {"highlight_count": 1, "unread_count": 0, "notify_count": 2},
)
def test_get_rooms_for_user_with_stream_ordering(self):
@@ -368,7 +368,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.get_success(
self.master_store.add_push_actions_to_staging(
- event.event_id, {user_id: actions for user_id, actions in push_actions}
+ event.event_id,
+ {user_id: actions for user_id, actions in push_actions},
+ False,
)
)
return event, context
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index fa3a3ec1bd..a31e44c97e 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -16,9 +16,9 @@
import json
import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client.v2_alpha import read_marker, sync
from tests import unittest
from tests.server import TimedOutException
@@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase):
"GET", sync_url % (access_token, next_batch)
)
self.assertRaises(TimedOutException, self.render, request)
+
+
+class UnreadMessagesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ read_marker.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.url = "/sync?since=%s"
+ self.next_batch = "s0"
+
+ # Register the first user (used to check the unread counts).
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ # Create the room we'll check unread counts for.
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ # Register the second user (used to send events to the room).
+ self.user2 = self.register_user("kermit2", "monkey")
+ self.tok2 = self.login("kermit2", "monkey")
+
+ # Change the power levels of the room so that the second user can send state
+ # events.
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.PowerLevels,
+ {
+ "users": {self.user_id: 100, self.user2: 100},
+ "users_default": 0,
+ "events": {
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ "m.room.history_visibility": 100,
+ "m.room.canonical_alias": 50,
+ "m.room.avatar": 50,
+ "m.room.tombstone": 100,
+ "m.room.server_acl": 100,
+ "m.room.encryption": 100,
+ },
+ "events_default": 0,
+ "state_default": 50,
+ "ban": 50,
+ "kick": 50,
+ "redact": 50,
+ "invite": 0,
+ },
+ tok=self.tok,
+ )
+
+ def test_unread_counts(self):
+ """Tests that /sync returns the right value for the unread count (MSC2654)."""
+
+ # Check that our own messages don't increase the unread count.
+ self.helper.send(self.room_id, "hello", tok=self.tok)
+ self._check_unread_count(0)
+
+ # Join the new user and check that this doesn't increase the unread count.
+ self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
+ self._check_unread_count(0)
+
+ # Check that the new user sending a message increases our unread count.
+ res = self.helper.send(self.room_id, "hello", tok=self.tok2)
+ self._check_unread_count(1)
+
+ # Send a read receipt to tell the server we've read the latest event.
+ body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/read_markers" % self.room_id,
+ body,
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that the unread counter is back to 0.
+ self._check_unread_count(0)
+
+ # Check that room name changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
+ )
+ self._check_unread_count(1)
+
+ # Check that room topic changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
+ )
+ self._check_unread_count(2)
+
+ # Check that encrypted messages increase the unread counter.
+ self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2)
+ self._check_unread_count(3)
+
+ # Check that custom events with a body increase the unread counter.
+ self.helper.send_event(
+ self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that edits don't increase the unread counter.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "body": "hello",
+ "msgtype": "m.text",
+ "m.relates_to": {"rel_type": RelationTypes.REPLACE},
+ },
+ tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that notices don't increase the unread counter.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"body": "hello", "msgtype": "m.notice"},
+ tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that tombstone events changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.Tombstone,
+ {"replacement_room": "!someroom:test"},
+ tok=self.tok2,
+ )
+ self._check_unread_count(5)
+
+ def _check_unread_count(self, expected_count: True):
+ """Syncs and compares the unread count with the expected value."""
+
+ request, channel = self.make_request(
+ "GET", self.url % self.next_batch, access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ room_entry = channel.json_body["rooms"]["join"][self.room_id]
+ self.assertEqual(
+ room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
+ )
+
+ # Store the next batch for the next request.
+ self.next_batch = channel.json_body["next_batch"]
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index cdfd2634aa..c0595963dd 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -67,7 +67,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
self.assertEquals(
counts,
- {"notify_count": noitf_count, "highlight_count": highlight_count},
+ {
+ "notify_count": noitf_count,
+ "unread_count": 0, # Unread counts are tested in the sync tests.
+ "highlight_count": highlight_count,
+ },
)
@defer.inlineCallbacks
@@ -80,7 +84,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield defer.ensureDeferred(
self.store.add_push_actions_to_staging(
- event.event_id, {user_id: action}
+ event.event_id, {user_id: action}, False,
)
)
yield defer.ensureDeferred(
|