summary refs log tree commit diff
path: root/synapse/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client')
-rw-r--r--synapse/rest/client/v1/login.py33
-rw-r--r--synapse/rest/client/v1/room.py167
-rw-r--r--synapse/rest/client/v2_alpha/account_validity.py7
-rw-r--r--synapse/rest/client/v2_alpha/sendtodevice.py2
4 files changed, 147 insertions, 62 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index cbcb60fe31..11567bf32c 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -44,19 +44,14 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-LoginResponse = TypedDict(
-    "LoginResponse",
-    {
-        "user_id": str,
-        "access_token": str,
-        "home_server": str,
-        "expires_in_ms": Optional[int],
-        "refresh_token": Optional[str],
-        "device_id": str,
-        "well_known": Optional[Dict[str, Any]],
-    },
-    total=False,
-)
+class LoginResponse(TypedDict, total=False):
+    user_id: str
+    access_token: str
+    home_server: str
+    expires_in_ms: Optional[int]
+    refresh_token: Optional[str]
+    device_id: str
+    well_known: Optional[Dict[str, Any]]
 
 
 class LoginRestServlet(RestServlet):
@@ -121,7 +116,7 @@ class LoginRestServlet(RestServlet):
             flows.append({"type": LoginRestServlet.CAS_TYPE})
 
         if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
-            sso_flow = {
+            sso_flow: JsonDict = {
                 "type": LoginRestServlet.SSO_TYPE,
                 "identity_providers": [
                     _get_auth_flow_dict_for_idp(
@@ -129,7 +124,7 @@ class LoginRestServlet(RestServlet):
                     )
                     for idp in self._sso_handler.get_identity_providers().values()
                 ],
-            }  # type: JsonDict
+            }
 
             if self._msc2858_enabled:
                 # backwards-compatibility support for clients which don't
@@ -150,9 +145,7 @@ class LoginRestServlet(RestServlet):
             # login flow types returned.
             flows.append({"type": LoginRestServlet.TOKEN_TYPE})
 
-        flows.extend(
-            ({"type": t} for t in self.auth_handler.get_supported_login_types())
-        )
+        flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types())
 
         flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
 
@@ -447,7 +440,7 @@ def _get_auth_flow_dict_for_idp(
         use_unstable_brands: whether we should use brand identifiers suitable
            for the unstable API
     """
-    e = {"id": idp.idp_id, "name": idp.idp_name}  # type: JsonDict
+    e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name}
     if idp.idp_icon:
         e["icon"] = idp.idp_icon
     if idp.idp_brand:
@@ -561,7 +554,7 @@ class SsoRedirectServlet(RestServlet):
             finish_request(request)
             return
 
-        args = request.args  # type: Dict[bytes, List[bytes]]  # type: ignore
+        args: Dict[bytes, List[bytes]] = request.args  # type: ignore
         client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True)
         sso_url = await self._sso_handler.handle_redirect_request(
             request,
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 92ebe838fd..31a1193cd3 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -29,6 +29,7 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.api.filtering import Filter
+from synapse.appservice import ApplicationService
 from synapse.events.utils import format_event_for_client_v2
 from synapse.http.servlet import (
     RestServlet,
@@ -47,11 +48,13 @@ from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
 from synapse.types import (
     JsonDict,
+    Requester,
     RoomAlias,
     RoomID,
     StreamToken,
     ThirdPartyInstanceID,
     UserID,
+    create_requester,
 )
 from synapse.util import json_decoder
 from synapse.util.stringutils import parse_and_validate_server_name, random_string
@@ -309,7 +312,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
         self.room_member_handler = hs.get_room_member_handler()
         self.auth = hs.get_auth()
 
-    async def inherit_depth_from_prev_ids(self, prev_event_ids) -> int:
+    async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int:
         (
             most_recent_prev_event_id,
             most_recent_prev_event_depth,
@@ -349,6 +352,54 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
 
         return depth
 
+    def _create_insertion_event_dict(
+        self, sender: str, room_id: str, origin_server_ts: int
+    ):
+        """Creates an event dict for an "insertion" event with the proper fields
+        and a random chunk ID.
+
+        Args:
+            sender: The event author MXID
+            room_id: The room ID that the event belongs to
+            origin_server_ts: Timestamp when the event was sent
+
+        Returns:
+            Tuple of event ID and stream ordering position
+        """
+
+        next_chunk_id = random_string(8)
+        insertion_event = {
+            "type": EventTypes.MSC2716_INSERTION,
+            "sender": sender,
+            "room_id": room_id,
+            "content": {
+                EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id,
+                EventContentFields.MSC2716_HISTORICAL: True,
+            },
+            "origin_server_ts": origin_server_ts,
+        }
+
+        return insertion_event
+
+    async def _create_requester_for_user_id_from_app_service(
+        self, user_id: str, app_service: ApplicationService
+    ) -> Requester:
+        """Creates a new requester for the given user_id
+        and validates that the app service is allowed to control
+        the given user.
+
+        Args:
+            user_id: The author MXID that the app service is controlling
+            app_service: The app service that controls the user
+
+        Returns:
+            Requester object
+        """
+
+        await self.auth.validate_appservice_can_control_user_id(app_service, user_id)
+
+        return create_requester(user_id, app_service=app_service)
+
     async def on_POST(self, request, room_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=False)
 
@@ -414,7 +465,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
             if event_dict["type"] == EventTypes.Member:
                 membership = event_dict["content"].get("membership", None)
                 event_id, _ = await self.room_member_handler.update_membership(
-                    requester,
+                    await self._create_requester_for_user_id_from_app_service(
+                        state_event["sender"], requester.app_service
+                    ),
                     target=UserID.from_string(event_dict["state_key"]),
                     room_id=room_id,
                     action=membership,
@@ -434,7 +487,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
                     event,
                     _,
                 ) = await self.event_creation_handler.create_and_send_nonmember_event(
-                    requester,
+                    await self._create_requester_for_user_id_from_app_service(
+                        state_event["sender"], requester.app_service
+                    ),
                     event_dict,
                     outlier=True,
                     prev_event_ids=[fake_prev_event_id],
@@ -449,37 +504,73 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
 
         events_to_create = body["events"]
 
-        # If provided, connect the chunk to the last insertion point
-        # The chunk ID passed in comes from the chunk_id in the
-        # "insertion" event from the previous chunk.
+        prev_event_ids = prev_events_from_query
+        inherited_depth = await self._inherit_depth_from_prev_ids(
+            prev_events_from_query
+        )
+
+        # Figure out which chunk to connect to. If they passed in
+        # chunk_id_from_query let's use it. The chunk ID passed in comes
+        # from the chunk_id in the "insertion" event from the previous chunk.
+        last_event_in_chunk = events_to_create[-1]
+        chunk_id_to_connect_to = chunk_id_from_query
+        base_insertion_event = None
         if chunk_id_from_query:
-            last_event_in_chunk = events_to_create[-1]
-            last_event_in_chunk["content"][
-                EventContentFields.MSC2716_CHUNK_ID
-            ] = chunk_id_from_query
+            # TODO: Verify the chunk_id_from_query corresponds to an insertion event
+            pass
+        # Otherwise, create an insertion event to act as a starting point.
+        #
+        # We don't always have an insertion event to start hanging more history
+        # off of (ideally there would be one in the main DAG, but that's not the
+        # case if we're wanting to add history to e.g. existing rooms without
+        # an insertion event), in which case we just create a new insertion event
+        # that can then get pointed to by a "marker" event later.
+        else:
+            base_insertion_event_dict = self._create_insertion_event_dict(
+                sender=requester.user.to_string(),
+                room_id=room_id,
+                origin_server_ts=last_event_in_chunk["origin_server_ts"],
+            )
+            base_insertion_event_dict["prev_events"] = prev_event_ids.copy()
+
+            (
+                base_insertion_event,
+                _,
+            ) = await self.event_creation_handler.create_and_send_nonmember_event(
+                await self._create_requester_for_user_id_from_app_service(
+                    base_insertion_event_dict["sender"],
+                    requester.app_service,
+                ),
+                base_insertion_event_dict,
+                prev_event_ids=base_insertion_event_dict.get("prev_events"),
+                auth_event_ids=auth_event_ids,
+                historical=True,
+                depth=inherited_depth,
+            )
+
+            chunk_id_to_connect_to = base_insertion_event["content"][
+                EventContentFields.MSC2716_NEXT_CHUNK_ID
+            ]
 
-        # Add an "insertion" event to the start of each chunk (next to the oldest
+        # Connect this current chunk to the insertion event from the previous chunk
+        last_event_in_chunk["content"][
+            EventContentFields.MSC2716_CHUNK_ID
+        ] = chunk_id_to_connect_to
+
+        # Add an "insertion" event to the start of each chunk (next to the oldest-in-time
         # event in the chunk) so the next chunk can be connected to this one.
-        next_chunk_id = random_string(64)
-        insertion_event = {
-            "type": EventTypes.MSC2716_INSERTION,
-            "sender": requester.user.to_string(),
-            "content": {
-                EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id,
-                EventContentFields.MSC2716_HISTORICAL: True,
-            },
+        insertion_event = self._create_insertion_event_dict(
+            sender=requester.user.to_string(),
+            room_id=room_id,
             # Since the insertion event is put at the start of the chunk,
-            # where the oldest event is, copy the origin_server_ts from
+            # where the oldest-in-time event is, copy the origin_server_ts from
             # the first event we're inserting
-            "origin_server_ts": events_to_create[0]["origin_server_ts"],
-        }
+            origin_server_ts=events_to_create[0]["origin_server_ts"],
+        )
         # Prepend the insertion event to the start of the chunk
         events_to_create = [insertion_event] + events_to_create
 
-        inherited_depth = await self.inherit_depth_from_prev_ids(prev_events_from_query)
-
         event_ids = []
-        prev_event_ids = prev_events_from_query
         events_to_persist = []
         for ev in events_to_create:
             assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"])
@@ -498,7 +589,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
             }
 
             event, context = await self.event_creation_handler.create_event(
-                requester,
+                await self._create_requester_for_user_id_from_app_service(
+                    ev["sender"], requester.app_service
+                ),
                 event_dict,
                 prev_event_ids=event_dict.get("prev_events"),
                 auth_event_ids=auth_event_ids,
@@ -528,15 +621,23 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
         # where topological_ordering is just depth.
         for (event, context) in reversed(events_to_persist):
             ev = await self.event_creation_handler.handle_new_client_event(
-                requester=requester,
+                await self._create_requester_for_user_id_from_app_service(
+                    event["sender"], requester.app_service
+                ),
                 event=event,
                 context=context,
             )
 
+        # Add the base_insertion_event to the bottom of the list we return
+        if base_insertion_event is not None:
+            event_ids.append(base_insertion_event.event_id)
+
         return 200, {
             "state_events": auth_event_ids,
             "events": event_ids,
-            "next_chunk_id": next_chunk_id,
+            "next_chunk_id": insertion_event["content"][
+                EventContentFields.MSC2716_NEXT_CHUNK_ID
+            ],
         }
 
     def on_GET(self, request, room_id):
@@ -682,7 +783,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
         server = parse_string(request, "server", default=None)
         content = parse_json_object_from_request(request)
 
-        limit = int(content.get("limit", 100))  # type: Optional[int]
+        limit: Optional[int] = int(content.get("limit", 100))
         since_token = content.get("since", None)
         search_filter = content.get("filter", None)
 
@@ -828,9 +929,7 @@ class RoomMessageListRestServlet(RestServlet):
         filter_str = parse_string(request, "filter", encoding="utf-8")
         if filter_str:
             filter_json = urlparse.unquote(filter_str)
-            event_filter = Filter(
-                json_decoder.decode(filter_json)
-            )  # type: Optional[Filter]
+            event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
             if (
                 event_filter
                 and event_filter.filter_json.get("event_format", "client")
@@ -943,9 +1042,7 @@ class RoomEventContextServlet(RestServlet):
         filter_str = parse_string(request, "filter", encoding="utf-8")
         if filter_str:
             filter_json = urlparse.unquote(filter_str)
-            event_filter = Filter(
-                json_decoder.decode(filter_json)
-            )  # type: Optional[Filter]
+            event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
         else:
             event_filter = None
 
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index 2d1ad3d3fb..3ebe401861 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -14,7 +14,7 @@
 
 import logging
 
-from synapse.api.errors import AuthError, SynapseError
+from synapse.api.errors import SynapseError
 from synapse.http.server import respond_with_html
 from synapse.http.servlet import RestServlet
 
@@ -92,11 +92,6 @@ class AccountValiditySendMailServlet(RestServlet):
         )
 
     async def on_POST(self, request):
-        if not self.account_validity_renew_by_email_enabled:
-            raise AuthError(
-                403, "Account renewal via email is disabled on this server."
-            )
-
         requester = await self.auth.get_user_by_req(request, allow_expired=True)
         user_id = requester.user.to_string()
         await self.account_activity_handler.send_renewal_email_to_user(user_id)
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index f8dcee603c..d537d811d8 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -59,7 +59,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
             requester, message_type, content["messages"]
         )
 
-        response = (200, {})  # type: Tuple[int, dict]
+        response: Tuple[int, dict] = (200, {})
         return response