summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2022-06-09 13:18:25 -0400
committerPatrick Cloke <patrickc@matrix.org>2022-06-13 14:00:08 -0400
commitf03935dcb7c3397eefad16922454b6c14c1c1f5e (patch)
treeb55069bf113aacc5092b52e2233be217981336b8
parentInclude the thread ID in the event push actions. (diff)
downloadsynapse-github/clokep/ranged-read-receipts-poc.tar.xz
Return thread notification counts down sync. github/clokep/ranged-read-receipts-poc clokep/ranged-read-receipts-poc
-rw-r--r--synapse/handlers/sync.py18
-rw-r--r--synapse/push/push_tools.py7
-rw-r--r--synapse/storage/databases/main/event_push_actions.py91
3 files changed, 74 insertions, 42 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b4ead79f97..698c9c8876 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1052,7 +1052,7 @@ class SyncHandler:
 
     async def unread_notifs_for_room_id(
         self, room_id: str, sync_config: SyncConfig
-    ) -> NotifCounts:
+    ) -> Dict[Optional[str], NotifCounts]:
         with Measure(self.clock, "unread_notifs_for_room_id"):
             last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
                 user_id=sync_config.user.to_string(),
@@ -2122,7 +2122,7 @@ class SyncHandler:
                 )
 
             if room_builder.rtype == "joined":
-                unread_notifications: Dict[str, int] = {}
+                unread_notifications: JsonDict = {}
                 room_sync = JoinedSyncResult(
                     room_id=room_id,
                     timeline=batch,
@@ -2137,10 +2137,18 @@ class SyncHandler:
                 if room_sync or always_include:
                     notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
 
-                    unread_notifications["notification_count"] = notifs.notify_count
-                    unread_notifications["highlight_count"] = notifs.highlight_count
+                    # Notifications for the main timeline.
+                    main_notifs = notifs[None]
+                    unread_notifications.update(main_notifs.to_dict())
 
-                    room_sync.unread_count = notifs.unread_count
+                    room_sync.unread_count = main_notifs.unread_count
+
+                    # And add info for each thread.
+                    unread_notifications["unread_thread_notifications"] = {
+                        thread_id: thread_notifs.to_dict()
+                        for thread_id, thread_notifs in notifs.items()
+                        if thread_id is not None
+                    }
 
                     sync_result_builder.joined.append(room_sync)
 
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 8397229ccb..a7029e1b75 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -39,7 +39,10 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
                     room_id, user_id, last_unread_event_id
                 )
             )
-            if notifs.notify_count == 0:
+            # Combine the counts from all the threads.
+            notify_count = sum(n.notify_count for n in notifs.values())
+
+            if notify_count == 0:
                 continue
 
             if group_by_room:
@@ -47,7 +50,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
                 badge += 1
             else:
                 # increment the badge count by the number of unread messages in the room
-                badge += notifs.notify_count
+                badge += notify_count
     return badge
 
 
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 812ed1a3d4..1c2dd50404 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -24,6 +24,7 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
@@ -79,7 +80,7 @@ class UserPushAction(EmailPushAction):
     profile_tag: str
 
 
-@attr.s(slots=True, frozen=True, auto_attribs=True)
+@attr.s(slots=True, auto_attribs=True)
 class NotifCounts:
     """
     The per-user, per-room count of notifications. Used by sync and push.
@@ -89,6 +90,12 @@ class NotifCounts:
     unread_count: int
     highlight_count: int
 
+    def to_dict(self) -> JsonDict:
+        return {
+            "notification_count": self.notify_count,
+            "highlight_count": self.highlight_count,
+        }
+
 
 def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
     """Custom serializer for actions. This allows us to "compress" common actions.
@@ -148,13 +155,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 self._rotate_notifs, 30 * 60 * 1000
             )
 
-    @cached(num_args=3, tree=True, max_entries=5000)
+    @cached(max_entries=5000, tree=True, iterable=True)
     async def get_unread_event_push_actions_by_room_for_user(
         self,
         room_id: str,
         user_id: str,
         last_read_event_id: Optional[str],
-    ) -> NotifCounts:
+    ) -> Dict[Optional[str], NotifCounts]:
         """Get the notification count, the highlight count and the unread message count
         for a given user in a given room after the given read receipt.
 
@@ -187,7 +194,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         room_id: str,
         user_id: str,
         last_read_event_id: Optional[str],
-    ) -> NotifCounts:
+    ) -> Dict[Optional[str], NotifCounts]:
         stream_ordering = None
 
         if last_read_event_id is not None:
@@ -217,49 +224,63 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
     def _get_unread_counts_by_pos_txn(
         self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
-    ) -> NotifCounts:
-        sql = (
-            "SELECT"
-            "   COUNT(CASE WHEN notif = 1 THEN 1 END),"
-            "   COUNT(CASE WHEN highlight = 1 THEN 1 END),"
-            "   COUNT(CASE WHEN unread = 1 THEN 1 END)"
-            " FROM event_push_actions ea"
-            " WHERE user_id = ?"
-            "   AND room_id = ?"
-            "   AND stream_ordering > ?"
-        )
+    ) -> Dict[Optional[str], NotifCounts]:
+        sql = """
+        SELECT
+            COUNT(CASE WHEN notif = 1 THEN 1 END),
+            COUNT(CASE WHEN highlight = 1 THEN 1 END),
+            COUNT(CASE WHEN unread = 1 THEN 1 END),
+            thread_id
+        FROM event_push_actions ea
+        WHERE user_id = ?
+            AND room_id = ?
+            AND stream_ordering > ?
+        GROUP BY thread_id
+        """
 
         txn.execute(sql, (user_id, room_id, stream_ordering))
-        row = txn.fetchone()
-
-        (notif_count, highlight_count, unread_count) = (0, 0, 0)
-
-        if row:
-            (notif_count, highlight_count, unread_count) = row
+        rows = txn.fetchall()
+
+        notif_counts: Dict[Optional[str], NotifCounts] = {
+            # Ensure the main timeline has notification counts.
+            None: NotifCounts(
+                notify_count=0,
+                unread_count=0,
+                highlight_count=0,
+            )
+        }
+        for notif_count, highlight_count, unread_count, thread_id in rows:
+            notif_counts[thread_id] = NotifCounts(
+                notify_count=notif_count,
+                unread_count=unread_count,
+                highlight_count=highlight_count,
+            )
 
         txn.execute(
             """
-                SELECT notif_count, unread_count FROM event_push_summary
+                SELECT notif_count, unread_count, thread_id FROM event_push_summary
                 WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
             """,
             (room_id, user_id, stream_ordering),
         )
-        row = txn.fetchone()
+        rows = txn.fetchall()
 
-        if row:
-            notif_count += row[0]
+        for notif_count, unread_count, thread_id in rows:
+            if unread_count is None:
+                # The unread_count column of event_push_summary is NULLable.
+                unread_count = 0
 
-            if row[1] is not None:
-                # The unread_count column of event_push_summary is NULLable, so we need
-                # to make sure we don't try increasing the unread counts if it's NULL
-                # for this row.
-                unread_count += row[1]
+            if thread_id in notif_counts:
+                notif_counts[thread_id].notify_count += notif_count
+                notif_counts[thread_id].unread_count += unread_count
+            else:
+                notif_counts[thread_id] = NotifCounts(
+                    notify_count=notif_count,
+                    unread_count=unread_count,
+                    highlight_count=0,
+                )
 
-        return NotifCounts(
-            notify_count=notif_count,
-            unread_count=unread_count,
-            highlight_count=highlight_count,
-        )
+        return notif_counts
 
     async def get_push_action_users_in_range(
         self, min_stream_ordering: int, max_stream_ordering: int