summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/14174.feature1
-rw-r--r--synapse/rest/client/receipts.py44
-rw-r--r--synapse/storage/databases/main/cache.py1
-rw-r--r--synapse/storage/databases/main/relations.py98
-rw-r--r--tests/storage/test_relations.py111
5 files changed, 247 insertions, 8 deletions
diff --git a/changelog.d/14174.feature b/changelog.d/14174.feature
new file mode 100644
index 0000000000..5d0ae16e13
--- /dev/null
+++ b/changelog.d/14174.feature
@@ -0,0 +1 @@
+Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)).
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 14dec7ac4e..18a282b22c 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -15,7 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Tuple
 
-from synapse.api.constants import ReceiptTypes
+from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -83,7 +83,7 @@ class ReceiptRestServlet(RestServlet):
                 )
 
             # Ensure the event ID roughly correlates to the thread ID.
-            if thread_id != await self._main_store.get_thread_id(event_id):
+            if not await self._is_event_in_thread(event_id, thread_id):
                 raise SynapseError(
                     400,
                     f"event_id {event_id} is not related to thread {thread_id}",
@@ -109,6 +109,46 @@ class ReceiptRestServlet(RestServlet):
 
         return 200, {}
 
+    async def _is_event_in_thread(self, event_id: str, thread_id: str) -> bool:
+        """
+        The event must be related to the thread ID (in a vague sense) to ensure
+        clients aren't sending bogus receipts.
+
+        A thread ID is considered valid for a given event E if:
+
+        1. E has a thread relation which matches the thread ID;
+        2. E has another event which has a thread relation to E matching the
+           thread ID; or
+        3. E is recursively related (via any rel_type) to an event which
+           satisfies 1 or 2.
+
+        Given the following DAG:
+
+            A <---[m.thread]-- B <--[m.annotation]-- C
+            ^
+            |--[m.reference]-- D <--[m.annotation]-- E
+
+        It is valid to send a receipt for thread A on A, B, C, D, or E.
+
+        It is valid to send a receipt for the main timeline on A, D, and E.
+
+        Args:
+            event_id: The event ID to check.
+            thread_id: The thread ID the event is potentially part of.
+
+        Returns:
+            True if the event belongs to the given thread, otherwise False.
+        """
+
+        # If the receipt is on the main timeline, it is enough to check whether
+        # the event is directly related to a thread.
+        if thread_id == MAIN_TIMELINE:
+            return MAIN_TIMELINE == await self._main_store.get_thread_id(event_id)
+
+        # Otherwise, check if the event is directly part of a thread, or is the
+        # root message (or related to the root message) of a thread.
+        return thread_id == await self._main_store.get_thread_id_for_receipts(event_id)
+
 
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     ReceiptRestServlet(hs).register(http_server)
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index b47fc606c7..ed0be4abe5 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -245,6 +245,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,))
             self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,))
             self._attempt_to_invalidate_cache("get_thread_id", (redacts,))
+            self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,))
 
         if etype == EventTypes.Member:
             self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 7c54ce0b2e..1de62ee9df 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -946,6 +946,20 @@ class RelationsWorkerStore(SQLBaseStore):
         Get the thread ID for an event. This considers multi-level relations,
         e.g. an annotation to an event which is part of a thread.
 
+        It only searches up the relations tree, i.e. it only searches for events
+        which the given event is related to (and which those events are related
+        to, etc.)
+
+        Given the following DAG:
+
+            A <---[m.thread]-- B <--[m.annotation]-- C
+            ^
+            |--[m.reference]-- D <--[m.annotation]-- E
+
+        get_thread_id(X) considers events B and C as part of thread A.
+
+        See also get_thread_id_for_receipts.
+
         Args:
             event_id: The event ID to fetch the thread ID for.
 
@@ -953,22 +967,32 @@ class RelationsWorkerStore(SQLBaseStore):
             The event ID of the root event in the thread, if this event is part
             of a thread. "main", otherwise.
         """
-        # Since event relations form a tree, we should only ever find 0 or 1
-        # results from the below query.
+
+        # Recurse event relations up to the *root* event, then search that chain
+        # of relations for a thread relation. If one is found, the root event is
+        # returned.
+        #
+        # Note that this should only ever find 0 or 1 entries since it is invalid
+        # for an event to have a thread relation to an event which also has a
+        # relation.
         sql = """
             WITH RECURSIVE related_events AS (
-                SELECT event_id, relates_to_id, relation_type
+                SELECT event_id, relates_to_id, relation_type, 0 depth
                 FROM event_relations
                 WHERE event_id = ?
-                UNION SELECT e.event_id, e.relates_to_id, e.relation_type
+                UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
                 FROM event_relations e
                 INNER JOIN related_events r ON r.relates_to_id = e.event_id
-            ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread';
+                WHERE depth <= 3
+            )
+            SELECT relates_to_id FROM related_events
+            WHERE relation_type = 'm.thread'
+            ORDER BY depth DESC
+            LIMIT 1;
         """
 
         def _get_thread_id(txn: LoggingTransaction) -> str:
             txn.execute(sql, (event_id,))
-            # TODO Should we ensure there's only a single result here?
             row = txn.fetchone()
             if row:
                 return row[0]
@@ -978,6 +1002,68 @@ class RelationsWorkerStore(SQLBaseStore):
 
         return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)
 
+    @cached()
+    async def get_thread_id_for_receipts(self, event_id: str) -> str:
+        """
+        Get the thread ID for an event by traversing to the top-most related event
+        and confirming any children events form a thread.
+
+        Given the following DAG:
+
+            A <---[m.thread]-- B <--[m.annotation]-- C
+            ^
+            |--[m.reference]-- D <--[m.annotation]-- E
+
+        get_thread_id_for_receipts(X) considers events A, B, C, D, and E as part
+        of thread A.
+
+        See also get_thread_id.
+
+        Args:
+            event_id: The event ID to fetch the thread ID for.
+
+        Returns:
+            The event ID of the root event in the thread, if this event is part
+            of a thread. "main", otherwise.
+        """
+
+        # Recurse event relations up to the *root* event, then search for any events
+        # related to that root node for a thread relation. If one is found, the
+        # root event is returned.
+        #
+        # Note that there cannot be thread relations in the middle of the chain since
+        # it is invalid for an event to have a thread relation to an event which also
+        # has a relation.
+        sql = """
+        SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE((
+            WITH RECURSIVE related_events AS (
+                SELECT event_id, relates_to_id, relation_type, 0 depth
+                FROM event_relations
+                WHERE event_id = ?
+                UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
+                FROM event_relations e
+                INNER JOIN related_events r ON r.relates_to_id = e.event_id
+                WHERE depth <= 3
+            )
+            SELECT relates_to_id FROM related_events
+            ORDER BY depth DESC
+            LIMIT 1
+        ), ?) AND relation_type = 'm.thread' LIMIT 1;
+        """
+
+        def _get_related_thread_id(txn: LoggingTransaction) -> str:
+            txn.execute(sql, (event_id, event_id))
+            row = txn.fetchone()
+            if row:
+                return row[0]
+
+            # If no thread was found, it is part of the main timeline.
+            return MAIN_TIMELINE
+
+        return await self.db_pool.runInteraction(
+            "get_related_thread_id", _get_related_thread_id
+        )
+
 
 class RelationsStore(RelationsWorkerStore):
     pass
diff --git a/tests/storage/test_relations.py b/tests/storage/test_relations.py
new file mode 100644
index 0000000000..cd1d00208b
--- /dev/null
+++ b/tests/storage/test_relations.py
@@ -0,0 +1,111 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import MAIN_TIMELINE
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+
+
+class RelationsStoreTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        """
+        Creates a DAG:
+
+            A <---[m.thread]-- B <--[m.annotation]-- C
+            ^
+            |--[m.reference]-- D <--[m.annotation]-- E
+
+            F <--[m.annotation]-- G
+
+        """
+        self._main_store = self.hs.get_datastores().main
+
+        self._create_relation("A", "B", "m.thread")
+        self._create_relation("B", "C", "m.annotation")
+        self._create_relation("A", "D", "m.reference")
+        self._create_relation("D", "E", "m.annotation")
+        self._create_relation("F", "G", "m.annotation")
+
+    def _create_relation(self, parent_id: str, event_id: str, rel_type: str) -> None:
+        self.get_success(
+            self._main_store.db_pool.simple_insert(
+                table="event_relations",
+                values={
+                    "event_id": event_id,
+                    "relates_to_id": parent_id,
+                    "relation_type": rel_type,
+                },
+            )
+        )
+
+    def test_get_thread_id(self) -> None:
+        """
+        Ensure that get_thread_id only searches up the tree for threads.
+        """
+        # The thread itself and children of it return the thread.
+        thread_id = self.get_success(self._main_store.get_thread_id("B"))
+        self.assertEqual("A", thread_id)
+
+        thread_id = self.get_success(self._main_store.get_thread_id("C"))
+        self.assertEqual("A", thread_id)
+
+        # But the root and events related to the root do not.
+        thread_id = self.get_success(self._main_store.get_thread_id("A"))
+        self.assertEqual(MAIN_TIMELINE, thread_id)
+
+        thread_id = self.get_success(self._main_store.get_thread_id("D"))
+        self.assertEqual(MAIN_TIMELINE, thread_id)
+
+        thread_id = self.get_success(self._main_store.get_thread_id("E"))
+        self.assertEqual(MAIN_TIMELINE, thread_id)
+
+        # Events which are not related to a thread at all should return the
+        # main timeline.
+        thread_id = self.get_success(self._main_store.get_thread_id("F"))
+        self.assertEqual(MAIN_TIMELINE, thread_id)
+
+        thread_id = self.get_success(self._main_store.get_thread_id("G"))
+        self.assertEqual(MAIN_TIMELINE, thread_id)
+
+    def test_get_thread_id_for_receipts(self) -> None:
+        """
+        Ensure that get_thread_id_for_receipts searches up and down the tree for a thread.
+        """
+        # All of the events are considered related to this thread.
+        thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("A"))
+        self.assertEqual("A", thread_id)
+
+        thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("B"))
+        self.assertEqual("A", thread_id)
+
+        thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("C"))
+        self.assertEqual("A", thread_id)
+
+        thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("D"))
+        self.assertEqual("A", thread_id)
+
+        thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("E"))
+        self.assertEqual("A", thread_id)
+
+        # Events which are not related to a thread at all should return the
+        # main timeline.
+        thread_id = self.get_success(self._main_store.get_thread_id("F"))
+        self.assertEqual(MAIN_TIMELINE, thread_id)
+
+        thread_id = self.get_success(self._main_store.get_thread_id("G"))
+        self.assertEqual(MAIN_TIMELINE, thread_id)