summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/experimental.py5
-rw-r--r--synapse/handlers/relations.py3
-rw-r--r--synapse/rest/client/relations.py10
-rw-r--r--synapse/storage/databases/main/relations.py65
4 files changed, 65 insertions, 18 deletions
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 6599679731..cab7ccf4b7 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -192,5 +192,10 @@ class ExperimentalConfig(Config):
         # MSC2659: Application service ping endpoint
         self.msc2659_enabled = experimental.get("msc2659_enabled", False)
 
+        # MSC3981: Recurse relations
+        self.msc3981_recurse_relations = experimental.get(
+            "msc3981_recurse_relations", False
+        )
+
         # MSC3970: Scope transaction IDs to devices
         self.msc3970_enabled = experimental.get("msc3970_enabled", False)
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 1d09fdf135..4824635162 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -85,6 +85,7 @@ class RelationsHandler:
         event_id: str,
         room_id: str,
         pagin_config: PaginationConfig,
+        recurse: bool,
         include_original_event: bool,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
@@ -98,6 +99,7 @@ class RelationsHandler:
             event_id: Fetch events that relate to this event ID.
             room_id: The room the event belongs to.
             pagin_config: The pagination config rules to apply, if any.
+            recurse: Whether to recursively find relations.
             include_original_event: Whether to include the parent event.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
@@ -132,6 +134,7 @@ class RelationsHandler:
             direction=pagin_config.direction,
             from_token=pagin_config.from_token,
             to_token=pagin_config.to_token,
+            recurse=recurse,
         )
 
         events = await self._main_store.get_events_as_list(
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index b8b296bc0c..785dfa08d8 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
 from synapse.api.constants import Direction
 from synapse.handlers.relations import ThreadsListInclude
 from synapse.http.server import HttpServer
-from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.rest.client._base import client_patterns
 from synapse.storage.databases.main.relations import ThreadsNextBatch
@@ -49,6 +49,7 @@ class RelationPaginationServlet(RestServlet):
         self.auth = hs.get_auth()
         self._store = hs.get_datastores().main
         self._relations_handler = hs.get_relations_handler()
+        self._support_recurse = hs.config.experimental.msc3981_recurse_relations
 
     async def on_GET(
         self,
@@ -63,6 +64,12 @@ class RelationPaginationServlet(RestServlet):
         pagination_config = await PaginationConfig.from_request(
             self._store, request, default_limit=5, default_dir=Direction.BACKWARDS
         )
+        if self._support_recurse:
+            recurse = parse_boolean(
+                request, "org.matrix.msc3981.recurse", default=False
+            )
+        else:
+            recurse = False
 
         # The unstable version of this API returns an extra field for client
         # compatibility, see https://github.com/matrix-org/synapse/issues/12930.
@@ -75,6 +82,7 @@ class RelationPaginationServlet(RestServlet):
             event_id=parent_id,
             room_id=room_id,
             pagin_config=pagination_config,
+            recurse=recurse,
             include_original_event=include_original_event,
             relation_type=relation_type,
             event_type=event_type,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 3955a8a9a5..4a6c6c724d 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -172,6 +172,7 @@ class RelationsWorkerStore(SQLBaseStore):
         direction: Direction = Direction.BACKWARDS,
         from_token: Optional[StreamToken] = None,
         to_token: Optional[StreamToken] = None,
+        recurse: bool = False,
     ) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
         """Get a list of relations for an event, ordered by topological ordering.
 
@@ -186,6 +187,7 @@ class RelationsWorkerStore(SQLBaseStore):
                 oldest first (forwards).
             from_token: Fetch rows from the given token, or from the start if None.
             to_token: Fetch rows up to the given token, or up to the end if None.
+            recurse: Whether to recursively find relations.
 
         Returns:
             A tuple of:
@@ -200,8 +202,8 @@ class RelationsWorkerStore(SQLBaseStore):
         # Ensure bad limits aren't being passed in.
         assert limit >= 0
 
-        where_clause = ["relates_to_id = ?", "room_id = ?"]
-        where_args: List[Union[str, int]] = [event.event_id, room_id]
+        where_clause = ["room_id = ?"]
+        where_args: List[Union[str, int]] = [room_id]
         is_redacted = event.internal_metadata.is_redacted()
 
         if relation_type is not None:
@@ -229,23 +231,52 @@ class RelationsWorkerStore(SQLBaseStore):
         if pagination_clause:
             where_clause.append(pagination_clause)
 
-        sql = """
-            SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
-            FROM event_relations
-            INNER JOIN events USING (event_id)
-            WHERE %s
-            ORDER BY topological_ordering %s, stream_ordering %s
-            LIMIT ?
-        """ % (
-            " AND ".join(where_clause),
-            order,
-            order,
-        )
+        # If a recursive query is requested then the filters are applied after
+        # recursively following relationships from the requested event to children
+        # up to 3-relations deep.
+        #
+        # If no recursion is needed then the event_relations table is queried
+        # for direct children of the requested event.
+        if recurse:
+            sql = """
+                WITH RECURSIVE related_events AS (
+                    SELECT event_id, relation_type, relates_to_id, 0 AS depth
+                    FROM event_relations
+                    WHERE relates_to_id = ?
+                    UNION SELECT e.event_id, e.relation_type, e.relates_to_id, depth + 1
+                    FROM event_relations e
+                    INNER JOIN related_events r ON r.event_id = e.relates_to_id
+                    WHERE depth <= 3
+                )
+                SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
+                FROM related_events
+                INNER JOIN events USING (event_id)
+                WHERE %s
+                ORDER BY topological_ordering %s, stream_ordering %s
+                LIMIT ?;
+            """ % (
+                " AND ".join(where_clause),
+                order,
+                order,
+            )
+        else:
+            sql = """
+                SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
+                FROM event_relations
+                INNER JOIN events USING (event_id)
+                WHERE relates_to_id = ? AND %s
+                ORDER BY topological_ordering %s, stream_ordering %s
+                LIMIT ?
+            """ % (
+                " AND ".join(where_clause),
+                order,
+                order,
+            )
 
         def _get_recent_references_for_event_txn(
             txn: LoggingTransaction,
         ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
-            txn.execute(sql, where_args + [limit + 1])
+            txn.execute(sql, [event.event_id] + where_args + [limit + 1])
 
             events = []
             topo_orderings: List[int] = []
@@ -965,7 +996,7 @@ class RelationsWorkerStore(SQLBaseStore):
         # relation.
         sql = """
             WITH RECURSIVE related_events AS (
-                SELECT event_id, relates_to_id, relation_type, 0 depth
+                SELECT event_id, relates_to_id, relation_type, 0 AS depth
                 FROM event_relations
                 WHERE event_id = ?
                 UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
@@ -1025,7 +1056,7 @@ class RelationsWorkerStore(SQLBaseStore):
         sql = """
         SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE((
             WITH RECURSIVE related_events AS (
-                SELECT event_id, relates_to_id, relation_type, 0 depth
+                SELECT event_id, relates_to_id, relation_type, 0 AS depth
                 FROM event_relations
                 WHERE event_id = ?
                 UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1