diff options
Diffstat (limited to 'synapse/storage/databases/main/relations.py')
-rw-r--r-- | synapse/storage/databases/main/relations.py | 78 |
1 files changed, 13 insertions, 65 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index b2295fd51f..3285450742 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -26,8 +26,6 @@ from typing import ( cast, ) -import attr - from synapse.api.constants import RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore @@ -39,8 +37,8 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine -from synapse.storage.relations import AggregationPaginationToken, PaginationChunk -from synapse.types import RoomStreamToken, StreamToken +from synapse.storage.relations import PaginationChunk +from synapse.types import JsonDict, RoomStreamToken, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -252,15 +250,8 @@ class RelationsWorkerStore(SQLBaseStore): @cached(tree=True) async def get_aggregation_groups_for_event( - self, - event_id: str, - room_id: str, - event_type: Optional[str] = None, - limit: int = 5, - direction: str = "b", - from_token: Optional[AggregationPaginationToken] = None, - to_token: Optional[AggregationPaginationToken] = None, - ) -> PaginationChunk: + self, event_id: str, room_id: str, limit: int = 5 + ) -> List[JsonDict]: """Get a list of annotations on the event, grouped by event type and aggregation key, sorted by count. @@ -270,79 +261,36 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. room_id: The room the event belongs to. - event_type: Only fetch events with this event type, if given. limit: Only fetch the `limit` groups. - direction: Whether to fetch the highest count first (`"b"`) or - the lowest count first (`"f"`). - from_token: Fetch rows from the given token, or from the start if None. - to_token: Fetch rows up to the given token, or up to the end if None. Returns: List of groups of annotations that match. Each row is a dict with `type`, `key` and `count` fields. """ - where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"] - where_args: List[Union[str, int]] = [ + where_args = [ event_id, room_id, RelationTypes.ANNOTATION, + limit, ] - if event_type: - where_clause.append("type = ?") - where_args.append(event_type) - - having_clause = generate_pagination_where_clause( - direction=direction, - column_names=("COUNT(*)", "MAX(stream_ordering)"), - from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type] - to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type] - engine=self.database_engine, - ) - - if direction == "b": - order = "DESC" - else: - order = "ASC" - - if having_clause: - having_clause = "HAVING " + having_clause - else: - having_clause = "" - sql = """ - SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering) + SELECT type, aggregation_key, COUNT(DISTINCT sender) FROM event_relations INNER JOIN events USING (event_id) - WHERE {where_clause} + WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? GROUP BY relation_type, type, aggregation_key - {having_clause} - ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} + ORDER BY COUNT(*) DESC LIMIT ? - """.format( - where_clause=" AND ".join(where_clause), - order=order, - having_clause=having_clause, - ) + """ def _get_aggregation_groups_for_event_txn( txn: LoggingTransaction, - ) -> PaginationChunk: - txn.execute(sql, where_args + [limit + 1]) - - next_batch = None - events = [] - for row in txn: - events.append({"type": row[0], "key": row[1], "count": row[2]}) - next_batch = AggregationPaginationToken(row[2], row[3]) - - if len(events) <= limit: - next_batch = None + ) -> List[JsonDict]: + txn.execute(sql, where_args) - return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token - ) + return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn] return await self.db_pool.runInteraction( "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn |