From 6c68e874b1eafe75db51e46064c0d3af702b4358 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 6 Jan 2022 14:00:34 -0500 Subject: Remove the /send_relation endpoint. (#11682) This was removed from MSC2674 before that was approved and is not used by any known clients. --- synapse/rest/client/relations.py | 125 ++------------------------------------- 1 file changed, 5 insertions(+), 120 deletions(-) (limited to 'synapse/rest/client/relations.py') diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 5815650ee6..3823498012 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -19,28 +19,20 @@ any time to reflect changes in the MSC. """ import logging -from typing import TYPE_CHECKING, Awaitable, Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple -from synapse.api.constants import EventTypes, RelationTypes -from synapse.api.errors import ShadowBanError, SynapseError +from synapse.api.constants import RelationTypes +from synapse.api.errors import SynapseError from synapse.http.server import HttpServer -from synapse.http.servlet import ( - RestServlet, - parse_integer, - parse_json_object_from_request, - parse_string, -) +from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest -from synapse.rest.client.transactions import HttpTransactionCache +from synapse.rest.client._base import client_patterns from synapse.storage.relations import ( AggregationPaginationToken, PaginationChunk, RelationPaginationToken, ) from synapse.types import JsonDict -from synapse.util.stringutils import random_string - -from ._base import client_patterns if TYPE_CHECKING: from synapse.server import HomeServer @@ -48,112 +40,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class RelationSendServlet(RestServlet): - """Helper API for sending events that have relation data. - - Example API shape to send a 👍 reaction to a room: - - POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D - {} - - { - "event_id": "$foobar" - } - """ - - PATTERN = ( - "/rooms/(?P[^/]*)/send_relation" - "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.event_creation_handler = hs.get_event_creation_handler() - self.txns = HttpTransactionCache(hs) - - def register(self, http_server: HttpServer) -> None: - http_server.register_paths( - "POST", - client_patterns(self.PATTERN + "$", releases=()), - self.on_PUT_or_POST, - self.__class__.__name__, - ) - http_server.register_paths( - "PUT", - client_patterns(self.PATTERN + "/(?P[^/]*)$", releases=()), - self.on_PUT, - self.__class__.__name__, - ) - - def on_PUT( - self, - request: SynapseRequest, - room_id: str, - parent_id: str, - relation_type: str, - event_type: str, - txn_id: Optional[str] = None, - ) -> Awaitable[Tuple[int, JsonDict]]: - return self.txns.fetch_or_execute_request( - request, - self.on_PUT_or_POST, - request, - room_id, - parent_id, - relation_type, - event_type, - txn_id, - ) - - async def on_PUT_or_POST( - self, - request: SynapseRequest, - room_id: str, - parent_id: str, - relation_type: str, - event_type: str, - txn_id: Optional[str] = None, - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - if event_type == EventTypes.Member: - # Add relations to a membership is meaningless, so we just deny it - # at the CS API rather than trying to handle it correctly. - raise SynapseError(400, "Cannot send member events with relations") - - content = parse_json_object_from_request(request) - - aggregation_key = parse_string(request, "key", encoding="utf-8") - - content["m.relates_to"] = { - "event_id": parent_id, - "rel_type": relation_type, - } - if aggregation_key is not None: - content["m.relates_to"]["key"] = aggregation_key - - event_dict = { - "type": event_type, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - } - - try: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict=event_dict, txn_id=txn_id - ) - event_id = event.event_id - except ShadowBanError: - event_id = "$" + random_string(43) - - return 200, {"event_id": event_id} - - class RelationPaginationServlet(RestServlet): """API to paginate relations on an event by topological ordering, optionally filtered by relation type and event type. @@ -431,7 +317,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - RelationSendServlet(hs).register(http_server) RelationPaginationServlet(hs).register(http_server) RelationAggregationPaginationServlet(hs).register(http_server) RelationAggregationGroupPaginationServlet(hs).register(http_server) -- cgit 1.5.1 From 6bf81a7a61d8d5248be5def955104c44fcb78dae Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 7 Jan 2022 09:10:46 -0500 Subject: Bundle aggregations outside of the serialization method. (#11612) This makes the serialization of events synchronous (and it no longer access the database), but we must manually calculate and provide the bundled aggregations. Overall this should cause no change in behavior, but is prep work for other improvements. --- changelog.d/11612.misc | 1 + synapse/events/utils.py | 126 ++++++++------------------- synapse/handlers/events.py | 2 +- synapse/handlers/initial_sync.py | 16 ++-- synapse/handlers/message.py | 2 +- synapse/handlers/pagination.py | 8 +- synapse/handlers/room.py | 10 +++ synapse/handlers/search.py | 10 +-- synapse/rest/admin/rooms.py | 16 ++-- synapse/rest/client/events.py | 2 +- synapse/rest/client/notifications.py | 2 +- synapse/rest/client/relations.py | 11 +-- synapse/rest/client/room.py | 28 +++--- synapse/rest/client/sync.py | 39 +++++---- synapse/server.py | 2 +- synapse/storage/databases/main/relations.py | 128 +++++++++++++++++++++++++++- tests/rest/client/test_retention.py | 2 +- 17 files changed, 249 insertions(+), 156 deletions(-) create mode 100644 changelog.d/11612.misc (limited to 'synapse/rest/client/relations.py') diff --git a/changelog.d/11612.misc b/changelog.d/11612.misc new file mode 100644 index 0000000000..2d886169c5 --- /dev/null +++ b/changelog.d/11612.misc @@ -0,0 +1 @@ +Avoid database access in the JSON serialization process. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 2038e72924..de0e0c1731 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -14,17 +14,7 @@ # limitations under the License. import collections.abc import re -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - List, - Mapping, - Optional, - Union, -) +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union from frozendict import frozendict @@ -32,14 +22,10 @@ from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.types import JsonDict -from synapse.util.async_helpers import yieldable_gather_results from synapse.util.frozenutils import unfreeze from . import EventBase -if TYPE_CHECKING: - from synapse.server import HomeServer - # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' # (? JsonDict: """Serializes a single event. @@ -418,66 +399,41 @@ class EventClientSerializer: serialized_event = serialize_event(event, time_now, **kwargs) # Check if there are any bundled aggregations to include with the event. - # - # Do not bundle aggregations if any of the following at true: - # - # * Support is disabled via the configuration or the caller. - # * The event is a state event. - # * The event has been redacted. - if ( - self._msc1849_enabled - and bundle_aggregations - and not event.is_state() - and not event.internal_metadata.is_redacted() - ): - await self._injected_bundled_aggregations(event, time_now, serialized_event) + if bundle_aggregations: + event_aggregations = bundle_aggregations.get(event.event_id) + if event_aggregations: + self._injected_bundled_aggregations( + event, + time_now, + bundle_aggregations[event.event_id], + serialized_event, + ) return serialized_event - async def _injected_bundled_aggregations( - self, event: EventBase, time_now: int, serialized_event: JsonDict + def _injected_bundled_aggregations( + self, + event: EventBase, + time_now: int, + aggregations: JsonDict, + serialized_event: JsonDict, ) -> None: """Potentially injects bundled aggregations into the unsigned portion of the serialized event. Args: event: The event being serialized. time_now: The current time in milliseconds + aggregations: The bundled aggregation to serialize. serialized_event: The serialized event which may be modified. """ - # Do not bundle aggregations for an event which represents an edit or an - # annotation. It does not make sense for them to have related events. - relates_to = event.content.get("m.relates_to") - if isinstance(relates_to, (dict, frozendict)): - relation_type = relates_to.get("rel_type") - if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): - return - - event_id = event.event_id - room_id = event.room_id - - # The bundled aggregations to include. - aggregations = {} - - annotations = await self.store.get_aggregation_groups_for_event( - event_id, room_id - ) - if annotations.chunk: - aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() + # Make a copy in-case the object is cached. + aggregations = aggregations.copy() - references = await self.store.get_relations_for_event( - event_id, room_id, RelationTypes.REFERENCE, direction="f" - ) - if references.chunk: - aggregations[RelationTypes.REFERENCE] = references.to_dict() - - edit = None - if event.type == EventTypes.Message: - edit = await self.store.get_applicable_edit(event_id, room_id) - - if edit: + if RelationTypes.REPLACE in aggregations: # If there is an edit replace the content, preserving existing # relations. + edit = aggregations[RelationTypes.REPLACE] # Ensure we take copies of the edit content, otherwise we risk modifying # the original event. @@ -502,27 +458,19 @@ class EventClientSerializer: } # If this event is the start of a thread, include a summary of the replies. - if self._msc3440_enabled: - ( - thread_count, - latest_thread_event, - ) = await self.store.get_thread_summary(event_id, room_id) - if latest_thread_event: - aggregations[RelationTypes.THREAD] = { - # Don't bundle aggregations as this could recurse forever. - "latest_event": await self.serialize_event( - latest_thread_event, time_now, bundle_aggregations=False - ), - "count": thread_count, - } - - # If any bundled aggregations were found, include them. - if aggregations: - serialized_event["unsigned"].setdefault("m.relations", {}).update( - aggregations + if RelationTypes.THREAD in aggregations: + # Serialize the latest thread event. + latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"] + + # Don't bundle aggregations as this could recurse forever. + aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event( + latest_thread_event, time_now, bundle_aggregations=None ) - async def serialize_events( + # Include the bundled aggregations in the event. + serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations) + + def serialize_events( self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any ) -> List[JsonDict]: """Serializes multiple events. @@ -535,9 +483,9 @@ class EventClientSerializer: Returns: The list of serialized events """ - return await yieldable_gather_results( - self.serialize_event, events, time_now=time_now, **kwargs - ) + return [ + self.serialize_event(event, time_now=time_now, **kwargs) for event in events + ] def copy_power_levels_contents( diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 1b996c420d..a3add8a586 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -119,7 +119,7 @@ class EventStreamHandler: events.extend(to_add) - chunks = await self._event_serializer.serialize_events( + chunks = self._event_serializer.serialize_events( events, time_now, as_client_event=as_client_event, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 601bab67f9..346a06ff49 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -170,7 +170,7 @@ class InitialSyncHandler: d["inviter"] = event.sender invite_event = await self.store.get_event(event.event_id) - d["invite"] = await self._event_serializer.serialize_event( + d["invite"] = self._event_serializer.serialize_event( invite_event, time_now, as_client_event=as_client_event, @@ -222,7 +222,7 @@ class InitialSyncHandler: d["messages"] = { "chunk": ( - await self._event_serializer.serialize_events( + self._event_serializer.serialize_events( messages, time_now=time_now, as_client_event=as_client_event, @@ -232,7 +232,7 @@ class InitialSyncHandler: "end": await end_token.to_string(self.store), } - d["state"] = await self._event_serializer.serialize_events( + d["state"] = self._event_serializer.serialize_events( current_state.values(), time_now=time_now, as_client_event=as_client_event, @@ -376,16 +376,14 @@ class InitialSyncHandler: "messages": { "chunk": ( # Don't bundle aggregations as this is a deprecated API. - await self._event_serializer.serialize_events(messages, time_now) + self._event_serializer.serialize_events(messages, time_now) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), }, "state": ( # Don't bundle aggregations as this is a deprecated API. - await self._event_serializer.serialize_events( - room_state.values(), time_now - ) + self._event_serializer.serialize_events(room_state.values(), time_now) ), "presence": [], "receipts": [], @@ -404,7 +402,7 @@ class InitialSyncHandler: # TODO: These concurrently time_now = self.clock.time_msec() # Don't bundle aggregations as this is a deprecated API. - state = await self._event_serializer.serialize_events( + state = self._event_serializer.serialize_events( current_state.values(), time_now ) @@ -480,7 +478,7 @@ class InitialSyncHandler: "messages": { "chunk": ( # Don't bundle aggregations as this is a deprecated API. - await self._event_serializer.serialize_events(messages, time_now) + self._event_serializer.serialize_events(messages, time_now) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5e3d3886eb..b37250aa38 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -246,7 +246,7 @@ class MessageHandler: room_state = room_state_events[membership_event_id] now = self.clock.time_msec() - events = await self._event_serializer.serialize_events(room_state.values(), now) + events = self._event_serializer.serialize_events(room_state.values(), now) return events async def get_joined_members(self, requester: Requester, room_id: str) -> dict: diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 7469cc55a2..472688f045 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -537,14 +537,16 @@ class PaginationHandler: state_dict = await self.store.get_events(list(state_ids.values())) state = state_dict.values() + aggregations = await self.store.get_bundled_aggregations(events) + time_now = self.clock.time_msec() chunk = { "chunk": ( - await self._event_serializer.serialize_events( + self._event_serializer.serialize_events( events, time_now, - bundle_aggregations=True, + bundle_aggregations=aggregations, as_client_event=as_client_event, ) ), @@ -553,7 +555,7 @@ class PaginationHandler: } if state: - chunk["state"] = await self._event_serializer.serialize_events( + chunk["state"] = self._event_serializer.serialize_events( state, time_now, as_client_event=as_client_event ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3d3a0f6ac3..3d47163f25 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1181,6 +1181,16 @@ class RoomContextHandler: # `filtered` rather than the event we retrieved from the datastore. results["event"] = filtered[0] + # Fetch the aggregations. + aggregations = await self.store.get_bundled_aggregations([results["event"]]) + aggregations.update( + await self.store.get_bundled_aggregations(results["events_before"]) + ) + aggregations.update( + await self.store.get_bundled_aggregations(results["events_after"]) + ) + results["aggregations"] = aggregations + if results["events_after"]: last_event_id = results["events_after"][-1].event_id else: diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index ab7eaab2fb..0b153a6822 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -420,10 +420,10 @@ class SearchHandler: time_now = self.clock.time_msec() for context in contexts.values(): - context["events_before"] = await self._event_serializer.serialize_events( + context["events_before"] = self._event_serializer.serialize_events( context["events_before"], time_now ) - context["events_after"] = await self._event_serializer.serialize_events( + context["events_after"] = self._event_serializer.serialize_events( context["events_after"], time_now ) @@ -441,9 +441,7 @@ class SearchHandler: results.append( { "rank": rank_map[e.event_id], - "result": ( - await self._event_serializer.serialize_event(e, time_now) - ), + "result": self._event_serializer.serialize_event(e, time_now), "context": contexts.get(e.event_id, {}), } ) @@ -457,7 +455,7 @@ class SearchHandler: if state_results: s = {} for room_id, state_events in state_results.items(): - s[room_id] = await self._event_serializer.serialize_events( + s[room_id] = self._event_serializer.serialize_events( state_events, time_now ) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 6030373ebc..2e714ac87b 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -424,7 +424,7 @@ class RoomStateRestServlet(RestServlet): event_ids = await self.store.get_current_state_ids(room_id) events = await self.store.get_events(event_ids.values()) now = self.clock.time_msec() - room_state = await self._event_serializer.serialize_events(events.values(), now) + room_state = self._event_serializer.serialize_events(events.values(), now) ret = {"state": room_state} return HTTPStatus.OK, ret @@ -744,22 +744,22 @@ class RoomEventContextServlet(RestServlet): ) time_now = self.clock.time_msec() - results["events_before"] = await self._event_serializer.serialize_events( + results["events_before"] = self._event_serializer.serialize_events( results["events_before"], time_now, - bundle_aggregations=True, + bundle_aggregations=results["aggregations"], ) - results["event"] = await self._event_serializer.serialize_event( + results["event"] = self._event_serializer.serialize_event( results["event"], time_now, - bundle_aggregations=True, + bundle_aggregations=results["aggregations"], ) - results["events_after"] = await self._event_serializer.serialize_events( + results["events_after"] = self._event_serializer.serialize_events( results["events_after"], time_now, - bundle_aggregations=True, + bundle_aggregations=results["aggregations"], ) - results["state"] = await self._event_serializer.serialize_events( + results["state"] = self._event_serializer.serialize_events( results["state"], time_now ) diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 13b72a045a..672c821061 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -91,7 +91,7 @@ class EventRestServlet(RestServlet): time_now = self.clock.time_msec() if event: - result = await self._event_serializer.serialize_event(event, time_now) + result = self._event_serializer.serialize_event(event, time_now) return 200, result else: return 404, "Event not found." diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index acd0c9e135..8e427a96a3 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -72,7 +72,7 @@ class NotificationsServlet(RestServlet): "actions": pa.actions, "ts": pa.received_ts, "event": ( - await self._event_serializer.serialize_event( + self._event_serializer.serialize_event( notif_events[pa.event_id], self.clock.time_msec(), event_format=format_event_for_client_v2_without_room_id, diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 3823498012..37d949a71e 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -113,13 +113,14 @@ class RelationPaginationServlet(RestServlet): now = self.clock.time_msec() # Do not bundle aggregations when retrieving the original event because # we want the content before relations are applied to it. - original_event = await self._event_serializer.serialize_event( - event, now, bundle_aggregations=False + original_event = self._event_serializer.serialize_event( + event, now, bundle_aggregations=None ) # The relations returned for the requested event do include their # bundled aggregations. - serialized_events = await self._event_serializer.serialize_events( - events, now, bundle_aggregations=True + aggregations = await self.store.get_bundled_aggregations(events) + serialized_events = self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations ) return_value = pagination_chunk.to_dict() @@ -308,7 +309,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ) now = self.clock.time_msec() - serialized_events = await self._event_serializer.serialize_events(events, now) + serialized_events = self._event_serializer.serialize_events(events, now) return_value = result.to_dict() return_value["chunk"] = serialized_events diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 40330749e5..da6014900a 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -642,6 +642,7 @@ class RoomEventServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() + self._store = hs.get_datastore() self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() @@ -660,10 +661,13 @@ class RoomEventServlet(RestServlet): # https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-client-r0-rooms-roomid-event-eventid raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) - time_now = self.clock.time_msec() if event: - event_dict = await self._event_serializer.serialize_event( - event, time_now, bundle_aggregations=True + # Ensure there are bundled aggregations available. + aggregations = await self._store.get_bundled_aggregations([event]) + + time_now = self.clock.time_msec() + event_dict = self._event_serializer.serialize_event( + event, time_now, bundle_aggregations=aggregations ) return 200, event_dict @@ -708,16 +712,20 @@ class RoomEventContextServlet(RestServlet): raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() - results["events_before"] = await self._event_serializer.serialize_events( - results["events_before"], time_now, bundle_aggregations=True + results["events_before"] = self._event_serializer.serialize_events( + results["events_before"], + time_now, + bundle_aggregations=results["aggregations"], ) - results["event"] = await self._event_serializer.serialize_event( - results["event"], time_now, bundle_aggregations=True + results["event"] = self._event_serializer.serialize_event( + results["event"], time_now, bundle_aggregations=results["aggregations"] ) - results["events_after"] = await self._event_serializer.serialize_events( - results["events_after"], time_now, bundle_aggregations=True + results["events_after"] = self._event_serializer.serialize_events( + results["events_after"], + time_now, + bundle_aggregations=results["aggregations"], ) - results["state"] = await self._event_serializer.serialize_events( + results["state"] = self._event_serializer.serialize_events( results["state"], time_now ) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index e99a943d0d..a3e57e4b20 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -17,7 +17,6 @@ from collections import defaultdict from typing import ( TYPE_CHECKING, Any, - Awaitable, Callable, Dict, Iterable, @@ -395,7 +394,7 @@ class SyncRestServlet(RestServlet): """ invited = {} for room in rooms: - invite = await self._event_serializer.serialize_event( + invite = self._event_serializer.serialize_event( room.invite, time_now, token_id=token_id, @@ -432,7 +431,7 @@ class SyncRestServlet(RestServlet): """ knocked = {} for room in rooms: - knock = await self._event_serializer.serialize_event( + knock = self._event_serializer.serialize_event( room.knock, time_now, token_id=token_id, @@ -525,21 +524,14 @@ class SyncRestServlet(RestServlet): The room, encoded in our response format """ - def serialize(events: Iterable[EventBase]) -> Awaitable[List[JsonDict]]: + def serialize( + events: Iterable[EventBase], + aggregations: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> List[JsonDict]: return self._event_serializer.serialize_events( events, time_now=time_now, - # Don't bother to bundle aggregations if the timeline is unlimited, - # as clients will have all the necessary information. - # bundle_aggregations=room.timeline.limited, - # - # richvdh 2021-12-15: disable this temporarily as it has too high an - # overhead for initialsyncs. We need to figure out a way that the - # bundling can be done *before* the events are stored in the - # SyncResponseCache so that this part can be synchronous. - # - # Ensure to re-enable the test at tests/rest/client/test_relations.py::RelationsTestCase.test_bundled_aggregations. - bundle_aggregations=False, + bundle_aggregations=aggregations, token_id=token_id, event_format=event_formatter, only_event_fields=only_fields, @@ -561,8 +553,21 @@ class SyncRestServlet(RestServlet): event.room_id, ) - serialized_state = await serialize(state_events) - serialized_timeline = await serialize(timeline_events) + serialized_state = serialize(state_events) + # Don't bother to bundle aggregations if the timeline is unlimited, + # as clients will have all the necessary information. + # bundle_aggregations=room.timeline.limited, + # + # richvdh 2021-12-15: disable this temporarily as it has too high an + # overhead for initialsyncs. We need to figure out a way that the + # bundling can be done *before* the events are stored in the + # SyncResponseCache so that this part can be synchronous. + # + # Ensure to re-enable the test at tests/rest/client/test_relations.py::RelationsTestCase.test_bundled_aggregations. + # if room.timeline.limited: + # aggregations = await self.store.get_bundled_aggregations(timeline_events) + aggregations = None + serialized_timeline = serialize(timeline_events, aggregations) account_data = room.account_data diff --git a/synapse/server.py b/synapse/server.py index 185e40e4da..3032f0b738 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -759,7 +759,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_event_client_serializer(self) -> EventClientSerializer: - return EventClientSerializer(self) + return EventClientSerializer() @cache_in_self def get_password_policy_handler(self) -> PasswordPolicyHandler: diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 4ff6aed253..c6c4bd18da 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,14 +13,30 @@ # limitations under the License. import logging -from typing import List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, + cast, +) import attr +from frozendict import frozendict -from synapse.api.constants import RelationTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.relations import ( AggregationPaginationToken, @@ -29,10 +45,24 @@ from synapse.storage.relations import ( ) from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class RelationsWorkerStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self._msc1849_enabled = hs.config.experimental.msc1849_enabled + self._msc3440_enabled = hs.config.experimental.msc3440_enabled + @cached(tree=True) async def get_relations_for_event( self, @@ -515,6 +545,98 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) + async def _get_bundled_aggregation_for_event( + self, event: EventBase + ) -> Optional[Dict[str, Any]]: + """Generate bundled aggregations for an event. + + Note that this does not use a cache, but depends on cached methods. + + Args: + event: The event to calculate bundled aggregations for. + + Returns: + The bundled aggregations for an event, if bundled aggregations are + enabled and the event can have bundled aggregations. + """ + # State events and redacted events do not get bundled aggregations. + if event.is_state() or event.internal_metadata.is_redacted(): + return None + + # Do not bundle aggregations for an event which represents an edit or an + # annotation. It does not make sense for them to have related events. + relates_to = event.content.get("m.relates_to") + if isinstance(relates_to, (dict, frozendict)): + relation_type = relates_to.get("rel_type") + if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): + return None + + event_id = event.event_id + room_id = event.room_id + + # The bundled aggregations to include, a mapping of relation type to a + # type-specific value. Some types include the direct return type here + # while others need more processing during serialization. + aggregations: Dict[str, Any] = {} + + annotations = await self.get_aggregation_groups_for_event(event_id, room_id) + if annotations.chunk: + aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() + + references = await self.get_relations_for_event( + event_id, room_id, RelationTypes.REFERENCE, direction="f" + ) + if references.chunk: + aggregations[RelationTypes.REFERENCE] = references.to_dict() + + edit = None + if event.type == EventTypes.Message: + edit = await self.get_applicable_edit(event_id, room_id) + + if edit: + aggregations[RelationTypes.REPLACE] = edit + + # If this event is the start of a thread, include a summary of the replies. + if self._msc3440_enabled: + ( + thread_count, + latest_thread_event, + ) = await self.get_thread_summary(event_id, room_id) + if latest_thread_event: + aggregations[RelationTypes.THREAD] = { + # Don't bundle aggregations as this could recurse forever. + "latest_event": latest_thread_event, + "count": thread_count, + } + + # Store the bundled aggregations in the event metadata for later use. + return aggregations + + async def get_bundled_aggregations( + self, events: Iterable[EventBase] + ) -> Dict[str, Dict[str, Any]]: + """Generate bundled aggregations for events. + + Args: + events: The iterable of events to calculate bundled aggregations for. + + Returns: + A map of event ID to the bundled aggregation for the event. Not all + events may have bundled aggregations in the results. + """ + # If bundled aggregations are disabled, nothing to do. + if not self._msc1849_enabled: + return {} + + # TODO Parallelize. + results = {} + for event in events: + event_result = await self._get_bundled_aggregation_for_event(event) + if event_result is not None: + results[event.event_id] = event_result + + return results + class RelationsStore(RelationsWorkerStore): pass diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index b58452195a..fe5b536d97 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -228,7 +228,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(event) time_now = self.clock.time_msec() - serialized = self.get_success(self.serializer.serialize_event(event, time_now)) + serialized = self.serializer.serialize_event(event, time_now) return serialized -- cgit 1.5.1 From 68acb0a29dcb03a0ecbcebdb95e09c5999598f42 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 18 Jan 2022 11:38:57 -0500 Subject: Include whether the requesting user has participated in a thread. (#11577) Per updates to MSC3440. This is implement as a separate method since it needs to be cached on a per-user basis, instead of a per-thread basis. --- changelog.d/11577.feature | 1 + synapse/handlers/pagination.py | 2 +- synapse/handlers/room.py | 12 ++++-- synapse/handlers/sync.py | 4 +- synapse/rest/client/relations.py | 4 +- synapse/rest/client/room.py | 4 +- synapse/storage/databases/main/events.py | 7 +++ synapse/storage/databases/main/relations.py | 66 ++++++++++++++++++++++++----- tests/rest/client/test_relations.py | 3 ++ 9 files changed, 85 insertions(+), 18 deletions(-) create mode 100644 changelog.d/11577.feature (limited to 'synapse/rest/client/relations.py') diff --git a/changelog.d/11577.feature b/changelog.d/11577.feature new file mode 100644 index 0000000000..f9c8a0d5f4 --- /dev/null +++ b/changelog.d/11577.feature @@ -0,0 +1 @@ +Include whether the requesting user has participated in a thread when generating a summary for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440). diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 472688f045..973f262964 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -537,7 +537,7 @@ class PaginationHandler: state_dict = await self.store.get_events(list(state_ids.values())) state = state_dict.values() - aggregations = await self.store.get_bundled_aggregations(events) + aggregations = await self.store.get_bundled_aggregations(events, user_id) time_now = self.clock.time_msec() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3d47163f25..f963078e59 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1182,12 +1182,18 @@ class RoomContextHandler: results["event"] = filtered[0] # Fetch the aggregations. - aggregations = await self.store.get_bundled_aggregations([results["event"]]) + aggregations = await self.store.get_bundled_aggregations( + [results["event"]], user.to_string() + ) aggregations.update( - await self.store.get_bundled_aggregations(results["events_before"]) + await self.store.get_bundled_aggregations( + results["events_before"], user.to_string() + ) ) aggregations.update( - await self.store.get_bundled_aggregations(results["events_after"]) + await self.store.get_bundled_aggregations( + results["events_after"], user.to_string() + ) ) results["aggregations"] = aggregations diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index e1df9b3106..ffc6b748e8 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -637,7 +637,9 @@ class SyncHandler: # as clients will have all the necessary information. bundled_aggregations = None if limited or newly_joined_room: - bundled_aggregations = await self.store.get_bundled_aggregations(recents) + bundled_aggregations = await self.store.get_bundled_aggregations( + recents, sync_config.user.to_string() + ) return TimelineBatch( events=recents, diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 37d949a71e..8cf5ebaa07 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -118,7 +118,9 @@ class RelationPaginationServlet(RestServlet): ) # The relations returned for the requested event do include their # bundled aggregations. - aggregations = await self.store.get_bundled_aggregations(events) + aggregations = await self.store.get_bundled_aggregations( + events, requester.user.to_string() + ) serialized_events = self._event_serializer.serialize_events( events, now, bundle_aggregations=aggregations ) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index da6014900a..31fd329a38 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -663,7 +663,9 @@ class RoomEventServlet(RestServlet): if event: # Ensure there are bundled aggregations available. - aggregations = await self._store.get_bundled_aggregations([event]) + aggregations = await self._store.get_bundled_aggregations( + [event], requester.user.to_string() + ) time_now = self.clock.time_msec() event_dict = self._event_serializer.serialize_event( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 2be36a741a..7278002322 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1793,6 +1793,13 @@ class PersistEventsStore: txn.call_after( self.store.get_thread_summary.invalidate, (parent_id, event.room_id) ) + # It should be safe to only invalidate the cache if the user has not + # previously participated in the thread, but that's difficult (and + # potentially error-prone) so it is always invalidated. + txn.call_after( + self.store.get_thread_participated.invalidate, + (parent_id, event.room_id, event.sender), + ) def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): """Handles keeping track of insertion events and edges/connections. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index c6c4bd18da..2cb5d06c13 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -384,8 +384,7 @@ class RelationsWorkerStore(SQLBaseStore): async def get_thread_summary( self, event_id: str, room_id: str ) -> Tuple[int, Optional[EventBase]]: - """Get the number of threaded replies, the senders of those replies, and - the latest reply (if any) for the given event. + """Get the number of threaded replies and the latest reply (if any) for the given event. Args: event_id: Summarize the thread related to this event ID. @@ -398,7 +397,7 @@ class RelationsWorkerStore(SQLBaseStore): def _get_thread_summary_txn( txn: LoggingTransaction, ) -> Tuple[int, Optional[str]]: - # Fetch the count of threaded events and the latest event ID. + # Fetch the latest event ID in the thread. # TODO Should this only allow m.room.message events. sql = """ SELECT event_id @@ -419,6 +418,7 @@ class RelationsWorkerStore(SQLBaseStore): latest_event_id = row[0] + # Fetch the number of threaded replies. sql = """ SELECT COUNT(event_id) FROM event_relations @@ -443,6 +443,44 @@ class RelationsWorkerStore(SQLBaseStore): return count, latest_event + @cached() + async def get_thread_participated( + self, event_id: str, room_id: str, user_id: str + ) -> bool: + """Get whether the requesting user participated in a thread. + + This is separate from get_thread_summary since that can be cached across + all users while this value is specific to the requeser. + + Args: + event_id: The thread related to this event ID. + room_id: The room the event belongs to. + user_id: The user requesting the summary. + + Returns: + True if the requesting user participated in the thread, otherwise false. + """ + + def _get_thread_summary_txn(txn: LoggingTransaction) -> bool: + # Fetch whether the requester has participated or not. + sql = """ + SELECT 1 + FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id = ? + AND room_id = ? + AND relation_type = ? + AND sender = ? + """ + + txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id)) + return bool(txn.fetchone()) + + return await self.db_pool.runInteraction( + "get_thread_summary", _get_thread_summary_txn + ) + async def events_have_relations( self, parent_ids: List[str], @@ -546,7 +584,7 @@ class RelationsWorkerStore(SQLBaseStore): ) async def _get_bundled_aggregation_for_event( - self, event: EventBase + self, event: EventBase, user_id: str ) -> Optional[Dict[str, Any]]: """Generate bundled aggregations for an event. @@ -554,6 +592,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event: The event to calculate bundled aggregations for. + user_id: The user requesting the bundled aggregations. Returns: The bundled aggregations for an event, if bundled aggregations are @@ -598,27 +637,32 @@ class RelationsWorkerStore(SQLBaseStore): # If this event is the start of a thread, include a summary of the replies. if self._msc3440_enabled: - ( - thread_count, - latest_thread_event, - ) = await self.get_thread_summary(event_id, room_id) + thread_count, latest_thread_event = await self.get_thread_summary( + event_id, room_id + ) + participated = await self.get_thread_participated( + event_id, room_id, user_id + ) if latest_thread_event: aggregations[RelationTypes.THREAD] = { - # Don't bundle aggregations as this could recurse forever. "latest_event": latest_thread_event, "count": thread_count, + "current_user_participated": participated, } # Store the bundled aggregations in the event metadata for later use. return aggregations async def get_bundled_aggregations( - self, events: Iterable[EventBase] + self, + events: Iterable[EventBase], + user_id: str, ) -> Dict[str, Dict[str, Any]]: """Generate bundled aggregations for events. Args: events: The iterable of events to calculate bundled aggregations for. + user_id: The user requesting the bundled aggregations. Returns: A map of event ID to the bundled aggregation for the event. Not all @@ -631,7 +675,7 @@ class RelationsWorkerStore(SQLBaseStore): # TODO Parallelize. results = {} for event in events: - event_result = await self._get_bundled_aggregation_for_event(event) + event_result = await self._get_bundled_aggregation_for_event(event, user_id) if event_result is not None: results[event.event_id] = event_result diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index ee26751430..4b20ab0e3e 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -515,6 +515,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): 2, actual[RelationTypes.THREAD].get("count"), ) + self.assertTrue( + actual[RelationTypes.THREAD].get("current_user_participated") + ) # The latest thread event has some fields that don't matter. self.assert_dict( { -- cgit 1.5.1