summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2022-06-08 13:46:49 -0400
committerPatrick Cloke <patrickc@matrix.org>2022-06-13 09:57:06 -0400
commitcbbe77f62073379da7ca77d5b743f64c1bbc3e82 (patch)
tree967626cfef14b42834ba1720ca1d3b842935fcfb
parentAdd demo script. (diff)
downloadsynapse-cbbe77f62073379da7ca77d5b743f64c1bbc3e82.tar.xz
Include the thread ID in the event push actions.
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py30
-rw-r--r--synapse/storage/databases/main/event_push_actions.py48
-rw-r--r--synapse/storage/databases/main/events.py4
-rw-r--r--synapse/storage/databases/main/receipts.py2
-rw-r--r--synapse/storage/schema/main/delta/70/03thread_notifications.sql23
-rw-r--r--tests/replication/slave/storage/test_events.py1
-rw-r--r--tests/storage/test_event_push_actions.py1
7 files changed, 71 insertions, 38 deletions
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 7791b289e2..d1c929e202 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -195,7 +195,7 @@ class BulkPushRuleEvaluator:
         return pl_event.content if pl_event else {}, sender_level
 
     async def _get_mutual_relations(
-        self, event: EventBase, rules: Iterable[Dict[str, Any]]
+        self, parent_id: str, rules: Iterable[Dict[str, Any]]
     ) -> Dict[str, Set[Tuple[str, str]]]:
         """
         Fetch event metadata for events which related to the same event as the given event.
@@ -203,7 +203,7 @@ class BulkPushRuleEvaluator:
         If the given event has no relation information, returns an empty dictionary.
 
         Args:
-            event_id: The event ID which is targeted by relations.
+            parent_id: The event ID which is targeted by relations.
             rules: The push rules which will be processed for this event.
 
         Returns:
@@ -217,12 +217,6 @@ class BulkPushRuleEvaluator:
         if not self._relations_match_enabled:
             return {}
 
-        # If the event does not have a relation, then cannot have any mutual
-        # relations.
-        relation = relation_from_event(event)
-        if not relation:
-            return {}
-
         # Pre-filter to figure out which relation types are interesting.
         rel_types = set()
         for rule in rules:
@@ -244,9 +238,7 @@ class BulkPushRuleEvaluator:
             return {}
 
         # If any valid rules were found, fetch the mutual relations.
-        return await self.store.get_mutual_event_relations(
-            relation.parent_id, rel_types
-        )
+        return await self.store.get_mutual_event_relations(parent_id, rel_types)
 
     @measure_func("action_for_event_by_user")
     async def action_for_event_by_user(
@@ -272,9 +264,18 @@ class BulkPushRuleEvaluator:
             sender_power_level,
         ) = await self._get_power_levels_and_sender_level(event, context)
 
-        relations = await self._get_mutual_relations(
-            event, itertools.chain(*rules_by_user.values())
-        )
+        relation = relation_from_event(event)
+        # If the event does not have a relation, then cannot have any mutual
+        # relations or thread ID.
+        relations = {}
+        thread_id = None
+        if relation:
+            relations = await self._get_mutual_relations(
+                relation.parent_id, itertools.chain(*rules_by_user.values())
+            )
+            # XXX Does this need to point to a valid parent ID or anything?
+            if relation.rel_type == RelationTypes.THREAD:
+                thread_id = relation.parent_id
 
         evaluator = PushRuleEvaluatorForEvent(
             event,
@@ -339,6 +340,7 @@ class BulkPushRuleEvaluator:
             event.event_id,
             actions_by_user,
             count_as_unread,
+            thread_id,
         )
 
 
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index b0b3695012..812ed1a3d4 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -528,6 +528,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         event_id: str,
         user_id_actions: Dict[str, List[Union[dict, str]]],
         count_as_unread: bool,
+        thread_id: Optional[str],
     ) -> None:
         """Add the push actions for the event to the push action staging area.
 
@@ -536,6 +537,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             user_id_actions: A mapping of user_id to list of push actions, where
                 an action can either be a string or dict.
             count_as_unread: Whether this event should increment unread counts.
+            thread_id: The thread this event is parent of, if applicable.
         """
         if not user_id_actions:
             return
@@ -544,7 +546,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # can be used to insert into the `event_push_actions_staging` table.
         def _gen_entry(
             user_id: str, actions: List[Union[dict, str]]
-        ) -> Tuple[str, str, str, int, int, int]:
+        ) -> Tuple[str, str, str, int, int, int, Optional[str]]:
             is_highlight = 1 if _action_has_highlight(actions) else 0
             notif = 1 if "notify" in actions else 0
             return (
@@ -554,6 +556,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 notif,  # notif column
                 is_highlight,  # highlight column
                 int(count_as_unread),  # unread column
+                thread_id,  # thread_id column
             )
 
         def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
@@ -562,8 +565,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
             sql = """
                 INSERT INTO event_push_actions_staging
-                    (event_id, user_id, actions, notif, highlight, unread)
-                VALUES (?, ?, ?, ?, ?, ?)
+                    (event_id, user_id, actions, notif, highlight, unread, thread_id)
+                VALUES (?, ?, ?, ?, ?, ?, ?)
             """
 
             txn.execute_batch(
@@ -810,20 +813,20 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
         # Calculate the new counts that should be upserted into event_push_summary
         sql = """
-            SELECT user_id, room_id,
+            SELECT user_id, room_id, thread_id,
                 coalesce(old.%s, 0) + upd.cnt,
                 upd.stream_ordering,
                 old.user_id
             FROM (
-                SELECT user_id, room_id, count(*) as cnt,
+                SELECT user_id, room_id, thread_id, count(*) as cnt,
                     max(stream_ordering) as stream_ordering
                 FROM event_push_actions
                 WHERE ? <= stream_ordering AND stream_ordering < ?
                     AND highlight = 0
                     AND %s = 1
-                GROUP BY user_id, room_id
+                GROUP BY user_id, room_id, thread_id
             ) AS upd
-            LEFT JOIN event_push_summary AS old USING (user_id, room_id)
+            LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
         """
 
         # First get the count of unread messages.
@@ -837,12 +840,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # object because we might not have the same amount of rows in each of them. To do
         # this, we use a dict indexed on the user ID and room ID to make it easier to
         # populate.
-        summaries: Dict[Tuple[str, str], _EventPushSummary] = {}
+        summaries: Dict[Tuple[str, str, Optional[str]], _EventPushSummary] = {}
         for row in txn:
-            summaries[(row[0], row[1])] = _EventPushSummary(
-                unread_count=row[2],
-                stream_ordering=row[3],
-                old_user_id=row[4],
+            summaries[(row[0], row[1], row[2])] = _EventPushSummary(
+                unread_count=row[3],
+                stream_ordering=row[4],
+                old_user_id=row[5],
                 notif_count=0,
             )
 
@@ -853,18 +856,18 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
 
         for row in txn:
-            if (row[0], row[1]) in summaries:
-                summaries[(row[0], row[1])].notif_count = row[2]
+            if (row[0], row[1], row[2]) in summaries:
+                summaries[(row[0], row[1], row[2])].notif_count = row[3]
             else:
                 # Because the rules on notifying are different than the rules on marking
                 # a message unread, we might end up with messages that notify but aren't
                 # marked unread, so we might not have a summary for this (user, room)
                 # tuple to complete.
-                summaries[(row[0], row[1])] = _EventPushSummary(
+                summaries[(row[0], row[1], row[2])] = _EventPushSummary(
                     unread_count=0,
-                    stream_ordering=row[3],
-                    old_user_id=row[4],
-                    notif_count=row[2],
+                    stream_ordering=row[4],
+                    old_user_id=row[5],
+                    notif_count=row[3],
                 )
 
         logger.info("Rotating notifications, handling %d rows", len(summaries))
@@ -881,6 +884,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "notif_count",
                 "unread_count",
                 "stream_ordering",
+                "thread_id",
             ),
             values=[
                 (
@@ -889,8 +893,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                     summary.notif_count,
                     summary.unread_count,
                     summary.stream_ordering,
+                    thread_id,
                 )
-                for ((user_id, room_id), summary) in summaries.items()
+                for ((user_id, room_id, thread_id), summary) in summaries.items()
                 if summary.old_user_id is None
             ],
         )
@@ -899,7 +904,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             """
                 UPDATE event_push_summary
                 SET notif_count = ?, unread_count = ?, stream_ordering = ?
-                WHERE user_id = ? AND room_id = ?
+                WHERE user_id = ? AND room_id = ? AND thread_id = ?
             """,
             (
                 (
@@ -908,8 +913,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                     summary.stream_ordering,
                     user_id,
                     room_id,
+                    thread_id,
                 )
-                for ((user_id, room_id), summary) in summaries.items()
+                for ((user_id, room_id, thread_id), summary) in summaries.items()
                 if summary.old_user_id is not None
             ),
         )
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5b86ac55e9..6a1564349f 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -2302,9 +2302,9 @@ class PersistEventsStore:
         sql = """
             INSERT INTO event_push_actions (
                 room_id, event_id, user_id, actions, stream_ordering,
-                topological_ordering, notif, highlight, unread
+                topological_ordering, notif, highlight, unread, thread_id
             )
-            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
+            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread, thread_id
             FROM event_push_actions_staging
             WHERE event_id = ?
         """
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 4622e8910e..cece802c6e 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -731,7 +731,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 user_id,
                 receipt_type,
                 start_topo_ordering - 1 if start_topo_ordering is not None else None,
-                end_topo_ordering + 1,
+                end_topo_ordering + 1 if end_topo_ordering is not None else None,
             ),
         )
         overlapping_receipts = txn.fetchall()
diff --git a/synapse/storage/schema/main/delta/70/03thread_notifications.sql b/synapse/storage/schema/main/delta/70/03thread_notifications.sql
new file mode 100644
index 0000000000..6fd444ccc1
--- /dev/null
+++ b/synapse/storage/schema/main/delta/70/03thread_notifications.sql
@@ -0,0 +1,23 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE event_push_actions_staging
+  ADD COLUMN thread_id TEXT DEFAULT NULL;
+
+ALTER TABLE event_push_actions
+  ADD COLUMN thread_id TEXT DEFAULT NULL;
+
+ALTER TABLE event_push_summary
+  ADD COLUMN thread_id TEXT DEFAULT NULL;
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 6d3d4afe52..3ad2b2ad79 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -393,6 +393,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
                 event.event_id,
                 {user_id: actions for user_id, actions in push_actions},
                 False,
+                None,
             )
         )
         return event, context
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 0f9add4841..1b52da201c 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -79,6 +79,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
                     event.event_id,
                     {user_id: action},
                     False,
+                    None,
                 )
             )
             self.get_success(