diff options
Diffstat (limited to 'synapse/storage/databases/main/relations.py')
-rw-r--r-- | synapse/storage/databases/main/relations.py | 87 |
1 files changed, 87 insertions, 0 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 7bd27790eb..57b2f7c188 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -814,6 +814,93 @@ class RelationsWorkerStore(SQLBaseStore): "get_event_relations", _get_event_relations ) + @cached(tree=True) + async def get_threads( + self, + room_id: str, + limit: int = 5, + from_token: Optional[StreamToken] = None, + to_token: Optional[StreamToken] = None, + ) -> Tuple[List[str], Optional[StreamToken]]: + """Get a list of thread IDs, ordered by topological ordering of their + latest reply. + + Args: + room_id: The room the event belongs to. + limit: Only fetch the most recent `limit` threads. + from_token: Fetch rows from the given token, or from the start if None. + to_token: Fetch rows up to the given token, or up to the end if None. + + Returns: + A tuple of: + A list of thread root event IDs. + + The next stream token, if one exists. + """ + pagination_clause = generate_pagination_where_clause( + direction="b", + column_names=("topological_ordering", "stream_ordering"), + from_token=from_token.room_key.as_historical_tuple() + if from_token + else None, + to_token=to_token.room_key.as_historical_tuple() if to_token else None, + engine=self.database_engine, + ) + + if pagination_clause: + pagination_clause = "AND " + pagination_clause + + sql = f""" + SELECT relates_to_id, MAX(topological_ordering), MAX(stream_ordering) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE + room_id = ? AND + relation_type = '{RelationTypes.THREAD}' + {pagination_clause} + GROUP BY relates_to_id + ORDER BY MAX(topological_ordering) DESC, MAX(stream_ordering) DESC + LIMIT ? + """ + + def _get_threads_txn( + txn: LoggingTransaction, + ) -> Tuple[List[str], Optional[StreamToken]]: + txn.execute(sql, [room_id, limit + 1]) + + last_topo_id = None + last_stream_id = None + thread_ids = [] + for thread_id, topo_id, stream_id in txn: + thread_ids.append(thread_id) + last_topo_id = topo_id + last_stream_id = stream_id + + # If there are more events, generate the next pagination key. + next_token = None + if len(thread_ids) > limit and last_topo_id and last_stream_id: + next_key = RoomStreamToken(last_topo_id, last_stream_id) + if from_token: + next_token = from_token.copy_and_replace( + StreamKeyType.ROOM, next_key + ) + else: + next_token = StreamToken( + room_key=next_key, + presence_key=0, + typing_key=0, + receipt_key=0, + account_data_key=0, + push_rules_key=0, + to_device_key=0, + device_list_key=0, + groups_key=0, + ) + + return thread_ids[:limit], next_token + + return await self.db_pool.runInteraction("get_threads", _get_threads_txn) + class RelationsStore(RelationsWorkerStore): pass |