diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 7c54ce0b2e..1de62ee9df 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -946,6 +946,20 @@ class RelationsWorkerStore(SQLBaseStore):
Get the thread ID for an event. This considers multi-level relations,
e.g. an annotation to an event which is part of a thread.
+ It only searches up the relations tree, i.e. it only searches for events
+ which the given event is related to (and which those events are related
+ to, etc.)
+
+ Given the following DAG:
+
+ A <---[m.thread]-- B <--[m.annotation]-- C
+ ^
+ |--[m.reference]-- D <--[m.annotation]-- E
+
+ get_thread_id(X) considers events B and C as part of thread A.
+
+ See also get_thread_id_for_receipts.
+
Args:
event_id: The event ID to fetch the thread ID for.
@@ -953,22 +967,32 @@ class RelationsWorkerStore(SQLBaseStore):
The event ID of the root event in the thread, if this event is part
of a thread. "main", otherwise.
"""
- # Since event relations form a tree, we should only ever find 0 or 1
- # results from the below query.
+
+ # Recurse event relations up to the *root* event, then search that chain
+ # of relations for a thread relation. If one is found, the root event is
+ # returned.
+ #
+ # Note that this should only ever find 0 or 1 entries since it is invalid
+ # for an event to have a thread relation to an event which also has a
+ # relation.
sql = """
WITH RECURSIVE related_events AS (
- SELECT event_id, relates_to_id, relation_type
+ SELECT event_id, relates_to_id, relation_type, 0 depth
FROM event_relations
WHERE event_id = ?
- UNION SELECT e.event_id, e.relates_to_id, e.relation_type
+ UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
FROM event_relations e
INNER JOIN related_events r ON r.relates_to_id = e.event_id
- ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread';
+ WHERE depth <= 3
+ )
+ SELECT relates_to_id FROM related_events
+ WHERE relation_type = 'm.thread'
+ ORDER BY depth DESC
+ LIMIT 1;
"""
def _get_thread_id(txn: LoggingTransaction) -> str:
txn.execute(sql, (event_id,))
- # TODO Should we ensure there's only a single result here?
row = txn.fetchone()
if row:
return row[0]
@@ -978,6 +1002,68 @@ class RelationsWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)
+ @cached()
+ async def get_thread_id_for_receipts(self, event_id: str) -> str:
+ """
+ Get the thread ID for an event by traversing to the top-most related event
+ and confirming any children events form a thread.
+
+ Given the following DAG:
+
+ A <---[m.thread]-- B <--[m.annotation]-- C
+ ^
+ |--[m.reference]-- D <--[m.annotation]-- E
+
+ get_thread_id_for_receipts(X) considers events A, B, C, D, and E as part
+ of thread A.
+
+ See also get_thread_id.
+
+ Args:
+ event_id: The event ID to fetch the thread ID for.
+
+ Returns:
+ The event ID of the root event in the thread, if this event is part
+ of a thread. "main", otherwise.
+ """
+
+ # Recurse event relations up to the *root* event, then search for any events
+ # related to that root node for a thread relation. If one is found, the
+ # root event is returned.
+ #
+ # Note that there cannot be thread relations in the middle of the chain since
+ # it is invalid for an event to have a thread relation to an event which also
+ # has a relation.
+ sql = """
+ SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE((
+ WITH RECURSIVE related_events AS (
+ SELECT event_id, relates_to_id, relation_type, 0 depth
+ FROM event_relations
+ WHERE event_id = ?
+ UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
+ FROM event_relations e
+ INNER JOIN related_events r ON r.relates_to_id = e.event_id
+ WHERE depth <= 3
+ )
+ SELECT relates_to_id FROM related_events
+ ORDER BY depth DESC
+ LIMIT 1
+ ), ?) AND relation_type = 'm.thread' LIMIT 1;
+ """
+
+ def _get_related_thread_id(txn: LoggingTransaction) -> str:
+ txn.execute(sql, (event_id, event_id))
+ row = txn.fetchone()
+ if row:
+ return row[0]
+
+ # If no thread was found, it is part of the main timeline.
+ return MAIN_TIMELINE
+
+ return await self.db_pool.runInteraction(
+ "get_related_thread_id", _get_related_thread_id
+ )
+
class RelationsStore(RelationsWorkerStore):
pass
|