diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b4ead79f97..698c9c8876 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1052,7 +1052,7 @@ class SyncHandler:
async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig
- ) -> NotifCounts:
+ ) -> Dict[Optional[str], NotifCounts]:
with Measure(self.clock, "unread_notifs_for_room_id"):
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
@@ -2122,7 +2122,7 @@ class SyncHandler:
)
if room_builder.rtype == "joined":
- unread_notifications: Dict[str, int] = {}
+ unread_notifications: JsonDict = {}
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,
@@ -2137,10 +2137,18 @@ class SyncHandler:
if room_sync or always_include:
notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
- unread_notifications["notification_count"] = notifs.notify_count
- unread_notifications["highlight_count"] = notifs.highlight_count
+ # Notifications for the main timeline.
+ main_notifs = notifs[None]
+ unread_notifications.update(main_notifs.to_dict())
- room_sync.unread_count = notifs.unread_count
+ room_sync.unread_count = main_notifs.unread_count
+
+ # And add info for each thread.
+ unread_notifications["unread_thread_notifications"] = {
+ thread_id: thread_notifs.to_dict()
+ for thread_id, thread_notifs in notifs.items()
+ if thread_id is not None
+ }
sync_result_builder.joined.append(room_sync)
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 8397229ccb..a7029e1b75 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -39,7 +39,10 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
room_id, user_id, last_unread_event_id
)
)
- if notifs.notify_count == 0:
+ # Combine the counts from all the threads.
+ notify_count = sum(n.notify_count for n in notifs.values())
+
+ if notify_count == 0:
continue
if group_by_room:
@@ -47,7 +50,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
badge += 1
else:
# increment the badge count by the number of unread messages in the room
- badge += notifs.notify_count
+ badge += notify_count
return badge
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 812ed1a3d4..1c2dd50404 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -24,6 +24,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -79,7 +80,7 @@ class UserPushAction(EmailPushAction):
profile_tag: str
-@attr.s(slots=True, frozen=True, auto_attribs=True)
+@attr.s(slots=True, auto_attribs=True)
class NotifCounts:
"""
The per-user, per-room count of notifications. Used by sync and push.
@@ -89,6 +90,12 @@ class NotifCounts:
unread_count: int
highlight_count: int
+ def to_dict(self) -> JsonDict:
+ return {
+ "notification_count": self.notify_count,
+ "highlight_count": self.highlight_count,
+ }
+
def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
"""Custom serializer for actions. This allows us to "compress" common actions.
@@ -148,13 +155,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self._rotate_notifs, 30 * 60 * 1000
)
- @cached(num_args=3, tree=True, max_entries=5000)
+ @cached(max_entries=5000, tree=True, iterable=True)
async def get_unread_event_push_actions_by_room_for_user(
self,
room_id: str,
user_id: str,
last_read_event_id: Optional[str],
- ) -> NotifCounts:
+ ) -> Dict[Optional[str], NotifCounts]:
"""Get the notification count, the highlight count and the unread message count
for a given user in a given room after the given read receipt.
@@ -187,7 +194,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
room_id: str,
user_id: str,
last_read_event_id: Optional[str],
- ) -> NotifCounts:
+ ) -> Dict[Optional[str], NotifCounts]:
stream_ordering = None
if last_read_event_id is not None:
@@ -217,49 +224,63 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _get_unread_counts_by_pos_txn(
self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
- ) -> NotifCounts:
- sql = (
- "SELECT"
- " COUNT(CASE WHEN notif = 1 THEN 1 END),"
- " COUNT(CASE WHEN highlight = 1 THEN 1 END),"
- " COUNT(CASE WHEN unread = 1 THEN 1 END)"
- " FROM event_push_actions ea"
- " WHERE user_id = ?"
- " AND room_id = ?"
- " AND stream_ordering > ?"
- )
+ ) -> Dict[Optional[str], NotifCounts]:
+ sql = """
+ SELECT
+ COUNT(CASE WHEN notif = 1 THEN 1 END),
+ COUNT(CASE WHEN highlight = 1 THEN 1 END),
+ COUNT(CASE WHEN unread = 1 THEN 1 END),
+ thread_id
+ FROM event_push_actions ea
+ WHERE user_id = ?
+ AND room_id = ?
+ AND stream_ordering > ?
+ GROUP BY thread_id
+ """
txn.execute(sql, (user_id, room_id, stream_ordering))
- row = txn.fetchone()
-
- (notif_count, highlight_count, unread_count) = (0, 0, 0)
-
- if row:
- (notif_count, highlight_count, unread_count) = row
+ rows = txn.fetchall()
+
+ notif_counts: Dict[Optional[str], NotifCounts] = {
+ # Ensure the main timeline has notification counts.
+ None: NotifCounts(
+ notify_count=0,
+ unread_count=0,
+ highlight_count=0,
+ )
+ }
+ for notif_count, highlight_count, unread_count, thread_id in rows:
+ notif_counts[thread_id] = NotifCounts(
+ notify_count=notif_count,
+ unread_count=unread_count,
+ highlight_count=highlight_count,
+ )
txn.execute(
"""
- SELECT notif_count, unread_count FROM event_push_summary
+ SELECT notif_count, unread_count, thread_id FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
""",
(room_id, user_id, stream_ordering),
)
- row = txn.fetchone()
+ rows = txn.fetchall()
- if row:
- notif_count += row[0]
+ for notif_count, unread_count, thread_id in rows:
+ if unread_count is None:
+ # The unread_count column of event_push_summary is NULLable.
+ unread_count = 0
- if row[1] is not None:
- # The unread_count column of event_push_summary is NULLable, so we need
- # to make sure we don't try increasing the unread counts if it's NULL
- # for this row.
- unread_count += row[1]
+ if thread_id in notif_counts:
+ notif_counts[thread_id].notify_count += notif_count
+ notif_counts[thread_id].unread_count += unread_count
+ else:
+ notif_counts[thread_id] = NotifCounts(
+ notify_count=notif_count,
+ unread_count=unread_count,
+ highlight_count=0,
+ )
- return NotifCounts(
- notify_count=notif_count,
- unread_count=unread_count,
- highlight_count=highlight_count,
- )
+ return notif_counts
async def get_push_action_users_in_range(
self, min_stream_ordering: int, max_stream_ordering: int
|