diff options
Diffstat (limited to 'synapse/rest/client/relations.py')
-rw-r--r-- | synapse/rest/client/relations.py | 82 |
1 files changed, 50 insertions, 32 deletions
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 7a25de5c85..9dd59196d9 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -13,13 +13,17 @@ # limitations under the License. import logging +import re from typing import TYPE_CHECKING, Optional, Tuple +from synapse.handlers.relations import ThreadsListInclude from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.types import JsonDict, StreamToken +from synapse.storage.databases.main.relations import ThreadsNextBatch +from synapse.streams.config import PaginationConfig +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -41,9 +45,8 @@ class RelationPaginationServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastores().main + self._store = hs.get_datastores().main self._relations_handler = hs.get_relations_handler() - self._msc3715_enabled = hs.config.experimental.msc3715_enabled async def on_GET( self, @@ -55,49 +58,63 @@ class RelationPaginationServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - limit = parse_integer(request, "limit", default=5) - # Fetch the direction parameter, if provided. - # - # TODO Use PaginationConfig.from_request when the unstable parameter is - # no longer needed. - direction = parse_string(request, "dir", allowed_values=["f", "b"]) - if direction is None: - if self._msc3715_enabled: - direction = parse_string( - request, - "org.matrix.msc3715.dir", - default="b", - allowed_values=["f", "b"], - ) - else: - direction = "b" - from_token_str = parse_string(request, "from") - to_token_str = parse_string(request, "to") - - # Return the relations - from_token = None - if from_token_str: - from_token = await StreamToken.from_string(self.store, from_token_str) - to_token = None - if to_token_str: - to_token = await StreamToken.from_string(self.store, to_token_str) + pagination_config = await PaginationConfig.from_request( + self._store, request, default_limit=5, default_dir="b" + ) # The unstable version of this API returns an extra field for client # compatibility, see https://github.com/matrix-org/synapse/issues/12930. assert request.path is not None include_original_event = request.path.startswith(b"/_matrix/client/unstable/") + # Return the relations result = await self._relations_handler.get_relations( requester=requester, event_id=parent_id, room_id=room_id, + pagin_config=pagination_config, + include_original_event=include_original_event, relation_type=relation_type, event_type=event_type, + ) + + return 200, result + + +class ThreadsServlet(RestServlet): + PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/threads"),) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self._relations_handler = hs.get_relations_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + limit = parse_integer(request, "limit", default=5) + from_token_str = parse_string(request, "from") + include = parse_string( + request, + "include", + default=ThreadsListInclude.all.value, + allowed_values=[v.value for v in ThreadsListInclude], + ) + + # Return the relations + from_token = None + if from_token_str: + from_token = ThreadsNextBatch.from_string(from_token_str) + + result = await self._relations_handler.get_threads( + requester=requester, + room_id=room_id, + include=ThreadsListInclude(include), limit=limit, - direction=direction, from_token=from_token, - to_token=to_token, - include_original_event=include_original_event, ) return 200, result @@ -105,3 +122,4 @@ class RelationPaginationServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) + ThreadsServlet(hs).register(http_server) |