summary refs log tree commit diff
path: root/synapse/rest/client/relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/relations.py')
-rw-r--r--synapse/rest/client/relations.py76
1 files changed, 53 insertions, 23 deletions
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py

index ce97080013..9dd59196d9 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py
@@ -13,13 +13,17 @@ # limitations under the License. import logging +import re from typing import TYPE_CHECKING, Optional, Tuple +from synapse.handlers.relations import ThreadsListInclude from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.types import JsonDict, StreamToken +from synapse.storage.databases.main.relations import ThreadsNextBatch +from synapse.streams.config import PaginationConfig +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -41,9 +45,8 @@ class RelationPaginationServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastores().main + self._store = hs.get_datastores().main self._relations_handler = hs.get_relations_handler() - self._msc3715_enabled = hs.config.experimental.msc3715_enabled async def on_GET( self, @@ -55,37 +58,63 @@ class RelationPaginationServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) + pagination_config = await PaginationConfig.from_request( + self._store, request, default_limit=5, default_dir="b" + ) + + # The unstable version of this API returns an extra field for client + # compatibility, see https://github.com/matrix-org/synapse/issues/12930. + assert request.path is not None + include_original_event = request.path.startswith(b"/_matrix/client/unstable/") + + # Return the relations + result = await self._relations_handler.get_relations( + requester=requester, + event_id=parent_id, + room_id=room_id, + pagin_config=pagination_config, + include_original_event=include_original_event, + relation_type=relation_type, + event_type=event_type, + ) + + return 200, result + + +class ThreadsServlet(RestServlet): + PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/threads"),) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self._relations_handler = hs.get_relations_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + limit = parse_integer(request, "limit", default=5) - if self._msc3715_enabled: - direction = parse_string( - request, - "org.matrix.msc3715.dir", - default="b", - allowed_values=["f", "b"], - ) - else: - direction = "b" from_token_str = parse_string(request, "from") - to_token_str = parse_string(request, "to") + include = parse_string( + request, + "include", + default=ThreadsListInclude.all.value, + allowed_values=[v.value for v in ThreadsListInclude], + ) # Return the relations from_token = None if from_token_str: - from_token = await StreamToken.from_string(self.store, from_token_str) - to_token = None - if to_token_str: - to_token = await StreamToken.from_string(self.store, to_token_str) + from_token = ThreadsNextBatch.from_string(from_token_str) - result = await self._relations_handler.get_relations( + result = await self._relations_handler.get_threads( requester=requester, - event_id=parent_id, room_id=room_id, - relation_type=relation_type, - event_type=event_type, + include=ThreadsListInclude(include), limit=limit, - direction=direction, from_token=from_token, - to_token=to_token, ) return 200, result @@ -93,3 +122,4 @@ class RelationPaginationServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) + ThreadsServlet(hs).register(http_server)