diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 7bd27790eb..57b2f7c188 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -814,6 +814,93 @@ class RelationsWorkerStore(SQLBaseStore):
"get_event_relations", _get_event_relations
)
+ @cached(tree=True)
+ async def get_threads(
+ self,
+ room_id: str,
+ limit: int = 5,
+ from_token: Optional[StreamToken] = None,
+ to_token: Optional[StreamToken] = None,
+ ) -> Tuple[List[str], Optional[StreamToken]]:
+ """Get a list of thread IDs, ordered by topological ordering of their
+ latest reply.
+
+ Args:
+ room_id: The room the event belongs to.
+ limit: Only fetch the most recent `limit` threads.
+ from_token: Fetch rows from the given token, or from the start if None.
+ to_token: Fetch rows up to the given token, or up to the end if None.
+
+ Returns:
+ A tuple of:
+ A list of thread root event IDs.
+
+ The next stream token, if one exists.
+ """
+ pagination_clause = generate_pagination_where_clause(
+ direction="b",
+ column_names=("topological_ordering", "stream_ordering"),
+ from_token=from_token.room_key.as_historical_tuple()
+ if from_token
+ else None,
+ to_token=to_token.room_key.as_historical_tuple() if to_token else None,
+ engine=self.database_engine,
+ )
+
+ if pagination_clause:
+ pagination_clause = "AND " + pagination_clause
+
+ sql = f"""
+ SELECT relates_to_id, MAX(topological_ordering), MAX(stream_ordering)
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ room_id = ? AND
+ relation_type = '{RelationTypes.THREAD}'
+ {pagination_clause}
+ GROUP BY relates_to_id
+ ORDER BY MAX(topological_ordering) DESC, MAX(stream_ordering) DESC
+ LIMIT ?
+ """
+
+ def _get_threads_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[str], Optional[StreamToken]]:
+ txn.execute(sql, [room_id, limit + 1])
+
+ last_topo_id = None
+ last_stream_id = None
+ thread_ids = []
+ for thread_id, topo_id, stream_id in txn:
+ thread_ids.append(thread_id)
+ last_topo_id = topo_id
+ last_stream_id = stream_id
+
+ # If there are more events, generate the next pagination key.
+ next_token = None
+ if len(thread_ids) > limit and last_topo_id and last_stream_id:
+ next_key = RoomStreamToken(last_topo_id, last_stream_id)
+ if from_token:
+ next_token = from_token.copy_and_replace(
+ StreamKeyType.ROOM, next_key
+ )
+ else:
+ next_token = StreamToken(
+ room_key=next_key,
+ presence_key=0,
+ typing_key=0,
+ receipt_key=0,
+ account_data_key=0,
+ push_rules_key=0,
+ to_device_key=0,
+ device_list_key=0,
+ groups_key=0,
+ )
+
+ return thread_ids[:limit], next_token
+
+ return await self.db_pool.runInteraction("get_threads", _get_threads_txn)
+
class RelationsStore(RelationsWorkerStore):
pass
|