diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 6599679731..cab7ccf4b7 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -192,5 +192,10 @@ class ExperimentalConfig(Config):
# MSC2659: Application service ping endpoint
self.msc2659_enabled = experimental.get("msc2659_enabled", False)
+ # MSC3981: Recurse relations
+ self.msc3981_recurse_relations = experimental.get(
+ "msc3981_recurse_relations", False
+ )
+
# MSC3970: Scope transaction IDs to devices
self.msc3970_enabled = experimental.get("msc3970_enabled", False)
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 1d09fdf135..4824635162 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -85,6 +85,7 @@ class RelationsHandler:
event_id: str,
room_id: str,
pagin_config: PaginationConfig,
+ recurse: bool,
include_original_event: bool,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
@@ -98,6 +99,7 @@ class RelationsHandler:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
pagin_config: The pagination config rules to apply, if any.
+ recurse: Whether to recursively find relations.
include_original_event: Whether to include the parent event.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
@@ -132,6 +134,7 @@ class RelationsHandler:
direction=pagin_config.direction,
from_token=pagin_config.from_token,
to_token=pagin_config.to_token,
+ recurse=recurse,
)
events = await self._main_store.get_events_as_list(
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index b8b296bc0c..785dfa08d8 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.constants import Direction
from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer
-from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.storage.databases.main.relations import ThreadsNextBatch
@@ -49,6 +49,7 @@ class RelationPaginationServlet(RestServlet):
self.auth = hs.get_auth()
self._store = hs.get_datastores().main
self._relations_handler = hs.get_relations_handler()
+ self._support_recurse = hs.config.experimental.msc3981_recurse_relations
async def on_GET(
self,
@@ -63,6 +64,12 @@ class RelationPaginationServlet(RestServlet):
pagination_config = await PaginationConfig.from_request(
self._store, request, default_limit=5, default_dir=Direction.BACKWARDS
)
+ if self._support_recurse:
+ recurse = parse_boolean(
+ request, "org.matrix.msc3981.recurse", default=False
+ )
+ else:
+ recurse = False
# The unstable version of this API returns an extra field for client
# compatibility, see https://github.com/matrix-org/synapse/issues/12930.
@@ -75,6 +82,7 @@ class RelationPaginationServlet(RestServlet):
event_id=parent_id,
room_id=room_id,
pagin_config=pagination_config,
+ recurse=recurse,
include_original_event=include_original_event,
relation_type=relation_type,
event_type=event_type,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 3955a8a9a5..4a6c6c724d 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -172,6 +172,7 @@ class RelationsWorkerStore(SQLBaseStore):
direction: Direction = Direction.BACKWARDS,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
+ recurse: bool = False,
) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering.
@@ -186,6 +187,7 @@ class RelationsWorkerStore(SQLBaseStore):
oldest first (forwards).
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
+ recurse: Whether to recursively find relations.
Returns:
A tuple of:
@@ -200,8 +202,8 @@ class RelationsWorkerStore(SQLBaseStore):
# Ensure bad limits aren't being passed in.
assert limit >= 0
- where_clause = ["relates_to_id = ?", "room_id = ?"]
- where_args: List[Union[str, int]] = [event.event_id, room_id]
+ where_clause = ["room_id = ?"]
+ where_args: List[Union[str, int]] = [room_id]
is_redacted = event.internal_metadata.is_redacted()
if relation_type is not None:
@@ -229,23 +231,52 @@ class RelationsWorkerStore(SQLBaseStore):
if pagination_clause:
where_clause.append(pagination_clause)
- sql = """
- SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
- FROM event_relations
- INNER JOIN events USING (event_id)
- WHERE %s
- ORDER BY topological_ordering %s, stream_ordering %s
- LIMIT ?
- """ % (
- " AND ".join(where_clause),
- order,
- order,
- )
+ # If a recursive query is requested then the filters are applied after
+ # recursively following relationships from the requested event to children
+ # up to 3-relations deep.
+ #
+ # If no recursion is needed then the event_relations table is queried
+ # for direct children of the requested event.
+ if recurse:
+ sql = """
+ WITH RECURSIVE related_events AS (
+ SELECT event_id, relation_type, relates_to_id, 0 AS depth
+ FROM event_relations
+ WHERE relates_to_id = ?
+ UNION SELECT e.event_id, e.relation_type, e.relates_to_id, depth + 1
+ FROM event_relations e
+ INNER JOIN related_events r ON r.event_id = e.relates_to_id
+ WHERE depth <= 3
+ )
+ SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
+ FROM related_events
+ INNER JOIN events USING (event_id)
+ WHERE %s
+ ORDER BY topological_ordering %s, stream_ordering %s
+ LIMIT ?;
+ """ % (
+ " AND ".join(where_clause),
+ order,
+ order,
+ )
+ else:
+ sql = """
+ SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE relates_to_id = ? AND %s
+ ORDER BY topological_ordering %s, stream_ordering %s
+ LIMIT ?
+ """ % (
+ " AND ".join(where_clause),
+ order,
+ order,
+ )
def _get_recent_references_for_event_txn(
txn: LoggingTransaction,
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
- txn.execute(sql, where_args + [limit + 1])
+ txn.execute(sql, [event.event_id] + where_args + [limit + 1])
events = []
topo_orderings: List[int] = []
@@ -965,7 +996,7 @@ class RelationsWorkerStore(SQLBaseStore):
# relation.
sql = """
WITH RECURSIVE related_events AS (
- SELECT event_id, relates_to_id, relation_type, 0 depth
+ SELECT event_id, relates_to_id, relation_type, 0 AS depth
FROM event_relations
WHERE event_id = ?
UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
@@ -1025,7 +1056,7 @@ class RelationsWorkerStore(SQLBaseStore):
sql = """
SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE((
WITH RECURSIVE related_events AS (
- SELECT event_id, relates_to_id, relation_type, 0 depth
+ SELECT event_id, relates_to_id, relation_type, 0 AS depth
FROM event_relations
WHERE event_id = ?
UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
|