summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2022-06-14 09:56:47 -0400
committerPatrick Cloke <patrickc@matrix.org>2022-08-05 08:18:08 -0400
commit2c7a5681b44fe70feb3baf5ed8364031f0410db7 (patch)
tree6e12702672688a04eabbb175466f16142330fe2e
parentMark token-authenticaticated-registration API as not-experimental (#11897) (diff)
downloadsynapse-2c7a5681b44fe70feb3baf5ed8364031f0410db7.tar.xz
Extract the thread ID when processing push rules.
-rw-r--r--changelog.d/13181.feature1
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py30
-rw-r--r--synapse/storage/databases/main/event_push_actions.py50
-rw-r--r--synapse/storage/databases/main/events.py4
-rw-r--r--synapse/storage/schema/main/delta/72/03thread_notifications.sql27
-rw-r--r--tests/replication/slave/storage/test_events.py1
6 files changed, 78 insertions, 35 deletions
diff --git a/changelog.d/13181.feature b/changelog.d/13181.feature
new file mode 100644
index 0000000000..22bce125ce
--- /dev/null
+++ b/changelog.d/13181.feature
@@ -0,0 +1 @@
+Experimental support for thread-specific notifications ([MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)).
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 713dcf6950..b63b45cd75 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -186,7 +186,7 @@ class BulkPushRuleEvaluator:
         return pl_event.content if pl_event else {}, sender_level
 
     async def _get_mutual_relations(
-        self, event: EventBase, rules: Iterable[Dict[str, Any]]
+        self, parent_id: str, rules: Iterable[Dict[str, Any]]
     ) -> Dict[str, Set[Tuple[str, str]]]:
         """
         Fetch event metadata for events which related to the same event as the given event.
@@ -194,7 +194,7 @@ class BulkPushRuleEvaluator:
         If the given event has no relation information, returns an empty dictionary.
 
         Args:
-            event_id: The event ID which is targeted by relations.
+            parent_id: The event ID which is targeted by relations.
             rules: The push rules which will be processed for this event.
 
         Returns:
@@ -208,12 +208,6 @@ class BulkPushRuleEvaluator:
         if not self._relations_match_enabled:
             return {}
 
-        # If the event does not have a relation, then cannot have any mutual
-        # relations.
-        relation = relation_from_event(event)
-        if not relation:
-            return {}
-
         # Pre-filter to figure out which relation types are interesting.
         rel_types = set()
         for rule in rules:
@@ -235,9 +229,7 @@ class BulkPushRuleEvaluator:
             return {}
 
         # If any valid rules were found, fetch the mutual relations.
-        return await self.store.get_mutual_event_relations(
-            relation.parent_id, rel_types
-        )
+        return await self.store.get_mutual_event_relations(parent_id, rel_types)
 
     @measure_func("action_for_event_by_user")
     async def action_for_event_by_user(
@@ -265,9 +257,18 @@ class BulkPushRuleEvaluator:
             sender_power_level,
         ) = await self._get_power_levels_and_sender_level(event, context)
 
-        relations = await self._get_mutual_relations(
-            event, itertools.chain(*rules_by_user.values())
-        )
+        relation = relation_from_event(event)
+        # If the event does not have a relation, then cannot have any mutual
+        # relations or thread ID.
+        relations = {}
+        thread_id = None
+        if relation:
+            relations = await self._get_mutual_relations(
+                relation.parent_id, itertools.chain(*rules_by_user.values())
+            )
+            # XXX Does this need to point to a valid parent ID or anything?
+            if relation.rel_type == RelationTypes.THREAD:
+                thread_id = relation.parent_id
 
         evaluator = PushRuleEvaluatorForEvent(
             event,
@@ -338,6 +339,7 @@ class BulkPushRuleEvaluator:
             event.event_id,
             actions_by_user,
             count_as_unread,
+            thread_id,
         )
 
 
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 5db70f9a60..8cdbc242e3 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -220,6 +220,15 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             replaces_index="event_push_summary_user_rm",
         )
 
+        self.db_pool.updates.register_background_index_update(
+            "event_push_summary_unique_index2",
+            index_name="event_push_summary_unique_index2",
+            table="event_push_summary",
+            columns=["user_id", "room_id", "thread_id"],
+            unique=True,
+            replaces_index="event_push_summary_unique_index",
+        )
+
     @cached(tree=True, max_entries=5000)
     async def get_unread_event_push_actions_by_room_for_user(
         self,
@@ -680,6 +689,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         event_id: str,
         user_id_actions: Dict[str, List[Union[dict, str]]],
         count_as_unread: bool,
+        thread_id: Optional[str],
     ) -> None:
         """Add the push actions for the event to the push action staging area.
 
@@ -688,6 +698,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             user_id_actions: A mapping of user_id to list of push actions, where
                 an action can either be a string or dict.
             count_as_unread: Whether this event should increment unread counts.
+            thread_id: The thread this event is parent of, if applicable.
         """
         if not user_id_actions:
             return
@@ -696,7 +707,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         # can be used to insert into the `event_push_actions_staging` table.
         def _gen_entry(
             user_id: str, actions: List[Union[dict, str]]
-        ) -> Tuple[str, str, str, int, int, int]:
+        ) -> Tuple[str, str, str, int, int, int, Optional[str]]:
             is_highlight = 1 if _action_has_highlight(actions) else 0
             notif = 1 if "notify" in actions else 0
             return (
@@ -706,6 +717,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 notif,  # notif column
                 is_highlight,  # highlight column
                 int(count_as_unread),  # unread column
+                thread_id or "",  # thread_id column
             )
 
         def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
@@ -714,8 +726,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
             sql = """
                 INSERT INTO event_push_actions_staging
-                    (event_id, user_id, actions, notif, highlight, unread)
-                VALUES (?, ?, ?, ?, ?, ?)
+                    (event_id, user_id, actions, notif, highlight, unread, thread_id)
+                VALUES (?, ?, ?, ?, ?, ?, ?)
             """
 
             txn.execute_batch(
@@ -1102,23 +1114,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         # Calculate the new counts that should be upserted into event_push_summary
         sql = """
-            SELECT user_id, room_id,
+            SELECT user_id, room_id, thread_id,
                 coalesce(old.%s, 0) + upd.cnt,
                 upd.stream_ordering
             FROM (
-                SELECT user_id, room_id, count(*) as cnt,
+                SELECT user_id, room_id, thread_id, count(*) as cnt,
                     max(ea.stream_ordering) as stream_ordering
                 FROM event_push_actions AS ea
-                LEFT JOIN event_push_summary AS old USING (user_id, room_id)
+                LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
                 WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ?
                     AND (
                         old.last_receipt_stream_ordering IS NULL
                         OR old.last_receipt_stream_ordering < ea.stream_ordering
                     )
                     AND %s = 1
-                GROUP BY user_id, room_id
+                GROUP BY user_id, room_id, thread_id
             ) AS upd
-            LEFT JOIN event_push_summary AS old USING (user_id, room_id)
+            LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
         """
 
         # First get the count of unread messages.
@@ -1132,11 +1144,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         # object because we might not have the same amount of rows in each of them. To do
         # this, we use a dict indexed on the user ID and room ID to make it easier to
         # populate.
-        summaries: Dict[Tuple[str, str], _EventPushSummary] = {}
+        summaries: Dict[Tuple[str, str, str], _EventPushSummary] = {}
         for row in txn:
-            summaries[(row[0], row[1])] = _EventPushSummary(
-                unread_count=row[2],
-                stream_ordering=row[3],
+            summaries[(row[0], row[1], row[2])] = _EventPushSummary(
+                unread_count=row[3],
+                stream_ordering=row[4],
                 notif_count=0,
             )
 
@@ -1147,17 +1159,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         )
 
         for row in txn:
-            if (row[0], row[1]) in summaries:
-                summaries[(row[0], row[1])].notif_count = row[2]
+            if (row[0], row[1], row[2]) in summaries:
+                summaries[(row[0], row[1], row[2])].notif_count = row[3]
             else:
                 # Because the rules on notifying are different than the rules on marking
                 # a message unread, we might end up with messages that notify but aren't
                 # marked unread, so we might not have a summary for this (user, room)
                 # tuple to complete.
-                summaries[(row[0], row[1])] = _EventPushSummary(
+                summaries[(row[0], row[1], row[2])] = _EventPushSummary(
                     unread_count=0,
-                    stream_ordering=row[3],
-                    notif_count=row[2],
+                    stream_ordering=row[4],
+                    notif_count=row[3],
                 )
 
         logger.info("Rotating notifications, handling %d rows", len(summaries))
@@ -1165,8 +1177,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         self.db_pool.simple_upsert_many_txn(
             txn,
             table="event_push_summary",
-            key_names=("user_id", "room_id"),
-            key_values=[(user_id, room_id) for user_id, room_id in summaries],
+            key_names=("user_id", "room_id", "thread_id"),
+            key_values=[(user_id, room_id, thread_id) for user_id, room_id, thread_id in summaries],
             value_names=("notif_count", "unread_count", "stream_ordering"),
             value_values=[
                 (
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5560b38a48..da84820b76 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -2190,9 +2190,9 @@ class PersistEventsStore:
         sql = """
             INSERT INTO event_push_actions (
                 room_id, event_id, user_id, actions, stream_ordering,
-                topological_ordering, notif, highlight, unread
+                topological_ordering, notif, highlight, unread, thread_id
             )
-            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
+            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread, thread_id
             FROM event_push_actions_staging
             WHERE event_id = ?
         """
diff --git a/synapse/storage/schema/main/delta/72/03thread_notifications.sql b/synapse/storage/schema/main/delta/72/03thread_notifications.sql
new file mode 100644
index 0000000000..27eb52e347
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/03thread_notifications.sql
@@ -0,0 +1,27 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE event_push_actions_staging
+  ADD COLUMN thread_id TEXT NOT NULL DEFAULT '';
+
+ALTER TABLE event_push_actions
+  ADD COLUMN thread_id TEXT NOT NULL DEFAULT '';
+
+ALTER TABLE event_push_summary
+  ADD COLUMN thread_id TEXT NOT NULL DEFAULT '';
+
+-- Update the unique index for `event_push_summary`
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (7003, 'event_push_summary_unique_index2', '{}');
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 531a0db2d0..f16554cd5c 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -404,6 +404,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
                 event.event_id,
                 {user_id: actions for user_id, actions in push_actions},
                 False,
+                None,
             )
         )
         return event, context