summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13824.feature1
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py5
-rw-r--r--synapse/rest/client/receipts.py22
-rw-r--r--synapse/storage/databases/main/relations.py36
-rw-r--r--tests/storage/test_event_push_actions.py100
5 files changed, 162 insertions, 2 deletions
diff --git a/changelog.d/13824.feature b/changelog.d/13824.feature
new file mode 100644
index 0000000000..d0cb902dff
--- /dev/null
+++ b/changelog.d/13824.feature
@@ -0,0 +1 @@
+Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 61d952742d..f8c4dd74f0 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -286,8 +286,13 @@ class BulkPushRuleEvaluator:
                 relation.parent_id,
                 itertools.chain(*(r.rules() for r in rules_by_user.values())),
             )
+            # Recursively attempt to find the thread this event relates to.
             if relation.rel_type == RelationTypes.THREAD:
                 thread_id = relation.parent_id
+            else:
+                # Since the event has not yet been persisted we check whether
+                # the parent is part of a thread.
+                thread_id = await self.store.get_thread_id(relation.parent_id) or "main"
 
         evaluator = PushRuleEvaluator(
             _flatten_dict(event),
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index f3ff156abe..287dfdd69e 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -16,7 +16,7 @@ import logging
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.constants import ReceiptTypes
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.http.site import SynapseRequest
@@ -43,6 +43,7 @@ class ReceiptRestServlet(RestServlet):
         self.receipts_handler = hs.get_receipts_handler()
         self.read_marker_handler = hs.get_read_marker_handler()
         self.presence_handler = hs.get_presence_handler()
+        self._main_store = hs.get_datastores().main
 
         self._known_receipt_types = {
             ReceiptTypes.READ,
@@ -71,7 +72,24 @@ class ReceiptRestServlet(RestServlet):
                 thread_id = body.get("thread_id")
                 if not thread_id or not isinstance(thread_id, str):
                     raise SynapseError(
-                        400, "thread_id field must be a non-empty string"
+                        400,
+                        "thread_id field must be a non-empty string",
+                        Codes.INVALID_PARAM,
+                    )
+
+                if receipt_type == ReceiptTypes.FULLY_READ:
+                    raise SynapseError(
+                        400,
+                        f"thread_id is not compatible with {ReceiptTypes.FULLY_READ} receipts.",
+                        Codes.INVALID_PARAM,
+                    )
+
+                # Ensure the event ID roughly correlates to the thread ID.
+                if thread_id != await self._main_store.get_thread_id(event_id):
+                    raise SynapseError(
+                        400,
+                        f"event_id {event_id} is not related to thread {thread_id}",
+                        Codes.INVALID_PARAM,
                     )
 
         await self.presence_handler.bump_presence_active_time(requester.user)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 898947af95..154385b1e8 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -832,6 +832,42 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_event_relations", _get_event_relations
         )
 
+    @cached()
+    async def get_thread_id(self, event_id: str) -> Optional[str]:
+        """
+        Get the thread ID for an event. This considers multi-level relations,
+        e.g. an annotation to an event which is part of a thread.
+
+        Args:
+            event_id: The event ID to fetch the thread ID for.
+
+        Returns:
+            The event ID of the root event in the thread, if this event is part
+            of a thread. None, otherwise.
+        """
+        # Since event relations form a tree, we should only ever find 0 or 1
+        # results from the below query.
+        sql = """
+            WITH RECURSIVE related_events AS (
+                SELECT event_id, relates_to_id, relation_type
+                FROM event_relations
+                WHERE event_id = ?
+                UNION SELECT e.event_id, e.relates_to_id, e.relation_type
+                FROM event_relations e
+                INNER JOIN related_events r ON r.relates_to_id = e.event_id
+            ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread';
+        """
+
+        def _get_thread_id(txn: LoggingTransaction) -> Optional[str]:
+            txn.execute(sql, (event_id,))
+            # TODO Should we ensure there's only a single result here?
+            row = txn.fetchone()
+            if row:
+                return row[0]
+            return None
+
+        return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)
+
 
 class RelationsStore(RelationsWorkerStore):
     pass
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 6fa0cafb75..886585e9f2 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -588,6 +588,106 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
         _rotate()
         _assert_counts(0, 0, 0, 0)
 
+    def test_recursive_thread(self) -> None:
+        """
+        Events related to events in a thread should still be considered part of
+        that thread.
+        """
+
+        # Create a user to receive notifications and send receipts.
+        user_id = self.register_user("user1235", "pass")
+        token = self.login("user1235", "pass")
+
+        # And another users to send events.
+        other_id = self.register_user("other", "pass")
+        other_token = self.login("other", "pass")
+
+        # Create a room and put both users in it.
+        room_id = self.helper.create_room_as(user_id, tok=token)
+        self.helper.join(room_id, other_id, tok=other_token)
+
+        # Update the user's push rules to care about reaction events.
+        self.get_success(
+            self.store.add_push_rule(
+                user_id,
+                "related_events",
+                priority_class=5,
+                conditions=[
+                    {"kind": "event_match", "key": "type", "pattern": "m.reaction"}
+                ],
+                actions=["notify"],
+            )
+        )
+
+        def _create_event(type: str, content: JsonDict) -> str:
+            result = self.helper.send_event(
+                room_id, type=type, content=content, tok=other_token
+            )
+            return result["event_id"]
+
+        def _assert_counts(noitf_count: int, thread_notif_count: int) -> None:
+            counts = self.get_success(
+                self.store.db_pool.runInteraction(
+                    "get-unread-counts",
+                    self.store._get_unread_counts_by_receipt_txn,
+                    room_id,
+                    user_id,
+                )
+            )
+            self.assertEqual(
+                counts.main_timeline,
+                NotifCounts(
+                    notify_count=noitf_count, unread_count=0, highlight_count=0
+                ),
+            )
+            if thread_notif_count:
+                self.assertEqual(
+                    counts.threads,
+                    {
+                        thread_id: NotifCounts(
+                            notify_count=thread_notif_count,
+                            unread_count=0,
+                            highlight_count=0,
+                        ),
+                    },
+                )
+            else:
+                self.assertEqual(counts.threads, {})
+
+        # Create a root event.
+        thread_id = _create_event(
+            "m.room.message", {"msgtype": "m.text", "body": "msg"}
+        )
+        _assert_counts(1, 0)
+
+        # Reply, creating a thread.
+        reply_id = _create_event(
+            "m.room.message",
+            {
+                "msgtype": "m.text",
+                "body": "msg",
+                "m.relates_to": {
+                    "rel_type": "m.thread",
+                    "event_id": thread_id,
+                },
+            },
+        )
+        _assert_counts(1, 1)
+
+        # Create an event related to a thread event, this should still appear in
+        # the thread.
+        _create_event(
+            type="m.reaction",
+            content={
+                "m.relates_to": {
+                    "rel_type": "m.annotation",
+                    "event_id": reply_id,
+                    "key": "A",
+                }
+            },
+        )
+        _assert_counts(1, 2)
+
     def test_find_first_stream_ordering_after_ts(self) -> None:
         def add_event(so: int, ts: int) -> None:
             self.get_success(