diff --git a/changelog.d/14174.feature b/changelog.d/14174.feature
new file mode 100644
index 0000000000..5d0ae16e13
--- /dev/null
+++ b/changelog.d/14174.feature
@@ -0,0 +1 @@
+Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)).
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 14dec7ac4e..18a282b22c 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
-from synapse.api.constants import ReceiptTypes
+from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -83,7 +83,7 @@ class ReceiptRestServlet(RestServlet):
)
# Ensure the event ID roughly correlates to the thread ID.
- if thread_id != await self._main_store.get_thread_id(event_id):
+ if not await self._is_event_in_thread(event_id, thread_id):
raise SynapseError(
400,
f"event_id {event_id} is not related to thread {thread_id}",
@@ -109,6 +109,46 @@ class ReceiptRestServlet(RestServlet):
return 200, {}
+ async def _is_event_in_thread(self, event_id: str, thread_id: str) -> bool:
+ """
+ The event must be related to the thread ID (in a vague sense) to ensure
+ clients aren't sending bogus receipts.
+
+ A thread ID is considered valid for a given event E if:
+
+ 1. E has a thread relation which matches the thread ID;
+ 2. E has another event which has a thread relation to E matching the
+ thread ID; or
+ 3. E is recursively related (via any rel_type) to an event which
+ satisfies 1 or 2.
+
+ Given the following DAG:
+
+ A <---[m.thread]-- B <--[m.annotation]-- C
+ ^
+ |--[m.reference]-- D <--[m.annotation]-- E
+
+ It is valid to send a receipt for thread A on A, B, C, D, or E.
+
+ It is valid to send a receipt for the main timeline on A, D, and E.
+
+ Args:
+ event_id: The event ID to check.
+ thread_id: The thread ID the event is potentially part of.
+
+ Returns:
+ True if the event belongs to the given thread, otherwise False.
+ """
+
+ # If the receipt is on the main timeline, it is enough to check whether
+ # the event is directly related to a thread.
+ if thread_id == MAIN_TIMELINE:
+ return MAIN_TIMELINE == await self._main_store.get_thread_id(event_id)
+
+ # Otherwise, check if the event is directly part of a thread, or is the
+ # root message (or related to the root message) of a thread.
+ return thread_id == await self._main_store.get_thread_id_for_receipts(event_id)
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReceiptRestServlet(hs).register(http_server)
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index b47fc606c7..ed0be4abe5 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -245,6 +245,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,))
self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,))
self._attempt_to_invalidate_cache("get_thread_id", (redacts,))
+ self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,))
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 7c54ce0b2e..1de62ee9df 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -946,6 +946,20 @@ class RelationsWorkerStore(SQLBaseStore):
Get the thread ID for an event. This considers multi-level relations,
e.g. an annotation to an event which is part of a thread.
+ It only searches up the relations tree, i.e. it only searches for events
+ which the given event is related to (and which those events are related
+ to, etc.)
+
+ Given the following DAG:
+
+ A <---[m.thread]-- B <--[m.annotation]-- C
+ ^
+ |--[m.reference]-- D <--[m.annotation]-- E
+
+ get_thread_id(X) considers events B and C as part of thread A.
+
+ See also get_thread_id_for_receipts.
+
Args:
event_id: The event ID to fetch the thread ID for.
@@ -953,22 +967,32 @@ class RelationsWorkerStore(SQLBaseStore):
The event ID of the root event in the thread, if this event is part
of a thread. "main", otherwise.
"""
- # Since event relations form a tree, we should only ever find 0 or 1
- # results from the below query.
+
+ # Recurse event relations up to the *root* event, then search that chain
+ # of relations for a thread relation. If one is found, the root event is
+ # returned.
+ #
+ # Note that this should only ever find 0 or 1 entries since it is invalid
+ # for an event to have a thread relation to an event which also has a
+ # relation.
sql = """
WITH RECURSIVE related_events AS (
- SELECT event_id, relates_to_id, relation_type
+ SELECT event_id, relates_to_id, relation_type, 0 depth
FROM event_relations
WHERE event_id = ?
- UNION SELECT e.event_id, e.relates_to_id, e.relation_type
+ UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
FROM event_relations e
INNER JOIN related_events r ON r.relates_to_id = e.event_id
- ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread';
+ WHERE depth <= 3
+ )
+ SELECT relates_to_id FROM related_events
+ WHERE relation_type = 'm.thread'
+ ORDER BY depth DESC
+ LIMIT 1;
"""
def _get_thread_id(txn: LoggingTransaction) -> str:
txn.execute(sql, (event_id,))
- # TODO Should we ensure there's only a single result here?
row = txn.fetchone()
if row:
return row[0]
@@ -978,6 +1002,68 @@ class RelationsWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)
+ @cached()
+ async def get_thread_id_for_receipts(self, event_id: str) -> str:
+ """
+ Get the thread ID for an event by traversing to the top-most related event
+ and confirming any children events form a thread.
+
+ Given the following DAG:
+
+ A <---[m.thread]-- B <--[m.annotation]-- C
+ ^
+ |--[m.reference]-- D <--[m.annotation]-- E
+
+ get_thread_id_for_receipts(X) considers events A, B, C, D, and E as part
+ of thread A.
+
+ See also get_thread_id.
+
+ Args:
+ event_id: The event ID to fetch the thread ID for.
+
+ Returns:
+ The event ID of the root event in the thread, if this event is part
+ of a thread. "main", otherwise.
+ """
+
+ # Recurse event relations up to the *root* event, then search for any events
+ # related to that root node for a thread relation. If one is found, the
+ # root event is returned.
+ #
+ # Note that there cannot be thread relations in the middle of the chain since
+ # it is invalid for an event to have a thread relation to an event which also
+ # has a relation.
+ sql = """
+ SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE((
+ WITH RECURSIVE related_events AS (
+ SELECT event_id, relates_to_id, relation_type, 0 depth
+ FROM event_relations
+ WHERE event_id = ?
+ UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
+ FROM event_relations e
+ INNER JOIN related_events r ON r.relates_to_id = e.event_id
+ WHERE depth <= 3
+ )
+ SELECT relates_to_id FROM related_events
+ ORDER BY depth DESC
+ LIMIT 1
+ ), ?) AND relation_type = 'm.thread' LIMIT 1;
+ """
+
+ def _get_related_thread_id(txn: LoggingTransaction) -> str:
+ txn.execute(sql, (event_id, event_id))
+ row = txn.fetchone()
+ if row:
+ return row[0]
+
+ # If no thread was found, it is part of the main timeline.
+ return MAIN_TIMELINE
+
+ return await self.db_pool.runInteraction(
+ "get_related_thread_id", _get_related_thread_id
+ )
+
class RelationsStore(RelationsWorkerStore):
pass
diff --git a/tests/storage/test_relations.py b/tests/storage/test_relations.py
new file mode 100644
index 0000000000..cd1d00208b
--- /dev/null
+++ b/tests/storage/test_relations.py
@@ -0,0 +1,111 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import MAIN_TIMELINE
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+
+
+class RelationsStoreTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ """
+ Creates a DAG:
+
+ A <---[m.thread]-- B <--[m.annotation]-- C
+ ^
+ |--[m.reference]-- D <--[m.annotation]-- E
+
+ F <--[m.annotation]-- G
+
+ """
+ self._main_store = self.hs.get_datastores().main
+
+ self._create_relation("A", "B", "m.thread")
+ self._create_relation("B", "C", "m.annotation")
+ self._create_relation("A", "D", "m.reference")
+ self._create_relation("D", "E", "m.annotation")
+ self._create_relation("F", "G", "m.annotation")
+
+ def _create_relation(self, parent_id: str, event_id: str, rel_type: str) -> None:
+ self.get_success(
+ self._main_store.db_pool.simple_insert(
+ table="event_relations",
+ values={
+ "event_id": event_id,
+ "relates_to_id": parent_id,
+ "relation_type": rel_type,
+ },
+ )
+ )
+
+ def test_get_thread_id(self) -> None:
+ """
+ Ensure that get_thread_id only searches up the tree for threads.
+ """
+ # The thread itself and children of it return the thread.
+ thread_id = self.get_success(self._main_store.get_thread_id("B"))
+ self.assertEqual("A", thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id("C"))
+ self.assertEqual("A", thread_id)
+
+ # But the root and events related to the root do not.
+ thread_id = self.get_success(self._main_store.get_thread_id("A"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id("D"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id("E"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
+
+ # Events which are not related to a thread at all should return the
+ # main timeline.
+ thread_id = self.get_success(self._main_store.get_thread_id("F"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id("G"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
+
+ def test_get_thread_id_for_receipts(self) -> None:
+ """
+ Ensure that get_thread_id_for_receipts searches up and down the tree for a thread.
+ """
+ # All of the events are considered related to this thread.
+ thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("A"))
+ self.assertEqual("A", thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("B"))
+ self.assertEqual("A", thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("C"))
+ self.assertEqual("A", thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("D"))
+ self.assertEqual("A", thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("E"))
+ self.assertEqual("A", thread_id)
+
+ # Events which are not related to a thread at all should return the
+ # main timeline.
+ thread_id = self.get_success(self._main_store.get_thread_id("F"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id("G"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
|