summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/relations.py6
-rw-r--r--synapse/rest/client/relations.py170
-rw-r--r--synapse/storage/databases/main/relations.py78
-rw-r--r--synapse/storage/relations.py33
4 files changed, 16 insertions, 271 deletions
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 73217d135d..b9497ff3f3 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -193,10 +193,8 @@ class RelationsHandler:
         annotations = await self._main_store.get_aggregation_groups_for_event(
             event_id, room_id
         )
-        if annotations.chunk:
-            aggregations.annotations = await annotations.to_dict(
-                cast("DataStore", self)
-            )
+        if annotations:
+            aggregations.annotations = {"chunk": annotations}
 
         references = await self._main_store.get_relations_for_event(
             event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index c16078b187..55c96a2af3 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -12,22 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-"""This class implements the proposed relation APIs from MSC 1849.
-
-Since the MSC has not been approved all APIs here are unstable and may change at
-any time to reflect changes in the MSC.
-"""
-
 import logging
 from typing import TYPE_CHECKING, Optional, Tuple
 
-from synapse.api.constants import RelationTypes
-from synapse.api.errors import SynapseError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.rest.client._base import client_patterns
-from synapse.storage.relations import AggregationPaginationToken
 from synapse.types import JsonDict, StreamToken
 
 if TYPE_CHECKING:
@@ -93,166 +84,5 @@ class RelationPaginationServlet(RestServlet):
         return 200, result
 
 
-class RelationAggregationPaginationServlet(RestServlet):
-    """API to paginate aggregation groups of relations, e.g. paginate the
-    types and counts of the reactions on the events.
-
-    Example request and response:
-
-        GET /rooms/{room_id}/aggregations/{parent_id}
-
-        {
-            chunk: [
-                {
-                    "type": "m.reaction",
-                    "key": "👍",
-                    "count": 3
-                }
-            ]
-        }
-    """
-
-    PATTERNS = client_patterns(
-        "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
-        "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
-        releases=(),
-    )
-
-    def __init__(self, hs: "HomeServer"):
-        super().__init__()
-        self.auth = hs.get_auth()
-        self.store = hs.get_datastores().main
-        self.event_handler = hs.get_event_handler()
-
-    async def on_GET(
-        self,
-        request: SynapseRequest,
-        room_id: str,
-        parent_id: str,
-        relation_type: Optional[str] = None,
-        event_type: Optional[str] = None,
-    ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request, allow_guest=True)
-
-        await self.auth.check_user_in_room_or_world_readable(
-            room_id,
-            requester.user.to_string(),
-            allow_departed_users=True,
-        )
-
-        # This checks that a) the event exists and b) the user is allowed to
-        # view it.
-        event = await self.event_handler.get_event(requester.user, room_id, parent_id)
-        if event is None:
-            raise SynapseError(404, "Unknown parent event.")
-
-        if relation_type not in (RelationTypes.ANNOTATION, None):
-            raise SynapseError(
-                400, f"Relation type must be '{RelationTypes.ANNOTATION}'"
-            )
-
-        limit = parse_integer(request, "limit", default=5)
-        from_token_str = parse_string(request, "from")
-        to_token_str = parse_string(request, "to")
-
-        # Return the relations
-        from_token = None
-        if from_token_str:
-            from_token = AggregationPaginationToken.from_string(from_token_str)
-
-        to_token = None
-        if to_token_str:
-            to_token = AggregationPaginationToken.from_string(to_token_str)
-
-        pagination_chunk = await self.store.get_aggregation_groups_for_event(
-            event_id=parent_id,
-            room_id=room_id,
-            event_type=event_type,
-            limit=limit,
-            from_token=from_token,
-            to_token=to_token,
-        )
-
-        return 200, await pagination_chunk.to_dict(self.store)
-
-
-class RelationAggregationGroupPaginationServlet(RestServlet):
-    """API to paginate within an aggregation group of relations, e.g. paginate
-    all the 👍 reactions on an event.
-
-    Example request and response:
-
-        GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍
-
-        {
-            chunk: [
-                {
-                    "type": "m.reaction",
-                    "content": {
-                        "m.relates_to": {
-                            "rel_type": "m.annotation",
-                            "key": "👍"
-                        }
-                    }
-                },
-                ...
-            ]
-        }
-    """
-
-    PATTERNS = client_patterns(
-        "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
-        "/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$",
-        releases=(),
-    )
-
-    def __init__(self, hs: "HomeServer"):
-        super().__init__()
-        self.auth = hs.get_auth()
-        self.store = hs.get_datastores().main
-        self._relations_handler = hs.get_relations_handler()
-
-    async def on_GET(
-        self,
-        request: SynapseRequest,
-        room_id: str,
-        parent_id: str,
-        relation_type: str,
-        event_type: str,
-        key: str,
-    ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request, allow_guest=True)
-
-        if relation_type != RelationTypes.ANNOTATION:
-            raise SynapseError(400, "Relation type must be 'annotation'")
-
-        limit = parse_integer(request, "limit", default=5)
-        from_token_str = parse_string(request, "from")
-        to_token_str = parse_string(request, "to")
-
-        from_token = None
-        if from_token_str:
-            from_token = await StreamToken.from_string(self.store, from_token_str)
-        to_token = None
-        if to_token_str:
-            to_token = await StreamToken.from_string(self.store, to_token_str)
-
-        result = await self._relations_handler.get_relations(
-            requester=requester,
-            event_id=parent_id,
-            room_id=room_id,
-            relation_type=relation_type,
-            event_type=event_type,
-            aggregation_key=key,
-            limit=limit,
-            from_token=from_token,
-            to_token=to_token,
-        )
-
-        return 200, result
-
-
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     RelationPaginationServlet(hs).register(http_server)
-    RelationAggregationPaginationServlet(hs).register(http_server)
-    RelationAggregationGroupPaginationServlet(hs).register(http_server)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index b2295fd51f..3285450742 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -26,8 +26,6 @@ from typing import (
     cast,
 )
 
-import attr
-
 from synapse.api.constants import RelationTypes
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
@@ -39,8 +37,8 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
-from synapse.types import RoomStreamToken, StreamToken
+from synapse.storage.relations import PaginationChunk
+from synapse.types import JsonDict, RoomStreamToken, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
 if TYPE_CHECKING:
@@ -252,15 +250,8 @@ class RelationsWorkerStore(SQLBaseStore):
 
     @cached(tree=True)
     async def get_aggregation_groups_for_event(
-        self,
-        event_id: str,
-        room_id: str,
-        event_type: Optional[str] = None,
-        limit: int = 5,
-        direction: str = "b",
-        from_token: Optional[AggregationPaginationToken] = None,
-        to_token: Optional[AggregationPaginationToken] = None,
-    ) -> PaginationChunk:
+        self, event_id: str, room_id: str, limit: int = 5
+    ) -> List[JsonDict]:
         """Get a list of annotations on the event, grouped by event type and
         aggregation key, sorted by count.
 
@@ -270,79 +261,36 @@ class RelationsWorkerStore(SQLBaseStore):
         Args:
             event_id: Fetch events that relate to this event ID.
             room_id: The room the event belongs to.
-            event_type: Only fetch events with this event type, if given.
             limit: Only fetch the `limit` groups.
-            direction: Whether to fetch the highest count first (`"b"`) or
-                the lowest count first (`"f"`).
-            from_token: Fetch rows from the given token, or from the start if None.
-            to_token: Fetch rows up to the given token, or up to the end if None.
 
         Returns:
             List of groups of annotations that match. Each row is a dict with
             `type`, `key` and `count` fields.
         """
 
-        where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
-        where_args: List[Union[str, int]] = [
+        where_args = [
             event_id,
             room_id,
             RelationTypes.ANNOTATION,
+            limit,
         ]
 
-        if event_type:
-            where_clause.append("type = ?")
-            where_args.append(event_type)
-
-        having_clause = generate_pagination_where_clause(
-            direction=direction,
-            column_names=("COUNT(*)", "MAX(stream_ordering)"),
-            from_token=attr.astuple(from_token) if from_token else None,  # type: ignore[arg-type]
-            to_token=attr.astuple(to_token) if to_token else None,  # type: ignore[arg-type]
-            engine=self.database_engine,
-        )
-
-        if direction == "b":
-            order = "DESC"
-        else:
-            order = "ASC"
-
-        if having_clause:
-            having_clause = "HAVING " + having_clause
-        else:
-            having_clause = ""
-
         sql = """
-            SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
+            SELECT type, aggregation_key, COUNT(DISTINCT sender)
             FROM event_relations
             INNER JOIN events USING (event_id)
-            WHERE {where_clause}
+            WHERE relates_to_id = ? AND room_id = ? AND relation_type = ?
             GROUP BY relation_type, type, aggregation_key
-            {having_clause}
-            ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
+            ORDER BY COUNT(*) DESC
             LIMIT ?
-        """.format(
-            where_clause=" AND ".join(where_clause),
-            order=order,
-            having_clause=having_clause,
-        )
+        """
 
         def _get_aggregation_groups_for_event_txn(
             txn: LoggingTransaction,
-        ) -> PaginationChunk:
-            txn.execute(sql, where_args + [limit + 1])
-
-            next_batch = None
-            events = []
-            for row in txn:
-                events.append({"type": row[0], "key": row[1], "count": row[2]})
-                next_batch = AggregationPaginationToken(row[2], row[3])
-
-            if len(events) <= limit:
-                next_batch = None
+        ) -> List[JsonDict]:
+            txn.execute(sql, where_args)
 
-            return PaginationChunk(
-                chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
-            )
+            return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn]
 
         return await self.db_pool.runInteraction(
             "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index fba270150b..b9d2b46799 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -13,11 +13,10 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
 
 import attr
 
-from synapse.api.errors import SynapseError
 from synapse.types import JsonDict
 
 if TYPE_CHECKING:
@@ -52,33 +51,3 @@ class PaginationChunk:
             d["prev_batch"] = await self.prev_batch.to_string(store)
 
         return d
-
-
-@attr.s(frozen=True, slots=True, auto_attribs=True)
-class AggregationPaginationToken:
-    """Pagination token for relation aggregation pagination API.
-
-    As the results are order by count and then MAX(stream_ordering) of the
-    aggregation groups, we can just use them as our pagination token.
-
-    Attributes:
-        count: The count of relations in the boundary group.
-        stream: The MAX stream ordering in the boundary group.
-    """
-
-    count: int
-    stream: int
-
-    @staticmethod
-    def from_string(string: str) -> "AggregationPaginationToken":
-        try:
-            c, s = string.split("-")
-            return AggregationPaginationToken(int(c), int(s))
-        except ValueError:
-            raise SynapseError(400, "Invalid aggregation pagination token")
-
-    async def to_string(self, store: "DataStore") -> str:
-        return "%d-%d" % (self.count, self.stream)
-
-    def as_tuple(self) -> Tuple[Any, ...]:
-        return attr.astuple(self)