summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/rest/client/receipts.py44
-rw-r--r--synapse/storage/databases/main/cache.py1
-rw-r--r--synapse/storage/databases/main/relations.py98
3 files changed, 135 insertions, 8 deletions
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 14dec7ac4e..18a282b22c 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -15,7 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Tuple
 
-from synapse.api.constants import ReceiptTypes
+from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -83,7 +83,7 @@ class ReceiptRestServlet(RestServlet):
                 )
 
             # Ensure the event ID roughly correlates to the thread ID.
-            if thread_id != await self._main_store.get_thread_id(event_id):
+            if not await self._is_event_in_thread(event_id, thread_id):
                 raise SynapseError(
                     400,
                     f"event_id {event_id} is not related to thread {thread_id}",
@@ -109,6 +109,46 @@ class ReceiptRestServlet(RestServlet):
 
         return 200, {}
 
+    async def _is_event_in_thread(self, event_id: str, thread_id: str) -> bool:
+        """
+        The event must be related to the thread ID (in a vague sense) to ensure
+        clients aren't sending bogus receipts.
+
+        A thread ID is considered valid for a given event E if:
+
+        1. E has a thread relation which matches the thread ID;
+        2. E has another event which has a thread relation to E matching the
+           thread ID; or
+        3. E is recursively related (via any rel_type) to an event which
+           satisfies 1 or 2.
+
+        Given the following DAG:
+
+            A <---[m.thread]-- B <--[m.annotation]-- C
+            ^
+            |--[m.reference]-- D <--[m.annotation]-- E
+
+        It is valid to send a receipt for thread A on A, B, C, D, or E.
+
+        It is valid to send a receipt for the main timeline on A, D, and E.
+
+        Args:
+            event_id: The event ID to check.
+            thread_id: The thread ID the event is potentially part of.
+
+        Returns:
+            True if the event belongs to the given thread, otherwise False.
+        """
+
+        # If the receipt is on the main timeline, it is enough to check whether
+        # the event is directly related to a thread.
+        if thread_id == MAIN_TIMELINE:
+            return MAIN_TIMELINE == await self._main_store.get_thread_id(event_id)
+
+        # Otherwise, check if the event is directly part of a thread, or is the
+        # root message (or related to the root message) of a thread.
+        return thread_id == await self._main_store.get_thread_id_for_receipts(event_id)
+
 
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     ReceiptRestServlet(hs).register(http_server)
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index b47fc606c7..ed0be4abe5 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -245,6 +245,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,))
             self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,))
             self._attempt_to_invalidate_cache("get_thread_id", (redacts,))
+            self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,))
 
         if etype == EventTypes.Member:
             self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 7c54ce0b2e..1de62ee9df 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -946,6 +946,20 @@ class RelationsWorkerStore(SQLBaseStore):
         Get the thread ID for an event. This considers multi-level relations,
         e.g. an annotation to an event which is part of a thread.
 
+        It only searches up the relations tree, i.e. it only searches for events
+        which the given event is related to (and which those events are related
+        to, etc.)
+
+        Given the following DAG:
+
+            A <---[m.thread]-- B <--[m.annotation]-- C
+            ^
+            |--[m.reference]-- D <--[m.annotation]-- E
+
+        get_thread_id(X) considers events B and C as part of thread A.
+
+        See also get_thread_id_for_receipts.
+
         Args:
             event_id: The event ID to fetch the thread ID for.
 
@@ -953,22 +967,32 @@ class RelationsWorkerStore(SQLBaseStore):
             The event ID of the root event in the thread, if this event is part
             of a thread. "main", otherwise.
         """
-        # Since event relations form a tree, we should only ever find 0 or 1
-        # results from the below query.
+
+        # Recurse event relations up to the *root* event, then search that chain
+        # of relations for a thread relation. If one is found, the root event is
+        # returned.
+        #
+        # Note that this should only ever find 0 or 1 entries since it is invalid
+        # for an event to have a thread relation to an event which also has a
+        # relation.
         sql = """
             WITH RECURSIVE related_events AS (
-                SELECT event_id, relates_to_id, relation_type
+                SELECT event_id, relates_to_id, relation_type, 0 depth
                 FROM event_relations
                 WHERE event_id = ?
-                UNION SELECT e.event_id, e.relates_to_id, e.relation_type
+                UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
                 FROM event_relations e
                 INNER JOIN related_events r ON r.relates_to_id = e.event_id
-            ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread';
+                WHERE depth <= 3
+            )
+            SELECT relates_to_id FROM related_events
+            WHERE relation_type = 'm.thread'
+            ORDER BY depth DESC
+            LIMIT 1;
         """
 
         def _get_thread_id(txn: LoggingTransaction) -> str:
             txn.execute(sql, (event_id,))
-            # TODO Should we ensure there's only a single result here?
             row = txn.fetchone()
             if row:
                 return row[0]
@@ -978,6 +1002,68 @@ class RelationsWorkerStore(SQLBaseStore):
 
         return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)
 
+    @cached()
+    async def get_thread_id_for_receipts(self, event_id: str) -> str:
+        """
+        Get the thread ID for an event by traversing to the top-most related event
+        and confirming any children events form a thread.
+
+        Given the following DAG:
+
+            A <---[m.thread]-- B <--[m.annotation]-- C
+            ^
+            |--[m.reference]-- D <--[m.annotation]-- E
+
+        get_thread_id_for_receipts(X) considers events A, B, C, D, and E as part
+        of thread A.
+
+        See also get_thread_id.
+
+        Args:
+            event_id: The event ID to fetch the thread ID for.
+
+        Returns:
+            The event ID of the root event in the thread, if this event is part
+            of a thread. "main", otherwise.
+        """
+
+        # Recurse event relations up to the *root* event, then search for any events
+        # related to that root node for a thread relation. If one is found, the
+        # root event is returned.
+        #
+        # Note that there cannot be thread relations in the middle of the chain since
+        # it is invalid for an event to have a thread relation to an event which also
+        # has a relation.
+        sql = """
+        SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE((
+            WITH RECURSIVE related_events AS (
+                SELECT event_id, relates_to_id, relation_type, 0 depth
+                FROM event_relations
+                WHERE event_id = ?
+                UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
+                FROM event_relations e
+                INNER JOIN related_events r ON r.relates_to_id = e.event_id
+                WHERE depth <= 3
+            )
+            SELECT relates_to_id FROM related_events
+            ORDER BY depth DESC
+            LIMIT 1
+        ), ?) AND relation_type = 'm.thread' LIMIT 1;
+        """
+
+        def _get_related_thread_id(txn: LoggingTransaction) -> str:
+            txn.execute(sql, (event_id, event_id))
+            row = txn.fetchone()
+            if row:
+                return row[0]
+
+            # If no thread was found, it is part of the main timeline.
+            return MAIN_TIMELINE
+
+        return await self.db_pool.runInteraction(
+            "get_related_thread_id", _get_related_thread_id
+        )
+
 
 class RelationsStore(RelationsWorkerStore):
     pass