diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 332e13d1c9..7ebe34f773 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -74,6 +74,7 @@ receipt.
"""
import logging
+from collections import defaultdict
from typing import (
TYPE_CHECKING,
Collection,
@@ -95,6 +96,7 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
+ PostgresEngine,
)
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore
@@ -294,6 +296,44 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
self._background_backfill_thread_id,
)
+ # Indexes which will be used to quickly make the thread_id column non-null.
+ self.db_pool.updates.register_background_index_update(
+ "event_push_actions_thread_id_null",
+ index_name="event_push_actions_thread_id_null",
+ table="event_push_actions",
+ columns=["thread_id"],
+ where_clause="thread_id IS NULL",
+ )
+ self.db_pool.updates.register_background_index_update(
+ "event_push_summary_thread_id_null",
+ index_name="event_push_summary_thread_id_null",
+ table="event_push_summary",
+ columns=["thread_id"],
+ where_clause="thread_id IS NULL",
+ )
+
+ # Check ASAP (and then later, every 1s) to see if we have finished
+ # background updates the event_push_actions and event_push_summary tables.
+ self._clock.call_later(0.0, self._check_event_push_backfill_thread_id)
+ self._event_push_backfill_thread_id_done = False
+
+ @wrap_as_background_process("check_event_push_backfill_thread_id")
+ async def _check_event_push_backfill_thread_id(self) -> None:
+ """
+ Has thread_id finished backfilling?
+
+ If not, we need to just-in-time update it so the queries work.
+ """
+ done = await self.db_pool.updates.has_completed_background_update(
+ "event_push_backfill_thread_id"
+ )
+
+ if done:
+ self._event_push_backfill_thread_id_done = True
+ else:
+ # Reschedule to run.
+ self._clock.call_later(15.0, self._check_event_push_backfill_thread_id)
+
async def _background_backfill_thread_id(
self, progress: JsonDict, batch_size: int
) -> int:
@@ -310,11 +350,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
event_push_actions_done = progress.get("event_push_actions_done", False)
def add_thread_id_txn(
- txn: LoggingTransaction, table_name: str, start_stream_ordering: int
+ txn: LoggingTransaction, start_stream_ordering: int
) -> int:
- sql = f"""
+ sql = """
SELECT stream_ordering
- FROM {table_name}
+ FROM event_push_actions
WHERE
thread_id IS NULL
AND stream_ordering > ?
@@ -326,7 +366,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# No more rows to process.
rows = txn.fetchall()
if not rows:
- progress[f"{table_name}_done"] = True
+ progress["event_push_actions_done"] = True
self.db_pool.updates._background_update_progress_txn(
txn, "event_push_backfill_thread_id", progress
)
@@ -335,16 +375,65 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# Update the thread ID for any of those rows.
max_stream_ordering = rows[-1][0]
- sql = f"""
- UPDATE {table_name}
+ sql = """
+ UPDATE event_push_actions
SET thread_id = 'main'
- WHERE stream_ordering <= ? AND thread_id IS NULL
+ WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL
"""
- txn.execute(sql, (max_stream_ordering,))
+ txn.execute(
+ sql,
+ (
+ start_stream_ordering,
+ max_stream_ordering,
+ ),
+ )
# Update progress.
processed_rows = txn.rowcount
- progress[f"max_{table_name}_stream_ordering"] = max_stream_ordering
+ progress["max_event_push_actions_stream_ordering"] = max_stream_ordering
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "event_push_backfill_thread_id", progress
+ )
+
+ return processed_rows
+
+ def add_thread_id_summary_txn(txn: LoggingTransaction) -> int:
+ min_user_id = progress.get("max_summary_user_id", "")
+ min_room_id = progress.get("max_summary_room_id", "")
+
+ # Slightly overcomplicated query for getting the Nth user ID / room
+ # ID tuple, or the last if there are less than N remaining.
+ sql = """
+ SELECT user_id, room_id FROM (
+ SELECT user_id, room_id FROM event_push_summary
+ WHERE (user_id, room_id) > (?, ?)
+ AND thread_id IS NULL
+ ORDER BY user_id, room_id
+ LIMIT ?
+ ) AS e
+ ORDER BY user_id DESC, room_id DESC
+ LIMIT 1
+ """
+
+ txn.execute(sql, (min_user_id, min_room_id, batch_size))
+ row = txn.fetchone()
+ if not row:
+ return 0
+
+ max_user_id, max_room_id = row
+
+ sql = """
+ UPDATE event_push_summary
+ SET thread_id = 'main'
+ WHERE
+ (?, ?) < (user_id, room_id) AND (user_id, room_id) <= (?, ?)
+ AND thread_id IS NULL
+ """
+ txn.execute(sql, (min_user_id, min_room_id, max_user_id, max_room_id))
+ processed_rows = txn.rowcount
+
+ progress["max_summary_user_id"] = max_user_id
+ progress["max_summary_room_id"] = max_room_id
self.db_pool.updates._background_update_progress_txn(
txn, "event_push_backfill_thread_id", progress
)
@@ -360,15 +449,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
result = await self.db_pool.runInteraction(
"event_push_backfill_thread_id",
add_thread_id_txn,
- "event_push_actions",
progress.get("max_event_push_actions_stream_ordering", 0),
)
else:
result = await self.db_pool.runInteraction(
"event_push_backfill_thread_id",
- add_thread_id_txn,
- "event_push_summary",
- progress.get("max_event_push_summary_stream_ordering", 0),
+ add_thread_id_summary_txn,
)
# Only done after the event_push_summary table is done.
@@ -379,6 +465,153 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return result
+ async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]:
+ """Get the notification count by room for a user. Only considers notifications,
+ not highlight or unread counts, and threads are currently aggregated under their room.
+
+ This function is intentionally not cached because it is called to calculate the
+ unread badge for push notifications and thus the result is expected to change.
+
+ Note that this function assumes the user is a member of the room. Because
+ summary rows are not removed when a user leaves a room, the caller must
+ filter out those results from the result.
+
+ Returns:
+ A map of room ID to notification counts for the given user.
+ """
+ return await self.db_pool.runInteraction(
+ "get_unread_counts_by_room_for_user",
+ self._get_unread_counts_by_room_for_user_txn,
+ user_id,
+ )
+
+ def _get_unread_counts_by_room_for_user_txn(
+ self, txn: LoggingTransaction, user_id: str
+ ) -> Dict[str, int]:
+ receipt_types_clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "receipt_type",
+ (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
+ )
+ args.extend([user_id, user_id])
+
+ receipts_cte = f"""
+ WITH all_receipts AS (
+ SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering
+ FROM receipts_linearized
+ LEFT JOIN events USING (room_id, event_id)
+ WHERE
+ {receipt_types_clause}
+ AND user_id = ?
+ GROUP BY room_id, thread_id
+ )
+ """
+
+ receipts_joins = """
+ LEFT JOIN (
+ SELECT room_id, thread_id,
+ max_receipt_stream_ordering AS threaded_receipt_stream_ordering
+ FROM all_receipts
+ WHERE thread_id IS NOT NULL
+ ) AS threaded_receipts USING (room_id, thread_id)
+ LEFT JOIN (
+ SELECT room_id, thread_id,
+ max_receipt_stream_ordering AS unthreaded_receipt_stream_ordering
+ FROM all_receipts
+ WHERE thread_id IS NULL
+ ) AS unthreaded_receipts USING (room_id)
+ """
+
+ # First get summary counts by room / thread for the user. We use the max receipt
+ # stream ordering of both threaded & unthreaded receipts to compare against the
+ # summary table.
+ #
+ # PostgreSQL and SQLite differ in comparing scalar numerics.
+ if isinstance(self.database_engine, PostgresEngine):
+ # GREATEST ignores NULLs.
+ max_clause = """GREATEST(
+ threaded_receipt_stream_ordering,
+ unthreaded_receipt_stream_ordering
+ )"""
+ else:
+ # MAX returns NULL if any are NULL, so COALESCE to 0 first.
+ max_clause = """MAX(
+ COALESCE(threaded_receipt_stream_ordering, 0),
+ COALESCE(unthreaded_receipt_stream_ordering, 0)
+ )"""
+
+ sql = f"""
+ {receipts_cte}
+ SELECT eps.room_id, eps.thread_id, notif_count
+ FROM event_push_summary AS eps
+ {receipts_joins}
+ WHERE user_id = ?
+ AND notif_count != 0
+ AND (
+ (last_receipt_stream_ordering IS NULL AND stream_ordering > {max_clause})
+ OR last_receipt_stream_ordering = {max_clause}
+ )
+ """
+ txn.execute(sql, args)
+
+ seen_thread_ids = set()
+ room_to_count: Dict[str, int] = defaultdict(int)
+
+ for room_id, thread_id, notif_count in txn:
+ room_to_count[room_id] += notif_count
+ seen_thread_ids.add(thread_id)
+
+ # Now get any event push actions that haven't been rotated using the same OR
+ # join and filter by receipt and event push summary rotated up to stream ordering.
+ sql = f"""
+ {receipts_cte}
+ SELECT epa.room_id, epa.thread_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
+ FROM event_push_actions AS epa
+ {receipts_joins}
+ WHERE user_id = ?
+ AND epa.notif = 1
+ AND stream_ordering > (SELECT stream_ordering FROM event_push_summary_stream_ordering)
+ AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
+ AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
+ GROUP BY epa.room_id, epa.thread_id
+ """
+ txn.execute(sql, args)
+
+ for room_id, thread_id, notif_count in txn:
+ # Note: only count push actions we have valid summaries for with up to date receipt.
+ if thread_id not in seen_thread_ids:
+ continue
+ room_to_count[room_id] += notif_count
+
+ thread_id_clause, thread_ids_args = make_in_list_sql_clause(
+ self.database_engine, "epa.thread_id", seen_thread_ids
+ )
+
+ # Finally re-check event_push_actions for any rooms not in the summary, ignoring
+ # the rotated up-to position. This handles the case where a read receipt has arrived
+ # but not been rotated meaning the summary table is out of date, so we go back to
+ # the push actions table.
+ sql = f"""
+ {receipts_cte}
+ SELECT epa.room_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
+ FROM event_push_actions AS epa
+ {receipts_joins}
+ WHERE user_id = ?
+ AND NOT {thread_id_clause}
+ AND epa.notif = 1
+ AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
+ AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
+ GROUP BY epa.room_id
+ """
+
+ args.extend(thread_ids_args)
+ txn.execute(sql, args)
+
+ for room_id, notif_count in txn:
+ room_to_count[room_id] += notif_count
+
+ return room_to_count
+
@cached(tree=True, max_entries=5000, iterable=True)
async def get_unread_event_push_actions_by_room_for_user(
self,
@@ -480,6 +713,25 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)
+ # First ensure that the existing rows have an updated thread_id field.
+ if not self._event_push_backfill_thread_id_done:
+ txn.execute(
+ """
+ UPDATE event_push_summary
+ SET thread_id = ?
+ WHERE room_id = ? AND user_id = ? AND thread_id is NULL
+ """,
+ (MAIN_TIMELINE, room_id, user_id),
+ )
+ txn.execute(
+ """
+ UPDATE event_push_actions
+ SET thread_id = ?
+ WHERE room_id = ? AND user_id = ? AND thread_id is NULL
+ """,
+ (MAIN_TIMELINE, room_id, user_id),
+ )
+
# First we pull the counts from the summary table.
#
# We check that `last_receipt_stream_ordering` matches the stream ordering of the
@@ -1295,6 +1547,25 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
(room_id, user_id, stream_ordering, *thread_args),
)
+ # First ensure that the existing rows have an updated thread_id field.
+ if not self._event_push_backfill_thread_id_done:
+ txn.execute(
+ """
+ UPDATE event_push_summary
+ SET thread_id = ?
+ WHERE room_id = ? AND user_id = ? AND thread_id is NULL
+ """,
+ (MAIN_TIMELINE, room_id, user_id),
+ )
+ txn.execute(
+ """
+ UPDATE event_push_actions
+ SET thread_id = ?
+ WHERE room_id = ? AND user_id = ? AND thread_id is NULL
+ """,
+ (MAIN_TIMELINE, room_id, user_id),
+ )
+
# Fetch the notification counts between the stream ordering of the
# latest receipt and what was previously summarised.
unread_counts = self._get_notif_unread_count_for_user_room(
@@ -1429,6 +1700,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
rotate_to_stream_ordering: The new maximum event stream ordering to summarise.
"""
+ # Ensure that any new actions have an updated thread_id.
+ if not self._event_push_backfill_thread_id_done:
+ txn.execute(
+ """
+ UPDATE event_push_actions
+ SET thread_id = ?
+ WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL
+ """,
+ (MAIN_TIMELINE, old_rotate_stream_ordering, rotate_to_stream_ordering),
+ )
+
+ # XXX Do we need to update summaries here too?
+
# Calculate the new counts that should be upserted into event_push_summary
sql = """
SELECT user_id, room_id, thread_id,
@@ -1491,6 +1775,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
logger.info("Rotating notifications, handling %d rows", len(summaries))
+ # Ensure that any updated threads have the proper thread_id.
+ if not self._event_push_backfill_thread_id_done:
+ txn.execute_batch(
+ """
+ UPDATE event_push_summary
+ SET thread_id = ?
+ WHERE room_id = ? AND user_id = ? AND thread_id is NULL
+ """,
+ [
+ (MAIN_TIMELINE, room_id, user_id)
+ for user_id, room_id, _ in summaries
+ ],
+ )
+
self.db_pool.simple_upsert_many_txn(
txn,
table="event_push_summary",
|