diff options
Diffstat (limited to 'synapse/rest')
-rw-r--r-- | synapse/rest/client/relations.py | 75 |
1 files changed, 8 insertions, 67 deletions
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index d9a6be43f7..c16078b187 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -51,9 +51,7 @@ class RelationPaginationServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.clock = hs.get_clock() - self._event_serializer = hs.get_event_client_serializer() - self.event_handler = hs.get_event_handler() + self._relations_handler = hs.get_relations_handler() async def on_GET( self, @@ -65,16 +63,6 @@ class RelationPaginationServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - await self.auth.check_user_in_room_or_world_readable( - room_id, requester.user.to_string(), allow_departed_users=True - ) - - # This gets the original event and checks that a) the event exists and - # b) the user is allowed to view it. - event = await self.event_handler.get_event(requester.user, room_id, parent_id) - if event is None: - raise SynapseError(404, "Unknown parent event.") - limit = parse_integer(request, "limit", default=5) direction = parse_string( request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"] @@ -90,9 +78,9 @@ class RelationPaginationServlet(RestServlet): if to_token_str: to_token = await StreamToken.from_string(self.store, to_token_str) - pagination_chunk = await self.store.get_relations_for_event( + result = await self._relations_handler.get_relations( + requester=requester, event_id=parent_id, - event=event, room_id=room_id, relation_type=relation_type, event_type=event_type, @@ -102,30 +90,7 @@ class RelationPaginationServlet(RestServlet): to_token=to_token, ) - events = await self.store.get_events_as_list( - [c["event_id"] for c in pagination_chunk.chunk] - ) - - now = self.clock.time_msec() - # Do not bundle aggregations when retrieving the original event because - # we want the content before relations are applied to it. - original_event = self._event_serializer.serialize_event( - event, now, bundle_aggregations=None - ) - # The relations returned for the requested event do include their - # bundled aggregations. - aggregations = await self.store.get_bundled_aggregations( - events, requester.user.to_string() - ) - serialized_events = self._event_serializer.serialize_events( - events, now, bundle_aggregations=aggregations - ) - - return_value = await pagination_chunk.to_dict(self.store) - return_value["chunk"] = serialized_events - return_value["original_event"] = original_event - - return 200, return_value + return 200, result class RelationAggregationPaginationServlet(RestServlet): @@ -245,9 +210,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.clock = hs.get_clock() - self._event_serializer = hs.get_event_client_serializer() - self.event_handler = hs.get_event_handler() + self._relations_handler = hs.get_relations_handler() async def on_GET( self, @@ -260,18 +223,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - await self.auth.check_user_in_room_or_world_readable( - room_id, - requester.user.to_string(), - allow_departed_users=True, - ) - - # This checks that a) the event exists and b) the user is allowed to - # view it. - event = await self.event_handler.get_event(requester.user, room_id, parent_id) - if event is None: - raise SynapseError(404, "Unknown parent event.") - if relation_type != RelationTypes.ANNOTATION: raise SynapseError(400, "Relation type must be 'annotation'") @@ -286,9 +237,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet): if to_token_str: to_token = await StreamToken.from_string(self.store, to_token_str) - result = await self.store.get_relations_for_event( + result = await self._relations_handler.get_relations( + requester=requester, event_id=parent_id, - event=event, room_id=room_id, relation_type=relation_type, event_type=event_type, @@ -298,17 +249,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): to_token=to_token, ) - events = await self.store.get_events_as_list( - [c["event_id"] for c in result.chunk] - ) - - now = self.clock.time_msec() - serialized_events = self._event_serializer.serialize_events(events, now) - - return_value = await result.to_dict(self.store) - return_value["chunk"] = serialized_events - - return 200, return_value + return 200, result def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: |