diff options
-rw-r--r-- | changelog.d/12293.removal | 1 | ||||
-rw-r--r-- | synapse/handlers/relations.py | 6 | ||||
-rw-r--r-- | synapse/rest/client/relations.py | 170 | ||||
-rw-r--r-- | synapse/storage/databases/main/relations.py | 78 | ||||
-rw-r--r-- | synapse/storage/relations.py | 33 | ||||
-rw-r--r-- | tests/rest/client/test_relations.py | 207 |
6 files changed, 17 insertions, 478 deletions
diff --git a/changelog.d/12293.removal b/changelog.d/12293.removal new file mode 100644 index 0000000000..25214a4b49 --- /dev/null +++ b/changelog.d/12293.removal @@ -0,0 +1 @@ +Remove the unused and unstable `/aggregations` endpoint which was removed from [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 73217d135d..b9497ff3f3 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -193,10 +193,8 @@ class RelationsHandler: annotations = await self._main_store.get_aggregation_groups_for_event( event_id, room_id ) - if annotations.chunk: - aggregations.annotations = await annotations.to_dict( - cast("DataStore", self) - ) + if annotations: + aggregations.annotations = {"chunk": annotations} references = await self._main_store.get_relations_for_event( event_id, event, room_id, RelationTypes.REFERENCE, direction="f" diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index c16078b187..55c96a2af3 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -12,22 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""This class implements the proposed relation APIs from MSC 1849. - -Since the MSC has not been approved all APIs here are unstable and may change at -any time to reflect changes in the MSC. -""" - import logging from typing import TYPE_CHECKING, Optional, Tuple -from synapse.api.constants import RelationTypes -from synapse.api.errors import SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.storage.relations import AggregationPaginationToken from synapse.types import JsonDict, StreamToken if TYPE_CHECKING: @@ -93,166 +84,5 @@ class RelationPaginationServlet(RestServlet): return 200, result -class RelationAggregationPaginationServlet(RestServlet): - """API to paginate aggregation groups of relations, e.g. paginate the - types and counts of the reactions on the events. - - Example request and response: - - GET /rooms/{room_id}/aggregations/{parent_id} - - { - chunk: [ - { - "type": "m.reaction", - "key": "👍", - "count": 3 - } - ] - } - """ - - PATTERNS = client_patterns( - "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)" - "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$", - releases=(), - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastores().main - self.event_handler = hs.get_event_handler() - - async def on_GET( - self, - request: SynapseRequest, - room_id: str, - parent_id: str, - relation_type: Optional[str] = None, - event_type: Optional[str] = None, - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - await self.auth.check_user_in_room_or_world_readable( - room_id, - requester.user.to_string(), - allow_departed_users=True, - ) - - # This checks that a) the event exists and b) the user is allowed to - # view it. - event = await self.event_handler.get_event(requester.user, room_id, parent_id) - if event is None: - raise SynapseError(404, "Unknown parent event.") - - if relation_type not in (RelationTypes.ANNOTATION, None): - raise SynapseError( - 400, f"Relation type must be '{RelationTypes.ANNOTATION}'" - ) - - limit = parse_integer(request, "limit", default=5) - from_token_str = parse_string(request, "from") - to_token_str = parse_string(request, "to") - - # Return the relations - from_token = None - if from_token_str: - from_token = AggregationPaginationToken.from_string(from_token_str) - - to_token = None - if to_token_str: - to_token = AggregationPaginationToken.from_string(to_token_str) - - pagination_chunk = await self.store.get_aggregation_groups_for_event( - event_id=parent_id, - room_id=room_id, - event_type=event_type, - limit=limit, - from_token=from_token, - to_token=to_token, - ) - - return 200, await pagination_chunk.to_dict(self.store) - - -class RelationAggregationGroupPaginationServlet(RestServlet): - """API to paginate within an aggregation group of relations, e.g. paginate - all the 👍 reactions on an event. - - Example request and response: - - GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍 - - { - chunk: [ - { - "type": "m.reaction", - "content": { - "m.relates_to": { - "rel_type": "m.annotation", - "key": "👍" - } - } - }, - ... - ] - } - """ - - PATTERNS = client_patterns( - "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)" - "/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$", - releases=(), - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastores().main - self._relations_handler = hs.get_relations_handler() - - async def on_GET( - self, - request: SynapseRequest, - room_id: str, - parent_id: str, - relation_type: str, - event_type: str, - key: str, - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - if relation_type != RelationTypes.ANNOTATION: - raise SynapseError(400, "Relation type must be 'annotation'") - - limit = parse_integer(request, "limit", default=5) - from_token_str = parse_string(request, "from") - to_token_str = parse_string(request, "to") - - from_token = None - if from_token_str: - from_token = await StreamToken.from_string(self.store, from_token_str) - to_token = None - if to_token_str: - to_token = await StreamToken.from_string(self.store, to_token_str) - - result = await self._relations_handler.get_relations( - requester=requester, - event_id=parent_id, - room_id=room_id, - relation_type=relation_type, - event_type=event_type, - aggregation_key=key, - limit=limit, - from_token=from_token, - to_token=to_token, - ) - - return 200, result - - def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) - RelationAggregationPaginationServlet(hs).register(http_server) - RelationAggregationGroupPaginationServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index b2295fd51f..3285450742 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -26,8 +26,6 @@ from typing import ( cast, ) -import attr - from synapse.api.constants import RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore @@ -39,8 +37,8 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine -from synapse.storage.relations import AggregationPaginationToken, PaginationChunk -from synapse.types import RoomStreamToken, StreamToken +from synapse.storage.relations import PaginationChunk +from synapse.types import JsonDict, RoomStreamToken, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -252,15 +250,8 @@ class RelationsWorkerStore(SQLBaseStore): @cached(tree=True) async def get_aggregation_groups_for_event( - self, - event_id: str, - room_id: str, - event_type: Optional[str] = None, - limit: int = 5, - direction: str = "b", - from_token: Optional[AggregationPaginationToken] = None, - to_token: Optional[AggregationPaginationToken] = None, - ) -> PaginationChunk: + self, event_id: str, room_id: str, limit: int = 5 + ) -> List[JsonDict]: """Get a list of annotations on the event, grouped by event type and aggregation key, sorted by count. @@ -270,79 +261,36 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. room_id: The room the event belongs to. - event_type: Only fetch events with this event type, if given. limit: Only fetch the `limit` groups. - direction: Whether to fetch the highest count first (`"b"`) or - the lowest count first (`"f"`). - from_token: Fetch rows from the given token, or from the start if None. - to_token: Fetch rows up to the given token, or up to the end if None. Returns: List of groups of annotations that match. Each row is a dict with `type`, `key` and `count` fields. """ - where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"] - where_args: List[Union[str, int]] = [ + where_args = [ event_id, room_id, RelationTypes.ANNOTATION, + limit, ] - if event_type: - where_clause.append("type = ?") - where_args.append(event_type) - - having_clause = generate_pagination_where_clause( - direction=direction, - column_names=("COUNT(*)", "MAX(stream_ordering)"), - from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type] - to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type] - engine=self.database_engine, - ) - - if direction == "b": - order = "DESC" - else: - order = "ASC" - - if having_clause: - having_clause = "HAVING " + having_clause - else: - having_clause = "" - sql = """ - SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering) + SELECT type, aggregation_key, COUNT(DISTINCT sender) FROM event_relations INNER JOIN events USING (event_id) - WHERE {where_clause} + WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? GROUP BY relation_type, type, aggregation_key - {having_clause} - ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} + ORDER BY COUNT(*) DESC LIMIT ? - """.format( - where_clause=" AND ".join(where_clause), - order=order, - having_clause=having_clause, - ) + """ def _get_aggregation_groups_for_event_txn( txn: LoggingTransaction, - ) -> PaginationChunk: - txn.execute(sql, where_args + [limit + 1]) - - next_batch = None - events = [] - for row in txn: - events.append({"type": row[0], "key": row[1], "count": row[2]}) - next_batch = AggregationPaginationToken(row[2], row[3]) - - if len(events) <= limit: - next_batch = None + ) -> List[JsonDict]: + txn.execute(sql, where_args) - return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token - ) + return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn] return await self.db_pool.runInteraction( "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index fba270150b..b9d2b46799 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -13,11 +13,10 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional import attr -from synapse.api.errors import SynapseError from synapse.types import JsonDict if TYPE_CHECKING: @@ -52,33 +51,3 @@ class PaginationChunk: d["prev_batch"] = await self.prev_batch.to_string(store) return d - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class AggregationPaginationToken: - """Pagination token for relation aggregation pagination API. - - As the results are order by count and then MAX(stream_ordering) of the - aggregation groups, we can just use them as our pagination token. - - Attributes: - count: The count of relations in the boundary group. - stream: The MAX stream ordering in the boundary group. - """ - - count: int - stream: int - - @staticmethod - def from_string(string: str) -> "AggregationPaginationToken": - try: - c, s = string.split("-") - return AggregationPaginationToken(int(c), int(s)) - except ValueError: - raise SynapseError(400, "Invalid aggregation pagination token") - - async def to_string(self, store: "DataStore") -> str: - return "%d-%d" % (self.count, self.stream) - - def as_tuple(self) -> Tuple[Any, ...]: - return attr.astuple(self) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index fe97a0b3dd..419eef166a 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import urllib.parse from typing import Any, Callable, Dict, List, Optional, Tuple from unittest.mock import patch @@ -145,16 +144,6 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) return channel.json_body["unsigned"].get("m.relations", {}) - def _get_aggregations(self) -> List[JsonDict]: - """Request /aggregations on the parent ID and includes the returned chunk.""" - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - return channel.json_body["chunk"] - def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: """ Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. @@ -264,43 +253,6 @@ class RelationsTestCase(BaseRelationsTestCase): expected_response_code=400, ) - def test_aggregation(self) -> None: - """Test that annotations get correctly aggregated.""" - - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual( - channel.json_body, - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - ) - - def test_aggregation_must_be_annotation(self) -> None: - """Test that aggregations must be annotations.""" - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations" - f"/{self.parent_id}/{RelationTypes.REPLACE}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(400, channel.code, channel.json_body) - def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. @@ -394,15 +346,6 @@ class RelationsTestCase(BaseRelationsTestCase): self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["chunk"], []) - # And when fetching aggregations. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) - # And for bundled aggregations. channel = self.make_request( "GET", @@ -717,15 +660,6 @@ class RelationsTestCase(BaseRelationsTestCase): self.assertEqual(200, channel.code, channel.json_body) self.assertNotIn("m.relations", channel.json_body["unsigned"]) - # But unknown relations can be directly queried. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) - def test_background_update(self) -> None: """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") @@ -941,131 +875,6 @@ class RelationPaginationTestCase(BaseRelationsTestCase): annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] ) - def test_aggregation_pagination_groups(self) -> None: - """Test that we can paginate annotation groups correctly.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} - for key in itertools.chain.from_iterable( - itertools.repeat(key, num) for key, num in sent_groups.items() - ): - self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key=key, - access_token=access_tokens[idx], - ) - - idx += 1 - idx %= len(access_tokens) - - prev_token: Optional[str] = None - found_groups: Dict[str, int] = {} - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - for groups in channel.json_body["chunk"]: - # We only expect reactions - self.assertEqual(groups["type"], "m.reaction", channel.json_body) - - # We should only see each key once - self.assertNotIn(groups["key"], found_groups, channel.json_body) - - found_groups[groups["key"]] = groups["count"] - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - self.assertEqual(sent_groups, found_groups) - - def test_aggregation_pagination_within_group(self) -> None: - """Test that we can paginate within an annotation group.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - expected_event_ids = [] - for _ in range(10): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key="👍", - access_token=access_tokens[idx], - ) - expected_event_ids.append(channel.json_body["event_id"]) - - idx += 1 - - # Also send a different type of reaction so that we test we don't see it - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - - prev_token = "" - found_event_ids: List[str] = [] - encoded_key = urllib.parse.quote_plus("👍".encode()) - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}" - f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" - f"/m.reaction/{encoded_key}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - class BundledAggregationsTestCase(BaseRelationsTestCase): """ @@ -1453,10 +1262,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]}, ) - # Both relations appear in the aggregation. - chunk = self._get_aggregations() - self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}]) - # Redact one of the reactions. self._redact(to_redact_event_id) @@ -1469,10 +1274,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, ) - # The unredacted aggregation should still exist. - chunk = self._get_aggregations() - self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}]) - def test_redact_relation_thread(self) -> None: """ Test that thread replies are properly handled after the thread reply redacted. @@ -1578,10 +1379,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self.assertEqual(len(event_ids), 1) self.assertIn(RelationTypes.ANNOTATION, relations) - # The aggregation should exist. - chunk = self._get_aggregations() - self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}]) - # Redact the original event. self._redact(self.parent_id) @@ -1594,10 +1391,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, ) - # There's nothing to aggregate. - chunk = self._get_aggregations() - self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}]) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_redact_parent_thread(self) -> None: """ |