summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13394.feature1
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py2
-rw-r--r--synapse/config/experimental.py3
-rw-r--r--synapse/handlers/relations.py86
-rw-r--r--synapse/rest/client/relations.py50
-rw-r--r--synapse/storage/databases/main/cache.py1
-rw-r--r--synapse/storage/databases/main/events.py38
-rw-r--r--synapse/storage/databases/main/relations.py166
-rw-r--r--synapse/storage/schema/main/delta/73/09threads_table.sql30
-rw-r--r--tests/rest/client/test_relations.py151
10 files changed, 522 insertions, 6 deletions
diff --git a/changelog.d/13394.feature b/changelog.d/13394.feature
new file mode 100644
index 0000000000..68de079cf3
--- /dev/null
+++ b/changelog.d/13394.feature
@@ -0,0 +1 @@
+Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API.
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 5fa599e70e..d850e54e17 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -72,6 +72,7 @@ from synapse.storage.databases.main.registration import (
     RegistrationBackgroundUpdateStore,
     find_max_generated_user_id_localpart,
 )
+from synapse.storage.databases.main.relations import RelationsWorkerStore
 from synapse.storage.databases.main.room import RoomBackgroundUpdateStore
 from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore
 from synapse.storage.databases.main.search import SearchBackgroundUpdateStore
@@ -206,6 +207,7 @@ class Store(
     PusherWorkerStore,
     PresenceBackgroundUpdateStore,
     ReceiptsBackgroundUpdateStore,
+    RelationsWorkerStore,
 ):
     def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
         return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index f44655516e..1860006536 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -101,6 +101,9 @@ class ExperimentalConfig(Config):
         # MSC3848: Introduce errcodes for specific event sending failures
         self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)
 
+        # MSC3856: Threads list API
+        self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False)
+
         # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
         self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
 
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index cc5e45c241..1fdd7a10bc 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import enum
 import logging
 from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
 
@@ -20,7 +21,7 @@ from synapse.api.constants import RelationTypes
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase, relation_from_event
 from synapse.logging.opentracing import trace
-from synapse.storage.databases.main.relations import _RelatedEvent
+from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
 from synapse.streams.config import PaginationConfig
 from synapse.types import JsonDict, Requester, StreamToken, UserID
 from synapse.visibility import filter_events_for_client
@@ -32,6 +33,13 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class ThreadsListInclude(str, enum.Enum):
+    """Valid values for the 'include' flag of /threads."""
+
+    all = "all"
+    participated = "participated"
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _ThreadAggregation:
     # The latest event in the thread.
@@ -482,3 +490,79 @@ class RelationsHandler:
             results.setdefault(event_id, BundledAggregations()).replace = edit
 
         return results
+
+    async def get_threads(
+        self,
+        requester: Requester,
+        room_id: str,
+        include: ThreadsListInclude,
+        limit: int = 5,
+        from_token: Optional[ThreadsNextBatch] = None,
+    ) -> JsonDict:
+        """Get related events of a event, ordered by topological ordering.
+
+        Args:
+            requester: The user requesting the relations.
+            room_id: The room the event belongs to.
+            include: One of "all" or "participated" to indicate which threads should
+                be returned.
+            limit: Only fetch the most recent `limit` events.
+            from_token: Fetch rows from the given token, or from the start if None.
+
+        Returns:
+            The pagination chunk.
+        """
+
+        user_id = requester.user.to_string()
+
+        # TODO Properly handle a user leaving a room.
+        (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
+            room_id, requester, allow_departed_users=True
+        )
+
+        # Note that ignored users are not passed into get_relations_for_event
+        # below. Ignored users are handled in filter_events_for_client (and by
+        # not passing them in here we should get a better cache hit rate).
+        thread_roots, next_batch = await self._main_store.get_threads(
+            room_id=room_id, limit=limit, from_token=from_token
+        )
+
+        events = await self._main_store.get_events_as_list(thread_roots)
+
+        if include == ThreadsListInclude.participated:
+            # Pre-seed thread participation with whether the requester sent the event.
+            participated = {event.event_id: event.sender == user_id for event in events}
+            # For events the requester did not send, check the database for whether
+            # the requester sent a threaded reply.
+            participated.update(
+                await self._main_store.get_threads_participated(
+                    [eid for eid, p in participated.items() if not p],
+                    user_id,
+                )
+            )
+
+            # Limit the returned threads to those the user has participated in.
+            events = [event for event in events if participated[event.event_id]]
+
+        events = await filter_events_for_client(
+            self._storage_controllers,
+            user_id,
+            events,
+            is_peeking=(member_event_id is None),
+        )
+
+        aggregations = await self.get_bundled_aggregations(
+            events, requester.user.to_string()
+        )
+
+        now = self._clock.time_msec()
+        serialized_events = self._event_serializer.serialize_events(
+            events, now, bundle_aggregations=aggregations
+        )
+
+        return_value: JsonDict = {"chunk": serialized_events}
+
+        if next_batch:
+            return_value["next_batch"] = str(next_batch)
+
+        return return_value
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index b31ce5a0d3..d1aa1947a5 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -13,12 +13,15 @@
 # limitations under the License.
 
 import logging
+import re
 from typing import TYPE_CHECKING, Optional, Tuple
 
+from synapse.handlers.relations import ThreadsListInclude
 from synapse.http.server import HttpServer
-from synapse.http.servlet import RestServlet
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.rest.client._base import client_patterns
+from synapse.storage.databases.main.relations import ThreadsNextBatch
 from synapse.streams.config import PaginationConfig
 from synapse.types import JsonDict
 
@@ -78,5 +81,50 @@ class RelationPaginationServlet(RestServlet):
         return 200, result
 
 
+class ThreadsServlet(RestServlet):
+    PATTERNS = (
+        re.compile(
+            "^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P<room_id>[^/]*)/threads"
+        ),
+    )
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastores().main
+        self._relations_handler = hs.get_relations_handler()
+
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+
+        limit = parse_integer(request, "limit", default=5)
+        from_token_str = parse_string(request, "from")
+        include = parse_string(
+            request,
+            "include",
+            default=ThreadsListInclude.all.value,
+            allowed_values=[v.value for v in ThreadsListInclude],
+        )
+
+        # Return the relations
+        from_token = None
+        if from_token_str:
+            from_token = ThreadsNextBatch.from_string(from_token_str)
+
+        result = await self._relations_handler.get_threads(
+            requester=requester,
+            room_id=room_id,
+            include=ThreadsListInclude(include),
+            limit=limit,
+            from_token=from_token,
+        )
+
+        return 200, result
+
+
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     RelationPaginationServlet(hs).register(http_server)
+    if hs.config.experimental.msc3856_enabled:
+        ThreadsServlet(hs).register(http_server)
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index a9f25a5904..0ce3156c9c 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -259,6 +259,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
             self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
             self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,))
+            self._attempt_to_invalidate_cache("get_threads", (room_id,))
 
     async def invalidate_cache_and_stream(
         self, cache_name: str, keys: Tuple[Any, ...]
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 060fe71454..6698cbf664 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -35,7 +35,7 @@ import attr
 from prometheus_client import Counter
 
 import synapse.metrics
-from synapse.api.constants import EventContentFields, EventTypes
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, relation_from_event
@@ -1616,7 +1616,7 @@ class PersistEventsStore:
                 )
 
                 # Remove from relations table.
-                self._handle_redact_relations(txn, event.redacts)
+                self._handle_redact_relations(txn, event.room_id, event.redacts)
 
         # Update the event_forward_extremities, event_backward_extremities and
         # event_edges tables.
@@ -1866,6 +1866,34 @@ class PersistEventsStore:
             },
         )
 
+        if relation.rel_type == RelationTypes.THREAD:
+            # Upsert into the threads table, but only overwrite the value if the
+            # new event is of a later topological order OR if the topological
+            # ordering is equal, but the stream ordering is later.
+            sql = """
+            INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering)
+            VALUES (?, ?, ?, ?, ?)
+            ON CONFLICT (room_id, thread_id)
+            DO UPDATE SET
+                latest_event_id = excluded.latest_event_id,
+                topological_ordering = excluded.topological_ordering,
+                stream_ordering = excluded.stream_ordering
+            WHERE
+                threads.topological_ordering <= excluded.topological_ordering AND
+                threads.stream_ordering < excluded.stream_ordering
+            """
+
+            txn.execute(
+                sql,
+                (
+                    event.room_id,
+                    relation.parent_id,
+                    event.event_id,
+                    event.depth,
+                    event.internal_metadata.stream_ordering,
+                ),
+            )
+
     def _handle_insertion_event(
         self, txn: LoggingTransaction, event: EventBase
     ) -> None:
@@ -1989,13 +2017,14 @@ class PersistEventsStore:
         txn.execute(sql, (batch_id,))
 
     def _handle_redact_relations(
-        self, txn: LoggingTransaction, redacted_event_id: str
+        self, txn: LoggingTransaction, room_id: str, redacted_event_id: str
     ) -> None:
         """Handles receiving a redaction and checking whether the redacted event
         has any relations which must be removed from the database.
 
         Args:
             txn
+            room_id: The room ID of the event that was redacted.
             redacted_event_id: The event that was redacted.
         """
 
@@ -2024,6 +2053,9 @@ class PersistEventsStore:
             self.store._invalidate_cache_and_stream(
                 txn, self.store.get_thread_participated, (redacted_relates_to,)
             )
+            self.store._invalidate_cache_and_stream(
+                txn, self.store.get_threads, (room_id,)
+            )
 
         self.db_pool.simple_delete_txn(
             txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index e7fbf950e6..ac9b96ab44 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -14,6 +14,7 @@
 
 import logging
 from typing import (
+    TYPE_CHECKING,
     Collection,
     Dict,
     FrozenSet,
@@ -29,18 +30,47 @@ from typing import (
 import attr
 
 from synapse.api.constants import MAIN_TIMELINE, RelationTypes
+from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+    make_in_list_sql_clause,
+)
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.engines import PostgresEngine
 from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
+class ThreadsNextBatch:
+    topological_ordering: int
+    stream_ordering: int
+
+    def __str__(self) -> str:
+        return f"{self.topological_ordering}_{self.stream_ordering}"
+
+    @classmethod
+    def from_string(cls, string: str) -> "ThreadsNextBatch":
+        """
+        Creates a ThreadsNextBatch from its textual representation.
+        """
+        try:
+            keys = (int(s) for s in string.split("_"))
+            return cls(*keys)
+        except Exception:
+            raise SynapseError(400, "Invalid threads token")
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class _RelatedEvent:
     """
     Contains enough information about a related event in order to properly filter
@@ -56,6 +86,76 @@ class _RelatedEvent:
 
 
 class RelationsWorkerStore(SQLBaseStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        self.db_pool.updates.register_background_update_handler(
+            "threads_backfill", self._backfill_threads
+        )
+
+    async def _backfill_threads(self, progress: JsonDict, batch_size: int) -> int:
+        """Backfill the threads table."""
+
+        def threads_backfill_txn(txn: LoggingTransaction) -> int:
+            last_thread_id = progress.get("last_thread_id", "")
+
+            # Get the latest event in each thread by topo ordering / stream ordering.
+            #
+            # Note that the MAX(event_id) is needed to abide by the rules of group by,
+            # but doesn't actually do anything since there should only be a single event
+            # ID per topo/stream ordering pair.
+            sql = f"""
+            SELECT room_id, relates_to_id, MAX(topological_ordering), MAX(stream_ordering), MAX(event_id)
+            FROM event_relations
+            INNER JOIN events USING (event_id)
+            WHERE
+                relates_to_id > ? AND
+                relation_type = '{RelationTypes.THREAD}'
+            GROUP BY room_id, relates_to_id
+            ORDER BY relates_to_id
+            LIMIT ?
+            """
+            txn.execute(sql, (last_thread_id, batch_size))
+
+            # No more rows to process.
+            rows = txn.fetchall()
+            if not rows:
+                return 0
+
+            # Insert the rows into the threads table. If a matching thread already exists,
+            # assume it is from a newer event.
+            sql = """
+            INSERT INTO threads (room_id, thread_id, topological_ordering, stream_ordering, latest_event_id)
+            VALUES %s
+            ON CONFLICT (room_id, thread_id)
+            DO NOTHING
+            """
+            if isinstance(txn.database_engine, PostgresEngine):
+                txn.execute_values(sql % ("?",), rows, fetch=False)
+            else:
+                txn.execute_batch(sql % ("?, ?, ?, ?, ?",), rows)
+
+            # Mark the progress.
+            self.db_pool.updates._background_update_progress_txn(
+                txn, "threads_backfill", {"last_thread_id": rows[-1][1]}
+            )
+
+            return txn.rowcount
+
+        result = await self.db_pool.runInteraction(
+            "threads_backfill", threads_backfill_txn
+        )
+
+        if not result:
+            await self.db_pool.updates._end_background_update("threads_backfill")
+
+        return result
+
     @cached(uncached_args=("event",), tree=True)
     async def get_relations_for_event(
         self,
@@ -776,6 +876,70 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
         )
 
+    @cached(tree=True)
+    async def get_threads(
+        self,
+        room_id: str,
+        limit: int = 5,
+        from_token: Optional[ThreadsNextBatch] = None,
+    ) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
+        """Get a list of thread IDs, ordered by topological ordering of their
+        latest reply.
+
+        Args:
+            room_id: The room the event belongs to.
+            limit: Only fetch the most recent `limit` threads.
+            from_token: Fetch rows from a previous next_batch, or from the start if None.
+
+        Returns:
+            A tuple of:
+                A list of thread root event IDs.
+
+                The next_batch, if one exists.
+        """
+        # Generate the pagination clause, if necessary.
+        #
+        # Find any threads where the latest reply is equal / before the last
+        # thread's topo ordering and earlier in stream ordering.
+        pagination_clause = ""
+        pagination_args: tuple = ()
+        if from_token:
+            pagination_clause = "AND topological_ordering <= ? AND stream_ordering < ?"
+            pagination_args = (
+                from_token.topological_ordering,
+                from_token.stream_ordering,
+            )
+
+        sql = f"""
+            SELECT thread_id, topological_ordering, stream_ordering
+            FROM threads
+            WHERE
+                room_id = ?
+                {pagination_clause}
+            ORDER BY topological_ordering DESC, stream_ordering DESC
+            LIMIT ?
+        """
+
+        def _get_threads_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
+            txn.execute(sql, (room_id, *pagination_args, limit + 1))
+
+            rows = cast(List[Tuple[str, int, int]], txn.fetchall())
+            thread_ids = [r[0] for r in rows]
+
+            # If there are more events, generate the next pagination key from the
+            # last thread which will be returned.
+            next_token = None
+            if len(thread_ids) > limit:
+                last_topo_id = rows[-2][1]
+                last_stream_id = rows[-2][2]
+                next_token = ThreadsNextBatch(last_topo_id, last_stream_id)
+
+            return thread_ids[:limit], next_token
+
+        return await self.db_pool.runInteraction("get_threads", _get_threads_txn)
+
     @cached()
     async def get_thread_id(self, event_id: str) -> str:
         """
diff --git a/synapse/storage/schema/main/delta/73/09threads_table.sql b/synapse/storage/schema/main/delta/73/09threads_table.sql
new file mode 100644
index 0000000000..aa7c5e9a2e
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/09threads_table.sql
@@ -0,0 +1,30 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE threads (
+    room_id TEXT NOT NULL,
+    -- The event ID of the root event in the thread.
+    thread_id TEXT NOT NULL,
+    -- The latest event ID and corresponding topo / stream ordering.
+    latest_event_id TEXT NOT NULL,
+    topological_ordering BIGINT NOT NULL,
+    stream_ordering BIGINT NOT NULL,
+    CONSTRAINT threads_uniqueness UNIQUE (room_id, thread_id)
+);
+
+CREATE INDEX threads_ordering_idx ON threads(room_id, topological_ordering, stream_ordering);
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (7309, 'threads_backfill', '{}');
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 988cdb746d..d595295e2c 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1707,3 +1707,154 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
             relations[RelationTypes.THREAD]["latest_event"]["event_id"],
             related_event_id,
         )
+
+
+class ThreadsTestCase(BaseRelationsTestCase):
+    @unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
+    def test_threads(self) -> None:
+        """Create threads and ensure the ordering is due to their latest event."""
+        # Create 2 threads.
+        thread_1 = self.parent_id
+        res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
+        thread_2 = res["event_id"]
+
+        self._send_relation(RelationTypes.THREAD, "m.room.test")
+        self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
+
+        # Request the threads in the room.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+        self.assertEqual(thread_roots, [thread_2, thread_1])
+
+        # Update the first thread, the ordering should swap.
+        self._send_relation(RelationTypes.THREAD, "m.room.test")
+
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+        self.assertEqual(thread_roots, [thread_1, thread_2])
+
+    @unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
+    def test_pagination(self) -> None:
+        """Create threads and paginate through them."""
+        # Create 2 threads.
+        thread_1 = self.parent_id
+        res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
+        thread_2 = res["event_id"]
+
+        self._send_relation(RelationTypes.THREAD, "m.room.test")
+        self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
+
+        # Request the threads in the room.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+        self.assertEqual(thread_roots, [thread_2])
+
+        # Make sure next_batch has something in it that looks like it could be a
+        # valid token.
+        next_batch = channel.json_body.get("next_batch")
+        self.assertIsInstance(next_batch, str, channel.json_body)
+
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1&from={next_batch}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+        self.assertEqual(thread_roots, [thread_1], channel.json_body)
+
+        self.assertNotIn("next_batch", channel.json_body, channel.json_body)
+
+    @unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
+    def test_include(self) -> None:
+        """Filtering threads to all or participated in should work."""
+        # Thread 1 has the user as the root event.
+        thread_1 = self.parent_id
+        self._send_relation(
+            RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+        )
+
+        # Thread 2 has the user replying.
+        res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token)
+        thread_2 = res["event_id"]
+        self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
+
+        # Thread 3 has the user not participating in.
+        res = self.helper.send(self.room, body="Another thread!", tok=self.user2_token)
+        thread_3 = res["event_id"]
+        self._send_relation(
+            RelationTypes.THREAD,
+            "m.room.test",
+            access_token=self.user2_token,
+            parent_id=thread_3,
+        )
+
+        # All threads in the room.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+        self.assertEqual(
+            thread_roots, [thread_3, thread_2, thread_1], channel.json_body
+        )
+
+        # Only participated threads.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?include=participated",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+        self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body)
+
+    @unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
+    def test_ignored_user(self) -> None:
+        """Events from ignored users should be ignored."""
+        # Thread 1 has a reply from an ignored user.
+        thread_1 = self.parent_id
+        self._send_relation(
+            RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+        )
+
+        # Thread 2 is created by an ignored user.
+        res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token)
+        thread_2 = res["event_id"]
+        self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
+
+        # Ignore user2.
+        self.get_success(
+            self.store.add_account_data_for_user(
+                self.user_id,
+                AccountDataTypes.IGNORED_USER_LIST,
+                {"ignored_users": {self.user2_id: {}}},
+            )
+        )
+
+        # Only thread 1 is returned.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+        self.assertEqual(thread_roots, [thread_1], channel.json_body)