summary refs log tree commit diff
path: root/synapse/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client')
-rw-r--r--synapse/rest/client/login.py88
-rw-r--r--synapse/rest/client/register.py20
-rw-r--r--synapse/rest/client/relations.py16
-rw-r--r--synapse/rest/client/room.py67
-rw-r--r--synapse/rest/client/sync.py6
5 files changed, 136 insertions, 61 deletions
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 67e03dca04..f9994658c4 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -14,7 +14,17 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    Dict,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
 
 from typing_extensions import TypedDict
 
@@ -28,7 +38,6 @@ from synapse.http.server import HttpServer, finish_request
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
-    parse_boolean,
     parse_bytes_from_args,
     parse_json_object_from_request,
     parse_string,
@@ -63,7 +72,7 @@ class LoginRestServlet(RestServlet):
     JWT_TYPE_DEPRECATED = "m.login.jwt"
     APPSERVICE_TYPE = "m.login.application_service"
     APPSERVICE_TYPE_UNSTABLE = "uk.half-shot.msc2778.login.application_service"
-    REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token"
+    REFRESH_TOKEN_PARAM = "refresh_token"
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
@@ -81,7 +90,7 @@ class LoginRestServlet(RestServlet):
         self.saml2_enabled = hs.config.saml2.saml2_enabled
         self.cas_enabled = hs.config.cas.cas_enabled
         self.oidc_enabled = hs.config.oidc.oidc_enabled
-        self._msc2918_enabled = (
+        self._refresh_tokens_enabled = (
             hs.config.registration.refreshable_access_token_lifetime is not None
         )
 
@@ -154,14 +163,16 @@ class LoginRestServlet(RestServlet):
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
         login_submission = parse_json_object_from_request(request)
 
-        if self._msc2918_enabled:
-            # Check if this login should also issue a refresh token, as per
-            # MSC2918
-            should_issue_refresh_token = parse_boolean(
-                request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False
-            )
-        else:
-            should_issue_refresh_token = False
+        # Check to see if the client requested a refresh token.
+        client_requested_refresh_token = login_submission.get(
+            LoginRestServlet.REFRESH_TOKEN_PARAM, False
+        )
+        if not isinstance(client_requested_refresh_token, bool):
+            raise SynapseError(400, "`refresh_token` should be true or false.")
+
+        should_issue_refresh_token = (
+            self._refresh_tokens_enabled and client_requested_refresh_token
+        )
 
         try:
             if login_submission["type"] in (
@@ -291,6 +302,7 @@ class LoginRestServlet(RestServlet):
         ratelimit: bool = True,
         auth_provider_id: Optional[str] = None,
         should_issue_refresh_token: bool = False,
+        auth_provider_session_id: Optional[str] = None,
     ) -> LoginResponse:
         """Called when we've successfully authed the user and now need to
         actually login them in (e.g. create devices). This gets called on
@@ -306,10 +318,10 @@ class LoginRestServlet(RestServlet):
             create_non_existent_users: Whether to create the user if they don't
                 exist. Defaults to False.
             ratelimit: Whether to ratelimit the login request.
-            auth_provider_id: The SSO IdP the user used, if any (just used for the
-                prometheus metrics).
+            auth_provider_id: The SSO IdP the user used, if any.
             should_issue_refresh_token: True if this login should issue
                 a refresh token alongside the access token.
+            auth_provider_session_id: The session ID got during login from the SSO IdP.
 
         Returns:
             result: Dictionary of account information after successful login.
@@ -342,6 +354,7 @@ class LoginRestServlet(RestServlet):
             initial_display_name,
             auth_provider_id=auth_provider_id,
             should_issue_refresh_token=should_issue_refresh_token,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
         result = LoginResponse(
@@ -387,6 +400,7 @@ class LoginRestServlet(RestServlet):
             self.auth_handler._sso_login_callback,
             auth_provider_id=res.auth_provider_id,
             should_issue_refresh_token=should_issue_refresh_token,
+            auth_provider_session_id=res.auth_provider_session_id,
         )
 
     async def _do_jwt_login(
@@ -448,9 +462,7 @@ def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
 
 
 class RefreshTokenServlet(RestServlet):
-    PATTERNS = client_patterns(
-        "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True
-    )
+    PATTERNS = (re.compile("^/_matrix/client/v1/refresh$"),)
 
     def __init__(self, hs: "HomeServer"):
         self._auth_handler = hs.get_auth_handler()
@@ -458,6 +470,7 @@ class RefreshTokenServlet(RestServlet):
         self.refreshable_access_token_lifetime = (
             hs.config.registration.refreshable_access_token_lifetime
         )
+        self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         refresh_submission = parse_json_object_from_request(request)
@@ -467,29 +480,40 @@ class RefreshTokenServlet(RestServlet):
         if not isinstance(token, str):
             raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
 
-        valid_until_ms = (
-            self._clock.time_msec() + self.refreshable_access_token_lifetime
-        )
-        access_token, refresh_token = await self._auth_handler.refresh_token(
-            token, valid_until_ms
-        )
-        expires_in_ms = valid_until_ms - self._clock.time_msec()
-        return (
-            200,
-            {
-                "access_token": access_token,
-                "refresh_token": refresh_token,
-                "expires_in_ms": expires_in_ms,
-            },
+        now = self._clock.time_msec()
+        access_valid_until_ms = None
+        if self.refreshable_access_token_lifetime is not None:
+            access_valid_until_ms = now + self.refreshable_access_token_lifetime
+        refresh_valid_until_ms = None
+        if self.refresh_token_lifetime is not None:
+            refresh_valid_until_ms = now + self.refresh_token_lifetime
+
+        (
+            access_token,
+            refresh_token,
+            actual_access_token_expiry,
+        ) = await self._auth_handler.refresh_token(
+            token, access_valid_until_ms, refresh_valid_until_ms
         )
 
+        response: Dict[str, Union[str, int]] = {
+            "access_token": access_token,
+            "refresh_token": refresh_token,
+        }
+
+        # expires_in_ms is only present if the token expires
+        if actual_access_token_expiry is not None:
+            response["expires_in_ms"] = actual_access_token_expiry - now
+
+        return 200, response
+
 
 class SsoRedirectServlet(RestServlet):
     PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
         re.compile(
             "^"
             + CLIENT_API_PREFIX
-            + "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
+            + "/(r0|v3)/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
         )
     ]
 
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index d2b11e39d9..8b56c76aed 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -41,7 +41,6 @@ from synapse.http.server import HttpServer, finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
-    parse_boolean,
     parse_json_object_from_request,
     parse_string,
 )
@@ -420,7 +419,7 @@ class RegisterRestServlet(RestServlet):
         self.password_policy_handler = hs.get_password_policy_handler()
         self.clock = hs.get_clock()
         self._registration_enabled = self.hs.config.registration.enable_registration
-        self._msc2918_enabled = (
+        self._refresh_tokens_enabled = (
             hs.config.registration.refreshable_access_token_lifetime is not None
         )
 
@@ -446,14 +445,15 @@ class RegisterRestServlet(RestServlet):
                 f"Do not understand membership kind: {kind}",
             )
 
-        if self._msc2918_enabled:
-            # Check if this registration should also issue a refresh token, as
-            # per MSC2918
-            should_issue_refresh_token = parse_boolean(
-                request, name="org.matrix.msc2918.refresh_token", default=False
-            )
-        else:
-            should_issue_refresh_token = False
+        # Check if the clients wishes for this registration to issue a refresh
+        # token.
+        client_requested_refresh_tokens = body.get("refresh_token", False)
+        if not isinstance(client_requested_refresh_tokens, bool):
+            raise SynapseError(400, "`refresh_token` should be true or false.")
+
+        should_issue_refresh_token = (
+            self._refresh_tokens_enabled and client_requested_refresh_tokens
+        )
 
         # Pull out the provided username and do basic sanity checks early since
         # the auth layer will store these in sessions.
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 45e9f1dd90..fc4e6921c5 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -224,18 +224,14 @@ class RelationPaginationServlet(RestServlet):
         )
 
         now = self.clock.time_msec()
-        # We set bundle_relations to False when retrieving the original
-        # event because we want the content before relations were applied to
-        # it.
+        # Do not bundle aggregations when retrieving the original event because
+        # we want the content before relations are applied to it.
         original_event = await self._event_serializer.serialize_event(
-            event, now, bundle_relations=False
-        )
-        # Similarly, we don't allow relations to be applied to relations, so we
-        # return the original relations without any aggregations on top of them
-        # here.
-        serialized_events = await self._event_serializer.serialize_events(
-            events, now, bundle_relations=False
+            event, now, bundle_aggregations=False
         )
+        # The relations returned for the requested event do include their
+        # bundled aggregations.
+        serialized_events = await self._event_serializer.serialize_events(events, now)
 
         return_value = pagination_chunk.to_dict()
         return_value["chunk"] = serialized_events
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 955d4e8641..f48e2e6ca2 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -716,10 +716,7 @@ class RoomEventContextServlet(RestServlet):
             results["events_after"], time_now
         )
         results["state"] = await self._event_serializer.serialize_events(
-            results["state"],
-            time_now,
-            # No need to bundle aggregations for state events
-            bundle_relations=False,
+            results["state"], time_now
         )
 
         return 200, results
@@ -1070,6 +1067,62 @@ def register_txn_path(
         )
 
 
+class TimestampLookupRestServlet(RestServlet):
+    """
+    API endpoint to fetch the `event_id` of the closest event to the given
+    timestamp (`ts` query parameter) in the given direction (`dir` query
+    parameter).
+
+    Useful for cases like jump to date so you can start paginating messages from
+    a given date in the archive.
+
+    `ts` is a timestamp in milliseconds where we will find the closest event in
+    the given direction.
+
+    `dir` can be `f` or `b` to indicate forwards and backwards in time from the
+    given timestamp.
+
+    GET /_matrix/client/unstable/org.matrix.msc3030/rooms/<roomID>/timestamp_to_event?ts=<timestamp>&dir=<direction>
+    {
+        "event_id": ...
+    }
+    """
+
+    PATTERNS = (
+        re.compile(
+            "^/_matrix/client/unstable/org.matrix.msc3030"
+            "/rooms/(?P<room_id>[^/]*)/timestamp_to_event$"
+        ),
+    )
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._auth = hs.get_auth()
+        self._store = hs.get_datastore()
+        self.timestamp_lookup_handler = hs.get_timestamp_lookup_handler()
+
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
+        requester = await self._auth.get_user_by_req(request)
+        await self._auth.check_user_in_room(room_id, requester.user.to_string())
+
+        timestamp = parse_integer(request, "ts", required=True)
+        direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
+
+        (
+            event_id,
+            origin_server_ts,
+        ) = await self.timestamp_lookup_handler.get_event_for_timestamp(
+            requester, room_id, timestamp, direction
+        )
+
+        return 200, {
+            "event_id": event_id,
+            "origin_server_ts": origin_server_ts,
+        }
+
+
 class RoomSpaceSummaryRestServlet(RestServlet):
     PATTERNS = (
         re.compile(
@@ -1140,7 +1193,7 @@ class RoomSpaceSummaryRestServlet(RestServlet):
 class RoomHierarchyRestServlet(RestServlet):
     PATTERNS = (
         re.compile(
-            "^/_matrix/client/unstable/org.matrix.msc2946"
+            "^/_matrix/client/(v1|unstable/org.matrix.msc2946)"
             "/rooms/(?P<room_id>[^/]*)/hierarchy$"
         ),
     )
@@ -1168,7 +1221,7 @@ class RoomHierarchyRestServlet(RestServlet):
             )
 
         return 200, await self._room_summary_handler.get_room_hierarchy(
-            requester.user.to_string(),
+            requester,
             room_id,
             suggested_only=parse_boolean(request, "suggested_only", default=False),
             max_depth=max_depth,
@@ -1239,6 +1292,8 @@ def register_servlets(
     RoomAliasListServlet(hs).register(http_server)
     SearchRestServlet(hs).register(http_server)
     RoomCreateRestServlet(hs).register(http_server)
+    if hs.config.experimental.msc3030_enabled:
+        TimestampLookupRestServlet(hs).register(http_server)
 
     # Some servlets only get registered for the main process.
     if not is_worker:
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index b6a2485732..88e4f5e063 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -520,9 +520,9 @@ class SyncRestServlet(RestServlet):
             return self._event_serializer.serialize_events(
                 events,
                 time_now=time_now,
-                # We don't bundle "live" events, as otherwise clients
-                # will end up double counting annotations.
-                bundle_relations=False,
+                # Don't bother to bundle aggregations if the timeline is unlimited,
+                # as clients will have all the necessary information.
+                bundle_aggregations=room.timeline.limited,
                 token_id=token_id,
                 event_format=event_formatter,
                 only_event_fields=only_fields,