summary refs log tree commit diff
path: root/synapse/handlers/relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/relations.py')
-rw-r--r--synapse/handlers/relations.py137
1 files changed, 109 insertions, 28 deletions
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 49c9d6e3c6..8fdb9df012 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import enum
 import logging
 from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
 
@@ -20,7 +21,8 @@ from synapse.api.constants import RelationTypes
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase, relation_from_event
 from synapse.logging.tracing import trace
-from synapse.storage.databases.main.relations import _RelatedEvent
+from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
+from synapse.streams.config import PaginationConfig
 from synapse.types import JsonDict, Requester, StreamToken, UserID
 from synapse.visibility import filter_events_for_client
 
@@ -31,6 +33,13 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class ThreadsListInclude(str, enum.Enum):
+    """Valid values for the 'include' flag of /threads."""
+
+    all = "all"
+    participated = "participated"
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _ThreadAggregation:
     # The latest event in the thread.
@@ -72,12 +81,10 @@ class RelationsHandler:
         requester: Requester,
         event_id: str,
         room_id: str,
+        pagin_config: PaginationConfig,
+        include_original_event: bool,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
-        limit: int = 5,
-        direction: str = "b",
-        from_token: Optional[StreamToken] = None,
-        to_token: Optional[StreamToken] = None,
     ) -> JsonDict:
         """Get related events of a event, ordered by topological ordering.
 
@@ -87,13 +94,10 @@ class RelationsHandler:
             requester: The user requesting the relations.
             event_id: Fetch events that relate to this event ID.
             room_id: The room the event belongs to.
+            pagin_config: The pagination config rules to apply, if any.
+            include_original_event: Whether to include the parent event.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
-            limit: Only fetch the most recent `limit` events.
-            direction: Whether to fetch the most recent first (`"b"`) or the
-                oldest first (`"f"`).
-            from_token: Fetch rows from the given token, or from the start if None.
-            to_token: Fetch rows up to the given token, or up to the end if None.
 
         Returns:
             The pagination chunk.
@@ -121,10 +125,10 @@ class RelationsHandler:
             room_id=room_id,
             relation_type=relation_type,
             event_type=event_type,
-            limit=limit,
-            direction=direction,
-            from_token=from_token,
-            to_token=to_token,
+            limit=pagin_config.limit,
+            direction=pagin_config.direction,
+            from_token=pagin_config.from_token,
+            to_token=pagin_config.to_token,
         )
 
         events = await self._main_store.get_events_as_list(
@@ -138,31 +142,32 @@ class RelationsHandler:
             is_peeking=(member_event_id is None),
         )
 
-        now = self._clock.time_msec()
-        # Do not bundle aggregations when retrieving the original event because
-        # we want the content before relations are applied to it.
-        original_event = self._event_serializer.serialize_event(
-            event, now, bundle_aggregations=None
-        )
         # The relations returned for the requested event do include their
         # bundled aggregations.
         aggregations = await self.get_bundled_aggregations(
             events, requester.user.to_string()
         )
-        serialized_events = self._event_serializer.serialize_events(
-            events, now, bundle_aggregations=aggregations
-        )
 
-        return_value = {
-            "chunk": serialized_events,
-            "original_event": original_event,
+        now = self._clock.time_msec()
+        return_value: JsonDict = {
+            "chunk": self._event_serializer.serialize_events(
+                events, now, bundle_aggregations=aggregations
+            ),
         }
+        if include_original_event:
+            # Do not bundle aggregations when retrieving the original event because
+            # we want the content before relations are applied to it.
+            return_value["original_event"] = self._event_serializer.serialize_event(
+                event, now, bundle_aggregations=None
+            )
 
         if next_token:
             return_value["next_batch"] = await next_token.to_string(self._main_store)
 
-        if from_token:
-            return_value["prev_batch"] = await from_token.to_string(self._main_store)
+        if pagin_config.from_token:
+            return_value["prev_batch"] = await pagin_config.from_token.to_string(
+                self._main_store
+            )
 
         return return_value
 
@@ -482,3 +487,79 @@ class RelationsHandler:
             results.setdefault(event_id, BundledAggregations()).replace = edit
 
         return results
+
+    async def get_threads(
+        self,
+        requester: Requester,
+        room_id: str,
+        include: ThreadsListInclude,
+        limit: int = 5,
+        from_token: Optional[ThreadsNextBatch] = None,
+    ) -> JsonDict:
+        """Get related events of a event, ordered by topological ordering.
+
+        Args:
+            requester: The user requesting the relations.
+            room_id: The room the event belongs to.
+            include: One of "all" or "participated" to indicate which threads should
+                be returned.
+            limit: Only fetch the most recent `limit` events.
+            from_token: Fetch rows from the given token, or from the start if None.
+
+        Returns:
+            The pagination chunk.
+        """
+
+        user_id = requester.user.to_string()
+
+        # TODO Properly handle a user leaving a room.
+        (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
+            room_id, requester, allow_departed_users=True
+        )
+
+        # Note that ignored users are not passed into get_relations_for_event
+        # below. Ignored users are handled in filter_events_for_client (and by
+        # not passing them in here we should get a better cache hit rate).
+        thread_roots, next_batch = await self._main_store.get_threads(
+            room_id=room_id, limit=limit, from_token=from_token
+        )
+
+        events = await self._main_store.get_events_as_list(thread_roots)
+
+        if include == ThreadsListInclude.participated:
+            # Pre-seed thread participation with whether the requester sent the event.
+            participated = {event.event_id: event.sender == user_id for event in events}
+            # For events the requester did not send, check the database for whether
+            # the requester sent a threaded reply.
+            participated.update(
+                await self._main_store.get_threads_participated(
+                    [eid for eid, p in participated.items() if not p],
+                    user_id,
+                )
+            )
+
+            # Limit the returned threads to those the user has participated in.
+            events = [event for event in events if participated[event.event_id]]
+
+        events = await filter_events_for_client(
+            self._storage_controllers,
+            user_id,
+            events,
+            is_peeking=(member_event_id is None),
+        )
+
+        aggregations = await self.get_bundled_aggregations(
+            events, requester.user.to_string()
+        )
+
+        now = self._clock.time_msec()
+        serialized_events = self._event_serializer.serialize_events(
+            events, now, bundle_aggregations=aggregations
+        )
+
+        return_value: JsonDict = {"chunk": serialized_events}
+
+        if next_batch:
+            return_value["next_batch"] = str(next_batch)
+
+        return return_value