summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/sync.py6
-rw-r--r--synapse/push/push_tools.py17
-rw-r--r--synapse/rest/client/v2_alpha/sync.py1
-rw-r--r--synapse/storage/data_stores/main/cache.py1
-rw-r--r--synapse/storage/data_stores/main/events.py48
-rw-r--r--synapse/storage/data_stores/main/events_worker.py86
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql18
7 files changed, 162 insertions, 15 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index ebd3e98105..eaa4eeadf7 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -103,6 +103,7 @@ class JoinedSyncResult:
     account_data = attr.ib(type=List[JsonDict])
     unread_notifications = attr.ib(type=JsonDict)
     summary = attr.ib(type=Optional[JsonDict])
+    unread_count = attr.ib(type=int)
 
     def __nonzero__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
@@ -1886,6 +1887,10 @@ class SyncHandler(object):
 
         if room_builder.rtype == "joined":
             unread_notifications = {}  # type: Dict[str, str]
+
+            unread_count = await self.store.get_unread_message_count_for_user(
+                room_id, sync_config.user.to_string(),
+            )
             room_sync = JoinedSyncResult(
                 room_id=room_id,
                 timeline=batch,
@@ -1894,6 +1899,7 @@ class SyncHandler(object):
                 account_data=account_data_events,
                 unread_notifications=unread_notifications,
                 summary=summary,
+                unread_count=unread_count,
             )
 
             if room_sync or always_include:
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index d0145666bf..bc8f71916b 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -21,22 +21,13 @@ async def get_badge_count(store, user_id):
     invites = await store.get_invited_rooms_for_local_user(user_id)
     joins = await store.get_rooms_for_user(user_id)
 
-    my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read")
-
     badge = len(invites)
 
     for room_id in joins:
-        if room_id in my_receipts_by_room:
-            last_unread_event_id = my_receipts_by_room[room_id]
-
-            notifs = await (
-                store.get_unread_event_push_actions_by_room_for_user(
-                    room_id, user_id, last_unread_event_id
-                )
-            )
-            # return one badge count per conversation, as count per
-            # message is so noisy as to be almost useless
-            badge += 1 if notifs["notify_count"] else 0
+        unread_count = await store.get_unread_message_count_for_user(room_id, user_id)
+        # return one badge count per conversation, as count per
+        # message is so noisy as to be almost useless
+        badge += 1 if unread_count else 0
     return badge
 
 
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index a5c24fbd63..3f5bf75e59 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -426,6 +426,7 @@ class SyncRestServlet(RestServlet):
             result["ephemeral"] = {"events": ephemeral_events}
             result["unread_notifications"] = room.unread_notifications
             result["summary"] = room.summary
+            result["org.matrix.msc2654.unread_count"] = room.unread_count
 
         return result
 
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index f39f556c20..edc3624fed 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -172,6 +172,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         self.get_latest_event_ids_in_room.invalidate((room_id,))
 
+        self.get_unread_message_count_for_user.invalidate_many((room_id,))
         self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
 
         if not backfilled:
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 6f2e0d15cc..0c9c02afa1 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -53,6 +53,47 @@ event_counter = Counter(
     ["type", "origin_type", "origin_entity"],
 )
 
+STATE_EVENT_TYPES_TO_MARK_UNREAD = {
+    EventTypes.Topic,
+    EventTypes.Name,
+    EventTypes.RoomAvatar,
+    EventTypes.Tombstone,
+}
+
+
+def should_count_as_unread(event: EventBase, context: EventContext) -> bool:
+    # Exclude rejected and soft-failed events.
+    if context.rejected or event.internal_metadata.is_soft_failed():
+        return False
+
+    # Exclude notices.
+    if (
+        not event.is_state()
+        and event.type == EventTypes.Message
+        and event.content.get("msgtype") == "m.notice"
+    ):
+        return False
+
+    # Exclude edits.
+    relates_to = event.content.get("m.relates_to", {})
+    if relates_to.get("rel_type") == RelationTypes.REPLACE:
+        return False
+
+    # Mark events that have a non-empty string body as unread.
+    body = event.content.get("body")
+    if isinstance(body, str) and body:
+        return True
+
+    # Mark some state events as unread.
+    if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
+        return True
+
+    # Mark encrypted events as unread.
+    if not event.is_state() and event.type == EventTypes.Encrypted:
+        return True
+
+    return False
+
 
 def encode_json(json_object):
     """
@@ -196,6 +237,10 @@ class PersistEventsStore:
 
                 event_counter.labels(event.type, origin_type, origin_entity).inc()
 
+                self.store.get_unread_message_count_for_user.invalidate_many(
+                    (event.room_id,),
+                )
+
             for room_id, new_state in current_state_for_room.items():
                 self.store.get_current_state_ids.prefill((room_id,), new_state)
 
@@ -817,8 +862,9 @@ class PersistEventsStore:
                     "contains_url": (
                         "url" in event.content and isinstance(event.content["url"], str)
                     ),
+                    "count_as_unread": should_count_as_unread(event, context),
                 }
-                for event, _ in events_and_contexts
+                for event, context in events_and_contexts
             ],
         )
 
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index e812c67078..b03b259636 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -41,9 +41,15 @@ from synapse.replication.tcp.streams import BackfillStream
 from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import Database
+from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import (
+    Cache,
+    _CacheContext,
+    cached,
+    cachedInlineCallbacks,
+)
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -1358,6 +1364,84 @@ class EventsWorkerStore(SQLBaseStore):
             desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
         )
 
+    @cached(tree=True, cache_context=True)
+    async def get_unread_message_count_for_user(
+        self, room_id: str, user_id: str, cache_context: _CacheContext,
+    ) -> int:
+        """Retrieve the count of unread messages for the given room and user.
+
+        Args:
+            room_id: The ID of the room to count unread messages in.
+            user_id: The ID of the user to count unread messages for.
+
+        Returns:
+            The number of unread messages for the given user in the given room.
+        """
+        with Measure(self._clock, "get_unread_message_count_for_user"):
+            last_read_event_id = await self.get_last_receipt_event_id_for_user(
+                user_id=user_id,
+                room_id=room_id,
+                receipt_type="m.read",
+                on_invalidate=cache_context.invalidate,
+            )
+
+            return await self.db.runInteraction(
+                "get_unread_message_count_for_user",
+                self._get_unread_message_count_for_user_txn,
+                user_id,
+                room_id,
+                last_read_event_id,
+            )
+
+    def _get_unread_message_count_for_user_txn(
+        self,
+        txn: Cursor,
+        user_id: str,
+        room_id: str,
+        last_read_event_id: Optional[str],
+    ) -> int:
+        if last_read_event_id:
+            # Get the stream ordering for the last read event.
+            stream_ordering = self.db.simple_select_one_onecol_txn(
+                txn=txn,
+                table="events",
+                keyvalues={"room_id": room_id, "event_id": last_read_event_id},
+                retcol="stream_ordering",
+            )
+        else:
+            # If there's no read receipt for that room, it probably means the user hasn't
+            # opened it yet, in which case use the stream ID of their join event.
+            # We can't just set it to 0 otherwise messages from other local users from
+            # before this user joined will be counted as well.
+            txn.execute(
+                """
+                SELECT stream_ordering FROM local_current_membership
+                LEFT JOIN events USING (event_id, room_id)
+                WHERE membership = 'join'
+                    AND user_id = ?
+                    AND room_id = ?
+                """,
+                (user_id, room_id),
+            )
+            row = txn.fetchone()
+
+            if row is None:
+                return 0
+
+            stream_ordering = row[0]
+
+        # Count the messages that qualify as unread after the stream ordering we've just
+        # retrieved.
+        sql = """
+            SELECT COUNT(*) FROM events
+            WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread
+        """
+
+        txn.execute(sql, (user_id, room_id, stream_ordering))
+        row = txn.fetchone()
+
+        return row[0] if row else 0
+
 
 AllNewEventsResult = namedtuple(
     "AllNewEventsResult",
diff --git a/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql
new file mode 100644
index 0000000000..531b532c73
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Store a boolean value in the events table for whether the event should be counted in
+-- the unread_count property of sync responses.
+ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN;