diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 409bfd43c1..18a282b22c 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -15,8 +15,8 @@
import logging
from typing import TYPE_CHECKING, Tuple
-from synapse.api.constants import ReceiptTypes
-from synapse.api.errors import SynapseError
+from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes
+from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
@@ -43,12 +43,13 @@ class ReceiptRestServlet(RestServlet):
self.receipts_handler = hs.get_receipts_handler()
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
+ self._main_store = hs.get_datastores().main
- self._known_receipt_types = {ReceiptTypes.READ}
- if hs.config.experimental.msc2285_enabled:
- self._known_receipt_types.update(
- (ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ)
- )
+ self._known_receipt_types = {
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ReceiptTypes.FULLY_READ,
+ }
async def on_POST(
self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
@@ -61,7 +62,33 @@ class ReceiptRestServlet(RestServlet):
f"Receipt type must be {', '.join(self._known_receipt_types)}",
)
- parse_json_object_from_request(request, allow_empty_body=False)
+ body = parse_json_object_from_request(request)
+
+ # Pull the thread ID, if one exists.
+ thread_id = None
+ if "thread_id" in body:
+ thread_id = body.get("thread_id")
+ if not thread_id or not isinstance(thread_id, str):
+ raise SynapseError(
+ 400,
+ "thread_id field must be a non-empty string",
+ Codes.INVALID_PARAM,
+ )
+
+ if receipt_type == ReceiptTypes.FULLY_READ:
+ raise SynapseError(
+ 400,
+ f"thread_id is not compatible with {ReceiptTypes.FULLY_READ} receipts.",
+ Codes.INVALID_PARAM,
+ )
+
+ # Ensure the event ID roughly correlates to the thread ID.
+ if not await self._is_event_in_thread(event_id, thread_id):
+ raise SynapseError(
+ 400,
+ f"event_id {event_id} is not related to thread {thread_id}",
+ Codes.INVALID_PARAM,
+ )
await self.presence_handler.bump_presence_active_time(requester.user)
@@ -77,10 +104,51 @@ class ReceiptRestServlet(RestServlet):
receipt_type,
user_id=requester.user.to_string(),
event_id=event_id,
+ thread_id=thread_id,
)
return 200, {}
+ async def _is_event_in_thread(self, event_id: str, thread_id: str) -> bool:
+ """
+ The event must be related to the thread ID (in a vague sense) to ensure
+ clients aren't sending bogus receipts.
+
+ A thread ID is considered valid for a given event E if:
+
+ 1. E has a thread relation which matches the thread ID;
+ 2. E has another event which has a thread relation to E matching the
+ thread ID; or
+ 3. E is recursively related (via any rel_type) to an event which
+ satisfies 1 or 2.
+
+ Given the following DAG:
+
+ A <---[m.thread]-- B <--[m.annotation]-- C
+ ^
+ |--[m.reference]-- D <--[m.annotation]-- E
+
+ It is valid to send a receipt for thread A on A, B, C, D, or E.
+
+ It is valid to send a receipt for the main timeline on A, D, and E.
+
+ Args:
+ event_id: The event ID to check.
+ thread_id: The thread ID the event is potentially part of.
+
+ Returns:
+ True if the event belongs to the given thread, otherwise False.
+ """
+
+ # If the receipt is on the main timeline, it is enough to check whether
+ # the event is directly related to a thread.
+ if thread_id == MAIN_TIMELINE:
+ return MAIN_TIMELINE == await self._main_store.get_thread_id(event_id)
+
+ # Otherwise, check if the event is directly part of a thread, or is the
+ # root message (or related to the root message) of a thread.
+ return thread_id == await self._main_store.get_thread_id_for_receipts(event_id)
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReceiptRestServlet(hs).register(http_server)
|