diff options
Diffstat (limited to 'synapse/handlers/relations.py')
-rw-r--r-- | synapse/handlers/relations.py | 137 |
1 files changed, 109 insertions, 28 deletions
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 49c9d6e3c6..8fdb9df012 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import logging from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple @@ -20,7 +21,8 @@ from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.tracing import trace -from synapse.storage.databases.main.relations import _RelatedEvent +from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent +from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.visibility import filter_events_for_client @@ -31,6 +33,13 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class ThreadsListInclude(str, enum.Enum): + """Valid values for the 'include' flag of /threads.""" + + all = "all" + participated = "participated" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class _ThreadAggregation: # The latest event in the thread. @@ -72,12 +81,10 @@ class RelationsHandler: requester: Requester, event_id: str, room_id: str, + pagin_config: PaginationConfig, + include_original_event: bool, relation_type: Optional[str] = None, event_type: Optional[str] = None, - limit: int = 5, - direction: str = "b", - from_token: Optional[StreamToken] = None, - to_token: Optional[StreamToken] = None, ) -> JsonDict: """Get related events of a event, ordered by topological ordering. @@ -87,13 +94,10 @@ class RelationsHandler: requester: The user requesting the relations. event_id: Fetch events that relate to this event ID. room_id: The room the event belongs to. + pagin_config: The pagination config rules to apply, if any. + include_original_event: Whether to include the parent event. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. - limit: Only fetch the most recent `limit` events. - direction: Whether to fetch the most recent first (`"b"`) or the - oldest first (`"f"`). - from_token: Fetch rows from the given token, or from the start if None. - to_token: Fetch rows up to the given token, or up to the end if None. Returns: The pagination chunk. @@ -121,10 +125,10 @@ class RelationsHandler: room_id=room_id, relation_type=relation_type, event_type=event_type, - limit=limit, - direction=direction, - from_token=from_token, - to_token=to_token, + limit=pagin_config.limit, + direction=pagin_config.direction, + from_token=pagin_config.from_token, + to_token=pagin_config.to_token, ) events = await self._main_store.get_events_as_list( @@ -138,31 +142,32 @@ class RelationsHandler: is_peeking=(member_event_id is None), ) - now = self._clock.time_msec() - # Do not bundle aggregations when retrieving the original event because - # we want the content before relations are applied to it. - original_event = self._event_serializer.serialize_event( - event, now, bundle_aggregations=None - ) # The relations returned for the requested event do include their # bundled aggregations. aggregations = await self.get_bundled_aggregations( events, requester.user.to_string() ) - serialized_events = self._event_serializer.serialize_events( - events, now, bundle_aggregations=aggregations - ) - return_value = { - "chunk": serialized_events, - "original_event": original_event, + now = self._clock.time_msec() + return_value: JsonDict = { + "chunk": self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations + ), } + if include_original_event: + # Do not bundle aggregations when retrieving the original event because + # we want the content before relations are applied to it. + return_value["original_event"] = self._event_serializer.serialize_event( + event, now, bundle_aggregations=None + ) if next_token: return_value["next_batch"] = await next_token.to_string(self._main_store) - if from_token: - return_value["prev_batch"] = await from_token.to_string(self._main_store) + if pagin_config.from_token: + return_value["prev_batch"] = await pagin_config.from_token.to_string( + self._main_store + ) return return_value @@ -482,3 +487,79 @@ class RelationsHandler: results.setdefault(event_id, BundledAggregations()).replace = edit return results + + async def get_threads( + self, + requester: Requester, + room_id: str, + include: ThreadsListInclude, + limit: int = 5, + from_token: Optional[ThreadsNextBatch] = None, + ) -> JsonDict: + """Get related events of a event, ordered by topological ordering. + + Args: + requester: The user requesting the relations. + room_id: The room the event belongs to. + include: One of "all" or "participated" to indicate which threads should + be returned. + limit: Only fetch the most recent `limit` events. + from_token: Fetch rows from the given token, or from the start if None. + + Returns: + The pagination chunk. + """ + + user_id = requester.user.to_string() + + # TODO Properly handle a user leaving a room. + (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable( + room_id, requester, allow_departed_users=True + ) + + # Note that ignored users are not passed into get_relations_for_event + # below. Ignored users are handled in filter_events_for_client (and by + # not passing them in here we should get a better cache hit rate). + thread_roots, next_batch = await self._main_store.get_threads( + room_id=room_id, limit=limit, from_token=from_token + ) + + events = await self._main_store.get_events_as_list(thread_roots) + + if include == ThreadsListInclude.participated: + # Pre-seed thread participation with whether the requester sent the event. + participated = {event.event_id: event.sender == user_id for event in events} + # For events the requester did not send, check the database for whether + # the requester sent a threaded reply. + participated.update( + await self._main_store.get_threads_participated( + [eid for eid, p in participated.items() if not p], + user_id, + ) + ) + + # Limit the returned threads to those the user has participated in. + events = [event for event in events if participated[event.event_id]] + + events = await filter_events_for_client( + self._storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), + ) + + aggregations = await self.get_bundled_aggregations( + events, requester.user.to_string() + ) + + now = self._clock.time_msec() + serialized_events = self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations + ) + + return_value: JsonDict = {"chunk": serialized_events} + + if next_batch: + return_value["next_batch"] = str(next_batch) + + return return_value |