diff --git a/changelog.d/13824.feature b/changelog.d/13824.feature
new file mode 100644
index 0000000000..d0cb902dff
--- /dev/null
+++ b/changelog.d/13824.feature
@@ -0,0 +1 @@
+Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 61d952742d..f8c4dd74f0 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -286,8 +286,13 @@ class BulkPushRuleEvaluator:
relation.parent_id,
itertools.chain(*(r.rules() for r in rules_by_user.values())),
)
+ # Recursively attempt to find the thread this event relates to.
if relation.rel_type == RelationTypes.THREAD:
thread_id = relation.parent_id
+ else:
+ # Since the event has not yet been persisted we check whether
+ # the parent is part of a thread.
+ thread_id = await self.store.get_thread_id(relation.parent_id) or "main"
evaluator = PushRuleEvaluator(
_flatten_dict(event),
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index f3ff156abe..287dfdd69e 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -16,7 +16,7 @@ import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReceiptTypes
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
@@ -43,6 +43,7 @@ class ReceiptRestServlet(RestServlet):
self.receipts_handler = hs.get_receipts_handler()
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
+ self._main_store = hs.get_datastores().main
self._known_receipt_types = {
ReceiptTypes.READ,
@@ -71,7 +72,24 @@ class ReceiptRestServlet(RestServlet):
thread_id = body.get("thread_id")
if not thread_id or not isinstance(thread_id, str):
raise SynapseError(
- 400, "thread_id field must be a non-empty string"
+ 400,
+ "thread_id field must be a non-empty string",
+ Codes.INVALID_PARAM,
+ )
+
+ if receipt_type == ReceiptTypes.FULLY_READ:
+ raise SynapseError(
+ 400,
+ f"thread_id is not compatible with {ReceiptTypes.FULLY_READ} receipts.",
+ Codes.INVALID_PARAM,
+ )
+
+ # Ensure the event ID roughly correlates to the thread ID.
+ if thread_id != await self._main_store.get_thread_id(event_id):
+ raise SynapseError(
+ 400,
+ f"event_id {event_id} is not related to thread {thread_id}",
+ Codes.INVALID_PARAM,
)
await self.presence_handler.bump_presence_active_time(requester.user)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 898947af95..154385b1e8 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -832,6 +832,42 @@ class RelationsWorkerStore(SQLBaseStore):
"get_event_relations", _get_event_relations
)
+ @cached()
+ async def get_thread_id(self, event_id: str) -> Optional[str]:
+ """
+ Get the thread ID for an event. This considers multi-level relations,
+ e.g. an annotation to an event which is part of a thread.
+
+ Args:
+ event_id: The event ID to fetch the thread ID for.
+
+ Returns:
+ The event ID of the root event in the thread, if this event is part
+ of a thread. None, otherwise.
+ """
+ # Since event relations form a tree, we should only ever find 0 or 1
+ # results from the below query.
+ sql = """
+ WITH RECURSIVE related_events AS (
+ SELECT event_id, relates_to_id, relation_type
+ FROM event_relations
+ WHERE event_id = ?
+ UNION SELECT e.event_id, e.relates_to_id, e.relation_type
+ FROM event_relations e
+ INNER JOIN related_events r ON r.relates_to_id = e.event_id
+ ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread';
+ """
+
+ def _get_thread_id(txn: LoggingTransaction) -> Optional[str]:
+ txn.execute(sql, (event_id,))
+ # TODO Should we ensure there's only a single result here?
+ row = txn.fetchone()
+ if row:
+ return row[0]
+ return None
+
+ return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)
+
class RelationsStore(RelationsWorkerStore):
pass
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 6fa0cafb75..886585e9f2 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -588,6 +588,106 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
_rotate()
_assert_counts(0, 0, 0, 0)
+ def test_recursive_thread(self) -> None:
+ """
+ Events related to events in a thread should still be considered part of
+ that thread.
+ """
+
+ # Create a user to receive notifications and send receipts.
+ user_id = self.register_user("user1235", "pass")
+ token = self.login("user1235", "pass")
+
+ # And another users to send events.
+ other_id = self.register_user("other", "pass")
+ other_token = self.login("other", "pass")
+
+ # Create a room and put both users in it.
+ room_id = self.helper.create_room_as(user_id, tok=token)
+ self.helper.join(room_id, other_id, tok=other_token)
+
+ # Update the user's push rules to care about reaction events.
+ self.get_success(
+ self.store.add_push_rule(
+ user_id,
+ "related_events",
+ priority_class=5,
+ conditions=[
+ {"kind": "event_match", "key": "type", "pattern": "m.reaction"}
+ ],
+ actions=["notify"],
+ )
+ )
+
+ def _create_event(type: str, content: JsonDict) -> str:
+ result = self.helper.send_event(
+ room_id, type=type, content=content, tok=other_token
+ )
+ return result["event_id"]
+
+ def _assert_counts(noitf_count: int, thread_notif_count: int) -> None:
+ counts = self.get_success(
+ self.store.db_pool.runInteraction(
+ "get-unread-counts",
+ self.store._get_unread_counts_by_receipt_txn,
+ room_id,
+ user_id,
+ )
+ )
+ self.assertEqual(
+ counts.main_timeline,
+ NotifCounts(
+ notify_count=noitf_count, unread_count=0, highlight_count=0
+ ),
+ )
+ if thread_notif_count:
+ self.assertEqual(
+ counts.threads,
+ {
+ thread_id: NotifCounts(
+ notify_count=thread_notif_count,
+ unread_count=0,
+ highlight_count=0,
+ ),
+ },
+ )
+ else:
+ self.assertEqual(counts.threads, {})
+
+ # Create a root event.
+ thread_id = _create_event(
+ "m.room.message", {"msgtype": "m.text", "body": "msg"}
+ )
+ _assert_counts(1, 0)
+
+ # Reply, creating a thread.
+ reply_id = _create_event(
+ "m.room.message",
+ {
+ "msgtype": "m.text",
+ "body": "msg",
+ "m.relates_to": {
+ "rel_type": "m.thread",
+ "event_id": thread_id,
+ },
+ },
+ )
+ _assert_counts(1, 1)
+
+ # Create an event related to a thread event, this should still appear in
+ # the thread.
+ _create_event(
+ type="m.reaction",
+ content={
+ "m.relates_to": {
+ "rel_type": "m.annotation",
+ "event_id": reply_id,
+ "key": "A",
+ }
+ },
+ )
+ _assert_counts(1, 2)
+
def test_find_first_stream_ordering_after_ts(self) -> None:
def add_event(so: int, ts: int) -> None:
self.get_success(
|