diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 3955a8a9a5..4a6c6c724d 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -172,6 +172,7 @@ class RelationsWorkerStore(SQLBaseStore):
direction: Direction = Direction.BACKWARDS,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
+ recurse: bool = False,
) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering.
@@ -186,6 +187,7 @@ class RelationsWorkerStore(SQLBaseStore):
oldest first (forwards).
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
+ recurse: Whether to recursively find relations.
Returns:
A tuple of:
@@ -200,8 +202,8 @@ class RelationsWorkerStore(SQLBaseStore):
# Ensure bad limits aren't being passed in.
assert limit >= 0
- where_clause = ["relates_to_id = ?", "room_id = ?"]
- where_args: List[Union[str, int]] = [event.event_id, room_id]
+ where_clause = ["room_id = ?"]
+ where_args: List[Union[str, int]] = [room_id]
is_redacted = event.internal_metadata.is_redacted()
if relation_type is not None:
@@ -229,23 +231,52 @@ class RelationsWorkerStore(SQLBaseStore):
if pagination_clause:
where_clause.append(pagination_clause)
- sql = """
- SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
- FROM event_relations
- INNER JOIN events USING (event_id)
- WHERE %s
- ORDER BY topological_ordering %s, stream_ordering %s
- LIMIT ?
- """ % (
- " AND ".join(where_clause),
- order,
- order,
- )
+ # If a recursive query is requested then the filters are applied after
+ # recursively following relationships from the requested event to children
+ # up to 3-relations deep.
+ #
+ # If no recursion is needed then the event_relations table is queried
+ # for direct children of the requested event.
+ if recurse:
+ sql = """
+ WITH RECURSIVE related_events AS (
+ SELECT event_id, relation_type, relates_to_id, 0 AS depth
+ FROM event_relations
+ WHERE relates_to_id = ?
+ UNION SELECT e.event_id, e.relation_type, e.relates_to_id, depth + 1
+ FROM event_relations e
+ INNER JOIN related_events r ON r.event_id = e.relates_to_id
+ WHERE depth <= 3
+ )
+ SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
+ FROM related_events
+ INNER JOIN events USING (event_id)
+ WHERE %s
+ ORDER BY topological_ordering %s, stream_ordering %s
+ LIMIT ?;
+ """ % (
+ " AND ".join(where_clause),
+ order,
+ order,
+ )
+ else:
+ sql = """
+ SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE relates_to_id = ? AND %s
+ ORDER BY topological_ordering %s, stream_ordering %s
+ LIMIT ?
+ """ % (
+ " AND ".join(where_clause),
+ order,
+ order,
+ )
def _get_recent_references_for_event_txn(
txn: LoggingTransaction,
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
- txn.execute(sql, where_args + [limit + 1])
+ txn.execute(sql, [event.event_id] + where_args + [limit + 1])
events = []
topo_orderings: List[int] = []
@@ -965,7 +996,7 @@ class RelationsWorkerStore(SQLBaseStore):
# relation.
sql = """
WITH RECURSIVE related_events AS (
- SELECT event_id, relates_to_id, relation_type, 0 depth
+ SELECT event_id, relates_to_id, relation_type, 0 AS depth
FROM event_relations
WHERE event_id = ?
UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
@@ -1025,7 +1056,7 @@ class RelationsWorkerStore(SQLBaseStore):
sql = """
SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE((
WITH RECURSIVE related_events AS (
- SELECT event_id, relates_to_id, relation_type, 0 depth
+ SELECT event_id, relates_to_id, relation_type, 0 AS depth
FROM event_relations
WHERE event_id = ?
UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
|