diff options
-rw-r--r-- | synapse/handlers/sync.py | 19 | ||||
-rw-r--r-- | synapse/push/push_tools.py | 11 | ||||
-rw-r--r-- | synapse/rest/client/sync.py | 1 | ||||
-rw-r--r-- | synapse/storage/databases/main/event_push_actions.py | 143 | ||||
-rw-r--r-- | tests/replication/slave/storage/test_events.py | 6 | ||||
-rw-r--r-- | tests/storage/test_event_push_actions.py | 175 |
6 files changed, 303 insertions, 52 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index d827c03ad1..94af490cc6 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -115,6 +115,7 @@ class JoinedSyncResult: ephemeral: List[JsonDict] account_data: List[JsonDict] unread_notifications: JsonDict + unread_thread_notifications: JsonDict summary: Optional[JsonDict] unread_count: int @@ -1053,7 +1054,7 @@ class SyncHandler: async def unread_notifs_for_room_id( self, room_id: str, sync_config: SyncConfig - ) -> NotifCounts: + ) -> Tuple[NotifCounts, Dict[str, NotifCounts]]: with Measure(self.clock, "unread_notifs_for_room_id"): return await self.store.get_unread_event_push_actions_by_room_for_user( @@ -2115,18 +2116,32 @@ class SyncHandler: ephemeral=ephemeral, account_data=account_data_events, unread_notifications=unread_notifications, + unread_thread_notifications={}, summary=summary, unread_count=0, ) if room_sync or always_include: - notifs = await self.unread_notifs_for_room_id(room_id, sync_config) + notifs, thread_notifs = await self.unread_notifs_for_room_id( + room_id, sync_config + ) + # Notifications for the main timeline. unread_notifications["notification_count"] = notifs.notify_count unread_notifications["highlight_count"] = notifs.highlight_count room_sync.unread_count = notifs.unread_count + # And add info for each thread. + room_sync.unread_thread_notifications = { + thread_id: { + "notification_count": thread_notifs.notify_count, + "highlight_count": thread_notifs.highlight_count, + } + for thread_id, thread_notifs in thread_notifs.items() + if thread_id is not None + } + sync_result_builder.joined.append(room_sync) if batch.limited and since_token: diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 6661887d9f..edabb8e136 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -26,13 +26,18 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - badge = len(invites) for room_id in joins: - notifs = await ( + notifs, thread_notifs = await ( store.get_unread_event_push_actions_by_room_for_user( room_id, user_id, ) ) - if notifs.notify_count == 0: + # Combine the counts from all the threads. + notify_count = notifs.notify_count + sum( + n.notify_count for n in thread_notifs.values() + ) + + if notify_count == 0: continue if group_by_room: @@ -40,7 +45,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - badge += 1 else: # increment the badge count by the number of unread messages in the room - badge += notifs.notify_count + badge += notify_count return badge diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index c2989765ce..16b0bc9f04 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -509,6 +509,7 @@ class SyncRestServlet(RestServlet): ephemeral_events = room.ephemeral result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications + result["unread_thread_notifications"] = room.unread_thread_notifications result["summary"] = room.summary if self._msc2654_enabled: result["org.matrix.msc2654.unread_count"] = room.unread_count diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 8cdbc242e3..78191ee4bd 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -229,12 +229,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas replaces_index="event_push_summary_unique_index", ) - @cached(tree=True, max_entries=5000) + @cached(tree=True, max_entries=5000, iterable=True) async def get_unread_event_push_actions_by_room_for_user( self, room_id: str, user_id: str, - ) -> NotifCounts: + ) -> Tuple[NotifCounts, Dict[str, NotifCounts]]: """Get the notification count, the highlight count and the unread message count for a given user in a given room after the given read receipt. @@ -263,7 +263,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas txn: LoggingTransaction, room_id: str, user_id: str, - ) -> NotifCounts: + ) -> Tuple[NotifCounts, Dict[str, NotifCounts]]: result = self.get_last_receipt_for_user_txn( txn, user_id, @@ -295,12 +295,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def _get_unread_counts_by_pos_txn( self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int - ) -> NotifCounts: + ) -> Tuple[NotifCounts, Dict[str, NotifCounts]]: """Get the number of unread messages for a user/room that have happened since the given stream ordering. + + Returns: + A tuple of: + The unread messages for the main timeline + + A dictionary of thread ID to unread messages for that thread. + Only contains threads with unread messages. """ counts = NotifCounts() + thread_counts = {} # First we pull the counts from the summary table. # @@ -317,7 +325,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # receipt). txn.execute( """ - SELECT stream_ordering, notif_count, COALESCE(unread_count, 0) + SELECT stream_ordering, notif_count, COALESCE(unread_count, 0), thread_id FROM event_push_summary WHERE room_id = ? AND user_id = ? AND ( @@ -327,39 +335,67 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas """, (room_id, user_id, stream_ordering, stream_ordering), ) - row = txn.fetchone() + max_summary_stream_ordering = 0 + for summary_stream_ordering, notif_count, unread_count, thread_id in txn: + if not thread_id: + counts = NotifCounts( + notify_count=notif_count, unread_count=unread_count + ) + # TODO Delete zeroed out threads completely from the database. + elif notif_count or unread_count: + thread_counts[thread_id] = NotifCounts( + notify_count=notif_count, unread_count=unread_count + ) - summary_stream_ordering = 0 - if row: - summary_stream_ordering = row[0] - counts.notify_count += row[1] - counts.unread_count += row[2] + # XXX All threads should have the same stream ordering? + max_summary_stream_ordering = max( + summary_stream_ordering, max_summary_stream_ordering + ) # Next we need to count highlights, which aren't summarised sql = """ - SELECT COUNT(*) FROM event_push_actions + SELECT COUNT(*), thread_id FROM event_push_actions WHERE user_id = ? AND room_id = ? AND stream_ordering > ? AND highlight = 1 + GROUP BY thread_id """ txn.execute(sql, (user_id, room_id, stream_ordering)) - row = txn.fetchone() - if row: - counts.highlight_count += row[0] + for highlight_count, thread_id in txn: + if not thread_id: + counts.highlight_count += highlight_count + elif highlight_count: + if thread_id in thread_counts: + thread_counts[thread_id].highlight_count += highlight_count + else: + thread_counts[thread_id] = NotifCounts( + notify_count=0, unread_count=0, highlight_count=highlight_count + ) # Finally we need to count push actions that aren't included in the # summary returned above, e.g. recent events that haven't been # summarised yet, or the summary is empty due to a recent read receipt. - stream_ordering = max(stream_ordering, summary_stream_ordering) - notify_count, unread_count = self._get_notif_unread_count_for_user_room( + stream_ordering = max(stream_ordering, max_summary_stream_ordering) + unread_counts = self._get_notif_unread_count_for_user_room( txn, room_id, user_id, stream_ordering ) - counts.notify_count += notify_count - counts.unread_count += unread_count + for notif_count, unread_count, thread_id in unread_counts: + if not thread_id: + counts.notify_count += notif_count + counts.unread_count += unread_count + elif thread_id in thread_counts: + thread_counts[thread_id].notify_count += notif_count + thread_counts[thread_id].unread_count += unread_count + else: + thread_counts[thread_id] = NotifCounts( + notify_count=notif_count, + unread_count=unread_count, + highlight_count=0, + ) - return counts + return counts, thread_counts def _get_notif_unread_count_for_user_room( self, @@ -368,7 +404,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas user_id: str, stream_ordering: int, max_stream_ordering: Optional[int] = None, - ) -> Tuple[int, int]: + ) -> List[Tuple[int, int, str]]: """Returns the notify and unread counts from `event_push_actions` for the given user/room in the given range. @@ -390,7 +426,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # If there have been no events in the room since the stream ordering, # there can't be any push actions either. if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering): - return 0, 0 + return [] clause = "" args = [user_id, room_id, stream_ordering] @@ -401,26 +437,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # If the max stream ordering is less than the min stream ordering, # then obviously there are zero push actions in that range. if max_stream_ordering <= stream_ordering: - return 0, 0 + return [] sql = f""" SELECT COUNT(CASE WHEN notif = 1 THEN 1 END), - COUNT(CASE WHEN unread = 1 THEN 1 END) - FROM event_push_actions ea - WHERE user_id = ? + COUNT(CASE WHEN unread = 1 THEN 1 END), + thread_id + FROM event_push_actions ea + WHERE user_id = ? AND room_id = ? AND ea.stream_ordering > ? {clause} + GROUP BY thread_id """ txn.execute(sql, args) - row = txn.fetchone() - - if row: - return cast(Tuple[int, int], row) - - return 0, 0 + return cast(List[Tuple[int, int, str]], txn.fetchall()) async def get_push_action_users_in_range( self, min_stream_ordering: int, max_stream_ordering: int @@ -1010,21 +1043,42 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # Fetch the notification counts between the stream ordering of the # latest receipt and what was previously summarised. - notif_count, unread_count = self._get_notif_unread_count_for_user_room( + unread_counts = self._get_notif_unread_count_for_user_room( txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering ) - # Replace the previous summary with the new counts. - self.db_pool.simple_upsert_txn( + # Updated threads get their notification count and unread count updated. + self.db_pool.simple_upsert_many_txn( txn, table="event_push_summary", - keyvalues={"room_id": room_id, "user_id": user_id}, - values={ - "notif_count": notif_count, - "unread_count": unread_count, - "stream_ordering": old_rotate_stream_ordering, - "last_receipt_stream_ordering": stream_ordering, - }, + key_names=("room_id", "user_id", "thread_id"), + key_values=[(room_id, user_id, row[2]) for row in unread_counts], + value_names=( + "notif_count", + "unread_count", + "stream_ordering", + "last_receipt_stream_ordering", + ), + value_values=[ + (row[0], row[1], old_rotate_stream_ordering, stream_ordering) + for row in unread_counts + ], + ) + + # Other threads should be marked as reset at the old stream ordering. + txn.execute( + """ + UPDATE event_push_summary SET notif_count = 0, unread_count = 0, stream_ordering = ?, last_receipt_stream_ordering = ? + WHERE user_id = ? AND room_id = ? AND + stream_ordering <= ? + """, + ( + old_rotate_stream_ordering, + stream_ordering, + user_id, + room_id, + old_rotate_stream_ordering, + ), ) # We always update `event_push_summary_last_receipt_stream_id` to @@ -1178,7 +1232,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas txn, table="event_push_summary", key_names=("user_id", "room_id", "thread_id"), - key_values=[(user_id, room_id, thread_id) for user_id, room_id, thread_id in summaries], + key_values=[ + (user_id, room_id, thread_id) + for user_id, room_id, thread_id in summaries + ], value_names=("notif_count", "unread_count", "stream_ordering"), value_values=[ ( diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index f16554cd5c..1ac3260984 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -178,7 +178,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2], - NotifCounts(highlight_count=0, unread_count=0, notify_count=0), + (NotifCounts(highlight_count=0, unread_count=0, notify_count=0), {}), ) self.persist( @@ -191,7 +191,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2], - NotifCounts(highlight_count=0, unread_count=0, notify_count=1), + (NotifCounts(highlight_count=0, unread_count=0, notify_count=1), {}), ) self.persist( @@ -206,7 +206,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2], - NotifCounts(highlight_count=1, unread_count=0, notify_count=2), + (NotifCounts(highlight_count=1, unread_count=0, notify_count=2), {}), ) def test_get_rooms_for_user_with_stream_ordering(self): diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index ba40124c8a..d1c9035fa8 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + from twisted.test.proto_helpers import MemoryReactor from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.types import JsonDict from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -70,7 +73,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): def _assert_counts( noitf_count: int, unread_count: int, highlight_count: int ) -> None: - counts = self.get_success( + counts, thread_counts = self.get_success( self.store.db_pool.runInteraction( "get-unread-counts", self.store._get_unread_counts_by_receipt_txn, @@ -86,6 +89,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): highlight_count=highlight_count, ), ) + self.assertEqual(thread_counts, {}) def _create_event(highlight: bool = False) -> str: result = self.helper.send_event( @@ -131,6 +135,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _assert_counts(0, 0, 0) _create_event() + _assert_counts(1, 1, 0) _rotate() _assert_counts(1, 1, 0) @@ -166,6 +171,174 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): _rotate() _assert_counts(0, 0, 0) + def test_count_aggregation_threads(self) -> None: + # Create a user to receive notifications and send receipts. + user_id = self.register_user("user1235", "pass") + token = self.login("user1235", "pass") + + # And another users to send events. + other_id = self.register_user("other", "pass") + other_token = self.login("other", "pass") + + # Create a room and put both users in it. + room_id = self.helper.create_room_as(user_id, tok=token) + self.helper.join(room_id, other_id, tok=other_token) + thread_id: str + + last_event_id: str + + def _assert_counts( + noitf_count: int, + unread_count: int, + highlight_count: int, + thread_notif_count: int, + thread_unread_count: int, + thread_highlight_count: int, + ) -> None: + counts, thread_counts = self.get_success( + self.store.db_pool.runInteraction( + "get-unread-counts", + self.store._get_unread_counts_by_receipt_txn, + room_id, + user_id, + ) + ) + self.assertEqual( + counts, + NotifCounts( + notify_count=noitf_count, + unread_count=unread_count, + highlight_count=highlight_count, + ), + ) + if thread_notif_count or thread_unread_count or thread_highlight_count: + self.assertEqual( + thread_counts, + { + thread_id: NotifCounts( + notify_count=thread_notif_count, + unread_count=thread_unread_count, + highlight_count=thread_highlight_count, + ), + }, + ) + else: + self.assertEqual(thread_counts, {}) + + def _create_event( + highlight: bool = False, thread_id: Optional[str] = None + ) -> str: + content: JsonDict = { + "msgtype": "m.text", + "body": user_id if highlight else "", + } + if thread_id: + content["m.relates_to"] = { + "rel_type": "m.thread", + "event_id": thread_id, + } + + result = self.helper.send_event( + room_id, + type="m.room.message", + content=content, + tok=other_token, + ) + nonlocal last_event_id + last_event_id = result["event_id"] + return last_event_id + + def _rotate() -> None: + self.get_success(self.store._rotate_notifs()) + + def _mark_read(event_id: str) -> None: + self.get_success( + self.store.insert_receipt( + room_id, + "m.read", + user_id=user_id, + event_ids=[event_id], + data={}, + ) + ) + + _assert_counts(0, 0, 0, 0, 0, 0) + thread_id = _create_event() + _assert_counts(1, 0, 0, 0, 0, 0) + _rotate() + _assert_counts(1, 0, 0, 0, 0, 0) + + _create_event(thread_id=thread_id) + _assert_counts(1, 0, 0, 1, 0, 0) + _rotate() + _assert_counts(1, 0, 0, 1, 0, 0) + + _create_event() + _assert_counts(2, 0, 0, 1, 0, 0) + _rotate() + _assert_counts(2, 0, 0, 1, 0, 0) + + event_id = _create_event(thread_id=thread_id) + _assert_counts(2, 0, 0, 2, 0, 0) + _rotate() + _assert_counts(2, 0, 0, 2, 0, 0) + + _create_event() + _create_event(thread_id=thread_id) + _mark_read(event_id) + _assert_counts(1, 0, 0, 1, 0, 0) + + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0, 0, 0) + + _create_event() + _create_event(thread_id=thread_id) + _assert_counts(1, 0, 0, 1, 0, 0) + _rotate() + _assert_counts(1, 0, 0, 1, 0, 0) + + # Delete old event push actions, this should not affect the (summarised) count. + self.get_success(self.store._remove_old_push_actions_that_have_rotated()) + _assert_counts(1, 0, 0, 1, 0, 0) + + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0, 0, 0) + + _create_event(True) + _assert_counts(1, 1, 1, 0, 0, 0) + _rotate() + _assert_counts(1, 1, 1, 0, 0, 0) + + event_id = _create_event(True, thread_id) + _assert_counts(1, 1, 1, 1, 1, 1) + _rotate() + _assert_counts(1, 1, 1, 1, 1, 1) + + # Check that adding another notification and rotating after highlight + # works. + _create_event() + _rotate() + _assert_counts(2, 0, 1, 1, 1, 1) + + _create_event(thread_id=thread_id) + _rotate() + _assert_counts(2, 0, 1, 2, 0, 1) + + # Check that sending read receipts at different points results in the + # right counts. + _mark_read(event_id) + _assert_counts(1, 0, 0, 1, 0, 0) + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0, 0, 0) + + _create_event(True) + _create_event(True, thread_id) + _assert_counts(1, 1, 1, 1, 1, 1) + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0, 0, 0) + _rotate() + _assert_counts(0, 0, 0, 0, 0, 0) + def test_find_first_stream_ordering_after_ts(self) -> None: def add_event(so: int, ts: int) -> None: self.get_success( |