diff --git a/changelog.d/7736.feature b/changelog.d/7736.feature
new file mode 100644
index 0000000000..c97864677a
--- /dev/null
+++ b/changelog.d/7736.feature
@@ -0,0 +1 @@
+Add unread messages count to sync responses.
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 22a6abd7d2..bee525197f 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -69,7 +69,7 @@ logger = logging.getLogger("synapse_port_db")
BOOLEAN_COLUMNS = {
- "events": ["processed", "outlier", "contains_url"],
+ "events": ["processed", "outlier", "contains_url", "count_as_unread"],
"rooms": ["is_public"],
"event_edges": ["is_state"],
"presence_list": ["accepted"],
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index ebd3e98105..eaa4eeadf7 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -103,6 +103,7 @@ class JoinedSyncResult:
account_data = attr.ib(type=List[JsonDict])
unread_notifications = attr.ib(type=JsonDict)
summary = attr.ib(type=Optional[JsonDict])
+ unread_count = attr.ib(type=int)
def __nonzero__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -1886,6 +1887,10 @@ class SyncHandler(object):
if room_builder.rtype == "joined":
unread_notifications = {} # type: Dict[str, str]
+
+ unread_count = await self.store.get_unread_message_count_for_user(
+ room_id, sync_config.user.to_string(),
+ )
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,
@@ -1894,6 +1899,7 @@ class SyncHandler(object):
account_data=account_data_events,
unread_notifications=unread_notifications,
summary=summary,
+ unread_count=unread_count,
)
if room_sync or always_include:
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index d0145666bf..bc8f71916b 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -21,22 +21,13 @@ async def get_badge_count(store, user_id):
invites = await store.get_invited_rooms_for_local_user(user_id)
joins = await store.get_rooms_for_user(user_id)
- my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read")
-
badge = len(invites)
for room_id in joins:
- if room_id in my_receipts_by_room:
- last_unread_event_id = my_receipts_by_room[room_id]
-
- notifs = await (
- store.get_unread_event_push_actions_by_room_for_user(
- room_id, user_id, last_unread_event_id
- )
- )
- # return one badge count per conversation, as count per
- # message is so noisy as to be almost useless
- badge += 1 if notifs["notify_count"] else 0
+ unread_count = await store.get_unread_message_count_for_user(room_id, user_id)
+ # return one badge count per conversation, as count per
+ # message is so noisy as to be almost useless
+ badge += 1 if unread_count else 0
return badge
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index a5c24fbd63..3f5bf75e59 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -426,6 +426,7 @@ class SyncRestServlet(RestServlet):
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary
+ result["org.matrix.msc2654.unread_count"] = room.unread_count
return result
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index f39f556c20..edc3624fed 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -172,6 +172,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_latest_event_ids_in_room.invalidate((room_id,))
+ self.get_unread_message_count_for_user.invalidate_many((room_id,))
self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
if not backfilled:
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 6f2e0d15cc..0c9c02afa1 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -53,6 +53,47 @@ event_counter = Counter(
["type", "origin_type", "origin_entity"],
)
+STATE_EVENT_TYPES_TO_MARK_UNREAD = {
+ EventTypes.Topic,
+ EventTypes.Name,
+ EventTypes.RoomAvatar,
+ EventTypes.Tombstone,
+}
+
+
+def should_count_as_unread(event: EventBase, context: EventContext) -> bool:
+ # Exclude rejected and soft-failed events.
+ if context.rejected or event.internal_metadata.is_soft_failed():
+ return False
+
+ # Exclude notices.
+ if (
+ not event.is_state()
+ and event.type == EventTypes.Message
+ and event.content.get("msgtype") == "m.notice"
+ ):
+ return False
+
+ # Exclude edits.
+ relates_to = event.content.get("m.relates_to", {})
+ if relates_to.get("rel_type") == RelationTypes.REPLACE:
+ return False
+
+ # Mark events that have a non-empty string body as unread.
+ body = event.content.get("body")
+ if isinstance(body, str) and body:
+ return True
+
+ # Mark some state events as unread.
+ if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
+ return True
+
+ # Mark encrypted events as unread.
+ if not event.is_state() and event.type == EventTypes.Encrypted:
+ return True
+
+ return False
+
def encode_json(json_object):
"""
@@ -196,6 +237,10 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc()
+ self.store.get_unread_message_count_for_user.invalidate_many(
+ (event.room_id,),
+ )
+
for room_id, new_state in current_state_for_room.items():
self.store.get_current_state_ids.prefill((room_id,), new_state)
@@ -817,8 +862,9 @@ class PersistEventsStore:
"contains_url": (
"url" in event.content and isinstance(event.content["url"], str)
),
+ "count_as_unread": should_count_as_unread(event, context),
}
- for event, _ in events_and_contexts
+ for event, context in events_and_contexts
],
)
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index e812c67078..b03b259636 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -41,9 +41,15 @@ from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
+from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import (
+ Cache,
+ _CacheContext,
+ cached,
+ cachedInlineCallbacks,
+)
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -1358,6 +1364,84 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
+ @cached(tree=True, cache_context=True)
+ async def get_unread_message_count_for_user(
+ self, room_id: str, user_id: str, cache_context: _CacheContext,
+ ) -> int:
+ """Retrieve the count of unread messages for the given room and user.
+
+ Args:
+ room_id: The ID of the room to count unread messages in.
+ user_id: The ID of the user to count unread messages for.
+
+ Returns:
+ The number of unread messages for the given user in the given room.
+ """
+ with Measure(self._clock, "get_unread_message_count_for_user"):
+ last_read_event_id = await self.get_last_receipt_event_id_for_user(
+ user_id=user_id,
+ room_id=room_id,
+ receipt_type="m.read",
+ on_invalidate=cache_context.invalidate,
+ )
+
+ return await self.db.runInteraction(
+ "get_unread_message_count_for_user",
+ self._get_unread_message_count_for_user_txn,
+ user_id,
+ room_id,
+ last_read_event_id,
+ )
+
+ def _get_unread_message_count_for_user_txn(
+ self,
+ txn: Cursor,
+ user_id: str,
+ room_id: str,
+ last_read_event_id: Optional[str],
+ ) -> int:
+ if last_read_event_id:
+ # Get the stream ordering for the last read event.
+ stream_ordering = self.db.simple_select_one_onecol_txn(
+ txn=txn,
+ table="events",
+ keyvalues={"room_id": room_id, "event_id": last_read_event_id},
+ retcol="stream_ordering",
+ )
+ else:
+ # If there's no read receipt for that room, it probably means the user hasn't
+ # opened it yet, in which case use the stream ID of their join event.
+ # We can't just set it to 0 otherwise messages from other local users from
+ # before this user joined will be counted as well.
+ txn.execute(
+ """
+ SELECT stream_ordering FROM local_current_membership
+ LEFT JOIN events USING (event_id, room_id)
+ WHERE membership = 'join'
+ AND user_id = ?
+ AND room_id = ?
+ """,
+ (user_id, room_id),
+ )
+ row = txn.fetchone()
+
+ if row is None:
+ return 0
+
+ stream_ordering = row[0]
+
+ # Count the messages that qualify as unread after the stream ordering we've just
+ # retrieved.
+ sql = """
+ SELECT COUNT(*) FROM events
+ WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread
+ """
+
+ txn.execute(sql, (user_id, room_id, stream_ordering))
+ row = txn.fetchone()
+
+ return row[0] if row else 0
+
AllNewEventsResult = namedtuple(
"AllNewEventsResult",
diff --git a/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql
new file mode 100644
index 0000000000..531b532c73
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Store a boolean value in the events table for whether the event should be counted in
+-- the unread_count property of sync responses.
+ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN;
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 22d734e763..7f8252330a 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -143,6 +143,26 @@ class RestHelper(object):
return channel.json_body
+ def redact(self, room_id, event_id, txn_id=None, tok=None, expect_code=200):
+ if txn_id is None:
+ txn_id = "m%s" % (str(time.time()))
+
+ path = "/_matrix/client/r0/rooms/%s/redact/%s/%s" % (room_id, event_id, txn_id)
+ if tok:
+ path = path + "?access_token=%s" % tok
+
+ request, channel = make_request(
+ self.hs.get_reactor(), "PUT", path, json.dumps({}).encode("utf8")
+ )
+ render(request, self.resource, self.hs.get_reactor())
+
+ assert int(channel.result["code"]) == expect_code, (
+ "Expected: %d, got: %d, resp: %r"
+ % (expect_code, int(channel.result["code"]), channel.result["body"])
+ )
+
+ return channel.json_body
+
def _read_write_state(
self,
room_id: str,
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index fa3a3ec1bd..a31e44c97e 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -16,9 +16,9 @@
import json
import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client.v2_alpha import read_marker, sync
from tests import unittest
from tests.server import TimedOutException
@@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase):
"GET", sync_url % (access_token, next_batch)
)
self.assertRaises(TimedOutException, self.render, request)
+
+
+class UnreadMessagesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ read_marker.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.url = "/sync?since=%s"
+ self.next_batch = "s0"
+
+ # Register the first user (used to check the unread counts).
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ # Create the room we'll check unread counts for.
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ # Register the second user (used to send events to the room).
+ self.user2 = self.register_user("kermit2", "monkey")
+ self.tok2 = self.login("kermit2", "monkey")
+
+ # Change the power levels of the room so that the second user can send state
+ # events.
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.PowerLevels,
+ {
+ "users": {self.user_id: 100, self.user2: 100},
+ "users_default": 0,
+ "events": {
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ "m.room.history_visibility": 100,
+ "m.room.canonical_alias": 50,
+ "m.room.avatar": 50,
+ "m.room.tombstone": 100,
+ "m.room.server_acl": 100,
+ "m.room.encryption": 100,
+ },
+ "events_default": 0,
+ "state_default": 50,
+ "ban": 50,
+ "kick": 50,
+ "redact": 50,
+ "invite": 0,
+ },
+ tok=self.tok,
+ )
+
+ def test_unread_counts(self):
+ """Tests that /sync returns the right value for the unread count (MSC2654)."""
+
+ # Check that our own messages don't increase the unread count.
+ self.helper.send(self.room_id, "hello", tok=self.tok)
+ self._check_unread_count(0)
+
+ # Join the new user and check that this doesn't increase the unread count.
+ self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
+ self._check_unread_count(0)
+
+ # Check that the new user sending a message increases our unread count.
+ res = self.helper.send(self.room_id, "hello", tok=self.tok2)
+ self._check_unread_count(1)
+
+ # Send a read receipt to tell the server we've read the latest event.
+ body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/read_markers" % self.room_id,
+ body,
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that the unread counter is back to 0.
+ self._check_unread_count(0)
+
+ # Check that room name changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
+ )
+ self._check_unread_count(1)
+
+ # Check that room topic changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
+ )
+ self._check_unread_count(2)
+
+ # Check that encrypted messages increase the unread counter.
+ self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2)
+ self._check_unread_count(3)
+
+ # Check that custom events with a body increase the unread counter.
+ self.helper.send_event(
+ self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that edits don't increase the unread counter.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "body": "hello",
+ "msgtype": "m.text",
+ "m.relates_to": {"rel_type": RelationTypes.REPLACE},
+ },
+ tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that notices don't increase the unread counter.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"body": "hello", "msgtype": "m.notice"},
+ tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that tombstone events changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.Tombstone,
+ {"replacement_room": "!someroom:test"},
+ tok=self.tok2,
+ )
+ self._check_unread_count(5)
+
+ def _check_unread_count(self, expected_count: True):
+ """Syncs and compares the unread count with the expected value."""
+
+ request, channel = self.make_request(
+ "GET", self.url % self.next_batch, access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ room_entry = channel.json_body["rooms"]["join"][self.room_id]
+ self.assertEqual(
+ room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
+ )
+
+ # Store the next batch for the next request.
+ self.next_batch = channel.json_body["next_batch"]
|