summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2022-07-25 15:01:06 -0400
committerPatrick Cloke <patrickc@matrix.org>2022-07-27 11:46:22 -0400
commite9a649ec314122d34d1b408c83800e1f1e263554 (patch)
treed29cd47c03635234d7140189cb4ddcda527a4406
parentImplement MSC3848: Introduce errcodes for specific event sending failures (#1... (diff)
downloadsynapse-e9a649ec314122d34d1b408c83800e1f1e263554.tar.xz
Add an API for listing threads in a room.
-rw-r--r--changelog.d/13394.feature1
-rw-r--r--synapse/config/experimental.py3
-rw-r--r--synapse/handlers/relations.py63
-rw-r--r--synapse/rest/client/relations.py44
-rw-r--r--synapse/storage/databases/main/events.py7
-rw-r--r--synapse/storage/databases/main/relations.py87
-rw-r--r--tests/rest/client/test_relations.py74
7 files changed, 277 insertions, 2 deletions
diff --git a/changelog.d/13394.feature b/changelog.d/13394.feature
new file mode 100644
index 0000000000..68de079cf3
--- /dev/null
+++ b/changelog.d/13394.feature
@@ -0,0 +1 @@
+Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API.
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 1902222d7b..5613c8e038 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -93,3 +93,6 @@ class ExperimentalConfig(Config):
 
         # MSC3848: Introduce errcodes for specific event sending failures
         self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)
+
+        # MSC3856: Threads list API
+        self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False)
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 8f797e3ae9..af83ac1a7b 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -480,3 +480,66 @@ class RelationsHandler:
             results.setdefault(event_id, BundledAggregations()).replace = edit
 
         return results
+
+    async def get_threads(
+        self,
+        requester: Requester,
+        room_id: str,
+        limit: int = 5,
+        from_token: Optional[StreamToken] = None,
+        to_token: Optional[StreamToken] = None,
+    ) -> JsonDict:
+        """Get related events of a event, ordered by topological ordering.
+
+        Args:
+            requester: The user requesting the relations.
+            room_id: The room the event belongs to.
+            limit: Only fetch the most recent `limit` events.
+            from_token: Fetch rows from the given token, or from the start if None.
+            to_token: Fetch rows up to the given token, or up to the end if None.
+
+        Returns:
+            The pagination chunk.
+        """
+
+        user_id = requester.user.to_string()
+
+        # TODO Properly handle a user leaving a room.
+        (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
+            room_id, user_id, allow_departed_users=True
+        )
+
+        # Note that ignored users are not passed into get_relations_for_event
+        # below. Ignored users are handled in filter_events_for_client (and by
+        # not passing them in here we should get a better cache hit rate).
+        thread_roots, next_token = await self._main_store.get_threads(
+            room_id=room_id, limit=limit, from_token=from_token, to_token=to_token
+        )
+
+        events = await self._main_store.get_events_as_list(thread_roots)
+
+        events = await filter_events_for_client(
+            self._storage_controllers,
+            user_id,
+            events,
+            is_peeking=(member_event_id is None),
+        )
+
+        now = self._clock.time_msec()
+
+        aggregations = await self.get_bundled_aggregations(
+            events, requester.user.to_string()
+        )
+        serialized_events = self._event_serializer.serialize_events(
+            events, now, bundle_aggregations=aggregations
+        )
+
+        return_value: JsonDict = {"chunk": serialized_events}
+
+        if next_token:
+            return_value["next_batch"] = await next_token.to_string(self._main_store)
+
+        if from_token:
+            return_value["prev_batch"] = await from_token.to_string(self._main_store)
+
+        return return_value
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index ce97080013..faa962e3a8 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import logging
+import re
 from typing import TYPE_CHECKING, Optional, Tuple
 
 from synapse.http.server import HttpServer
@@ -91,5 +92,48 @@ class RelationPaginationServlet(RestServlet):
         return 200, result
 
 
+class ThreadsServlet(RestServlet):
+    PATTERNS = (
+        re.compile(
+            "^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P<room_id>[^/]*)/threads"
+        ),
+    )
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastores().main
+        self._relations_handler = hs.get_relations_handler()
+
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+
+        limit = parse_integer(request, "limit", default=5)
+        from_token_str = parse_string(request, "from")
+        to_token_str = parse_string(request, "to")
+
+        # Return the relations
+        from_token = None
+        if from_token_str:
+            from_token = await StreamToken.from_string(self.store, from_token_str)
+        to_token = None
+        if to_token_str:
+            to_token = await StreamToken.from_string(self.store, to_token_str)
+
+        result = await self._relations_handler.get_threads(
+            requester=requester,
+            room_id=room_id,
+            limit=limit,
+            from_token=from_token,
+            to_token=to_token,
+        )
+
+        return 200, result
+
+
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     RelationPaginationServlet(hs).register(http_server)
+    if hs.config.experimental.msc3856_enabled:
+        ThreadsServlet(hs).register(http_server)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1f600f1190..cf18dd179c 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1594,7 +1594,7 @@ class PersistEventsStore:
                 )
 
                 # Remove from relations table.
-                self._handle_redact_relations(txn, event.redacts)
+                self._handle_redact_relations(txn, event.room_id, event.redacts)
 
         # Update the event_forward_extremities, event_backward_extremities and
         # event_edges tables.
@@ -1909,6 +1909,7 @@ class PersistEventsStore:
                 self.store.get_thread_participated.invalidate,
                 (relation.parent_id, event.sender),
             )
+            txn.call_after(self.store.get_threads.invalidate, (event.room_id,))
 
     def _handle_insertion_event(
         self, txn: LoggingTransaction, event: EventBase
@@ -2033,13 +2034,14 @@ class PersistEventsStore:
         txn.execute(sql, (batch_id,))
 
     def _handle_redact_relations(
-        self, txn: LoggingTransaction, redacted_event_id: str
+        self, txn: LoggingTransaction, room_id: str, redacted_event_id: str
     ) -> None:
         """Handles receiving a redaction and checking whether the redacted event
         has any relations which must be removed from the database.
 
         Args:
             txn
+            room_id: The room ID of the event that was redacted.
             redacted_event_id: The event that was redacted.
         """
 
@@ -2068,6 +2070,7 @@ class PersistEventsStore:
             self.store._invalidate_cache_and_stream(
                 txn, self.store.get_thread_participated, (redacted_relates_to,)
             )
+            txn.call_after(self.store.get_threads.invalidate, (room_id,))
             self.store._invalidate_cache_and_stream(
                 txn,
                 self.store.get_mutual_event_relations_for_rel_type,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 7bd27790eb..57b2f7c188 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -814,6 +814,93 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_event_relations", _get_event_relations
         )
 
+    @cached(tree=True)
+    async def get_threads(
+        self,
+        room_id: str,
+        limit: int = 5,
+        from_token: Optional[StreamToken] = None,
+        to_token: Optional[StreamToken] = None,
+    ) -> Tuple[List[str], Optional[StreamToken]]:
+        """Get a list of thread IDs, ordered by topological ordering of their
+        latest reply.
+
+        Args:
+            room_id: The room the event belongs to.
+            limit: Only fetch the most recent `limit` threads.
+            from_token: Fetch rows from the given token, or from the start if None.
+            to_token: Fetch rows up to the given token, or up to the end if None.
+
+        Returns:
+            A tuple of:
+                A list of thread root event IDs.
+
+                The next stream token, if one exists.
+        """
+        pagination_clause = generate_pagination_where_clause(
+            direction="b",
+            column_names=("topological_ordering", "stream_ordering"),
+            from_token=from_token.room_key.as_historical_tuple()
+            if from_token
+            else None,
+            to_token=to_token.room_key.as_historical_tuple() if to_token else None,
+            engine=self.database_engine,
+        )
+
+        if pagination_clause:
+            pagination_clause = "AND " + pagination_clause
+
+        sql = f"""
+            SELECT relates_to_id, MAX(topological_ordering), MAX(stream_ordering)
+            FROM event_relations
+            INNER JOIN events USING (event_id)
+            WHERE
+                room_id = ? AND
+                relation_type = '{RelationTypes.THREAD}'
+                {pagination_clause}
+            GROUP BY relates_to_id
+            ORDER BY MAX(topological_ordering) DESC, MAX(stream_ordering) DESC
+            LIMIT ?
+        """
+
+        def _get_threads_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[str], Optional[StreamToken]]:
+            txn.execute(sql, [room_id, limit + 1])
+
+            last_topo_id = None
+            last_stream_id = None
+            thread_ids = []
+            for thread_id, topo_id, stream_id in txn:
+                thread_ids.append(thread_id)
+                last_topo_id = topo_id
+                last_stream_id = stream_id
+
+            # If there are more events, generate the next pagination key.
+            next_token = None
+            if len(thread_ids) > limit and last_topo_id and last_stream_id:
+                next_key = RoomStreamToken(last_topo_id, last_stream_id)
+                if from_token:
+                    next_token = from_token.copy_and_replace(
+                        StreamKeyType.ROOM, next_key
+                    )
+                else:
+                    next_token = StreamToken(
+                        room_key=next_key,
+                        presence_key=0,
+                        typing_key=0,
+                        receipt_key=0,
+                        account_data_key=0,
+                        push_rules_key=0,
+                        to_device_key=0,
+                        device_list_key=0,
+                        groups_key=0,
+                    )
+
+            return thread_ids[:limit], next_token
+
+        return await self.db_pool.runInteraction("get_threads", _get_threads_txn)
+
 
 class RelationsStore(RelationsWorkerStore):
     pass
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index ad03eee17b..0666bec479 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1677,3 +1677,77 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
             relations[RelationTypes.THREAD]["latest_event"]["event_id"],
             related_event_id,
         )
+
+
+class ThreadsTestCase(BaseRelationsTestCase):
+    @unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
+    def test_threads(self) -> None:
+        """Create threads and ensure the ordering is due to their latest event."""
+        # Create 2 threads.
+        thread_1 = self.parent_id
+        res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
+        thread_2 = res["event_id"]
+
+        self._send_relation(RelationTypes.THREAD, "m.room.test")
+        self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
+
+        # Request the threads in the room.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+        self.assertEqual(thread_roots, [thread_2, thread_1])
+
+        # Update the first thread, the ordering should swap.
+        self._send_relation(RelationTypes.THREAD, "m.room.test")
+
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+        self.assertEqual(thread_roots, [thread_1, thread_2])
+
+    @unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
+    def test_pagination(self) -> None:
+        """Create threads and paginate through them."""
+        # Create 2 threads.
+        thread_1 = self.parent_id
+        res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
+        thread_2 = res["event_id"]
+
+        self._send_relation(RelationTypes.THREAD, "m.room.test")
+        self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
+
+        # Request the threads in the room.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+        self.assertEqual(thread_roots, [thread_2])
+
+        # Make sure next_batch has something in it that looks like it could be a
+        # valid token.
+        next_batch = channel.json_body.get("next_batch")
+        self.assertIsInstance(next_batch, str, channel.json_body)
+
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1&from={next_batch}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+        self.assertEqual(thread_roots, [thread_1], channel.json_body)
+
+        self.assertNotIn("next_batch", channel.json_body, channel.json_body)
+
+    # XXX Test ignoring users.