diff options
Diffstat (limited to 'synapse/rest/client/relations.py')
-rw-r--r-- | synapse/rest/client/relations.py | 138 |
1 files changed, 13 insertions, 125 deletions
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 5815650ee6..8cf5ebaa07 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -19,28 +19,20 @@ any time to reflect changes in the MSC. """ import logging -from typing import TYPE_CHECKING, Awaitable, Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple -from synapse.api.constants import EventTypes, RelationTypes -from synapse.api.errors import ShadowBanError, SynapseError +from synapse.api.constants import RelationTypes +from synapse.api.errors import SynapseError from synapse.http.server import HttpServer -from synapse.http.servlet import ( - RestServlet, - parse_integer, - parse_json_object_from_request, - parse_string, -) +from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest -from synapse.rest.client.transactions import HttpTransactionCache +from synapse.rest.client._base import client_patterns from synapse.storage.relations import ( AggregationPaginationToken, PaginationChunk, RelationPaginationToken, ) from synapse.types import JsonDict -from synapse.util.stringutils import random_string - -from ._base import client_patterns if TYPE_CHECKING: from synapse.server import HomeServer @@ -48,112 +40,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class RelationSendServlet(RestServlet): - """Helper API for sending events that have relation data. - - Example API shape to send a 👍 reaction to a room: - - POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D - {} - - { - "event_id": "$foobar" - } - """ - - PATTERN = ( - "/rooms/(?P<room_id>[^/]*)/send_relation" - "/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.event_creation_handler = hs.get_event_creation_handler() - self.txns = HttpTransactionCache(hs) - - def register(self, http_server: HttpServer) -> None: - http_server.register_paths( - "POST", - client_patterns(self.PATTERN + "$", releases=()), - self.on_PUT_or_POST, - self.__class__.__name__, - ) - http_server.register_paths( - "PUT", - client_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()), - self.on_PUT, - self.__class__.__name__, - ) - - def on_PUT( - self, - request: SynapseRequest, - room_id: str, - parent_id: str, - relation_type: str, - event_type: str, - txn_id: Optional[str] = None, - ) -> Awaitable[Tuple[int, JsonDict]]: - return self.txns.fetch_or_execute_request( - request, - self.on_PUT_or_POST, - request, - room_id, - parent_id, - relation_type, - event_type, - txn_id, - ) - - async def on_PUT_or_POST( - self, - request: SynapseRequest, - room_id: str, - parent_id: str, - relation_type: str, - event_type: str, - txn_id: Optional[str] = None, - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - if event_type == EventTypes.Member: - # Add relations to a membership is meaningless, so we just deny it - # at the CS API rather than trying to handle it correctly. - raise SynapseError(400, "Cannot send member events with relations") - - content = parse_json_object_from_request(request) - - aggregation_key = parse_string(request, "key", encoding="utf-8") - - content["m.relates_to"] = { - "event_id": parent_id, - "rel_type": relation_type, - } - if aggregation_key is not None: - content["m.relates_to"]["key"] = aggregation_key - - event_dict = { - "type": event_type, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - } - - try: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict=event_dict, txn_id=txn_id - ) - event_id = event.event_id - except ShadowBanError: - event_id = "$" + random_string(43) - - return 200, {"event_id": event_id} - - class RelationPaginationServlet(RestServlet): """API to paginate relations on an event by topological ordering, optionally filtered by relation type and event type. @@ -227,13 +113,16 @@ class RelationPaginationServlet(RestServlet): now = self.clock.time_msec() # Do not bundle aggregations when retrieving the original event because # we want the content before relations are applied to it. - original_event = await self._event_serializer.serialize_event( - event, now, bundle_aggregations=False + original_event = self._event_serializer.serialize_event( + event, now, bundle_aggregations=None ) # The relations returned for the requested event do include their # bundled aggregations. - serialized_events = await self._event_serializer.serialize_events( - events, now, bundle_aggregations=True + aggregations = await self.store.get_bundled_aggregations( + events, requester.user.to_string() + ) + serialized_events = self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations ) return_value = pagination_chunk.to_dict() @@ -422,7 +311,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ) now = self.clock.time_msec() - serialized_events = await self._event_serializer.serialize_events(events, now) + serialized_events = self._event_serializer.serialize_events(events, now) return_value = result.to_dict() return_value["chunk"] = serialized_events @@ -431,7 +320,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - RelationSendServlet(hs).register(http_server) RelationPaginationServlet(hs).register(http_server) RelationAggregationPaginationServlet(hs).register(http_server) RelationAggregationGroupPaginationServlet(hs).register(http_server) |