diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index eabf9c9739..6b8668d2dc 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -98,6 +98,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -232,6 +233,104 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
replaces_index="event_push_summary_user_rm",
)
+ self.db_pool.updates.register_background_index_update(
+ "event_push_summary_unique_index2",
+ index_name="event_push_summary_unique_index2",
+ table="event_push_summary",
+ columns=["user_id", "room_id", "thread_id"],
+ unique=True,
+ )
+
+ self.db_pool.updates.register_background_update_handler(
+ "event_push_backfill_thread_id",
+ self._background_backfill_thread_id,
+ )
+
+ async def _background_backfill_thread_id(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """
+ Fill in the thread_id field for event_push_actions and event_push_summary.
+
+ This is preparatory so that it can be made non-nullable in the future.
+
+ Because all current (null) data is done in an unthreaded manner this
+ simply assumes it is on the "main" timeline. Since event_push_actions
+ are periodically cleared it is not possible to correctly re-calculate
+ the thread_id.
+ """
+ event_push_actions_done = progress.get("event_push_actions_done", False)
+
+ def add_thread_id_txn(
+ txn: LoggingTransaction, table_name: str, start_stream_ordering: int
+ ) -> int:
+ sql = f"""
+ SELECT stream_ordering
+ FROM {table_name}
+ WHERE
+ thread_id IS NULL
+ AND stream_ordering > ?
+ ORDER BY stream_ordering
+ LIMIT ?
+ """
+ txn.execute(sql, (start_stream_ordering, batch_size))
+
+ # No more rows to process.
+ rows = txn.fetchall()
+ if not rows:
+ progress[f"{table_name}_done"] = True
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "event_push_backfill_thread_id", progress
+ )
+ return 0
+
+ # Update the thread ID for any of those rows.
+ max_stream_ordering = rows[-1][0]
+
+ sql = f"""
+ UPDATE {table_name}
+ SET thread_id = 'main'
+ WHERE stream_ordering <= ? AND thread_id IS NULL
+ """
+ txn.execute(sql, (max_stream_ordering,))
+
+ # Update progress.
+ processed_rows = txn.rowcount
+ progress[f"max_{table_name}_stream_ordering"] = max_stream_ordering
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "event_push_backfill_thread_id", progress
+ )
+
+ return processed_rows
+
+ # First update the event_push_actions table, then the event_push_summary table.
+ #
+ # Note that the event_push_actions_staging table is ignored since it is
+ # assumed that items in that table will only exist for a short period of
+ # time.
+ if not event_push_actions_done:
+ result = await self.db_pool.runInteraction(
+ "event_push_backfill_thread_id",
+ add_thread_id_txn,
+ "event_push_actions",
+ progress.get("max_event_push_actions_stream_ordering", 0),
+ )
+ else:
+ result = await self.db_pool.runInteraction(
+ "event_push_backfill_thread_id",
+ add_thread_id_txn,
+ "event_push_summary",
+ progress.get("max_event_push_summary_stream_ordering", 0),
+ )
+
+ # Only done after the event_push_summary table is done.
+ if not result:
+ await self.db_pool.updates._end_background_update(
+ "event_push_backfill_thread_id"
+ )
+
+ return result
+
@cached(tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
self,
@@ -274,7 +373,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
receipt_types=(
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
- ReceiptTypes.UNSTABLE_READ_PRIVATE,
),
)
@@ -459,6 +557,31 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
+ def _get_receipts_by_room_txn(
+ self, txn: LoggingTransaction, user_id: str
+ ) -> List[Tuple[str, int]]:
+ receipt_types_clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "receipt_type",
+ (
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ),
+ )
+
+ sql = f"""
+ SELECT room_id, MAX(stream_ordering)
+ FROM receipts_linearized
+ INNER JOIN events USING (room_id, event_id)
+ WHERE {receipt_types_clause}
+ AND user_id = ?
+ GROUP BY room_id
+ """
+
+ args.extend((user_id,))
+ txn.execute(sql, args)
+ return cast(List[Tuple[str, int]], txn.fetchall())
+
async def get_unread_push_actions_for_user_in_range_for_http(
self,
user_id: str,
@@ -482,106 +605,45 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will have between 0~limit entries.
"""
- # find rooms that have a read receipt in them and return the next
- # push actions
- def get_after_receipt(
- txn: LoggingTransaction,
- ) -> List[Tuple[str, str, int, str, bool]]:
- # find rooms that have a read receipt in them and return the next
- # push actions
-
- receipt_types_clause, args = make_in_list_sql_clause(
- self.database_engine,
- "receipt_type",
- (
- ReceiptTypes.READ,
- ReceiptTypes.READ_PRIVATE,
- ReceiptTypes.UNSTABLE_READ_PRIVATE,
- ),
- )
-
- sql = f"""
- SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
- ep.highlight
- FROM (
- SELECT room_id,
- MAX(stream_ordering) as stream_ordering
- FROM events
- INNER JOIN receipts_linearized USING (room_id, event_id)
- WHERE {receipt_types_clause} AND user_id = ?
- GROUP BY room_id
- ) AS rl,
- event_push_actions AS ep
- WHERE
- ep.room_id = rl.room_id
- AND ep.stream_ordering > rl.stream_ordering
- AND ep.user_id = ?
- AND ep.stream_ordering > ?
- AND ep.stream_ordering <= ?
- AND ep.notif = 1
- ORDER BY ep.stream_ordering ASC LIMIT ?
- """
- args.extend(
- (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
- )
- txn.execute(sql, args)
- return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
-
- after_read_receipt = await self.db_pool.runInteraction(
- "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
+ receipts_by_room = dict(
+ await self.db_pool.runInteraction(
+ "get_unread_push_actions_for_user_in_range_http_receipts",
+ self._get_receipts_by_room_txn,
+ user_id=user_id,
+ ),
)
- # There are rooms with push actions in them but you don't have a read receipt in
- # them e.g. rooms you've been invited to, so get push actions for rooms which do
- # not have read receipts in them too.
- def get_no_receipt(
+ def get_push_actions_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool]]:
- receipt_types_clause, args = make_in_list_sql_clause(
- self.database_engine,
- "receipt_type",
- (
- ReceiptTypes.READ,
- ReceiptTypes.READ_PRIVATE,
- ReceiptTypes.UNSTABLE_READ_PRIVATE,
- ),
- )
-
- sql = f"""
- SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
- ep.highlight
+ sql = """
+ SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight
FROM event_push_actions AS ep
- INNER JOIN events AS e USING (room_id, event_id)
WHERE
- ep.room_id NOT IN (
- SELECT room_id FROM receipts_linearized
- WHERE {receipt_types_clause} AND user_id = ?
- GROUP BY room_id
- )
- AND ep.user_id = ?
+ ep.user_id = ?
AND ep.stream_ordering > ?
AND ep.stream_ordering <= ?
AND ep.notif = 1
ORDER BY ep.stream_ordering ASC LIMIT ?
"""
- args.extend(
- (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
- )
- txn.execute(sql, args)
+ txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
- no_read_receipt = await self.db_pool.runInteraction(
- "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
+ push_actions = await self.db_pool.runInteraction(
+ "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
)
notifs = [
HttpPushAction(
- event_id=row[0],
- room_id=row[1],
- stream_ordering=row[2],
- actions=_deserialize_action(row[3], row[4]),
+ event_id=event_id,
+ room_id=room_id,
+ stream_ordering=stream_ordering,
+ actions=_deserialize_action(actions, highlight),
)
- for row in after_read_receipt + no_read_receipt
+ for event_id, room_id, stream_ordering, actions, highlight in push_actions
+ # Only include push actions with a stream ordering after any receipt, or without any
+ # receipt present (invited to but never read rooms).
+ if stream_ordering > receipts_by_room.get(room_id, 0)
]
# Now sort it so it's ordered correctly, since currently it will
@@ -617,106 +679,49 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will have between 0~limit entries.
"""
- # find rooms that have a read receipt in them and return the most recent
- # push actions
- def get_after_receipt(
- txn: LoggingTransaction,
- ) -> List[Tuple[str, str, int, str, bool, int]]:
- receipt_types_clause, args = make_in_list_sql_clause(
- self.database_engine,
- "receipt_type",
- (
- ReceiptTypes.READ,
- ReceiptTypes.READ_PRIVATE,
- ReceiptTypes.UNSTABLE_READ_PRIVATE,
- ),
- )
-
- sql = f"""
- SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
- ep.highlight, e.received_ts
- FROM (
- SELECT room_id,
- MAX(stream_ordering) as stream_ordering
- FROM events
- INNER JOIN receipts_linearized USING (room_id, event_id)
- WHERE {receipt_types_clause} AND user_id = ?
- GROUP BY room_id
- ) AS rl,
- event_push_actions AS ep
- INNER JOIN events AS e USING (room_id, event_id)
- WHERE
- ep.room_id = rl.room_id
- AND ep.stream_ordering > rl.stream_ordering
- AND ep.user_id = ?
- AND ep.stream_ordering > ?
- AND ep.stream_ordering <= ?
- AND ep.notif = 1
- ORDER BY ep.stream_ordering DESC LIMIT ?
- """
- args.extend(
- (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
- )
- txn.execute(sql, args)
- return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
-
- after_read_receipt = await self.db_pool.runInteraction(
- "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
+ receipts_by_room = dict(
+ await self.db_pool.runInteraction(
+ "get_unread_push_actions_for_user_in_range_email_receipts",
+ self._get_receipts_by_room_txn,
+ user_id=user_id,
+ ),
)
- # There are rooms with push actions in them but you don't have a read receipt in
- # them e.g. rooms you've been invited to, so get push actions for rooms which do
- # not have read receipts in them too.
- def get_no_receipt(
+ def get_push_actions_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool, int]]:
- receipt_types_clause, args = make_in_list_sql_clause(
- self.database_engine,
- "receipt_type",
- (
- ReceiptTypes.READ,
- ReceiptTypes.READ_PRIVATE,
- ReceiptTypes.UNSTABLE_READ_PRIVATE,
- ),
- )
-
- sql = f"""
+ sql = """
SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
ep.highlight, e.received_ts
FROM event_push_actions AS ep
INNER JOIN events AS e USING (room_id, event_id)
WHERE
- ep.room_id NOT IN (
- SELECT room_id FROM receipts_linearized
- WHERE {receipt_types_clause} AND user_id = ?
- GROUP BY room_id
- )
- AND ep.user_id = ?
+ ep.user_id = ?
AND ep.stream_ordering > ?
AND ep.stream_ordering <= ?
AND ep.notif = 1
ORDER BY ep.stream_ordering DESC LIMIT ?
"""
- args.extend(
- (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
- )
- txn.execute(sql, args)
+ txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
- no_read_receipt = await self.db_pool.runInteraction(
- "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
+ push_actions = await self.db_pool.runInteraction(
+ "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
)
# Make a list of dicts from the two sets of results.
notifs = [
EmailPushAction(
- event_id=row[0],
- room_id=row[1],
- stream_ordering=row[2],
- actions=_deserialize_action(row[3], row[4]),
- received_ts=row[5],
+ event_id=event_id,
+ room_id=room_id,
+ stream_ordering=stream_ordering,
+ actions=_deserialize_action(actions, highlight),
+ received_ts=received_ts,
)
- for row in after_read_receipt + no_read_receipt
+ for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions
+ # Only include push actions with a stream ordering after any receipt, or without any
+ # receipt present (invited to but never read rooms).
+ if stream_ordering > receipts_by_room.get(room_id, 0)
]
# Now sort it so it's ordered correctly, since currently it will
@@ -764,6 +769,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
event_id: str,
user_id_actions: Dict[str, Collection[Union[Mapping, str]]],
count_as_unread: bool,
+ thread_id: str,
) -> None:
"""Add the push actions for the event to the push action staging area.
@@ -772,6 +778,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id_actions: A mapping of user_id to list of push actions, where
an action can either be a string or dict.
count_as_unread: Whether this event should increment unread counts.
+ thread_id: The thread this event is parent of, if applicable.
"""
if not user_id_actions:
return
@@ -780,7 +787,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(
user_id: str, actions: Collection[Union[Mapping, str]]
- ) -> Tuple[str, str, str, int, int, int]:
+ ) -> Tuple[str, str, str, int, int, int, str]:
is_highlight = 1 if _action_has_highlight(actions) else 0
notif = 1 if "notify" in actions else 0
return (
@@ -790,28 +797,25 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
notif, # notif column
is_highlight, # highlight column
int(count_as_unread), # unread column
+ thread_id, # thread_id column
)
- def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
- # We don't use simple_insert_many here to avoid the overhead
- # of generating lists of dicts.
-
- sql = """
- INSERT INTO event_push_actions_staging
- (event_id, user_id, actions, notif, highlight, unread)
- VALUES (?, ?, ?, ?, ?, ?)
- """
-
- txn.execute_batch(
- sql,
- (
- _gen_entry(user_id, actions)
- for user_id, actions in user_id_actions.items()
- ),
- )
-
- return await self.db_pool.runInteraction(
- "add_push_actions_to_staging", _add_push_actions_to_staging_txn
+ await self.db_pool.simple_insert_many(
+ "event_push_actions_staging",
+ keys=(
+ "event_id",
+ "user_id",
+ "actions",
+ "notif",
+ "highlight",
+ "unread",
+ "thread_id",
+ ),
+ values=[
+ _gen_entry(user_id, actions)
+ for user_id, actions in user_id_actions.items()
+ ],
+ desc="add_push_actions_to_staging",
)
async def remove_push_actions_from_staging(self, event_id: str) -> None:
@@ -1087,6 +1091,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
# Replace the previous summary with the new counts.
+ #
+ # TODO(threads): Upsert per-thread instead of setting them all to main.
self.db_pool.simple_upsert_txn(
txn,
table="event_push_summary",
@@ -1096,6 +1102,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
"unread_count": unread_count,
"stream_ordering": old_rotate_stream_ordering,
"last_receipt_stream_ordering": stream_ordering,
+ "thread_id": "main",
},
)
@@ -1244,17 +1251,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
logger.info("Rotating notifications, handling %d rows", len(summaries))
+ # TODO(threads): Update on a per-thread basis.
self.db_pool.simple_upsert_many_txn(
txn,
table="event_push_summary",
key_names=("user_id", "room_id"),
key_values=[(user_id, room_id) for user_id, room_id in summaries],
- value_names=("notif_count", "unread_count", "stream_ordering"),
+ value_names=("notif_count", "unread_count", "stream_ordering", "thread_id"),
value_values=[
(
summary.notif_count,
summary.unread_count,
summary.stream_ordering,
+ "main",
)
for summary in summaries.values()
],
@@ -1361,7 +1370,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
table="event_push_actions",
columns=["highlight", "stream_ordering"],
where_clause="highlight=0",
- psql_only=True,
)
async def get_push_actions_for_user(
|