Recursively fetch the thread for receipts & notifications. (#13824)
Consider an event to be part of a thread if you can follow a
chain of relations up to a thread root.
Part of MSC3773 & MSC3771.
1 files changed, 36 insertions, 0 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 898947af95..154385b1e8 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -832,6 +832,42 @@ class RelationsWorkerStore(SQLBaseStore):
"get_event_relations", _get_event_relations
)
+ @cached()
+ async def get_thread_id(self, event_id: str) -> Optional[str]:
+ """
+ Get the thread ID for an event. This considers multi-level relations,
+ e.g. an annotation to an event which is part of a thread.
+
+ Args:
+ event_id: The event ID to fetch the thread ID for.
+
+ Returns:
+ The event ID of the root event in the thread, if this event is part
+ of a thread. None, otherwise.
+ """
+ # Since event relations form a tree, we should only ever find 0 or 1
+ # results from the below query.
+ sql = """
+ WITH RECURSIVE related_events AS (
+ SELECT event_id, relates_to_id, relation_type
+ FROM event_relations
+ WHERE event_id = ?
+ UNION SELECT e.event_id, e.relates_to_id, e.relation_type
+ FROM event_relations e
+ INNER JOIN related_events r ON r.relates_to_id = e.event_id
+ ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread';
+ """
+
+ def _get_thread_id(txn: LoggingTransaction) -> Optional[str]:
+ txn.execute(sql, (event_id,))
+ # TODO Should we ensure there's only a single result here?
+ row = txn.fetchone()
+ if row:
+ return row[0]
+ return None
+
+ return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)
+
class RelationsStore(RelationsWorkerStore):
pass
|