summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/relations.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 898947af95..154385b1e8 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -832,6 +832,42 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_event_relations", _get_event_relations
         )
 
+    @cached()
+    async def get_thread_id(self, event_id: str) -> Optional[str]:
+        """
+        Get the thread ID for an event. This considers multi-level relations,
+        e.g. an annotation to an event which is part of a thread.
+
+        Args:
+            event_id: The event ID to fetch the thread ID for.
+
+        Returns:
+            The event ID of the root event in the thread, if this event is part
+            of a thread. None, otherwise.
+        """
+        # Since event relations form a tree, we should only ever find 0 or 1
+        # results from the below query.
+        sql = """
+            WITH RECURSIVE related_events AS (
+                SELECT event_id, relates_to_id, relation_type
+                FROM event_relations
+                WHERE event_id = ?
+                UNION SELECT e.event_id, e.relates_to_id, e.relation_type
+                FROM event_relations e
+                INNER JOIN related_events r ON r.relates_to_id = e.event_id
+            ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread';
+        """
+
+        def _get_thread_id(txn: LoggingTransaction) -> Optional[str]:
+            txn.execute(sql, (event_id,))
+            # TODO Should we ensure there's only a single result here?
+            row = txn.fetchone()
+            if row:
+                return row[0]
+            return None
+
+        return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)
+
 
 class RelationsStore(RelationsWorkerStore):
     pass