summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/15315.feature1
-rw-r--r--synapse/config/experimental.py5
-rw-r--r--synapse/handlers/relations.py3
-rw-r--r--synapse/rest/client/relations.py10
-rw-r--r--synapse/storage/databases/main/relations.py65
-rw-r--r--tests/rest/client/test_relations.py120
6 files changed, 186 insertions, 18 deletions
diff --git a/changelog.d/15315.feature b/changelog.d/15315.feature
new file mode 100644
index 0000000000..30b2abdc62
--- /dev/null
+++ b/changelog.d/15315.feature
@@ -0,0 +1 @@
+Experimental support to recursively provide relations per [MSC3981](https://github.com/matrix-org/matrix-spec-proposals/pull/3981).
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 6599679731..cab7ccf4b7 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -192,5 +192,10 @@ class ExperimentalConfig(Config):
         # MSC2659: Application service ping endpoint
         self.msc2659_enabled = experimental.get("msc2659_enabled", False)
 
+        # MSC3981: Recurse relations
+        self.msc3981_recurse_relations = experimental.get(
+            "msc3981_recurse_relations", False
+        )
+
         # MSC3970: Scope transaction IDs to devices
         self.msc3970_enabled = experimental.get("msc3970_enabled", False)
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 1d09fdf135..4824635162 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -85,6 +85,7 @@ class RelationsHandler:
         event_id: str,
         room_id: str,
         pagin_config: PaginationConfig,
+        recurse: bool,
         include_original_event: bool,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
@@ -98,6 +99,7 @@ class RelationsHandler:
             event_id: Fetch events that relate to this event ID.
             room_id: The room the event belongs to.
             pagin_config: The pagination config rules to apply, if any.
+            recurse: Whether to recursively find relations.
             include_original_event: Whether to include the parent event.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
@@ -132,6 +134,7 @@ class RelationsHandler:
             direction=pagin_config.direction,
             from_token=pagin_config.from_token,
             to_token=pagin_config.to_token,
+            recurse=recurse,
         )
 
         events = await self._main_store.get_events_as_list(
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index b8b296bc0c..785dfa08d8 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
 from synapse.api.constants import Direction
 from synapse.handlers.relations import ThreadsListInclude
 from synapse.http.server import HttpServer
-from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.rest.client._base import client_patterns
 from synapse.storage.databases.main.relations import ThreadsNextBatch
@@ -49,6 +49,7 @@ class RelationPaginationServlet(RestServlet):
         self.auth = hs.get_auth()
         self._store = hs.get_datastores().main
         self._relations_handler = hs.get_relations_handler()
+        self._support_recurse = hs.config.experimental.msc3981_recurse_relations
 
     async def on_GET(
         self,
@@ -63,6 +64,12 @@ class RelationPaginationServlet(RestServlet):
         pagination_config = await PaginationConfig.from_request(
             self._store, request, default_limit=5, default_dir=Direction.BACKWARDS
         )
+        if self._support_recurse:
+            recurse = parse_boolean(
+                request, "org.matrix.msc3981.recurse", default=False
+            )
+        else:
+            recurse = False
 
         # The unstable version of this API returns an extra field for client
         # compatibility, see https://github.com/matrix-org/synapse/issues/12930.
@@ -75,6 +82,7 @@ class RelationPaginationServlet(RestServlet):
             event_id=parent_id,
             room_id=room_id,
             pagin_config=pagination_config,
+            recurse=recurse,
             include_original_event=include_original_event,
             relation_type=relation_type,
             event_type=event_type,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 3955a8a9a5..4a6c6c724d 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -172,6 +172,7 @@ class RelationsWorkerStore(SQLBaseStore):
         direction: Direction = Direction.BACKWARDS,
         from_token: Optional[StreamToken] = None,
         to_token: Optional[StreamToken] = None,
+        recurse: bool = False,
     ) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
         """Get a list of relations for an event, ordered by topological ordering.
 
@@ -186,6 +187,7 @@ class RelationsWorkerStore(SQLBaseStore):
                 oldest first (forwards).
             from_token: Fetch rows from the given token, or from the start if None.
             to_token: Fetch rows up to the given token, or up to the end if None.
+            recurse: Whether to recursively find relations.
 
         Returns:
             A tuple of:
@@ -200,8 +202,8 @@ class RelationsWorkerStore(SQLBaseStore):
         # Ensure bad limits aren't being passed in.
         assert limit >= 0
 
-        where_clause = ["relates_to_id = ?", "room_id = ?"]
-        where_args: List[Union[str, int]] = [event.event_id, room_id]
+        where_clause = ["room_id = ?"]
+        where_args: List[Union[str, int]] = [room_id]
         is_redacted = event.internal_metadata.is_redacted()
 
         if relation_type is not None:
@@ -229,23 +231,52 @@ class RelationsWorkerStore(SQLBaseStore):
         if pagination_clause:
             where_clause.append(pagination_clause)
 
-        sql = """
-            SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
-            FROM event_relations
-            INNER JOIN events USING (event_id)
-            WHERE %s
-            ORDER BY topological_ordering %s, stream_ordering %s
-            LIMIT ?
-        """ % (
-            " AND ".join(where_clause),
-            order,
-            order,
-        )
+        # If a recursive query is requested then the filters are applied after
+        # recursively following relationships from the requested event to children
+        # up to 3-relations deep.
+        #
+        # If no recursion is needed then the event_relations table is queried
+        # for direct children of the requested event.
+        if recurse:
+            sql = """
+                WITH RECURSIVE related_events AS (
+                    SELECT event_id, relation_type, relates_to_id, 0 AS depth
+                    FROM event_relations
+                    WHERE relates_to_id = ?
+                    UNION SELECT e.event_id, e.relation_type, e.relates_to_id, depth + 1
+                    FROM event_relations e
+                    INNER JOIN related_events r ON r.event_id = e.relates_to_id
+                    WHERE depth <= 3
+                )
+                SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
+                FROM related_events
+                INNER JOIN events USING (event_id)
+                WHERE %s
+                ORDER BY topological_ordering %s, stream_ordering %s
+                LIMIT ?;
+            """ % (
+                " AND ".join(where_clause),
+                order,
+                order,
+            )
+        else:
+            sql = """
+                SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
+                FROM event_relations
+                INNER JOIN events USING (event_id)
+                WHERE relates_to_id = ? AND %s
+                ORDER BY topological_ordering %s, stream_ordering %s
+                LIMIT ?
+            """ % (
+                " AND ".join(where_clause),
+                order,
+                order,
+            )
 
         def _get_recent_references_for_event_txn(
             txn: LoggingTransaction,
         ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
-            txn.execute(sql, where_args + [limit + 1])
+            txn.execute(sql, [event.event_id] + where_args + [limit + 1])
 
             events = []
             topo_orderings: List[int] = []
@@ -965,7 +996,7 @@ class RelationsWorkerStore(SQLBaseStore):
         # relation.
         sql = """
             WITH RECURSIVE related_events AS (
-                SELECT event_id, relates_to_id, relation_type, 0 depth
+                SELECT event_id, relates_to_id, relation_type, 0 AS depth
                 FROM event_relations
                 WHERE event_id = ?
                 UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
@@ -1025,7 +1056,7 @@ class RelationsWorkerStore(SQLBaseStore):
         sql = """
         SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE((
             WITH RECURSIVE related_events AS (
-                SELECT event_id, relates_to_id, relation_type, 0 depth
+                SELECT event_id, relates_to_id, relation_type, 0 AS depth
                 FROM event_relations
                 WHERE event_id = ?
                 UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index fbbbcb23f1..75439416c1 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -30,6 +30,7 @@ from tests import unittest
 from tests.server import FakeChannel
 from tests.test_utils import make_awaitable
 from tests.test_utils.event_injection import inject_event
+from tests.unittest import override_config
 
 
 class BaseRelationsTestCase(unittest.HomeserverTestCase):
@@ -949,6 +950,125 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
             )
 
 
+class RecursiveRelationTestCase(BaseRelationsTestCase):
+    @override_config({"experimental_features": {"msc3981_recurse_relations": True}})
+    def test_recursive_relations(self) -> None:
+        """Generate a complex, multi-level relationship tree and query it."""
+        # Create a thread with a few messages in it.
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        thread_1 = channel.json_body["event_id"]
+
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        thread_2 = channel.json_body["event_id"]
+
+        # Add annotations.
+        channel = self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_2
+        )
+        annotation_1 = channel.json_body["event_id"]
+
+        channel = self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=thread_1
+        )
+        annotation_2 = channel.json_body["event_id"]
+
+        # Add a reference to part of the thread, then edit the reference and annotate it.
+        channel = self._send_relation(
+            RelationTypes.REFERENCE, "m.room.test", parent_id=thread_2
+        )
+        reference_1 = channel.json_body["event_id"]
+
+        channel = self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", "c", parent_id=reference_1
+        )
+        annotation_3 = channel.json_body["event_id"]
+
+        channel = self._send_relation(
+            RelationTypes.REPLACE,
+            "m.room.test",
+            parent_id=reference_1,
+        )
+        edit = channel.json_body["event_id"]
+
+        # Also more events off the root.
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "d")
+        annotation_4 = channel.json_body["event_id"]
+
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}"
+            "?dir=f&limit=20&org.matrix.msc3981.recurse=true",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+
+        # The above events should be returned in creation order.
+        event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]]
+        self.assertEqual(
+            event_ids,
+            [
+                thread_1,
+                thread_2,
+                annotation_1,
+                annotation_2,
+                reference_1,
+                annotation_3,
+                edit,
+                annotation_4,
+            ],
+        )
+
+    @override_config({"experimental_features": {"msc3981_recurse_relations": True}})
+    def test_recursive_relations_with_filter(self) -> None:
+        """The event_type and rel_type still apply."""
+        # Create a thread with a few messages in it.
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        thread_1 = channel.json_body["event_id"]
+
+        # Add annotations.
+        channel = self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=thread_1
+        )
+        annotation_1 = channel.json_body["event_id"]
+
+        # Add a reference to part of the thread, then edit the reference and annotate it.
+        channel = self._send_relation(
+            RelationTypes.REFERENCE, "m.room.test", parent_id=thread_1
+        )
+        reference_1 = channel.json_body["event_id"]
+
+        channel = self._send_relation(
+            RelationTypes.ANNOTATION, "org.matrix.reaction", "c", parent_id=reference_1
+        )
+        annotation_2 = channel.json_body["event_id"]
+
+        # Fetch only annotations, but recursively.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+            "?dir=f&limit=20&org.matrix.msc3981.recurse=true",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+
+        # The above events should be returned in creation order.
+        event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]]
+        self.assertEqual(event_ids, [annotation_1, annotation_2])
+
+        # Fetch only m.reactions, but recursively.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}/{RelationTypes.ANNOTATION}/m.reaction"
+            "?dir=f&limit=20&org.matrix.msc3981.recurse=true",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+
+        # The above events should be returned in creation order.
+        event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]]
+        self.assertEqual(event_ids, [annotation_1])
+
+
 class BundledAggregationsTestCase(BaseRelationsTestCase):
     """
     See RelationsTestCase.test_edit for a similar test for edits.