summary refs log tree commit diff
path: root/synapse/storage/databases/main/relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/relations.py')
-rw-r--r--synapse/storage/databases/main/relations.py89
1 files changed, 77 insertions, 12 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py

index 2bbf6d6a95..53576ad52f 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py
@@ -13,13 +13,14 @@ # limitations under the License. import logging -from typing import Optional +from typing import List, Optional, Tuple, Union import attr from synapse.api.constants import RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.relations import ( AggregationPaginationToken, @@ -63,7 +64,7 @@ class RelationsWorkerStore(SQLBaseStore): """ where_clause = ["relates_to_id = ?"] - where_args = [event_id] + where_args: List[Union[str, int]] = [event_id] if relation_type is not None: where_clause.append("relation_type = ?") @@ -80,8 +81,8 @@ class RelationsWorkerStore(SQLBaseStore): pagination_clause = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=attr.astuple(from_token) if from_token else None, - to_token=attr.astuple(to_token) if to_token else None, + from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type] + to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type] engine=self.database_engine, ) @@ -106,7 +107,9 @@ class RelationsWorkerStore(SQLBaseStore): order, ) - def _get_recent_references_for_event_txn(txn): + def _get_recent_references_for_event_txn( + txn: LoggingTransaction, + ) -> PaginationChunk: txn.execute(sql, where_args + [limit + 1]) last_topo_id = None @@ -160,7 +163,7 @@ class RelationsWorkerStore(SQLBaseStore): """ where_clause = ["relates_to_id = ?", "relation_type = ?"] - where_args = [event_id, RelationTypes.ANNOTATION] + where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION] if event_type: where_clause.append("type = ?") @@ -169,8 +172,8 @@ class RelationsWorkerStore(SQLBaseStore): having_clause = generate_pagination_where_clause( direction=direction, column_names=("COUNT(*)", "MAX(stream_ordering)"), - from_token=attr.astuple(from_token) if from_token else None, - to_token=attr.astuple(to_token) if to_token else None, + from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type] + to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type] engine=self.database_engine, ) @@ -199,7 +202,9 @@ class RelationsWorkerStore(SQLBaseStore): having_clause=having_clause, ) - def _get_aggregation_groups_for_event_txn(txn): + def _get_aggregation_groups_for_event_txn( + txn: LoggingTransaction, + ) -> PaginationChunk: txn.execute(sql, where_args + [limit + 1]) next_batch = None @@ -254,11 +259,12 @@ class RelationsWorkerStore(SQLBaseStore): LIMIT 1 """ - def _get_applicable_edit_txn(txn): + def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]: txn.execute(sql, (event_id, RelationTypes.REPLACE)) row = txn.fetchone() if row: return row[0] + return None edit_id = await self.db_pool.runInteraction( "get_applicable_edit", _get_applicable_edit_txn @@ -267,7 +273,66 @@ class RelationsWorkerStore(SQLBaseStore): if not edit_id: return None - return await self.get_event(edit_id, allow_none=True) + return await self.get_event(edit_id, allow_none=True) # type: ignore[attr-defined] + + @cached() + async def get_thread_summary( + self, event_id: str + ) -> Tuple[int, Optional[EventBase]]: + """Get the number of threaded replies, the senders of those replies, and + the latest reply (if any) for the given event. + + Args: + event_id: The original event ID + + Returns: + The number of items in the thread and the most recent response, if any. + """ + + def _get_thread_summary_txn( + txn: LoggingTransaction, + ) -> Tuple[int, Optional[str]]: + # Fetch the count of threaded events and the latest event ID. + # TODO Should this only allow m.room.message events. + sql = """ + SELECT event_id + FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id = ? + AND relation_type = ? + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT 1 + """ + + txn.execute(sql, (event_id, RelationTypes.THREAD)) + row = txn.fetchone() + if row is None: + return 0, None + + latest_event_id = row[0] + + sql = """ + SELECT COALESCE(COUNT(event_id), 0) + FROM event_relations + WHERE + relates_to_id = ? + AND relation_type = ? + """ + txn.execute(sql, (event_id, RelationTypes.THREAD)) + count = txn.fetchone()[0] # type: ignore[index] + + return count, latest_event_id + + count, latest_event_id = await self.db_pool.runInteraction( + "get_thread_summary", _get_thread_summary_txn + ) + + latest_event = None + if latest_event_id: + latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined] + + return count, latest_event async def has_user_annotated_event( self, parent_id: str, event_type: str, aggregation_key: str, sender: str @@ -297,7 +362,7 @@ class RelationsWorkerStore(SQLBaseStore): LIMIT 1; """ - def _get_if_user_has_annotated_event(txn): + def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool: txn.execute( sql, (