summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/relations.py78
1 files changed, 13 insertions, 65 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index b2295fd51f..3285450742 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -26,8 +26,6 @@ from typing import (
     cast,
 )
 
-import attr
-
 from synapse.api.constants import RelationTypes
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
@@ -39,8 +37,8 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
-from synapse.types import RoomStreamToken, StreamToken
+from synapse.storage.relations import PaginationChunk
+from synapse.types import JsonDict, RoomStreamToken, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
 if TYPE_CHECKING:
@@ -252,15 +250,8 @@ class RelationsWorkerStore(SQLBaseStore):
 
     @cached(tree=True)
     async def get_aggregation_groups_for_event(
-        self,
-        event_id: str,
-        room_id: str,
-        event_type: Optional[str] = None,
-        limit: int = 5,
-        direction: str = "b",
-        from_token: Optional[AggregationPaginationToken] = None,
-        to_token: Optional[AggregationPaginationToken] = None,
-    ) -> PaginationChunk:
+        self, event_id: str, room_id: str, limit: int = 5
+    ) -> List[JsonDict]:
         """Get a list of annotations on the event, grouped by event type and
         aggregation key, sorted by count.
 
@@ -270,79 +261,36 @@ class RelationsWorkerStore(SQLBaseStore):
         Args:
             event_id: Fetch events that relate to this event ID.
             room_id: The room the event belongs to.
-            event_type: Only fetch events with this event type, if given.
             limit: Only fetch the `limit` groups.
-            direction: Whether to fetch the highest count first (`"b"`) or
-                the lowest count first (`"f"`).
-            from_token: Fetch rows from the given token, or from the start if None.
-            to_token: Fetch rows up to the given token, or up to the end if None.
 
         Returns:
             List of groups of annotations that match. Each row is a dict with
             `type`, `key` and `count` fields.
         """
 
-        where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
-        where_args: List[Union[str, int]] = [
+        where_args = [
             event_id,
             room_id,
             RelationTypes.ANNOTATION,
+            limit,
         ]
 
-        if event_type:
-            where_clause.append("type = ?")
-            where_args.append(event_type)
-
-        having_clause = generate_pagination_where_clause(
-            direction=direction,
-            column_names=("COUNT(*)", "MAX(stream_ordering)"),
-            from_token=attr.astuple(from_token) if from_token else None,  # type: ignore[arg-type]
-            to_token=attr.astuple(to_token) if to_token else None,  # type: ignore[arg-type]
-            engine=self.database_engine,
-        )
-
-        if direction == "b":
-            order = "DESC"
-        else:
-            order = "ASC"
-
-        if having_clause:
-            having_clause = "HAVING " + having_clause
-        else:
-            having_clause = ""
-
         sql = """
-            SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
+            SELECT type, aggregation_key, COUNT(DISTINCT sender)
             FROM event_relations
             INNER JOIN events USING (event_id)
-            WHERE {where_clause}
+            WHERE relates_to_id = ? AND room_id = ? AND relation_type = ?
             GROUP BY relation_type, type, aggregation_key
-            {having_clause}
-            ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
+            ORDER BY COUNT(*) DESC
             LIMIT ?
-        """.format(
-            where_clause=" AND ".join(where_clause),
-            order=order,
-            having_clause=having_clause,
-        )
+        """
 
         def _get_aggregation_groups_for_event_txn(
             txn: LoggingTransaction,
-        ) -> PaginationChunk:
-            txn.execute(sql, where_args + [limit + 1])
-
-            next_batch = None
-            events = []
-            for row in txn:
-                events.append({"type": row[0], "key": row[1], "count": row[2]})
-                next_batch = AggregationPaginationToken(row[2], row[3])
-
-            if len(events) <= limit:
-                next_batch = None
+        ) -> List[JsonDict]:
+            txn.execute(sql, where_args)
 
-            return PaginationChunk(
-                chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
-            )
+            return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn]
 
         return await self.db_pool.runInteraction(
             "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn