diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index b2295fd51f..3285450742 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -26,8 +26,6 @@ from typing import (
cast,
)
-import attr
-
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
@@ -39,8 +37,8 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
-from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
-from synapse.types import RoomStreamToken, StreamToken
+from synapse.storage.relations import PaginationChunk
+from synapse.types import JsonDict, RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -252,15 +250,8 @@ class RelationsWorkerStore(SQLBaseStore):
@cached(tree=True)
async def get_aggregation_groups_for_event(
- self,
- event_id: str,
- room_id: str,
- event_type: Optional[str] = None,
- limit: int = 5,
- direction: str = "b",
- from_token: Optional[AggregationPaginationToken] = None,
- to_token: Optional[AggregationPaginationToken] = None,
- ) -> PaginationChunk:
+ self, event_id: str, room_id: str, limit: int = 5
+ ) -> List[JsonDict]:
"""Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count.
@@ -270,79 +261,36 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
- event_type: Only fetch events with this event type, if given.
limit: Only fetch the `limit` groups.
- direction: Whether to fetch the highest count first (`"b"`) or
- the lowest count first (`"f"`).
- from_token: Fetch rows from the given token, or from the start if None.
- to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
List of groups of annotations that match. Each row is a dict with
`type`, `key` and `count` fields.
"""
- where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
- where_args: List[Union[str, int]] = [
+ where_args = [
event_id,
room_id,
RelationTypes.ANNOTATION,
+ limit,
]
- if event_type:
- where_clause.append("type = ?")
- where_args.append(event_type)
-
- having_clause = generate_pagination_where_clause(
- direction=direction,
- column_names=("COUNT(*)", "MAX(stream_ordering)"),
- from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
- to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
- engine=self.database_engine,
- )
-
- if direction == "b":
- order = "DESC"
- else:
- order = "ASC"
-
- if having_clause:
- having_clause = "HAVING " + having_clause
- else:
- having_clause = ""
-
sql = """
- SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
+ SELECT type, aggregation_key, COUNT(DISTINCT sender)
FROM event_relations
INNER JOIN events USING (event_id)
- WHERE {where_clause}
+ WHERE relates_to_id = ? AND room_id = ? AND relation_type = ?
GROUP BY relation_type, type, aggregation_key
- {having_clause}
- ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
+ ORDER BY COUNT(*) DESC
LIMIT ?
- """.format(
- where_clause=" AND ".join(where_clause),
- order=order,
- having_clause=having_clause,
- )
+ """
def _get_aggregation_groups_for_event_txn(
txn: LoggingTransaction,
- ) -> PaginationChunk:
- txn.execute(sql, where_args + [limit + 1])
-
- next_batch = None
- events = []
- for row in txn:
- events.append({"type": row[0], "key": row[1], "count": row[2]})
- next_batch = AggregationPaginationToken(row[2], row[3])
-
- if len(events) <= limit:
- next_batch = None
+ ) -> List[JsonDict]:
+ txn.execute(sql, where_args)
- return PaginationChunk(
- chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
- )
+ return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn]
return await self.db_pool.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
|