diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 9a244e0bc6..5bbc30bd4b 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -94,3 +94,6 @@ class ExperimentalConfig(Config):
# MSC3848: Introduce errcodes for specific event sending failures
self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)
+
+ # MSC3856: Threads list API
+ self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False)
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 72d25df8c8..1c92834b34 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import enum
import logging
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
@@ -31,6 +32,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class ThreadsListInclude(str, enum.Enum):
+ """Valid values for the 'include' flag of /threads."""
+
+ all = "all"
+ participated = "participated"
+
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
# The latest event in the thread.
@@ -482,3 +490,84 @@ class RelationsHandler:
results.setdefault(event_id, BundledAggregations()).replace = edit
return results
+
+ async def get_threads(
+ self,
+ requester: Requester,
+ room_id: str,
+ include: ThreadsListInclude,
+ limit: int = 5,
+ from_token: Optional[StreamToken] = None,
+ to_token: Optional[StreamToken] = None,
+ ) -> JsonDict:
+ """Get related events of a event, ordered by topological ordering.
+
+ Args:
+ requester: The user requesting the relations.
+ room_id: The room the event belongs to.
+ include: One of "all" or "participated" to indicate which threads should
+ be returned.
+ limit: Only fetch the most recent `limit` events.
+ from_token: Fetch rows from the given token, or from the start if None.
+ to_token: Fetch rows up to the given token, or up to the end if None.
+
+ Returns:
+ The pagination chunk.
+ """
+
+ user_id = requester.user.to_string()
+
+ # TODO Properly handle a user leaving a room.
+ (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
+ room_id, user_id, allow_departed_users=True
+ )
+
+ # Note that ignored users are not passed into get_relations_for_event
+ # below. Ignored users are handled in filter_events_for_client (and by
+ # not passing them in here we should get a better cache hit rate).
+ thread_roots, next_token = await self._main_store.get_threads(
+ room_id=room_id, limit=limit, from_token=from_token, to_token=to_token
+ )
+
+ events = await self._main_store.get_events_as_list(thread_roots)
+
+ if include == ThreadsListInclude.participated:
+ # Pre-seed thread participation with whether the requester sent the event.
+ participated = {event.event_id: event.sender == user_id for event in events}
+ # For events the requester did not send, check the database for whether
+ # the requester sent a threaded reply.
+ participated.update(
+ await self._main_store.get_threads_participated(
+ [eid for eid, p in participated.items() if not p],
+ user_id,
+ )
+ )
+
+ # Limit the returned threads to those the user has participated in.
+ events = [event for event in events if participated[event.event_id]]
+
+ events = await filter_events_for_client(
+ self._storage_controllers,
+ user_id,
+ events,
+ is_peeking=(member_event_id is None),
+ )
+
+ now = self._clock.time_msec()
+
+ aggregations = await self.get_bundled_aggregations(
+ events, requester.user.to_string()
+ )
+ serialized_events = self._event_serializer.serialize_events(
+ events, now, bundle_aggregations=aggregations
+ )
+
+ return_value: JsonDict = {"chunk": serialized_events}
+
+ if next_token:
+ return_value["next_batch"] = await next_token.to_string(self._main_store)
+
+ if from_token:
+ return_value["prev_batch"] = await from_token.to_string(self._main_store)
+
+ return return_value
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index ce97080013..d787aeaae1 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -13,8 +13,10 @@
# limitations under the License.
import logging
+import re
from typing import TYPE_CHECKING, Optional, Tuple
+from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
@@ -91,5 +93,55 @@ class RelationPaginationServlet(RestServlet):
return 200, result
+class ThreadsServlet(RestServlet):
+ PATTERNS = (
+ re.compile(
+ "^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P<room_id>[^/]*)/threads"
+ ),
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastores().main
+ self._relations_handler = hs.get_relations_handler()
+
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+
+ limit = parse_integer(request, "limit", default=5)
+ from_token_str = parse_string(request, "from")
+ to_token_str = parse_string(request, "to")
+ include = parse_string(
+ request,
+ "include",
+ default=ThreadsListInclude.all.value,
+ allowed_values=[v.value for v in ThreadsListInclude],
+ )
+
+ # Return the relations
+ from_token = None
+ if from_token_str:
+ from_token = await StreamToken.from_string(self.store, from_token_str)
+ to_token = None
+ if to_token_str:
+ to_token = await StreamToken.from_string(self.store, to_token_str)
+
+ result = await self._relations_handler.get_threads(
+ requester=requester,
+ room_id=room_id,
+ include=ThreadsListInclude(include),
+ limit=limit,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ return 200, result
+
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationPaginationServlet(hs).register(http_server)
+ if hs.config.experimental.msc3856_enabled:
+ ThreadsServlet(hs).register(http_server)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index da84820b76..2e53005daa 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1594,7 +1594,7 @@ class PersistEventsStore:
)
# Remove from relations table.
- self._handle_redact_relations(txn, event.redacts)
+ self._handle_redact_relations(txn, event.room_id, event.redacts)
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
@@ -1909,6 +1909,7 @@ class PersistEventsStore:
self.store.get_thread_participated.invalidate,
(relation.parent_id, event.sender),
)
+ txn.call_after(self.store.get_threads.invalidate, (event.room_id,))
def _handle_insertion_event(
self, txn: LoggingTransaction, event: EventBase
@@ -2033,13 +2034,14 @@ class PersistEventsStore:
txn.execute(sql, (batch_id,))
def _handle_redact_relations(
- self, txn: LoggingTransaction, redacted_event_id: str
+ self, txn: LoggingTransaction, room_id: str, redacted_event_id: str
) -> None:
"""Handles receiving a redaction and checking whether the redacted event
has any relations which must be removed from the database.
Args:
txn
+ room_id: The room ID of the event that was redacted.
redacted_event_id: The event that was redacted.
"""
@@ -2068,6 +2070,7 @@ class PersistEventsStore:
self.store._invalidate_cache_and_stream(
txn, self.store.get_thread_participated, (redacted_relates_to,)
)
+ txn.call_after(self.store.get_threads.invalidate, (room_id,))
self.store._invalidate_cache_and_stream(
txn,
self.store.get_mutual_event_relations_for_rel_type,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 7bd27790eb..57b2f7c188 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -814,6 +814,93 @@ class RelationsWorkerStore(SQLBaseStore):
"get_event_relations", _get_event_relations
)
+ @cached(tree=True)
+ async def get_threads(
+ self,
+ room_id: str,
+ limit: int = 5,
+ from_token: Optional[StreamToken] = None,
+ to_token: Optional[StreamToken] = None,
+ ) -> Tuple[List[str], Optional[StreamToken]]:
+ """Get a list of thread IDs, ordered by topological ordering of their
+ latest reply.
+
+ Args:
+ room_id: The room the event belongs to.
+ limit: Only fetch the most recent `limit` threads.
+ from_token: Fetch rows from the given token, or from the start if None.
+ to_token: Fetch rows up to the given token, or up to the end if None.
+
+ Returns:
+ A tuple of:
+ A list of thread root event IDs.
+
+ The next stream token, if one exists.
+ """
+ pagination_clause = generate_pagination_where_clause(
+ direction="b",
+ column_names=("topological_ordering", "stream_ordering"),
+ from_token=from_token.room_key.as_historical_tuple()
+ if from_token
+ else None,
+ to_token=to_token.room_key.as_historical_tuple() if to_token else None,
+ engine=self.database_engine,
+ )
+
+ if pagination_clause:
+ pagination_clause = "AND " + pagination_clause
+
+ sql = f"""
+ SELECT relates_to_id, MAX(topological_ordering), MAX(stream_ordering)
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ room_id = ? AND
+ relation_type = '{RelationTypes.THREAD}'
+ {pagination_clause}
+ GROUP BY relates_to_id
+ ORDER BY MAX(topological_ordering) DESC, MAX(stream_ordering) DESC
+ LIMIT ?
+ """
+
+ def _get_threads_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[str], Optional[StreamToken]]:
+ txn.execute(sql, [room_id, limit + 1])
+
+ last_topo_id = None
+ last_stream_id = None
+ thread_ids = []
+ for thread_id, topo_id, stream_id in txn:
+ thread_ids.append(thread_id)
+ last_topo_id = topo_id
+ last_stream_id = stream_id
+
+ # If there are more events, generate the next pagination key.
+ next_token = None
+ if len(thread_ids) > limit and last_topo_id and last_stream_id:
+ next_key = RoomStreamToken(last_topo_id, last_stream_id)
+ if from_token:
+ next_token = from_token.copy_and_replace(
+ StreamKeyType.ROOM, next_key
+ )
+ else:
+ next_token = StreamToken(
+ room_key=next_key,
+ presence_key=0,
+ typing_key=0,
+ receipt_key=0,
+ account_data_key=0,
+ push_rules_key=0,
+ to_device_key=0,
+ device_list_key=0,
+ groups_key=0,
+ )
+
+ return thread_ids[:limit], next_token
+
+ return await self.db_pool.runInteraction("get_threads", _get_threads_txn)
+
class RelationsStore(RelationsWorkerStore):
pass
|