summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-05-02 07:59:55 -0400
committerGitHub <noreply@github.com>2023-05-02 07:59:55 -0400
commit07b1c70d6b11d6b8feca23442a09b60ab0c930e3 (patch)
tree707b1e7702b708a7b3a27bafbdf6b2e1822cff4f /synapse
parentBump anyhow from 1.0.70 to 1.0.71 (#15507) (diff)
downloadsynapse-07b1c70d6b11d6b8feca23442a09b60ab0c930e3.tar.xz
Initial implementation of MSC3981: recursive relations API (#15315)
Adds an optional keyword argument to the /relations API which
will recurse a limited number of event relationships.

This will cause the API to return not just the events related to the
parent event, but also events related to those related to the parent
event, etc.

This is disabled by default behind an experimental configuration
flag and is currently implemented using prefixed parameters.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/experimental.py5
-rw-r--r--synapse/handlers/relations.py3
-rw-r--r--synapse/rest/client/relations.py10
-rw-r--r--synapse/storage/databases/main/relations.py65
4 files changed, 65 insertions, 18 deletions
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 6599679731..cab7ccf4b7 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -192,5 +192,10 @@ class ExperimentalConfig(Config):
         # MSC2659: Application service ping endpoint
         self.msc2659_enabled = experimental.get("msc2659_enabled", False)
 
+        # MSC3981: Recurse relations
+        self.msc3981_recurse_relations = experimental.get(
+            "msc3981_recurse_relations", False
+        )
+
         # MSC3970: Scope transaction IDs to devices
         self.msc3970_enabled = experimental.get("msc3970_enabled", False)
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 1d09fdf135..4824635162 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -85,6 +85,7 @@ class RelationsHandler:
         event_id: str,
         room_id: str,
         pagin_config: PaginationConfig,
+        recurse: bool,
         include_original_event: bool,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
@@ -98,6 +99,7 @@ class RelationsHandler:
             event_id: Fetch events that relate to this event ID.
             room_id: The room the event belongs to.
             pagin_config: The pagination config rules to apply, if any.
+            recurse: Whether to recursively find relations.
             include_original_event: Whether to include the parent event.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
@@ -132,6 +134,7 @@ class RelationsHandler:
             direction=pagin_config.direction,
             from_token=pagin_config.from_token,
             to_token=pagin_config.to_token,
+            recurse=recurse,
         )
 
         events = await self._main_store.get_events_as_list(
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index b8b296bc0c..785dfa08d8 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
 from synapse.api.constants import Direction
 from synapse.handlers.relations import ThreadsListInclude
 from synapse.http.server import HttpServer
-from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.rest.client._base import client_patterns
 from synapse.storage.databases.main.relations import ThreadsNextBatch
@@ -49,6 +49,7 @@ class RelationPaginationServlet(RestServlet):
         self.auth = hs.get_auth()
         self._store = hs.get_datastores().main
         self._relations_handler = hs.get_relations_handler()
+        self._support_recurse = hs.config.experimental.msc3981_recurse_relations
 
     async def on_GET(
         self,
@@ -63,6 +64,12 @@ class RelationPaginationServlet(RestServlet):
         pagination_config = await PaginationConfig.from_request(
             self._store, request, default_limit=5, default_dir=Direction.BACKWARDS
         )
+        if self._support_recurse:
+            recurse = parse_boolean(
+                request, "org.matrix.msc3981.recurse", default=False
+            )
+        else:
+            recurse = False
 
         # The unstable version of this API returns an extra field for client
         # compatibility, see https://github.com/matrix-org/synapse/issues/12930.
@@ -75,6 +82,7 @@ class RelationPaginationServlet(RestServlet):
             event_id=parent_id,
             room_id=room_id,
             pagin_config=pagination_config,
+            recurse=recurse,
             include_original_event=include_original_event,
             relation_type=relation_type,
             event_type=event_type,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 3955a8a9a5..4a6c6c724d 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -172,6 +172,7 @@ class RelationsWorkerStore(SQLBaseStore):
         direction: Direction = Direction.BACKWARDS,
         from_token: Optional[StreamToken] = None,
         to_token: Optional[StreamToken] = None,
+        recurse: bool = False,
     ) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
         """Get a list of relations for an event, ordered by topological ordering.
 
@@ -186,6 +187,7 @@ class RelationsWorkerStore(SQLBaseStore):
                 oldest first (forwards).
             from_token: Fetch rows from the given token, or from the start if None.
             to_token: Fetch rows up to the given token, or up to the end if None.
+            recurse: Whether to recursively find relations.
 
         Returns:
             A tuple of:
@@ -200,8 +202,8 @@ class RelationsWorkerStore(SQLBaseStore):
         # Ensure bad limits aren't being passed in.
         assert limit >= 0
 
-        where_clause = ["relates_to_id = ?", "room_id = ?"]
-        where_args: List[Union[str, int]] = [event.event_id, room_id]
+        where_clause = ["room_id = ?"]
+        where_args: List[Union[str, int]] = [room_id]
         is_redacted = event.internal_metadata.is_redacted()
 
         if relation_type is not None:
@@ -229,23 +231,52 @@ class RelationsWorkerStore(SQLBaseStore):
         if pagination_clause:
             where_clause.append(pagination_clause)
 
-        sql = """
-            SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
-            FROM event_relations
-            INNER JOIN events USING (event_id)
-            WHERE %s
-            ORDER BY topological_ordering %s, stream_ordering %s
-            LIMIT ?
-        """ % (
-            " AND ".join(where_clause),
-            order,
-            order,
-        )
+        # If a recursive query is requested then the filters are applied after
+        # recursively following relationships from the requested event to children
+        # up to 3-relations deep.
+        #
+        # If no recursion is needed then the event_relations table is queried
+        # for direct children of the requested event.
+        if recurse:
+            sql = """
+                WITH RECURSIVE related_events AS (
+                    SELECT event_id, relation_type, relates_to_id, 0 AS depth
+                    FROM event_relations
+                    WHERE relates_to_id = ?
+                    UNION SELECT e.event_id, e.relation_type, e.relates_to_id, depth + 1
+                    FROM event_relations e
+                    INNER JOIN related_events r ON r.event_id = e.relates_to_id
+                    WHERE depth <= 3
+                )
+                SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
+                FROM related_events
+                INNER JOIN events USING (event_id)
+                WHERE %s
+                ORDER BY topological_ordering %s, stream_ordering %s
+                LIMIT ?;
+            """ % (
+                " AND ".join(where_clause),
+                order,
+                order,
+            )
+        else:
+            sql = """
+                SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
+                FROM event_relations
+                INNER JOIN events USING (event_id)
+                WHERE relates_to_id = ? AND %s
+                ORDER BY topological_ordering %s, stream_ordering %s
+                LIMIT ?
+            """ % (
+                " AND ".join(where_clause),
+                order,
+                order,
+            )
 
         def _get_recent_references_for_event_txn(
             txn: LoggingTransaction,
         ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
-            txn.execute(sql, where_args + [limit + 1])
+            txn.execute(sql, [event.event_id] + where_args + [limit + 1])
 
             events = []
             topo_orderings: List[int] = []
@@ -965,7 +996,7 @@ class RelationsWorkerStore(SQLBaseStore):
         # relation.
         sql = """
             WITH RECURSIVE related_events AS (
-                SELECT event_id, relates_to_id, relation_type, 0 depth
+                SELECT event_id, relates_to_id, relation_type, 0 AS depth
                 FROM event_relations
                 WHERE event_id = ?
                 UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
@@ -1025,7 +1056,7 @@ class RelationsWorkerStore(SQLBaseStore):
         sql = """
         SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE((
             WITH RECURSIVE related_events AS (
-                SELECT event_id, relates_to_id, relation_type, 0 depth
+                SELECT event_id, relates_to_id, relation_type, 0 AS depth
                 FROM event_relations
                 WHERE event_id = ?
                 UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1