diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index cdc9ee5a37..3210e9cca1 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -88,7 +88,7 @@ from typing import (
import attr
-from synapse.api.constants import ReceiptTypes
+from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
@@ -157,7 +157,7 @@ class UserPushAction(EmailPushAction):
@attr.s(slots=True, auto_attribs=True)
class NotifCounts:
"""
- The per-user, per-room count of notifications. Used by sync and push.
+ The per-user, per-room, per-thread count of notifications. Used by sync and push.
"""
notify_count: int = 0
@@ -165,6 +165,21 @@ class NotifCounts:
highlight_count: int = 0
+@attr.s(slots=True, auto_attribs=True)
+class RoomNotifCounts:
+ """
+ The per-user, per-room count of notifications. Used by sync and push.
+ """
+
+ main_timeline: NotifCounts
+ # Map of thread ID to the notification counts.
+ threads: Dict[str, NotifCounts]
+
+ def __len__(self) -> int:
+ # To properly account for the amount of space in any caches.
+ return len(self.threads) + 1
+
+
def _serialize_action(
actions: Collection[Union[Mapping, str]], is_highlight: bool
) -> str:
@@ -338,12 +353,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return result
- @cached(tree=True, max_entries=5000)
+ @cached(tree=True, max_entries=5000, iterable=True)
async def get_unread_event_push_actions_by_room_for_user(
self,
room_id: str,
user_id: str,
- ) -> NotifCounts:
+ ) -> RoomNotifCounts:
"""Get the notification count, the highlight count and the unread message count
for a given user in a given room after their latest read receipt.
@@ -356,8 +371,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: The user to retrieve the counts for.
Returns
- A NotifCounts object containing the notification count, the highlight count
- and the unread message count.
+ A RoomNotifCounts object containing the notification count, the
+ highlight count and the unread message count for both the main timeline
+ and threads.
"""
return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
@@ -371,7 +387,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
txn: LoggingTransaction,
room_id: str,
user_id: str,
- ) -> NotifCounts:
+ ) -> RoomNotifCounts:
# Get the stream ordering of the user's latest receipt in the room.
result = self.get_last_unthreaded_receipt_for_user_txn(
txn,
@@ -406,7 +422,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
room_id: str,
user_id: str,
receipt_stream_ordering: int,
- ) -> NotifCounts:
+ ) -> RoomNotifCounts:
"""Get the number of unread messages for a user/room that have happened
since the given stream ordering.
@@ -418,12 +434,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
receipt in the room. If there are no receipts, the stream ordering
of the user's join event.
- Returns
- A NotifCounts object containing the notification count, the highlight count
- and the unread message count.
+ Returns:
+ A RoomNotifCounts object containing the notification count, the
+ highlight count and the unread message count for both the main timeline
+ and threads.
"""
- counts = NotifCounts()
+ main_counts = NotifCounts()
+ thread_counts: Dict[str, NotifCounts] = {}
+
+ def _get_thread(thread_id: str) -> NotifCounts:
+ if thread_id == MAIN_TIMELINE:
+ return main_counts
+ return thread_counts.setdefault(thread_id, NotifCounts())
# First we pull the counts from the summary table.
#
@@ -440,52 +463,61 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# receipt).
txn.execute(
"""
- SELECT stream_ordering, notif_count, COALESCE(unread_count, 0)
+ SELECT stream_ordering, notif_count, COALESCE(unread_count, 0), thread_id
FROM event_push_summary
WHERE room_id = ? AND user_id = ?
AND (
(last_receipt_stream_ordering IS NULL AND stream_ordering > ?)
OR last_receipt_stream_ordering = ?
- )
+ ) AND (notif_count != 0 OR COALESCE(unread_count, 0) != 0)
""",
(room_id, user_id, receipt_stream_ordering, receipt_stream_ordering),
)
- row = txn.fetchone()
-
- summary_stream_ordering = 0
- if row:
- summary_stream_ordering = row[0]
- counts.notify_count += row[1]
- counts.unread_count += row[2]
+ max_summary_stream_ordering = 0
+ for summary_stream_ordering, notif_count, unread_count, thread_id in txn:
+ counts = _get_thread(thread_id)
+ counts.notify_count += notif_count
+ counts.unread_count += unread_count
+
+ # Summaries will only be used if they have not been invalidated by
+ # a recent receipt; track the latest stream ordering or a valid summary.
+ #
+ # Note that since there's only one read receipt in the room per user,
+ # valid summaries are contiguous.
+ max_summary_stream_ordering = max(
+ summary_stream_ordering, max_summary_stream_ordering
+ )
# Next we need to count highlights, which aren't summarised
sql = """
- SELECT COUNT(*) FROM event_push_actions
+ SELECT COUNT(*), thread_id FROM event_push_actions
WHERE user_id = ?
AND room_id = ?
AND stream_ordering > ?
AND highlight = 1
+ GROUP BY thread_id
"""
txn.execute(sql, (user_id, room_id, receipt_stream_ordering))
- row = txn.fetchone()
- if row:
- counts.highlight_count += row[0]
+ for highlight_count, thread_id in txn:
+ _get_thread(thread_id).highlight_count += highlight_count
# Finally we need to count push actions that aren't included in the
# summary returned above. This might be due to recent events that haven't
# been summarised yet or the summary is out of date due to a recent read
# receipt.
start_unread_stream_ordering = max(
- receipt_stream_ordering, summary_stream_ordering
+ receipt_stream_ordering, max_summary_stream_ordering
)
- notify_count, unread_count = self._get_notif_unread_count_for_user_room(
+ unread_counts = self._get_notif_unread_count_for_user_room(
txn, room_id, user_id, start_unread_stream_ordering
)
- counts.notify_count += notify_count
- counts.unread_count += unread_count
+ for notif_count, unread_count, thread_id in unread_counts:
+ counts = _get_thread(thread_id)
+ counts.notify_count += notif_count
+ counts.unread_count += unread_count
- return counts
+ return RoomNotifCounts(main_counts, thread_counts)
def _get_notif_unread_count_for_user_room(
self,
@@ -494,7 +526,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: str,
stream_ordering: int,
max_stream_ordering: Optional[int] = None,
- ) -> Tuple[int, int]:
+ ) -> List[Tuple[int, int, str]]:
"""Returns the notify and unread counts from `event_push_actions` for
the given user/room in the given range.
@@ -510,13 +542,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
If this is not given, then no maximum is applied.
Return:
- A tuple of the notif count and unread count in the given range.
+ A tuple of the notif count and unread count in the given range for
+ each thread.
"""
# If there have been no events in the room since the stream ordering,
# there can't be any push actions either.
if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering):
- return 0, 0
+ return []
clause = ""
args = [user_id, room_id, stream_ordering]
@@ -527,26 +560,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# If the max stream ordering is less than the min stream ordering,
# then obviously there are zero push actions in that range.
if max_stream_ordering <= stream_ordering:
- return 0, 0
+ return []
sql = f"""
SELECT
COUNT(CASE WHEN notif = 1 THEN 1 END),
- COUNT(CASE WHEN unread = 1 THEN 1 END)
- FROM event_push_actions ea
- WHERE user_id = ?
+ COUNT(CASE WHEN unread = 1 THEN 1 END),
+ thread_id
+ FROM event_push_actions ea
+ WHERE user_id = ?
AND room_id = ?
AND ea.stream_ordering > ?
{clause}
+ GROUP BY thread_id
"""
txn.execute(sql, args)
- row = txn.fetchone()
-
- if row:
- return cast(Tuple[int, int], row)
-
- return 0, 0
+ return cast(List[Tuple[int, int, str]], txn.fetchall())
async def get_push_action_users_in_range(
self, min_stream_ordering: int, max_stream_ordering: int
@@ -1099,26 +1129,34 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# Fetch the notification counts between the stream ordering of the
# latest receipt and what was previously summarised.
- notif_count, unread_count = self._get_notif_unread_count_for_user_room(
+ unread_counts = self._get_notif_unread_count_for_user_room(
txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
)
- # Replace the previous summary with the new counts.
- #
- # TODO(threads): Upsert per-thread instead of setting them all to main.
- self.db_pool.simple_upsert_txn(
+ # First mark the summary for all threads in the room as cleared.
+ self.db_pool.simple_update_txn(
txn,
table="event_push_summary",
- keyvalues={"room_id": room_id, "user_id": user_id},
- values={
- "notif_count": notif_count,
- "unread_count": unread_count,
+ keyvalues={"user_id": user_id, "room_id": room_id},
+ updatevalues={
+ "notif_count": 0,
+ "unread_count": 0,
"stream_ordering": old_rotate_stream_ordering,
"last_receipt_stream_ordering": stream_ordering,
- "thread_id": "main",
},
)
+ # Then any updated threads get their notification count and unread
+ # count updated.
+ self.db_pool.simple_update_many_txn(
+ txn,
+ table="event_push_summary",
+ key_names=("room_id", "user_id", "thread_id"),
+ key_values=[(room_id, user_id, row[2]) for row in unread_counts],
+ value_names=("notif_count", "unread_count"),
+ value_values=[(row[0], row[1]) for row in unread_counts],
+ )
+
# We always update `event_push_summary_last_receipt_stream_id` to
# ensure that we don't rescan the same receipts for remote users.
@@ -1204,23 +1242,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# Calculate the new counts that should be upserted into event_push_summary
sql = """
- SELECT user_id, room_id,
+ SELECT user_id, room_id, thread_id,
coalesce(old.%s, 0) + upd.cnt,
upd.stream_ordering
FROM (
- SELECT user_id, room_id, count(*) as cnt,
+ SELECT user_id, room_id, thread_id, count(*) as cnt,
max(ea.stream_ordering) as stream_ordering
FROM event_push_actions AS ea
- LEFT JOIN event_push_summary AS old USING (user_id, room_id)
+ LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ?
AND (
old.last_receipt_stream_ordering IS NULL
OR old.last_receipt_stream_ordering < ea.stream_ordering
)
AND %s = 1
- GROUP BY user_id, room_id
+ GROUP BY user_id, room_id, thread_id
) AS upd
- LEFT JOIN event_push_summary AS old USING (user_id, room_id)
+ LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
"""
# First get the count of unread messages.
@@ -1234,11 +1272,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# object because we might not have the same amount of rows in each of them. To do
# this, we use a dict indexed on the user ID and room ID to make it easier to
# populate.
- summaries: Dict[Tuple[str, str], _EventPushSummary] = {}
+ summaries: Dict[Tuple[str, str, str], _EventPushSummary] = {}
for row in txn:
- summaries[(row[0], row[1])] = _EventPushSummary(
- unread_count=row[2],
- stream_ordering=row[3],
+ summaries[(row[0], row[1], row[2])] = _EventPushSummary(
+ unread_count=row[3],
+ stream_ordering=row[4],
notif_count=0,
)
@@ -1249,34 +1287,35 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
for row in txn:
- if (row[0], row[1]) in summaries:
- summaries[(row[0], row[1])].notif_count = row[2]
+ if (row[0], row[1], row[2]) in summaries:
+ summaries[(row[0], row[1], row[2])].notif_count = row[3]
else:
# Because the rules on notifying are different than the rules on marking
# a message unread, we might end up with messages that notify but aren't
# marked unread, so we might not have a summary for this (user, room)
# tuple to complete.
- summaries[(row[0], row[1])] = _EventPushSummary(
+ summaries[(row[0], row[1], row[2])] = _EventPushSummary(
unread_count=0,
- stream_ordering=row[3],
- notif_count=row[2],
+ stream_ordering=row[4],
+ notif_count=row[3],
)
logger.info("Rotating notifications, handling %d rows", len(summaries))
- # TODO(threads): Update on a per-thread basis.
self.db_pool.simple_upsert_many_txn(
txn,
table="event_push_summary",
- key_names=("user_id", "room_id"),
- key_values=[(user_id, room_id) for user_id, room_id in summaries],
- value_names=("notif_count", "unread_count", "stream_ordering", "thread_id"),
+ key_names=("user_id", "room_id", "thread_id"),
+ key_values=[
+ (user_id, room_id, thread_id)
+ for user_id, room_id, thread_id in summaries
+ ],
+ value_names=("notif_count", "unread_count", "stream_ordering"),
value_values=[
(
summary.notif_count,
summary.unread_count,
summary.stream_ordering,
- "main",
)
for summary in summaries.values()
],
@@ -1288,7 +1327,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
async def _remove_old_push_actions_that_have_rotated(self) -> None:
- """Clear out old push actions that have been summarised."""
+ """
+ Clear out old push actions that have been summarised (and are older than
+ 1 day ago).
+ """
# We want to clear out anything that is older than a day that *has* already
# been rotated.
|