diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3a3fb8c507..6b8668d2dc 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -98,6 +98,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -232,6 +233,104 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
replaces_index="event_push_summary_user_rm",
)
+ self.db_pool.updates.register_background_index_update(
+ "event_push_summary_unique_index2",
+ index_name="event_push_summary_unique_index2",
+ table="event_push_summary",
+ columns=["user_id", "room_id", "thread_id"],
+ unique=True,
+ )
+
+ self.db_pool.updates.register_background_update_handler(
+ "event_push_backfill_thread_id",
+ self._background_backfill_thread_id,
+ )
+
+ async def _background_backfill_thread_id(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """
+ Fill in the thread_id field for event_push_actions and event_push_summary.
+
+ This is preparatory so that it can be made non-nullable in the future.
+
+ Because all current (null) data is done in an unthreaded manner this
+ simply assumes it is on the "main" timeline. Since event_push_actions
+ are periodically cleared it is not possible to correctly re-calculate
+ the thread_id.
+ """
+ event_push_actions_done = progress.get("event_push_actions_done", False)
+
+ def add_thread_id_txn(
+ txn: LoggingTransaction, table_name: str, start_stream_ordering: int
+ ) -> int:
+ sql = f"""
+ SELECT stream_ordering
+ FROM {table_name}
+ WHERE
+ thread_id IS NULL
+ AND stream_ordering > ?
+ ORDER BY stream_ordering
+ LIMIT ?
+ """
+ txn.execute(sql, (start_stream_ordering, batch_size))
+
+ # No more rows to process.
+ rows = txn.fetchall()
+ if not rows:
+ progress[f"{table_name}_done"] = True
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "event_push_backfill_thread_id", progress
+ )
+ return 0
+
+ # Update the thread ID for any of those rows.
+ max_stream_ordering = rows[-1][0]
+
+ sql = f"""
+ UPDATE {table_name}
+ SET thread_id = 'main'
+ WHERE stream_ordering <= ? AND thread_id IS NULL
+ """
+ txn.execute(sql, (max_stream_ordering,))
+
+ # Update progress.
+ processed_rows = txn.rowcount
+ progress[f"max_{table_name}_stream_ordering"] = max_stream_ordering
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "event_push_backfill_thread_id", progress
+ )
+
+ return processed_rows
+
+ # First update the event_push_actions table, then the event_push_summary table.
+ #
+ # Note that the event_push_actions_staging table is ignored since it is
+ # assumed that items in that table will only exist for a short period of
+ # time.
+ if not event_push_actions_done:
+ result = await self.db_pool.runInteraction(
+ "event_push_backfill_thread_id",
+ add_thread_id_txn,
+ "event_push_actions",
+ progress.get("max_event_push_actions_stream_ordering", 0),
+ )
+ else:
+ result = await self.db_pool.runInteraction(
+ "event_push_backfill_thread_id",
+ add_thread_id_txn,
+ "event_push_summary",
+ progress.get("max_event_push_summary_stream_ordering", 0),
+ )
+
+ # Only done after the event_push_summary table is done.
+ if not result:
+ await self.db_pool.updates._end_background_update(
+ "event_push_backfill_thread_id"
+ )
+
+ return result
+
@cached(tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
self,
@@ -670,6 +769,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
event_id: str,
user_id_actions: Dict[str, Collection[Union[Mapping, str]]],
count_as_unread: bool,
+ thread_id: str,
) -> None:
"""Add the push actions for the event to the push action staging area.
@@ -678,6 +778,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id_actions: A mapping of user_id to list of push actions, where
an action can either be a string or dict.
count_as_unread: Whether this event should increment unread counts.
+ thread_id: The thread this event is parent of, if applicable.
"""
if not user_id_actions:
return
@@ -686,7 +787,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(
user_id: str, actions: Collection[Union[Mapping, str]]
- ) -> Tuple[str, str, str, int, int, int]:
+ ) -> Tuple[str, str, str, int, int, int, str]:
is_highlight = 1 if _action_has_highlight(actions) else 0
notif = 1 if "notify" in actions else 0
return (
@@ -696,11 +797,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
notif, # notif column
is_highlight, # highlight column
int(count_as_unread), # unread column
+ thread_id, # thread_id column
)
await self.db_pool.simple_insert_many(
"event_push_actions_staging",
- keys=("event_id", "user_id", "actions", "notif", "highlight", "unread"),
+ keys=(
+ "event_id",
+ "user_id",
+ "actions",
+ "notif",
+ "highlight",
+ "unread",
+ "thread_id",
+ ),
values=[
_gen_entry(user_id, actions)
for user_id, actions in user_id_actions.items()
@@ -981,6 +1091,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
# Replace the previous summary with the new counts.
+ #
+ # TODO(threads): Upsert per-thread instead of setting them all to main.
self.db_pool.simple_upsert_txn(
txn,
table="event_push_summary",
@@ -990,6 +1102,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
"unread_count": unread_count,
"stream_ordering": old_rotate_stream_ordering,
"last_receipt_stream_ordering": stream_ordering,
+ "thread_id": "main",
},
)
@@ -1138,17 +1251,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
logger.info("Rotating notifications, handling %d rows", len(summaries))
+ # TODO(threads): Update on a per-thread basis.
self.db_pool.simple_upsert_many_txn(
txn,
table="event_push_summary",
key_names=("user_id", "room_id"),
key_values=[(user_id, room_id) for user_id, room_id in summaries],
- value_names=("notif_count", "unread_count", "stream_ordering"),
+ value_names=("notif_count", "unread_count", "stream_ordering", "thread_id"),
value_values=[
(
summary.notif_count,
summary.unread_count,
summary.stream_ordering,
+ "main",
)
for summary in summaries.values()
],
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index a4010ee28d..c0b4080e4b 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -2192,9 +2192,9 @@ class PersistEventsStore:
sql = """
INSERT INTO event_push_actions (
room_id, event_id, user_id, actions, stream_ordering,
- topological_ordering, notif, highlight, unread
+ topological_ordering, notif, highlight, unread, thread_id
)
- SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
+ SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread, thread_id
FROM event_push_actions_staging
WHERE event_id = ?
"""
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 719a12b0ae..ddb8e80b69 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -113,6 +113,24 @@ class ReceiptsWorkerStore(SQLBaseStore):
prefilled_cache=receipts_stream_prefill,
)
+ self.db_pool.updates.register_background_index_update(
+ "receipts_linearized_unique_index",
+ index_name="receipts_linearized_unique_index",
+ table="receipts_linearized",
+ columns=["room_id", "receipt_type", "user_id"],
+ where_clause="thread_id IS NULL",
+ unique=True,
+ )
+
+ self.db_pool.updates.register_background_index_update(
+ "receipts_graph_unique_index",
+ index_name="receipts_graph_unique_index",
+ table="receipts_graph",
+ columns=["room_id", "receipt_type", "user_id"],
+ where_clause="thread_id IS NULL",
+ unique=True,
+ )
+
def get_max_receipt_stream_id(self) -> int:
"""Get the current max stream ID for receipts stream"""
return self._receipts_id_gen.get_current_token()
@@ -677,6 +695,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"event_id": event_id,
"event_stream_ordering": stream_ordering,
"data": json_encoder.encode(data),
+ "thread_id": None,
},
# receipts_linearized has a unique constraint on
# (user_id, room_id, receipt_type), so no need to lock
@@ -824,6 +843,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
values={
"event_ids": json_encoder.encode(event_ids),
"data": json_encoder.encode(data),
+ "thread_id": None,
},
# receipts_graph has a unique constraint on
# (user_id, room_id, receipt_type), so no need to lock
|