diff options
Diffstat (limited to 'synapse/rest/client')
-rw-r--r-- | synapse/rest/client/v1/login.py | 33 | ||||
-rw-r--r-- | synapse/rest/client/v1/room.py | 167 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/account_validity.py | 7 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/sendtodevice.py | 2 |
4 files changed, 147 insertions, 62 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index cbcb60fe31..11567bf32c 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -44,19 +44,14 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -LoginResponse = TypedDict( - "LoginResponse", - { - "user_id": str, - "access_token": str, - "home_server": str, - "expires_in_ms": Optional[int], - "refresh_token": Optional[str], - "device_id": str, - "well_known": Optional[Dict[str, Any]], - }, - total=False, -) +class LoginResponse(TypedDict, total=False): + user_id: str + access_token: str + home_server: str + expires_in_ms: Optional[int] + refresh_token: Optional[str] + device_id: str + well_known: Optional[Dict[str, Any]] class LoginRestServlet(RestServlet): @@ -121,7 +116,7 @@ class LoginRestServlet(RestServlet): flows.append({"type": LoginRestServlet.CAS_TYPE}) if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: - sso_flow = { + sso_flow: JsonDict = { "type": LoginRestServlet.SSO_TYPE, "identity_providers": [ _get_auth_flow_dict_for_idp( @@ -129,7 +124,7 @@ class LoginRestServlet(RestServlet): ) for idp in self._sso_handler.get_identity_providers().values() ], - } # type: JsonDict + } if self._msc2858_enabled: # backwards-compatibility support for clients which don't @@ -150,9 +145,7 @@ class LoginRestServlet(RestServlet): # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - flows.extend( - ({"type": t} for t in self.auth_handler.get_supported_login_types()) - ) + flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types()) flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) @@ -447,7 +440,7 @@ def _get_auth_flow_dict_for_idp( use_unstable_brands: whether we should use brand identifiers suitable for the unstable API """ - e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict + e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name} if idp.idp_icon: e["icon"] = idp.idp_icon if idp.idp_brand: @@ -561,7 +554,7 @@ class SsoRedirectServlet(RestServlet): finish_request(request) return - args = request.args # type: Dict[bytes, List[bytes]] # type: ignore + args: Dict[bytes, List[bytes]] = request.args # type: ignore client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True) sso_url = await self._sso_handler.handle_redirect_request( request, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 92ebe838fd..31a1193cd3 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -29,6 +29,7 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.filtering import Filter +from synapse.appservice import ApplicationService from synapse.events.utils import format_event_for_client_v2 from synapse.http.servlet import ( RestServlet, @@ -47,11 +48,13 @@ from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import ( JsonDict, + Requester, RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID, + create_requester, ) from synapse.util import json_decoder from synapse.util.stringutils import parse_and_validate_server_name, random_string @@ -309,7 +312,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() - async def inherit_depth_from_prev_ids(self, prev_event_ids) -> int: + async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int: ( most_recent_prev_event_id, most_recent_prev_event_depth, @@ -349,6 +352,54 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): return depth + def _create_insertion_event_dict( + self, sender: str, room_id: str, origin_server_ts: int + ): + """Creates an event dict for an "insertion" event with the proper fields + and a random chunk ID. + + Args: + sender: The event author MXID + room_id: The room ID that the event belongs to + origin_server_ts: Timestamp when the event was sent + + Returns: + Tuple of event ID and stream ordering position + """ + + next_chunk_id = random_string(8) + insertion_event = { + "type": EventTypes.MSC2716_INSERTION, + "sender": sender, + "room_id": room_id, + "content": { + EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id, + EventContentFields.MSC2716_HISTORICAL: True, + }, + "origin_server_ts": origin_server_ts, + } + + return insertion_event + + async def _create_requester_for_user_id_from_app_service( + self, user_id: str, app_service: ApplicationService + ) -> Requester: + """Creates a new requester for the given user_id + and validates that the app service is allowed to control + the given user. + + Args: + user_id: The author MXID that the app service is controlling + app_service: The app service that controls the user + + Returns: + Requester object + """ + + await self.auth.validate_appservice_can_control_user_id(app_service, user_id) + + return create_requester(user_id, app_service=app_service) + async def on_POST(self, request, room_id): requester = await self.auth.get_user_by_req(request, allow_guest=False) @@ -414,7 +465,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): if event_dict["type"] == EventTypes.Member: membership = event_dict["content"].get("membership", None) event_id, _ = await self.room_member_handler.update_membership( - requester, + await self._create_requester_for_user_id_from_app_service( + state_event["sender"], requester.app_service + ), target=UserID.from_string(event_dict["state_key"]), room_id=room_id, action=membership, @@ -434,7 +487,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): event, _, ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, + await self._create_requester_for_user_id_from_app_service( + state_event["sender"], requester.app_service + ), event_dict, outlier=True, prev_event_ids=[fake_prev_event_id], @@ -449,37 +504,73 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): events_to_create = body["events"] - # If provided, connect the chunk to the last insertion point - # The chunk ID passed in comes from the chunk_id in the - # "insertion" event from the previous chunk. + prev_event_ids = prev_events_from_query + inherited_depth = await self._inherit_depth_from_prev_ids( + prev_events_from_query + ) + + # Figure out which chunk to connect to. If they passed in + # chunk_id_from_query let's use it. The chunk ID passed in comes + # from the chunk_id in the "insertion" event from the previous chunk. + last_event_in_chunk = events_to_create[-1] + chunk_id_to_connect_to = chunk_id_from_query + base_insertion_event = None if chunk_id_from_query: - last_event_in_chunk = events_to_create[-1] - last_event_in_chunk["content"][ - EventContentFields.MSC2716_CHUNK_ID - ] = chunk_id_from_query + # TODO: Verify the chunk_id_from_query corresponds to an insertion event + pass + # Otherwise, create an insertion event to act as a starting point. + # + # We don't always have an insertion event to start hanging more history + # off of (ideally there would be one in the main DAG, but that's not the + # case if we're wanting to add history to e.g. existing rooms without + # an insertion event), in which case we just create a new insertion event + # that can then get pointed to by a "marker" event later. + else: + base_insertion_event_dict = self._create_insertion_event_dict( + sender=requester.user.to_string(), + room_id=room_id, + origin_server_ts=last_event_in_chunk["origin_server_ts"], + ) + base_insertion_event_dict["prev_events"] = prev_event_ids.copy() + + ( + base_insertion_event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + await self._create_requester_for_user_id_from_app_service( + base_insertion_event_dict["sender"], + requester.app_service, + ), + base_insertion_event_dict, + prev_event_ids=base_insertion_event_dict.get("prev_events"), + auth_event_ids=auth_event_ids, + historical=True, + depth=inherited_depth, + ) + + chunk_id_to_connect_to = base_insertion_event["content"][ + EventContentFields.MSC2716_NEXT_CHUNK_ID + ] - # Add an "insertion" event to the start of each chunk (next to the oldest + # Connect this current chunk to the insertion event from the previous chunk + last_event_in_chunk["content"][ + EventContentFields.MSC2716_CHUNK_ID + ] = chunk_id_to_connect_to + + # Add an "insertion" event to the start of each chunk (next to the oldest-in-time # event in the chunk) so the next chunk can be connected to this one. - next_chunk_id = random_string(64) - insertion_event = { - "type": EventTypes.MSC2716_INSERTION, - "sender": requester.user.to_string(), - "content": { - EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id, - EventContentFields.MSC2716_HISTORICAL: True, - }, + insertion_event = self._create_insertion_event_dict( + sender=requester.user.to_string(), + room_id=room_id, # Since the insertion event is put at the start of the chunk, - # where the oldest event is, copy the origin_server_ts from + # where the oldest-in-time event is, copy the origin_server_ts from # the first event we're inserting - "origin_server_ts": events_to_create[0]["origin_server_ts"], - } + origin_server_ts=events_to_create[0]["origin_server_ts"], + ) # Prepend the insertion event to the start of the chunk events_to_create = [insertion_event] + events_to_create - inherited_depth = await self.inherit_depth_from_prev_ids(prev_events_from_query) - event_ids = [] - prev_event_ids = prev_events_from_query events_to_persist = [] for ev in events_to_create: assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) @@ -498,7 +589,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): } event, context = await self.event_creation_handler.create_event( - requester, + await self._create_requester_for_user_id_from_app_service( + ev["sender"], requester.app_service + ), event_dict, prev_event_ids=event_dict.get("prev_events"), auth_event_ids=auth_event_ids, @@ -528,15 +621,23 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): # where topological_ordering is just depth. for (event, context) in reversed(events_to_persist): ev = await self.event_creation_handler.handle_new_client_event( - requester=requester, + await self._create_requester_for_user_id_from_app_service( + event["sender"], requester.app_service + ), event=event, context=context, ) + # Add the base_insertion_event to the bottom of the list we return + if base_insertion_event is not None: + event_ids.append(base_insertion_event.event_id) + return 200, { "state_events": auth_event_ids, "events": event_ids, - "next_chunk_id": next_chunk_id, + "next_chunk_id": insertion_event["content"][ + EventContentFields.MSC2716_NEXT_CHUNK_ID + ], } def on_GET(self, request, room_id): @@ -682,7 +783,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): server = parse_string(request, "server", default=None) content = parse_json_object_from_request(request) - limit = int(content.get("limit", 100)) # type: Optional[int] + limit: Optional[int] = int(content.get("limit", 100)) since_token = content.get("since", None) search_filter = content.get("filter", None) @@ -828,9 +929,7 @@ class RoomMessageListRestServlet(RestServlet): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter( - json_decoder.decode(filter_json) - ) # type: Optional[Filter] + event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) if ( event_filter and event_filter.filter_json.get("event_format", "client") @@ -943,9 +1042,7 @@ class RoomEventContextServlet(RestServlet): filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter( - json_decoder.decode(filter_json) - ) # type: Optional[Filter] + event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) else: event_filter = None diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 2d1ad3d3fb..3ebe401861 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -14,7 +14,7 @@ import logging -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import SynapseError from synapse.http.server import respond_with_html from synapse.http.servlet import RestServlet @@ -92,11 +92,6 @@ class AccountValiditySendMailServlet(RestServlet): ) async def on_POST(self, request): - if not self.account_validity_renew_by_email_enabled: - raise AuthError( - 403, "Account renewal via email is disabled on this server." - ) - requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() await self.account_activity_handler.send_renewal_email_to_user(user_id) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index f8dcee603c..d537d811d8 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -59,7 +59,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): requester, message_type, content["messages"] ) - response = (200, {}) # type: Tuple[int, dict] + response: Tuple[int, dict] = (200, {}) return response |