diff options
author | Ben Banfield-Zanin <benbz@matrix.org> | 2021-02-16 13:33:20 +0000 |
---|---|---|
committer | Ben Banfield-Zanin <benbz@matrix.org> | 2021-02-16 13:33:20 +0000 |
commit | dcf1b9c276e22bb6f5200fc029301c4d40e87a1f (patch) | |
tree | 1f5badce24645d99534133a7a989069906088fff /synapse/rest | |
parent | Merge remote-tracking branch 'origin/release-v1.24.0' into bbz/info-mainline-... (diff) | |
parent | Fixup CHANGES (diff) | |
download | synapse-dcf1b9c276e22bb6f5200fc029301c4d40e87a1f.tar.xz |
Merge remote-tracking branch 'origin/release-v1.27.0' into bbz/info-mainline-1.27.0 github/bbz/info-mainline-1.27.0 bbz/info-mainline-1.27.0
Diffstat (limited to 'synapse/rest')
39 files changed, 1575 insertions, 576 deletions
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 55ddebb4fe..f5c5d164f9 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018-2019 New Vector Ltd +# Copyright 2020, 2021 The Matrix.org Foundation C.I.C. + # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,10 +38,13 @@ from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_medi from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet from synapse.rest.admin.rooms import ( DeleteRoomRestServlet, + ForwardExtremitiesRestServlet, JoinRoomAliasServlet, ListRoomRestServlet, + MakeRoomAdminRestServlet, RoomMembersRestServlet, RoomRestServlet, + RoomStateRestServlet, ShutdownRoomRestServlet, ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet @@ -50,6 +55,7 @@ from synapse.rest.admin.users import ( PushersRestServlet, ResetPasswordRestServlet, SearchUsersRestServlet, + ShadowBanRestServlet, UserAdminServlet, UserMediaRestServlet, UserMembershipRestServlet, @@ -208,6 +214,7 @@ def register_servlets(hs, http_server): """ register_servlets_for_client_rest_resource(hs, http_server) ListRoomRestServlet(hs).register(http_server) + RoomStateRestServlet(hs).register(http_server) RoomRestServlet(hs).register(http_server) RoomMembersRestServlet(hs).register(http_server) DeleteRoomRestServlet(hs).register(http_server) @@ -228,6 +235,9 @@ def register_servlets(hs, http_server): EventReportDetailRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server) PushersRestServlet(hs).register(http_server) + MakeRoomAdminRestServlet(hs).register(http_server) + ShadowBanRestServlet(hs).register(http_server) + ForwardExtremitiesRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index c82b4f87d6..8720b1401f 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -15,6 +15,9 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple + +from twisted.web.http import Request from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_boolean, parse_integer @@ -23,6 +26,10 @@ from synapse.rest.admin._base import ( assert_requester_is_admin, assert_user_is_admin, ) +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet): admin_patterns("/quarantine_media/(?P<room_id>[^/]+)") ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, room_id: str): + async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet): PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, user_id: str): + async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet): "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_POST(self, request, server_name: str, media_id: str): + async def on_POST( + self, request: Request, server_name: str, media_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet): return 200, {} +class ProtectMediaByID(RestServlet): + """Protect local media from being quarantined. + """ + + PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)") + + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + logging.info("Protecting local media by ID: %s", media_id) + + # Quarantine this media id + await self.store.mark_local_media_as_safe(media_id) + + return 200, {} + + class ListMediaInRoom(RestServlet): """Lists all of the media in a given room. """ PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_GET(self, request, room_id): + async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) is_admin = await self.auth.is_server_admin(requester.user) if not is_admin: @@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet): class PurgeMediaCacheRestServlet(RestServlet): PATTERNS = admin_patterns("/purge_media_cache") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.media_repository = hs.get_media_repository() self.auth = hs.get_auth() - async def on_POST(self, request): + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) @@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet): PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() - async def on_DELETE(self, request, server_name: str, media_id: str): + async def on_DELETE( + self, request: Request, server_name: str, media_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) if self.server_name != server_name: @@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet): PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() self.server_name = hs.hostname self.media_repository = hs.get_media_repository() - async def on_POST(self, request, server_name: str): + async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) before_ts = parse_integer(request, "before_ts", required=True) @@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet): return 200, {"deleted_media": deleted_media, "total": total} -def register_servlets_for_media_repo(hs, http_server): +def register_servlets_for_media_repo(hs: "HomeServer", http_server): """ Media repo specific APIs. """ @@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server): QuarantineMediaInRoom(hs).register(http_server) QuarantineMediaByID(hs).register(http_server) QuarantineMediaByUser(hs).register(http_server) + ProtectMediaByID(hs).register(http_server) ListMediaInRoom(hs).register(http_server) DeleteMediaByID(hs).register(http_server) DeleteMediaByDateSize(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 25f89e4685..3e57e6a4d0 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,10 +14,10 @@ # limitations under the License. import logging from http import HTTPStatus -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple -from synapse.api.constants import EventTypes, JoinRules -from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -25,13 +25,18 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.rest.admin._base import ( admin_patterns, assert_requester_is_admin, assert_user_is_admin, ) from synapse.storage.databases.main.room import RoomSortOrder -from synapse.types import RoomAlias, RoomID, UserID, create_requester +from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester + +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -45,12 +50,14 @@ class ShutdownRoomRestServlet(RestServlet): PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.room_shutdown_handler = hs.get_room_shutdown_handler() - async def on_POST(self, request, room_id): + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -86,13 +93,15 @@ class DeleteRoomRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() - async def on_POST(self, request, room_id): + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -146,12 +155,12 @@ class ListRoomRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -236,19 +245,24 @@ class RoomRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) ret = await self.store.get_room_with_stats(room_id) if not ret: raise NotFoundError("Room not found") - return 200, ret + members = await self.store.get_users_in_room(room_id) + ret["joined_local_devices"] = await self.store.count_devices_by_users(members) + + return (200, ret) class RoomMembersRestServlet(RestServlet): @@ -258,12 +272,14 @@ class RoomMembersRestServlet(RestServlet): PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) ret = await self.store.get_room(room_id) @@ -276,18 +292,59 @@ class RoomMembersRestServlet(RestServlet): return 200, ret +class RoomStateRestServlet(RestServlet): + """ + Get full state within a room. + """ + + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + ret = await self.store.get_room(room_id) + if not ret: + raise NotFoundError("Room not found") + + event_ids = await self.store.get_current_state_ids(room_id) + events = await self.store.get_events(event_ids.values()) + now = self.clock.time_msec() + room_state = await self._event_serializer.serialize_events( + events.values(), + now, + # We don't bother bundling aggregations in when asked for state + # events, as clients won't use them. + bundle_aggregations=False, + ) + ret = {"state": room_state} + + return 200, ret + + class JoinRoomAliasServlet(RestServlet): PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.room_member_handler = hs.get_room_member_handler() self.admin_handler = hs.get_admin_handler() self.state_handler = hs.get_state_handler() - async def on_POST(self, request, room_identifier): + async def on_POST( + self, request: SynapseRequest, room_identifier: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) @@ -314,7 +371,6 @@ class JoinRoomAliasServlet(RestServlet): handler = self.room_member_handler room_alias = RoomAlias.from_string(room_identifier) room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias) - room_id = room_id.to_string() else: raise SynapseError( 400, "%s was not legal room ID or room alias" % (room_identifier,) @@ -351,3 +407,201 @@ class JoinRoomAliasServlet(RestServlet): ) return 200, {"room_id": room_id} + + +class MakeRoomAdminRestServlet(RestServlet): + """Allows a server admin to get power in a room if a local user has power in + a room. Will also invite the user if they're not in the room and it's a + private room. Can specify another user (rather than the admin user) to be + granted power, e.g.: + + POST/_synapse/admin/v1/rooms/<room_id_or_alias>/make_room_admin + { + "user_id": "@foo:example.com" + } + """ + + PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.room_member_handler = hs.get_room_member_handler() + self.event_creation_handler = hs.get_event_creation_handler() + self.state_handler = hs.get_state_handler() + self.is_mine_id = hs.is_mine_id + + async def on_POST(self, request, room_identifier): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + content = parse_json_object_from_request(request, allow_empty_body=True) + + # Resolve to a room ID, if necessary. + if RoomID.is_valid(room_identifier): + room_id = room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias) + room_id = room_id.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + + # Which user to grant room admin rights to. + user_to_add = content.get("user_id", requester.user.to_string()) + + # Figure out which local users currently have power in the room, if any. + room_state = await self.state_handler.get_current_state(room_id) + if not room_state: + raise SynapseError(400, "Server not in room") + + create_event = room_state[(EventTypes.Create, "")] + power_levels = room_state.get((EventTypes.PowerLevels, "")) + + if power_levels is not None: + # We pick the local user with the highest power. + user_power = power_levels.content.get("users", {}) + admin_users = [ + user_id for user_id in user_power if self.is_mine_id(user_id) + ] + admin_users.sort(key=lambda user: user_power[user]) + + if not admin_users: + raise SynapseError(400, "No local admin user in room") + + admin_user_id = None + + for admin_user in reversed(admin_users): + if room_state.get((EventTypes.Member, admin_user)): + admin_user_id = admin_user + break + + if not admin_user_id: + raise SynapseError( + 400, "No local admin user in room", + ) + + pl_content = power_levels.content + else: + # If there is no power level events then the creator has rights. + pl_content = {} + admin_user_id = create_event.sender + if not self.is_mine_id(admin_user_id): + raise SynapseError( + 400, "No local admin user in room", + ) + + # Grant the user power equal to the room admin by attempting to send an + # updated power level event. + new_pl_content = dict(pl_content) + new_pl_content["users"] = dict(pl_content.get("users", {})) + new_pl_content["users"][user_to_add] = new_pl_content["users"][admin_user_id] + + fake_requester = create_requester( + admin_user_id, authenticated_entity=requester.authenticated_entity, + ) + + try: + await self.event_creation_handler.create_and_send_nonmember_event( + fake_requester, + event_dict={ + "content": new_pl_content, + "sender": admin_user_id, + "type": EventTypes.PowerLevels, + "state_key": "", + "room_id": room_id, + }, + ) + except AuthError: + # The admin user we found turned out not to have enough power. + raise SynapseError( + 400, "No local admin user in room with power to update power levels." + ) + + # Now we check if the user we're granting admin rights to is already in + # the room. If not and it's not a public room we invite them. + member_event = room_state.get((EventTypes.Member, user_to_add)) + is_joined = False + if member_event: + is_joined = member_event.content["membership"] in ( + Membership.JOIN, + Membership.INVITE, + ) + + if is_joined: + return 200, {} + + join_rules = room_state.get((EventTypes.JoinRules, "")) + is_public = False + if join_rules: + is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC + + if is_public: + return 200, {} + + await self.room_member_handler.update_membership( + fake_requester, + target=UserID.from_string(user_to_add), + room_id=room_id, + action=Membership.INVITE, + ) + + return 200, {} + + +class ForwardExtremitiesRestServlet(RestServlet): + """Allows a server admin to get or clear forward extremities. + + Clearing does not require restarting the server. + + Clear forward extremities: + DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities + + Get forward_extremities: + GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities + """ + + PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.room_member_handler = hs.get_room_member_handler() + self.store = hs.get_datastore() + + async def resolve_room_id(self, room_identifier: str) -> str: + """Resolve to a room ID, if necessary.""" + if RoomID.is_valid(room_identifier): + resolved_room_id = room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias) + resolved_room_id = room_id.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + if not resolved_room_id: + raise SynapseError( + 400, "Unknown room ID or room alias %s" % room_identifier + ) + return resolved_room_id + + async def on_DELETE(self, request, room_identifier): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + room_id = await self.resolve_room_id(room_identifier) + + deleted_count = await self.store.delete_forward_extremities_for_room(room_id) + return 200, {"deleted": deleted_count} + + async def on_GET(self, request, room_identifier): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + room_id = await self.resolve_room_id(room_identifier) + + extremities = await self.store.get_forward_extremities_for_room(room_id) + return 200, {"count": len(extremities), "results": extremities} diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index b0ff5e1ead..68c3c64a0d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -42,17 +42,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -_GET_PUSHERS_ALLOWED_KEYS = { - "app_display_name", - "app_id", - "data", - "device_display_name", - "kind", - "lang", - "profile_tag", - "pushkey", -} - class UsersRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$") @@ -94,17 +83,32 @@ class UsersRestServletV2(RestServlet): The parameter `deactivated` can be used to include deactivated users. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + 400, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + user_id = parse_string(request, "user_id", default=None) name = parse_string(request, "name", default=None) guests = parse_boolean(request, "guests", default=True) @@ -114,7 +118,7 @@ class UsersRestServletV2(RestServlet): start, limit, user_id, name, guests, deactivated ) ret = {"users": users, "total": total} - if len(users) >= limit: + if (start + limit) < total: ret["next_token"] = str(start + len(users)) return 200, ret @@ -255,7 +259,7 @@ class UserRestServletV2(RestServlet): if deactivate and not user["deactivated"]: await self.deactivate_account_handler.deactivate_account( - target_user.to_string(), False + target_user.to_string(), False, requester, by_admin=True ) elif not deactivate and user["deactivated"]: if "password" not in body: @@ -320,9 +324,9 @@ class UserRestServletV2(RestServlet): data={}, ) - if "avatar_url" in body and type(body["avatar_url"]) == str: + if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( - user_id, requester, body["avatar_url"], True + target_user, requester, body["avatar_url"], True ) ret = await self.admin_handler.get_user(target_user) @@ -420,6 +424,9 @@ class UserRegisterServlet(RestServlet): if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: raise SynapseError(400, "Invalid user type") + if "mac" not in body: + raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON) + got_mac = body["mac"] want_mac_builder = hmac.new( @@ -494,12 +501,22 @@ class WhoisRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet): PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self._deactivate_account_handler = hs.get_deactivate_account_handler() self.auth = hs.get_auth() + self.is_mine = hs.is_mine + self.store = hs.get_datastore() + + async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + if not self.is_mine(UserID.from_string(target_user_id)): + raise SynapseError(400, "Can only deactivate local users") + + if not await self.store.get_user_by_id(target_user_id): + raise NotFoundError("User not found") - async def on_POST(self, request, target_user_id): - await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request, allow_empty_body=True) erase = body.get("erase", False) if not isinstance(erase, bool): @@ -509,10 +526,8 @@ class DeactivateAccountRestServlet(RestServlet): Codes.BAD_JSON, ) - UserID.from_string(target_user_id) - result = await self._deactivate_account_handler.deactivate_account( - target_user_id, erase + target_user_id, erase, requester, by_admin=True ) if result: id_server_unbind_result = "success" @@ -722,13 +737,6 @@ class UserMembershipRestServlet(RestServlet): async def on_GET(self, request, user_id): await assert_requester_is_admin(self.auth, request) - if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only lookup local users") - - user = await self.store.get_user_by_id(user_id) - if user is None: - raise NotFoundError("Unknown user") - room_ids = await self.store.get_rooms_for_user(user_id) ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} return 200, ret @@ -767,10 +775,7 @@ class PushersRestServlet(RestServlet): pushers = await self.store.get_pushers_by_user_id(user_id) - filtered_pushers = [ - {k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS} - for p in pushers - ] + filtered_pushers = [p.as_dict() for p in pushers] return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)} @@ -885,3 +890,39 @@ class UserTokenRestServlet(RestServlet): ) return 200, {"access_token": token} + + +class ShadowBanRestServlet(RestServlet): + """An admin API for shadow-banning a user. + + A shadow-banned users receives successful responses to their client-server + API requests, but the events are not propagated into rooms. + + Shadow-banning a user should be used as a tool of last resort and may lead + to confusing or broken behaviour for the client. + + Example: + + POST /_synapse/admin/v1/users/@test:example.com/shadow_ban + {} + + 200 OK + {} + """ + + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + if not self.hs.is_mine_id(user_id): + raise SynapseError(400, "Only local users can be shadow-banned") + + await self.store.set_shadow_banned(UserID.from_string(user_id), True) + + return 200, {} diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index d7ae148214..0fb9419e58 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -14,12 +14,13 @@ # limitations under the License. import logging -from typing import Awaitable, Callable, Dict, Optional +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.appservice import ApplicationService -from synapse.http.server import finish_request +from synapse.handlers.sso import SsoIdentityProvider +from synapse.http.server import HttpServer, finish_request from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, @@ -30,6 +31,9 @@ from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import JsonDict, UserID +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -42,7 +46,7 @@ class LoginRestServlet(RestServlet): JWT_TYPE_DEPRECATED = "m.login.jwt" APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs @@ -57,11 +61,14 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled + self._msc2858_enabled = hs.config.experimental.msc2858_enabled self.auth = hs.get_auth() self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() + self._sso_handler = hs.get_sso_handler() + self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( clock=hs.get_clock(), @@ -86,8 +93,17 @@ class LoginRestServlet(RestServlet): flows.append({"type": LoginRestServlet.CAS_TYPE}) if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: - flows.append({"type": LoginRestServlet.SSO_TYPE}) - # While its valid for us to advertise this login type generally, + sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict + + if self._msc2858_enabled: + sso_flow["org.matrix.msc2858.identity_providers"] = [ + _get_auth_flow_dict_for_idp(idp) + for idp in self._sso_handler.get_identity_providers().values() + ] + + flows.append(sso_flow) + + # While it's valid for us to advertise this login type generally, # synapse currently only gives out these tokens as part of the # SSO login flow. # Generally we don't want to advertise login flows that clients @@ -105,22 +121,27 @@ class LoginRestServlet(RestServlet): return 200, {"flows": flows} async def on_POST(self, request: SynapseRequest): - self._address_ratelimiter.ratelimit(request.getClientIP()) - login_submission = parse_json_object_from_request(request) try: if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: appservice = self.auth.get_appservice_by_req(request) + + if appservice.is_rate_limited(): + self._address_ratelimiter.ratelimit(request.getClientIP()) + result = await self._do_appservice_login(login_submission, appservice) elif self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): + self._address_ratelimiter.ratelimit(request.getClientIP()) result = await self._do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: + self._address_ratelimiter.ratelimit(request.getClientIP()) result = await self._do_token_login(login_submission) else: + self._address_ratelimiter.ratelimit(request.getClientIP()) result = await self._do_other_login(login_submission) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -159,7 +180,9 @@ class LoginRestServlet(RestServlet): if not appservice.is_interested_in_user(qualified_user_id): raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN) - return await self._complete_login(qualified_user_id, login_submission) + return await self._complete_login( + qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited() + ) async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]: """Handle non-token/saml/jwt logins @@ -194,6 +217,7 @@ class LoginRestServlet(RestServlet): login_submission: JsonDict, callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None, create_non_existent_users: bool = False, + ratelimit: bool = True, ) -> Dict[str, str]: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on @@ -208,6 +232,7 @@ class LoginRestServlet(RestServlet): callback: Callback function to run after login. create_non_existent_users: Whether to create the user if they don't exist. Defaults to False. + ratelimit: Whether to ratelimit the login request. Returns: result: Dictionary of account information after successful login. @@ -216,7 +241,8 @@ class LoginRestServlet(RestServlet): # Before we actually log them in we check if they've already logged in # too often. This happens here rather than before as we don't # necessarily know the user before now. - self._account_ratelimiter.ratelimit(user_id.lower()) + if ratelimit: + self._account_ratelimiter.ratelimit(user_id.lower()) if create_non_existent_users: canonical_uid = await self.auth_handler.check_user_exists(user_id) @@ -298,48 +324,63 @@ class LoginRestServlet(RestServlet): return result -class BaseSSORedirectServlet(RestServlet): - """Common base class for /login/sso/redirect impls""" - - PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) +def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: + """Return an entry for the login flow dict + + Returns an entry suitable for inclusion in "identity_providers" in the + response to GET /_matrix/client/r0/login + """ + e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict + if idp.idp_icon: + e["icon"] = idp.idp_icon + if idp.idp_brand: + e["brand"] = idp.idp_brand + return e + + +class SsoRedirectServlet(RestServlet): + PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True) + + def __init__(self, hs: "HomeServer"): + # make sure that the relevant handlers are instantiated, so that they + # register themselves with the main SSOHandler. + if hs.config.cas_enabled: + hs.get_cas_handler() + if hs.config.saml2_enabled: + hs.get_saml_handler() + if hs.config.oidc_enabled: + hs.get_oidc_handler() + self._sso_handler = hs.get_sso_handler() + self._msc2858_enabled = hs.config.experimental.msc2858_enabled + + def register(self, http_server: HttpServer) -> None: + super().register(http_server) + if self._msc2858_enabled: + # expose additional endpoint for MSC2858 support + http_server.register_paths( + "GET", + client_patterns( + "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$", + releases=(), + unstable=True, + ), + self.on_GET, + self.__class__.__name__, + ) - async def on_GET(self, request: SynapseRequest): - args = request.args - if b"redirectUrl" not in args: - return 400, "Redirect URL not specified for SSO auth" - client_redirect_url = args[b"redirectUrl"][0] - sso_url = await self.get_sso_url(request, client_redirect_url) + async def on_GET( + self, request: SynapseRequest, idp_id: Optional[str] = None + ) -> None: + client_redirect_url = parse_string( + request, "redirectUrl", required=True, encoding=None + ) + sso_url = await self._sso_handler.handle_redirect_request( + request, client_redirect_url, idp_id, + ) + logger.info("Redirecting to %s", sso_url) request.redirect(sso_url) finish_request(request) - async def get_sso_url( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> bytes: - """Get the URL to redirect to, to perform SSO auth - - Args: - request: The client request to redirect. - client_redirect_url: the URL that we should redirect the - client to when everything is done - - Returns: - URL to redirect to - """ - # to be implemented by subclasses - raise NotImplementedError() - - -class CasRedirectServlet(BaseSSORedirectServlet): - def __init__(self, hs): - self._cas_handler = hs.get_cas_handler() - - async def get_sso_url( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> bytes: - return self._cas_handler.get_redirect_url( - {"redirectUrl": client_redirect_url} - ).encode("ascii") - class CasTicketServlet(RestServlet): PATTERNS = client_patterns("/login/cas/ticket", v1=True) @@ -366,40 +407,8 @@ class CasTicketServlet(RestServlet): ) -class SAMLRedirectServlet(BaseSSORedirectServlet): - PATTERNS = client_patterns("/login/sso/redirect", v1=True) - - def __init__(self, hs): - self._saml_handler = hs.get_saml_handler() - - async def get_sso_url( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> bytes: - return self._saml_handler.handle_redirect_request(client_redirect_url) - - -class OIDCRedirectServlet(BaseSSORedirectServlet): - """Implementation for /login/sso/redirect for the OIDC login flow.""" - - PATTERNS = client_patterns("/login/sso/redirect", v1=True) - - def __init__(self, hs): - self._oidc_handler = hs.get_oidc_handler() - - async def get_sso_url( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> bytes: - return await self._oidc_handler.handle_redirect_request( - request, client_redirect_url - ) - - def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) + SsoRedirectServlet(hs).register(http_server) if hs.config.cas_enabled: - CasRedirectServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server) - elif hs.config.saml2_enabled: - SAMLRedirectServlet(hs).register(http_server) - elif hs.config.oidc_enabled: - OIDCRedirectServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 8fe83f321a..89823fcc39 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -28,17 +28,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns logger = logging.getLogger(__name__) -ALLOWED_KEYS = { - "app_display_name", - "app_id", - "data", - "device_display_name", - "kind", - "lang", - "profile_tag", - "pushkey", -} - class PushersRestServlet(RestServlet): PATTERNS = client_patterns("/pushers$", v1=True) @@ -54,9 +43,7 @@ class PushersRestServlet(RestServlet): pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) - filtered_pushers = [ - {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers - ] + filtered_pushers = [p.as_dict() for p in pushers] return 200, {"pushers": filtered_pushers} diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 93c06afe27..f95627ee61 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -46,7 +46,7 @@ from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.util import json_decoder -from synapse.util.stringutils import random_string +from synapse.util.stringutils import parse_and_validate_server_name, random_string if TYPE_CHECKING: import synapse.server @@ -347,8 +347,6 @@ class PublicRoomListRestServlet(TransactionRestServlet): # provided. if server: raise e - else: - pass limit = parse_integer(request, "limit", 0) since_token = parse_string(request, "since", None) @@ -359,6 +357,14 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server and server != self.hs.config.server_name: + # Ensure the server is valid. + try: + parse_and_validate_server_name(server) + except ValueError: + raise SynapseError( + 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM, + ) + try: data = await handler.get_remote_public_room_list( server, limit=limit, since_token=since_token @@ -402,6 +408,14 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server and server != self.hs.config.server_name: + # Ensure the server is valid. + try: + parse_and_validate_server_name(server) + except ValueError: + raise SynapseError( + 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM, + ) + try: data = await handler.get_remote_public_room_list( server, @@ -963,25 +977,28 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): ) -def register_servlets(hs, http_server): +def register_servlets(hs, http_server, is_worker=False): RoomStateEventRestServlet(hs).register(http_server) - RoomCreateRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server) JoinedRoomMemberListRestServlet(hs).register(http_server) RoomMessageListRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) - RoomForgetRestServlet(hs).register(http_server) RoomMembershipRestServlet(hs).register(http_server) RoomSendEventRestServlet(hs).register(http_server) PublicRoomListRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server) - SearchRestServlet(hs).register(http_server) - JoinedRoomsRestServlet(hs).register(http_server) - RoomEventServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) - RoomAliasListServlet(hs).register(http_server) + + # Some servlets only get registered for the main process. + if not is_worker: + RoomCreateRestServlet(hs).register(http_server) + RoomForgetRestServlet(hs).register(http_server) + SearchRestServlet(hs).register(http_server) + JoinedRoomsRestServlet(hs).register(http_server) + RoomEventServlet(hs).register(http_server) + RoomAliasListServlet(hs).register(http_server) def register_deprecated_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index e0feebea94..b67c1702ca 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -20,9 +20,6 @@ from http import HTTPStatus from typing import TYPE_CHECKING from urllib.parse import urlparse -if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer - from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, @@ -31,6 +28,7 @@ from synapse.api.errors import ( ThreepidValidationError, ) from synapse.config.emailconfig import ThreepidBehaviour +from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http.server import finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, @@ -46,13 +44,17 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed from ._base import client_patterns, interactive_auth_handler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + + logger = logging.getLogger(__name__) class EmailPasswordRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/password/email/requestToken$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.datastore = hs.get_datastore() @@ -101,6 +103,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + # The email will be sent to the stored address. # This avoids a potential account hijack by requesting a password reset to # an email address which is controlled by the attacker but which, after @@ -189,11 +193,7 @@ class PasswordRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) try: params, session_id = await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "modify your account password", + requester, request, body, "modify your account password", ) except InteractiveAuthIncompleteError as e: # The user needs to provide more steps to complete auth, but @@ -204,7 +204,9 @@ class PasswordRestServlet(RestServlet): if new_password: password_hash = await self.auth_handler.hash(new_password) await self.auth_handler.set_session_data( - e.session_id, "password_hash", password_hash + e.session_id, + UIAuthSessionDataConstants.PASSWORD_HASH, + password_hash, ) raise user_id = requester.user.to_string() @@ -215,7 +217,6 @@ class PasswordRestServlet(RestServlet): [[LoginType.EMAIL_IDENTITY]], request, body, - self.hs.get_ip_from_request(request), "modify your account password", ) except InteractiveAuthIncompleteError as e: @@ -227,7 +228,9 @@ class PasswordRestServlet(RestServlet): if new_password: password_hash = await self.auth_handler.hash(new_password) await self.auth_handler.set_session_data( - e.session_id, "password_hash", password_hash + e.session_id, + UIAuthSessionDataConstants.PASSWORD_HASH, + password_hash, ) raise @@ -254,14 +257,18 @@ class PasswordRestServlet(RestServlet): logger.error("Auth succeeded but no known type! %r", result.keys()) raise SynapseError(500, "", Codes.UNKNOWN) - # If we have a password in this request, prefer it. Otherwise, there - # must be a password hash from an earlier request. + # If we have a password in this request, prefer it. Otherwise, use the + # password hash from an earlier request. if new_password: password_hash = await self.auth_handler.hash(new_password) - else: + elif session_id is not None: password_hash = await self.auth_handler.get_session_data( - session_id, "password_hash", None + session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None ) + else: + # UI validation was skipped, but the request did not include a new + # password. + password_hash = None if not password_hash: raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) @@ -300,19 +307,18 @@ class DeactivateAccountRestServlet(RestServlet): # allow ASes to deactivate their own users if requester.app_service: await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase + requester.user.to_string(), erase, requester ) return 200, {} await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "deactivate your account", + requester, request, body, "deactivate your account", ) result = await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase, id_server=body.get("id_server") + requester.user.to_string(), + erase, + requester, + id_server=body.get("id_server"), ) if result: id_server_unbind_result = "success" @@ -375,6 +381,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + if next_link: # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) @@ -426,7 +434,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): class MsisdnThreepidRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs super().__init__() self.store = self.hs.get_datastore() @@ -454,6 +462,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests( + request, "msisdn", msisdn + ) + if next_link: # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) @@ -691,11 +703,7 @@ class ThreepidAddRestServlet(RestServlet): assert_valid_client_secret(client_secret) await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "add a third-party identifier to your account", + requester, request, body, "add a third-party identifier to your account", ) validation_session = await self.identity_handler.validate_threepid_session( diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index 87a5b1b86b..3f28c0bc3e 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() - self.notifier = hs.get_notifier() - self._is_worker = hs.config.worker_app is not None + self.handler = hs.get_account_data_handler() async def on_PUT(self, request, user_id, account_data_type): - if self._is_worker: - raise Exception("Cannot handle PUT /account_data on worker") - requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") body = parse_json_object_from_request(request) - max_id = await self.store.add_account_data_for_user( - user_id, account_data_type, body - ) - - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) + await self.handler.add_account_data_for_user(user_id, account_data_type, body) return 200, {} @@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() - self.notifier = hs.get_notifier() - self._is_worker = hs.config.worker_app is not None + self.handler = hs.get_account_data_handler() async def on_PUT(self, request, user_id, room_id, account_data_type): - if self._is_worker: - raise Exception("Cannot handle PUT /account_data on worker") - requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") @@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet): " Use /rooms/!roomId:server.name/read_markers", ) - max_id = await self.store.add_account_data_to_room( + await self.handler.add_account_data_to_room( user_id, room_id, account_data_type, body ) - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) - return 200, {} async def on_GET(self, request, user_id, room_id, account_data_type): diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index fab077747f..75ece1c911 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.api.constants import LoginType from synapse.api.errors import SynapseError @@ -23,6 +24,9 @@ from synapse.http.servlet import RestServlet, parse_string from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -35,28 +39,12 @@ class AuthRestServlet(RestServlet): PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() - - # SSO configuration. - self._cas_enabled = hs.config.cas_enabled - if self._cas_enabled: - self._cas_handler = hs.get_cas_handler() - self._cas_server_url = hs.config.cas_server_url - self._cas_service_url = hs.config.cas_service_url - self._saml_enabled = hs.config.saml2_enabled - if self._saml_enabled: - self._saml_handler = hs.get_saml_handler() - self._oidc_enabled = hs.config.oidc_enabled - if self._oidc_enabled: - self._oidc_handler = hs.get_oidc_handler() - self._cas_server_url = hs.config.cas_server_url - self._cas_service_url = hs.config.cas_service_url - self.recaptcha_template = hs.config.recaptcha_template self.terms_template = hs.config.terms_template self.success_template = hs.config.fallback_success_template @@ -85,32 +73,7 @@ class AuthRestServlet(RestServlet): elif stagetype == LoginType.SSO: # Display a confirmation page which prompts the user to # re-authenticate with their SSO provider. - if self._cas_enabled: - # Generate a request to CAS that redirects back to an endpoint - # to verify the successful authentication. - sso_redirect_url = self._cas_handler.get_redirect_url( - {"session": session}, - ) - - elif self._saml_enabled: - # Some SAML identity providers (e.g. Google) require a - # RelayState parameter on requests. It is not necessary here, so - # pass in a dummy redirect URL (which will never get used). - client_redirect_url = b"unused" - sso_redirect_url = self._saml_handler.handle_redirect_request( - client_redirect_url, session - ) - - elif self._oidc_enabled: - client_redirect_url = b"" - sso_redirect_url = await self._oidc_handler.handle_redirect_request( - request, client_redirect_url, session - ) - - else: - raise SynapseError(400, "Homeserver not configured for SSO.") - - html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) + html = await self.auth_handler.start_sso_ui_auth(request, session) else: raise SynapseError(404, "Unknown auth stage type") @@ -134,7 +97,7 @@ class AuthRestServlet(RestServlet): authdict = {"response": response, "session": session} success = await self.auth_handler.add_oob_auth( - LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request) + LoginType.RECAPTCHA, authdict, request.getClientIP() ) if success: @@ -150,7 +113,7 @@ class AuthRestServlet(RestServlet): authdict = {"session": session} success = await self.auth_handler.add_oob_auth( - LoginType.TERMS, authdict, self.hs.get_ip_from_request(request) + LoginType.TERMS, authdict, request.getClientIP() ) if success: diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index af117cb27c..314e01dfe4 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -83,11 +83,7 @@ class DeleteDevicesRestServlet(RestServlet): assert_params_in_dict(body, ["devices"]) await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "remove device(s) from your account", + requester, request, body, "remove device(s) from your account", ) await self.device_handler.delete_devices( @@ -133,11 +129,7 @@ class DeviceRestServlet(RestServlet): raise await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "remove a device from your account", + requester, request, body, "remove a device from your account", ) await self.device_handler.delete_device(requester.user.to_string(), device_id) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 75215a3779..28b55f27ad 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +from functools import wraps from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -25,6 +26,22 @@ from ._base import client_patterns logger = logging.getLogger(__name__) +def _validate_group_id(f): + """Wrapper to validate the form of the group ID. + + Can be applied to any on_FOO methods that accepts a group ID as a URL parameter. + """ + + @wraps(f) + def wrapper(self, request, group_id, *args, **kwargs): + if not GroupID.is_valid(group_id): + raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) + + return f(self, request, group_id, *args, **kwargs) + + return wrapper + + class GroupServlet(RestServlet): """Get the group profile """ @@ -37,6 +54,7 @@ class GroupServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -47,6 +65,7 @@ class GroupServlet(RestServlet): return 200, group_description + @_validate_group_id async def on_POST(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -71,6 +90,7 @@ class GroupSummaryServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -102,6 +122,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, category_id, room_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -117,6 +138,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): return 200, resp + @_validate_group_id async def on_DELETE(self, request, group_id, category_id, room_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -142,6 +164,7 @@ class GroupCategoryServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id, category_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -152,6 +175,7 @@ class GroupCategoryServlet(RestServlet): return 200, category + @_validate_group_id async def on_PUT(self, request, group_id, category_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -163,6 +187,7 @@ class GroupCategoryServlet(RestServlet): return 200, resp + @_validate_group_id async def on_DELETE(self, request, group_id, category_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -186,6 +211,7 @@ class GroupCategoriesServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -209,6 +235,7 @@ class GroupRoleServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id, role_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -219,6 +246,7 @@ class GroupRoleServlet(RestServlet): return 200, category + @_validate_group_id async def on_PUT(self, request, group_id, role_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -230,6 +258,7 @@ class GroupRoleServlet(RestServlet): return 200, resp + @_validate_group_id async def on_DELETE(self, request, group_id, role_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -253,6 +282,7 @@ class GroupRolesServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -284,6 +314,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, role_id, user_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -299,6 +330,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): return 200, resp + @_validate_group_id async def on_DELETE(self, request, group_id, role_id, user_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -322,13 +354,11 @@ class GroupRoomServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - if not GroupID.is_valid(group_id): - raise SynapseError(400, "%s was not legal group ID" % (group_id,)) - result = await self.groups_handler.get_rooms_in_group( group_id, requester_user_id ) @@ -348,6 +378,7 @@ class GroupUsersServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -371,6 +402,7 @@ class GroupInvitedUsersServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -393,6 +425,7 @@ class GroupSettingJoinPolicyServlet(RestServlet): self.auth = hs.get_auth() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -449,6 +482,7 @@ class GroupAdminRoomsServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, room_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -460,6 +494,7 @@ class GroupAdminRoomsServlet(RestServlet): return 200, result + @_validate_group_id async def on_DELETE(self, request, group_id, room_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -486,6 +521,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, room_id, config_key): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -514,6 +550,7 @@ class GroupAdminUsersInviteServlet(RestServlet): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id + @_validate_group_id async def on_PUT(self, request, group_id, user_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -541,6 +578,7 @@ class GroupAdminUsersKickServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id, user_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -590,6 +628,7 @@ class GroupSelfLeaveServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -614,6 +653,7 @@ class GroupSelfJoinServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -638,6 +678,7 @@ class GroupSelfAcceptInviteServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -662,6 +703,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() + @_validate_group_id async def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index b91996c738..a6134ead8a 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -271,11 +271,7 @@ class SigningKeyUploadServlet(RestServlet): body = parse_json_object_from_request(request) await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "add a device signing key to your account", + requester, request, body, "add a device signing key to your account", ) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 5374d2c1b6..f0675abd32 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -38,6 +38,7 @@ from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.config.registration import RegistrationConfig from synapse.config.server import is_threepid_reserved from synapse.handlers.auth import AuthHandler +from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http.server import finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, @@ -125,6 +126,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email ) @@ -204,6 +207,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests( + request, "msisdn", msisdn + ) + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "msisdn", msisdn ) @@ -353,7 +360,7 @@ class UsernameAvailabilityRestServlet(RestServlet): 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) - ip = self.hs.get_ip_from_request(request) + ip = request.getClientIP() with self.ratelimiter.ratelimit(ip) as wait_deferred: await wait_deferred @@ -451,7 +458,7 @@ class RegisterRestServlet(RestServlet): # == Normal User Registration == (everyone else) if not self._registration_enabled: - raise SynapseError(403, "Registration has been disabled") + raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN) # For regular registration, convert the provided username to lowercase # before attempting to register it. This should mean that people who try @@ -494,11 +501,11 @@ class RegisterRestServlet(RestServlet): # user here. We carry on and go through the auth checks though, # for paranoia. registered_user_id = await self.auth_handler.get_session_data( - session_id, "registered_user_id", None + session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None ) # Extract the previously-hashed password from the session. password_hash = await self.auth_handler.get_session_data( - session_id, "password_hash", None + session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None ) # Ensure that the username is valid. @@ -513,11 +520,7 @@ class RegisterRestServlet(RestServlet): # not this will raise a user-interactive auth error. try: auth_result, params, session_id = await self.auth_handler.check_ui_auth( - self._registration_flows, - request, - body, - self.hs.get_ip_from_request(request), - "register a new account", + self._registration_flows, request, body, "register a new account", ) except InteractiveAuthIncompleteError as e: # The user needs to provide more steps to complete auth. @@ -532,7 +535,9 @@ class RegisterRestServlet(RestServlet): if not password_hash and password: password_hash = await self.auth_handler.hash(password) await self.auth_handler.set_session_data( - e.session_id, "password_hash", password_hash + e.session_id, + UIAuthSessionDataConstants.PASSWORD_HASH, + password_hash, ) raise @@ -635,7 +640,9 @@ class RegisterRestServlet(RestServlet): # Remember that the user account has been registered (and the user # ID it was registered with, since it might not have been specified). await self.auth_handler.set_session_data( - session_id, "registered_user_id", registered_user_id + session_id, + UIAuthSessionDataConstants.REGISTERED_USER_ID, + registered_user_id, ) registered = True @@ -657,9 +664,13 @@ class RegisterRestServlet(RestServlet): user_id = await self.registration_handler.appservice_register( username, as_token ) - return await self._create_registration_details(user_id, body) + return await self._create_registration_details( + user_id, body, is_appservice_ghost=True, + ) - async def _create_registration_details(self, user_id, params): + async def _create_registration_details( + self, user_id, params, is_appservice_ghost=False + ): """Complete registration of newly-registered user Allocates device_id if one was not given; also creates access_token. @@ -676,7 +687,11 @@ class RegisterRestServlet(RestServlet): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") device_id, access_token = await self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest=False + user_id, + device_id, + initial_display_name, + is_guest=False, + is_appservice_ghost=is_appservice_ghost, ) result.update({"access_token": access_token, "device_id": device_id}) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index bc4f43639a..a3dee14ed4 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -17,7 +17,7 @@ import logging from typing import Tuple from synapse.http import servlet -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request from synapse.logging.opentracing import set_tag, trace from synapse.rest.client.transactions import HttpTransactionCache @@ -54,6 +54,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) + assert_params_in_dict(content, ("messages",)) sender_user_id = requester.user.to_string() diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index bf3a79db44..a97cd66c52 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -58,8 +58,7 @@ class TagServlet(RestServlet): def __init__(self, hs): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.notifier = hs.get_notifier() + self.handler = hs.get_account_data_handler() async def on_PUT(self, request, user_id, room_id, tag): requester = await self.auth.get_user_by_req(request) @@ -68,9 +67,7 @@ class TagServlet(RestServlet): body = parse_json_object_from_request(request) - max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body) - - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) + await self.handler.add_tag_to_room(user_id, room_id, tag, body) return 200, {} @@ -79,9 +76,7 @@ class TagServlet(RestServlet): if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") - max_id = await self.store.remove_tag_from_room(user_id, room_id, tag) - - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) + await self.handler.remove_tag_from_room(user_id, room_id, tag) return 200, {} diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index b3e4d5612e..8b9ef26cf2 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -100,6 +100,7 @@ class ConsentResource(DirectServeHtmlResource): consent_template_directory = hs.config.user_consent_template_dir + # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(consent_template_directory) self._jinja_env = jinja2.Environment( loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"]) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index f843f02454..c57ac22e58 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Dict, Set +from typing import Dict from signedjson.sign import sign_json @@ -142,12 +142,13 @@ class RemoteKey(DirectServeJsonResource): time_now_ms = self.clock.time_msec() - cache_misses = {} # type: Dict[str, Set[str]] + # Note that the value is unused. + cache_misses = {} # type: Dict[str, Dict[str, int]] for (server_name, key_id, from_server), results in cached.items(): results = [(result["ts_added_ms"], result) for result in results] if not results and key_id is not None: - cache_misses.setdefault(server_name, set()).add(key_id) + cache_misses.setdefault(server_name, {})[key_id] = 0 continue if key_id is not None: @@ -201,7 +202,7 @@ class RemoteKey(DirectServeJsonResource): ) if miss: - cache_misses.setdefault(server_name, set()).add(key_id) + cache_misses.setdefault(server_name, {})[key_id] = 0 # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(most_recent_result["key_json"])) else: diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 67aa993f19..f71a03a12d 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019 New Vector Ltd +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,10 +17,11 @@ import logging import os import urllib -from typing import Awaitable +from typing import Awaitable, Dict, Generator, List, Optional, Tuple from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender +from twisted.web.http import Request from synapse.api.errors import Codes, SynapseError, cs_error from synapse.http.server import finish_request, respond_with_json @@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [ ] -def parse_media_id(request): +def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: try: # This allows users to append e.g. /test.png to the URL. Useful for # clients that parse the URL to see content type. @@ -69,7 +70,7 @@ def parse_media_id(request): ) -def respond_404(request): +def respond_404(request: Request) -> None: respond_with_json( request, 404, @@ -79,8 +80,12 @@ def respond_404(request): async def respond_with_file( - request, media_type, file_path, file_size=None, upload_name=None -): + request: Request, + media_type: str, + file_path: str, + file_size: Optional[int] = None, + upload_name: Optional[str] = None, +) -> None: logger.debug("Responding with %r", file_path) if os.path.isfile(file_path): @@ -98,15 +103,20 @@ async def respond_with_file( respond_404(request) -def add_file_headers(request, media_type, file_size, upload_name): +def add_file_headers( + request: Request, + media_type: str, + file_size: Optional[int], + upload_name: Optional[str], +) -> None: """Adds the correct response headers in preparation for responding with the media. Args: - request (twisted.web.http.Request) - media_type (str): The media/content type. - file_size (int): Size in bytes of the media, if known. - upload_name (str): The name of the requested file, if any. + request + media_type: The media/content type. + file_size: Size in bytes of the media, if known. + upload_name: The name of the requested file, if any. """ def _quote(x): @@ -153,7 +163,13 @@ def add_file_headers(request, media_type, file_size, upload_name): # select private. don't bother setting Expires as all our # clients are smart enough to be happy with Cache-Control request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") - request.setHeader(b"Content-Length", b"%d" % (file_size,)) + if file_size is not None: + request.setHeader(b"Content-Length", b"%d" % (file_size,)) + + # Tell web crawlers to not index, archive, or follow links in media. This + # should help to prevent things in the media repo from showing up in web + # search results. + request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex") # separators as defined in RFC2616. SP and HT are handled separately. @@ -179,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = { } -def _can_encode_filename_as_token(x): +def _can_encode_filename_as_token(x: str) -> bool: for c in x: # from RFC2616: # @@ -201,17 +217,21 @@ def _can_encode_filename_as_token(x): async def respond_with_responder( - request, responder, media_type, file_size, upload_name=None -): + request: Request, + responder: "Optional[Responder]", + media_type: str, + file_size: Optional[int], + upload_name: Optional[str] = None, +) -> None: """Responds to the request with given responder. If responder is None then returns 404. Args: - request (twisted.web.http.Request) - responder (Responder|None) - media_type (str): The media/content type. - file_size (int|None): Size in bytes of the media. If not known it should be None - upload_name (str|None): The name of the requested file, if any. + request + responder + media_type: The media/content type. + file_size: Size in bytes of the media. If not known it should be None + upload_name: The name of the requested file, if any. """ if request._disconnected: logger.warning( @@ -280,6 +300,7 @@ class FileInfo: thumbnail_height (int) thumbnail_method (str) thumbnail_type (str): Content type of thumbnail, e.g. image/png + thumbnail_length (int): The size of the media file, in bytes. """ def __init__( @@ -292,6 +313,7 @@ class FileInfo: thumbnail_height=None, thumbnail_method=None, thumbnail_type=None, + thumbnail_length=None, ): self.server_name = server_name self.file_id = file_id @@ -301,24 +323,25 @@ class FileInfo: self.thumbnail_height = thumbnail_height self.thumbnail_method = thumbnail_method self.thumbnail_type = thumbnail_type + self.thumbnail_length = thumbnail_length -def get_filename_from_headers(headers): +def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: """ Get the filename of the downloaded file by inspecting the Content-Disposition HTTP header. Args: - headers (dict[bytes, list[bytes]]): The HTTP request headers. + headers: The HTTP request headers. Returns: - A Unicode string of the filename, or None. + The filename, or None. """ content_disposition = headers.get(b"Content-Disposition", [b""]) # No header, bail out. if not content_disposition[0]: - return + return None _, params = _parse_header(content_disposition[0]) @@ -351,17 +374,16 @@ def get_filename_from_headers(headers): return upload_name -def _parse_header(line): +def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]: """Parse a Content-type like header. Cargo-culted from `cgi`, but works on bytes rather than strings. Args: - line (bytes): header to be parsed + line: header to be parsed Returns: - Tuple[bytes, dict[bytes, bytes]]: - the main content-type, followed by the parameter dictionary + The main content-type, followed by the parameter dictionary """ parts = _parseparam(b";" + line) key = next(parts) @@ -381,16 +403,16 @@ def _parse_header(line): return key, pdict -def _parseparam(s): +def _parseparam(s: bytes) -> Generator[bytes, None, None]: """Generator which splits the input on ;, respecting double-quoted sequences Cargo-culted from `cgi`, but works on bytes rather than strings. Args: - s (bytes): header to be parsed + s: header to be parsed Returns: - Iterable[bytes]: the split input + The split input """ while s[:1] == b";": s = s[1:] diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index 68dd2a1c8a..4e4c6971f7 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018 Will Hunt <will@half-shot.uk> +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,22 +15,29 @@ # limitations under the License. # +from typing import TYPE_CHECKING + +from twisted.web.http import Request + from synapse.http.server import DirectServeJsonResource, respond_with_json +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + class MediaConfigResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() config = hs.get_config() self.clock = hs.get_clock() self.auth = hs.get_auth() self.limits_dict = {"m.upload.size": config.max_upload_size} - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: await self.auth.get_user_by_req(request) respond_with_json(request, 200, self.limits_dict, send_cors=True) - async def _async_render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request: Request) -> None: respond_with_json(request, 200, {}, send_cors=True) diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index d3d8457303..3ed219ae43 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,24 +14,31 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request -import synapse.http.servlet from synapse.http.server import DirectServeJsonResource, set_cors_headers +from synapse.http.servlet import parse_boolean from ._base import parse_media_id, respond_404 +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) class DownloadResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo): + def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): super().__init__() self.media_repo = media_repo self.server_name = hs.hostname - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: set_cors_headers(request) request.setHeader( b"Content-Security-Policy", @@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource): if server_name == self.server_name: await self.media_repo.get_local_media(request, media_id, name) else: - allow_remote = synapse.http.servlet.parse_boolean( - request, "allow_remote", default=True - ) + allow_remote = parse_boolean(request, "allow_remote", default=True) if not allow_remote: logger.info( "Rejecting request for remote media %s/%s due to allow_remote", diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index 9e079f672f..7792f26e78 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,11 +17,12 @@ import functools import os import re +from typing import Callable, List NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") -def _wrap_in_base_path(func): +def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]": """Takes a function that returns a relative path and turns it into an absolute path based on the location of the primary media store """ @@ -41,12 +43,18 @@ class MediaFilePaths: to write to the backup media store (when one is configured) """ - def __init__(self, primary_base_path): + def __init__(self, primary_base_path: str): self.base_path = primary_base_path def default_thumbnail_rel( - self, default_top_level, default_sub_type, width, height, content_type, method - ): + self, + default_top_level: str, + default_sub_type: str, + width: int, + height: int, + content_type: str, + method: str, + ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( @@ -55,12 +63,14 @@ class MediaFilePaths: default_thumbnail = _wrap_in_base_path(default_thumbnail_rel) - def local_media_filepath_rel(self, media_id): + def local_media_filepath_rel(self, media_id: str) -> str: return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:]) local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) - def local_media_thumbnail_rel(self, media_id, width, height, content_type, method): + def local_media_thumbnail_rel( + self, media_id: str, width: int, height: int, content_type: str, method: str + ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( @@ -86,7 +96,7 @@ class MediaFilePaths: media_id[4:], ) - def remote_media_filepath_rel(self, server_name, file_id): + def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: return os.path.join( "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] ) @@ -94,8 +104,14 @@ class MediaFilePaths: remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) def remote_media_thumbnail_rel( - self, server_name, file_id, width, height, content_type, method - ): + self, + server_name: str, + file_id: str, + width: int, + height: int, + content_type: str, + method: str, + ) -> str: top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( @@ -113,7 +129,7 @@ class MediaFilePaths: # Should be removed after some time, when most of the thumbnails are stored # using the new path. def remote_media_thumbnail_rel_legacy( - self, server_name, file_id, width, height, content_type + self, server_name: str, file_id: str, width: int, height: int, content_type: str ): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) @@ -126,7 +142,7 @@ class MediaFilePaths: file_name, ) - def remote_media_thumbnail_dir(self, server_name, file_id): + def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: return os.path.join( self.base_path, "remote_thumbnail", @@ -136,7 +152,7 @@ class MediaFilePaths: file_id[4:], ) - def url_cache_filepath_rel(self, media_id): + def url_cache_filepath_rel(self, media_id: str) -> str: if NEW_FORMAT_ID_RE.match(media_id): # Media id is of the form <DATE><RANDOM_STRING> # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -146,7 +162,7 @@ class MediaFilePaths: url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) - def url_cache_filepath_dirs_to_delete(self, media_id): + def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id file" if NEW_FORMAT_ID_RE.match(media_id): return [os.path.join(self.base_path, "url_cache", media_id[:10])] @@ -156,7 +172,9 @@ class MediaFilePaths: os.path.join(self.base_path, "url_cache", media_id[0:2]), ] - def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method): + def url_cache_thumbnail_rel( + self, media_id: str, width: int, height: int, content_type: str, method: str + ) -> str: # Media id is of the form <DATE><RANDOM_STRING> # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -178,7 +196,7 @@ class MediaFilePaths: url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) - def url_cache_thumbnail_directory(self, media_id): + def url_cache_thumbnail_directory(self, media_id: str) -> str: # Media id is of the form <DATE><RANDOM_STRING> # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -195,7 +213,7 @@ class MediaFilePaths: media_id[4:], ) - def url_cache_thumbnail_dirs_to_delete(self, media_id): + def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id thumbnails" # Media id is of the form <DATE><RANDOM_STRING> # E.g.: 2017-09-28-fsdRDt24DS234dsf diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 9cac74ebd8..4c9946a616 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import errno import logging import os import shutil -from typing import IO, Dict, List, Optional, Tuple +from io import BytesIO +from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple import twisted.internet.error import twisted.web.http @@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource from .thumbnailer import Thumbnailer, ThumbnailError from .upload_resource import UploadResource +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -63,26 +66,26 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 class MediaRepository: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() - self.client = hs.get_http_client() + self.client = hs.get_federation_http_client() self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels - self.primary_base_path = hs.config.media_store_path - self.filepaths = MediaFilePaths(self.primary_base_path) + self.primary_base_path = hs.config.media_store_path # type: str + self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") - self.recently_accessed_remotes = set() - self.recently_accessed_locals = set() + self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]] + self.recently_accessed_locals = set() # type: Set[str] self.federation_domain_whitelist = hs.config.federation_domain_whitelist @@ -113,7 +116,7 @@ class MediaRepository: "update_recently_accessed_media", self._update_recently_accessed ) - async def _update_recently_accessed(self): + async def _update_recently_accessed(self) -> None: remote_media = self.recently_accessed_remotes self.recently_accessed_remotes = set() @@ -124,12 +127,12 @@ class MediaRepository: local_media, remote_media, self.clock.time_msec() ) - def mark_recently_accessed(self, server_name, media_id): + def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None: """Mark the given media as recently accessed. Args: - server_name (str|None): Origin server of media, or None if local - media_id (str): The media ID of the content + server_name: Origin server of media, or None if local + media_id: The media ID of the content """ if server_name: self.recently_accessed_remotes.add((server_name, media_id)) @@ -459,7 +462,14 @@ class MediaRepository: def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) - def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): + def _generate_thumbnail( + self, + thumbnailer: Thumbnailer, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + ) -> Optional[BytesIO]: m_width = thumbnailer.width m_height = thumbnailer.height @@ -470,22 +480,20 @@ class MediaRepository: m_height, self.max_image_pixels, ) - return + return None if thumbnailer.transpose_method is not None: m_width, m_height = thumbnailer.transpose() if t_method == "crop": - t_byte_source = thumbnailer.crop(t_width, t_height, t_type) + return thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": t_width, t_height = thumbnailer.aspect(t_width, t_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) - t_byte_source = thumbnailer.scale(t_width, t_height, t_type) - else: - t_byte_source = None + return thumbnailer.scale(t_width, t_height, t_type) - return t_byte_source + return None async def generate_local_exact_thumbnail( self, @@ -776,7 +784,7 @@ class MediaRepository: return {"width": m_width, "height": m_height} - async def delete_old_remote_media(self, before_ts): + async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]: old_media = await self.store.get_remote_media_before(before_ts) deleted = 0 @@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource): within a given rectangle. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): # If we're not configured to use it, raise if we somehow got here. if not hs.config.can_load_media_repo: raise ConfigError("Synapse is not configured to use a media repo.") diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 268e0c8f50..89cdd605aa 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vecotr Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ import os import shutil from typing import IO, TYPE_CHECKING, Any, Optional, Sequence +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender from synapse.logging.context import defer_to_thread, make_deferred_yieldable @@ -270,7 +272,7 @@ class MediaStorage: return self.filepaths.local_media_filepath_rel(file_info.file_id) -def _write_file_synchronously(source, dest): +def _write_file_synchronously(source: IO, dest: IO) -> None: """Write `source` to the file like `dest` synchronously. Should be called from a thread. @@ -286,14 +288,14 @@ class FileResponder(Responder): """Wraps an open file that can be sent to a request. Args: - open_file (file): A file like object to be streamed ot the client, + open_file: A file like object to be streamed ot the client, is closed when finished streaming. """ - def __init__(self, open_file): + def __init__(self, open_file: IO): self.open_file = open_file - def write_to_consumer(self, consumer): + def write_to_consumer(self, consumer: IConsumer) -> Deferred: return make_deferred_yieldable( FileSender().beginFileTransfer(self.open_file, consumer) ) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index dce6c4d168..bf3be653aa 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import datetime import errno import fnmatch @@ -23,12 +23,13 @@ import re import shutil import sys import traceback -from typing import Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union from urllib import parse as urlparse import attr from twisted.internet.error import DNSLookupError +from twisted.web.http import Request from synapse.api.errors import Codes, SynapseError from synapse.http.client import SimpleHttpClient @@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers +from synapse.rest.media.v1.media_storage import MediaStorage from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache @@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string from ._base import FileInfo +if TYPE_CHECKING: + from lxml import etree + + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I) @@ -119,7 +127,12 @@ class OEmbedError(Exception): class PreviewUrlResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo, media_storage): + def __init__( + self, + hs: "HomeServer", + media_repo: "MediaRepository", + media_storage: MediaStorage, + ): super().__init__() self.auth = hs.get_auth() @@ -166,11 +179,11 @@ class PreviewUrlResource(DirectServeJsonResource): self._start_expire_url_cache_data, 10 * 1000 ) - async def _async_render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request: Request) -> None: request.setHeader(b"Allow", b"OPTIONS, GET") respond_with_json(request, 200, {}, send_cors=True) - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: # XXX: if get_user_by_req fails, what should we do in an async render? requester = await self.auth.get_user_by_req(request) @@ -373,7 +386,7 @@ class PreviewUrlResource(DirectServeJsonResource): """ Check whether the URL should be downloaded as oEmbed content instead. - Params: + Args: url: The URL to check. Returns: @@ -390,7 +403,7 @@ class PreviewUrlResource(DirectServeJsonResource): """ Request content from an oEmbed endpoint. - Params: + Args: endpoint: The oEmbed API endpoint. url: The URL to pass to the API. @@ -450,7 +463,7 @@ class PreviewUrlResource(DirectServeJsonResource): logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) raise OEmbedError() from e - async def _download_url(self, url: str, user): + async def _download_url(self, url: str, user: str) -> Dict[str, Any]: # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -580,7 +593,7 @@ class PreviewUrlResource(DirectServeJsonResource): "expire_url_cache_data", self._expire_url_cache_data ) - async def _expire_url_cache_data(self): + async def _expire_url_cache_data(self) -> None: """Clean up expired url cache content, media and thumbnails. """ # TODO: Delete from backup media store @@ -676,24 +689,54 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("No media removed from url cache") -def decode_and_calc_og(body, media_uri, request_encoding=None): +def decode_and_calc_og( + body: bytes, media_uri: str, request_encoding: Optional[str] = None +) -> Dict[str, Optional[str]]: + """ + Calculate metadata for an HTML document. + + This uses lxml to parse the HTML document into the OG response. If errors + occur during processing of the document, an empty response is returned. + + Args: + body: The HTML document, as bytes. + media_url: The URI used to download the body. + request_encoding: The character encoding of the body, as a string. + + Returns: + The OG response as a dictionary. + """ + # If there's no body, nothing useful is going to be found. + if not body: + return {} + from lxml import etree + # Create an HTML parser. If this fails, log and return no metadata. try: parser = etree.HTMLParser(recover=True, encoding=request_encoding) - tree = etree.fromstring(body, parser) - og = _calc_og(tree, media_uri) + except LookupError: + # blindly consider the encoding as utf-8. + parser = etree.HTMLParser(recover=True, encoding="utf-8") + except Exception as e: + logger.warning("Unable to create HTML parser: %s" % (e,)) + return {} + + def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]: + # Attempt to parse the body. If this fails, log and return no metadata. + tree = etree.fromstring(body_attempt, parser) + return _calc_og(tree, media_uri) + + # Attempt to parse the body. If this fails, log and return no metadata. + try: + return _attempt_calc_og(body) except UnicodeDecodeError: # blindly try decoding the body as utf-8, which seems to fix # the charset mismatches on https://google.com - parser = etree.HTMLParser(recover=True, encoding=request_encoding) - tree = etree.fromstring(body.decode("utf-8", "ignore"), parser) - og = _calc_og(tree, media_uri) - - return og + return _attempt_calc_og(body.decode("utf-8", "ignore")) -def _calc_og(tree, media_uri): +def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]: # suck our tree into lxml and define our OG response. # if we see any image URLs in the OG response, then spider them @@ -797,7 +840,9 @@ def _calc_og(tree, media_uri): for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE) ) og["og:description"] = summarize_paragraphs(text_nodes) - else: + elif og["og:description"]: + # This must be a non-empty string at this point. + assert isinstance(og["og:description"], str) og["og:description"] = summarize_paragraphs([og["og:description"]]) # TODO: delete the url downloads to stop diskfilling, @@ -805,7 +850,9 @@ def _calc_og(tree, media_uri): return og -def _iterate_over_text(tree, *tags_to_ignore): +def _iterate_over_text( + tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]] +) -> Generator[str, None, None]: """Iterate over the tree returning text nodes in a depth first fashion, skipping text nodes inside certain tags. """ @@ -839,32 +886,32 @@ def _iterate_over_text(tree, *tags_to_ignore): ) -def _rebase_url(url, base): - base = list(urlparse.urlparse(base)) - url = list(urlparse.urlparse(url)) - if not url[0]: # fix up schema - url[0] = base[0] or "http" - if not url[1]: # fix up hostname - url[1] = base[1] - if not url[2].startswith("/"): - url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2] - return urlparse.urlunparse(url) +def _rebase_url(url: str, base: str) -> str: + base_parts = list(urlparse.urlparse(base)) + url_parts = list(urlparse.urlparse(url)) + if not url_parts[0]: # fix up schema + url_parts[0] = base_parts[0] or "http" + if not url_parts[1]: # fix up hostname + url_parts[1] = base_parts[1] + if not url_parts[2].startswith("/"): + url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2] + return urlparse.urlunparse(url_parts) -def _is_media(content_type): - if content_type.lower().startswith("image/"): - return True +def _is_media(content_type: str) -> bool: + return content_type.lower().startswith("image/") -def _is_html(content_type): +def _is_html(content_type: str) -> bool: content_type = content_type.lower() - if content_type.startswith("text/html") or content_type.startswith( + return content_type.startswith("text/html") or content_type.startswith( "application/xhtml" - ): - return True + ) -def summarize_paragraphs(text_nodes, min_size=200, max_size=500): +def summarize_paragraphs( + text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500 +) -> Optional[str]: # Try to get a summary of between 200 and 500 words, respecting # first paragraph and then word boundaries. # TODO: Respect sentences? diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 18c9ed48d6..e92006faa9 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,27 +13,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect +import abc import logging import os import shutil -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background +from synapse.util.async_helpers import maybe_awaitable from ._base import FileInfo, Responder from .media_storage import FileResponder logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer -class StorageProvider: + +class StorageProvider(metaclass=abc.ABCMeta): """A storage provider is a service that can store uploaded media and retrieve them. """ - async def store_file(self, path: str, file_info: FileInfo): + @abc.abstractmethod + async def store_file(self, path: str, file_info: FileInfo) -> None: """Store the file described by file_info. The actual contents can be retrieved by reading the file in file_info.upload_path. @@ -42,6 +47,7 @@ class StorageProvider: file_info: The metadata of the file. """ + @abc.abstractmethod async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """Attempt to fetch the file described by file_info and stream it into writer. @@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider): self.store_synchronous = store_synchronous self.store_remote = store_remote - def __str__(self): + def __str__(self) -> str: return "StorageProviderWrapper[%s]" % (self.backend,) - async def store_file(self, path, file_info): + async def store_file(self, path: str, file_info: FileInfo) -> None: if not file_info.server_name and not self.store_local: return None @@ -91,39 +97,34 @@ class StorageProviderWrapper(StorageProvider): if self.store_synchronous: # store_file is supposed to return an Awaitable, but guard # against improper implementations. - result = self.backend.store_file(path, file_info) - if inspect.isawaitable(result): - return await result + await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore else: # TODO: Handle errors. async def store(): try: - result = self.backend.store_file(path, file_info) - if inspect.isawaitable(result): - return await result + return await maybe_awaitable( + self.backend.store_file(path, file_info) + ) except Exception: logger.exception("Error storing file") run_in_background(store) - return None - async def fetch(self, path, file_info): + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: # store_file is supposed to return an Awaitable, but guard # against improper implementations. - result = self.backend.fetch(path, file_info) - if inspect.isawaitable(result): - return await result + return await maybe_awaitable(self.backend.fetch(path, file_info)) class FileStorageProviderBackend(StorageProvider): """A storage provider that stores files in a directory on a filesystem. Args: - hs (HomeServer) + hs config: The config returned by `parse_config`. """ - def __init__(self, hs, config): + def __init__(self, hs: "HomeServer", config: str): self.hs = hs self.cache_directory = hs.config.media_store_path self.base_directory = config @@ -131,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider): def __str__(self): return "FileStorageProviderBackend[%s]" % (self.base_directory,) - async def store_file(self, path, file_info): + async def store_file(self, path: str, file_info: FileInfo) -> None: """See StorageProvider.store_file""" primary_fname = os.path.join(self.cache_directory, path) @@ -141,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider): if not os.path.exists(dirname): os.makedirs(dirname) - return await defer_to_thread( + await defer_to_thread( self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname ) - async def fetch(self, path, file_info): + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """See StorageProvider.fetch""" backup_fname = os.path.join(self.base_directory, path) if os.path.isfile(backup_fname): return FileResponder(open(backup_fname, "rb")) + return None + @staticmethod - def parse_config(config): + def parse_config(config: dict) -> str: """Called on startup to parse config supplied. This should parse the config and raise if there is a problem. diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 30421b663a..d653a58be9 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,10 +16,14 @@ import logging +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from twisted.web.http import Request from synapse.api.errors import SynapseError from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.servlet import parse_integer, parse_string +from synapse.rest.media.v1.media_storage import MediaStorage from ._base import ( FileInfo, @@ -28,13 +33,22 @@ from ._base import ( respond_with_responder, ) +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) class ThumbnailResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo, media_storage): + def __init__( + self, + hs: "HomeServer", + media_repo: "MediaRepository", + media_storage: MediaStorage, + ): super().__init__() self.store = hs.get_datastore() @@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource): self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.server_name = hs.hostname - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: Request) -> None: set_cors_headers(request) server_name, media_id, _ = parse_media_id(request) width = parse_integer(request, "width", required=True) @@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource): self.media_repo.mark_recently_accessed(server_name, media_id) async def _respond_local_thumbnail( - self, request, media_id, width, height, method, m_type - ): + self, + request: Request, + media_id: str, + width: int, + height: int, + method: str, + m_type: str, + ) -> None: media_info = await self.store.get_local_media(media_id) if not media_info: @@ -86,41 +106,27 @@ class ThumbnailResource(DirectServeJsonResource): return thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) - - if thumbnail_infos: - thumbnail_info = self._select_thumbnail( - width, height, method, m_type, thumbnail_infos - ) - - file_info = FileInfo( - server_name=None, - file_id=media_id, - url_cache=media_info["url_cache"], - thumbnail=True, - thumbnail_width=thumbnail_info["thumbnail_width"], - thumbnail_height=thumbnail_info["thumbnail_height"], - thumbnail_type=thumbnail_info["thumbnail_type"], - thumbnail_method=thumbnail_info["thumbnail_method"], - ) - - t_type = file_info.thumbnail_type - t_length = thumbnail_info["thumbnail_length"] - - responder = await self.media_storage.fetch_media(file_info) - await respond_with_responder(request, responder, t_type, t_length) - else: - logger.info("Couldn't find any generated thumbnails") - respond_404(request) + await self._select_and_respond_with_thumbnail( + request, + width, + height, + method, + m_type, + thumbnail_infos, + media_id, + url_cache=media_info["url_cache"], + server_name=None, + ) async def _select_or_generate_local_thumbnail( self, - request, - media_id, - desired_width, - desired_height, - desired_method, - desired_type, - ): + request: Request, + media_id: str, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + ) -> None: media_info = await self.store.get_local_media(media_id) if not media_info: @@ -178,14 +184,14 @@ class ThumbnailResource(DirectServeJsonResource): async def _select_or_generate_remote_thumbnail( self, - request, - server_name, - media_id, - desired_width, - desired_height, - desired_method, - desired_type, - ): + request: Request, + server_name: str, + media_id: str, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + ) -> None: media_info = await self.media_repo.get_remote_media_info(server_name, media_id) thumbnail_infos = await self.store.get_remote_media_thumbnails( @@ -239,8 +245,15 @@ class ThumbnailResource(DirectServeJsonResource): raise SynapseError(400, "Failed to generate thumbnail.") async def _respond_remote_thumbnail( - self, request, server_name, media_id, width, height, method, m_type - ): + self, + request: Request, + server_name: str, + media_id: str, + width: int, + height: int, + method: str, + m_type: str, + ) -> None: # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead of # downloading the remote file and generating our own thumbnails. @@ -249,97 +262,185 @@ class ThumbnailResource(DirectServeJsonResource): thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) + await self._select_and_respond_with_thumbnail( + request, + width, + height, + method, + m_type, + thumbnail_infos, + media_info["filesystem_id"], + url_cache=None, + server_name=server_name, + ) + async def _select_and_respond_with_thumbnail( + self, + request: Request, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + thumbnail_infos: List[Dict[str, Any]], + file_id: str, + url_cache: Optional[str] = None, + server_name: Optional[str] = None, + ) -> None: + """ + Respond to a request with an appropriate thumbnail from the previously generated thumbnails. + + Args: + request: The incoming request. + desired_width: The desired width, the returned thumbnail may be larger than this. + desired_height: The desired height, the returned thumbnail may be larger than this. + desired_method: The desired method used to generate the thumbnail. + desired_type: The desired content-type of the thumbnail. + thumbnail_infos: A list of dictionaries of candidate thumbnails. + file_id: The ID of the media that a thumbnail is being requested for. + url_cache: The URL cache value. + server_name: The server name, if this is a remote thumbnail. + """ if thumbnail_infos: - thumbnail_info = self._select_thumbnail( - width, height, method, m_type, thumbnail_infos + file_info = self._select_thumbnail( + desired_width, + desired_height, + desired_method, + desired_type, + thumbnail_infos, + file_id, + url_cache, + server_name, ) - file_info = FileInfo( - server_name=server_name, - file_id=media_info["filesystem_id"], - thumbnail=True, - thumbnail_width=thumbnail_info["thumbnail_width"], - thumbnail_height=thumbnail_info["thumbnail_height"], - thumbnail_type=thumbnail_info["thumbnail_type"], - thumbnail_method=thumbnail_info["thumbnail_method"], - ) - - t_type = file_info.thumbnail_type - t_length = thumbnail_info["thumbnail_length"] + if not file_info: + logger.info("Couldn't find a thumbnail matching the desired inputs") + respond_404(request) + return responder = await self.media_storage.fetch_media(file_info) - await respond_with_responder(request, responder, t_type, t_length) + await respond_with_responder( + request, responder, file_info.thumbnail_type, file_info.thumbnail_length + ) else: logger.info("Failed to find any generated thumbnails") respond_404(request) def _select_thumbnail( self, - desired_width, - desired_height, - desired_method, - desired_type, - thumbnail_infos, - ): + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + thumbnail_infos: List[Dict[str, Any]], + file_id: str, + url_cache: Optional[str], + server_name: Optional[str], + ) -> Optional[FileInfo]: + """ + Choose an appropriate thumbnail from the previously generated thumbnails. + + Args: + desired_width: The desired width, the returned thumbnail may be larger than this. + desired_height: The desired height, the returned thumbnail may be larger than this. + desired_method: The desired method used to generate the thumbnail. + desired_type: The desired content-type of the thumbnail. + thumbnail_infos: A list of dictionaries of candidate thumbnails. + file_id: The ID of the media that a thumbnail is being requested for. + url_cache: The URL cache value. + server_name: The server name, if this is a remote thumbnail. + + Returns: + The thumbnail which best matches the desired parameters. + """ + desired_method = desired_method.lower() + + # The chosen thumbnail. + thumbnail_info = None + d_w = desired_width d_h = desired_height - if desired_method.lower() == "crop": + if desired_method == "crop": + # Thumbnails that match equal or larger sizes of desired width/height. crop_info_list = [] + # Other thumbnails. crop_info_list2 = [] for info in thumbnail_infos: + # Skip thumbnails generated with different methods. + if info["thumbnail_method"] != "crop": + continue + t_w = info["thumbnail_width"] t_h = info["thumbnail_height"] - t_method = info["thumbnail_method"] - if t_method == "crop": - aspect_quality = abs(d_w * t_h - d_h * t_w) - min_quality = 0 if d_w <= t_w and d_h <= t_h else 1 - size_quality = abs((d_w - t_w) * (d_h - t_h)) - type_quality = desired_type != info["thumbnail_type"] - length_quality = info["thumbnail_length"] - if t_w >= d_w or t_h >= d_h: - crop_info_list.append( - ( - aspect_quality, - min_quality, - size_quality, - type_quality, - length_quality, - info, - ) + aspect_quality = abs(d_w * t_h - d_h * t_w) + min_quality = 0 if d_w <= t_w and d_h <= t_h else 1 + size_quality = abs((d_w - t_w) * (d_h - t_h)) + type_quality = desired_type != info["thumbnail_type"] + length_quality = info["thumbnail_length"] + if t_w >= d_w or t_h >= d_h: + crop_info_list.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, ) - else: - crop_info_list2.append( - ( - aspect_quality, - min_quality, - size_quality, - type_quality, - length_quality, - info, - ) + ) + else: + crop_info_list2.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, ) + ) if crop_info_list: - return min(crop_info_list)[-1] - else: - return min(crop_info_list2)[-1] - else: + thumbnail_info = min(crop_info_list)[-1] + elif crop_info_list2: + thumbnail_info = min(crop_info_list2)[-1] + elif desired_method == "scale": + # Thumbnails that match equal or larger sizes of desired width/height. info_list = [] + # Other thumbnails. info_list2 = [] + for info in thumbnail_infos: + # Skip thumbnails generated with different methods. + if info["thumbnail_method"] != "scale": + continue + t_w = info["thumbnail_width"] t_h = info["thumbnail_height"] - t_method = info["thumbnail_method"] size_quality = abs((d_w - t_w) * (d_h - t_h)) type_quality = desired_type != info["thumbnail_type"] length_quality = info["thumbnail_length"] - if t_method == "scale" and (t_w >= d_w or t_h >= d_h): + if t_w >= d_w or t_h >= d_h: info_list.append((size_quality, type_quality, length_quality, info)) - elif t_method == "scale": + else: info_list2.append( (size_quality, type_quality, length_quality, info) ) if info_list: - return min(info_list)[-1] - else: - return min(info_list2)[-1] + thumbnail_info = min(info_list)[-1] + elif info_list2: + thumbnail_info = min(info_list2)[-1] + + if thumbnail_info: + return FileInfo( + file_id=file_id, + url_cache=url_cache, + server_name=server_name, + thumbnail=True, + thumbnail_width=thumbnail_info["thumbnail_width"], + thumbnail_height=thumbnail_info["thumbnail_height"], + thumbnail_type=thumbnail_info["thumbnail_type"], + thumbnail_method=thumbnail_info["thumbnail_method"], + thumbnail_length=thumbnail_info["thumbnail_length"], + ) + + # No matching thumbnail was found. + return None diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 32a8e4f960..07903e4017 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +15,7 @@ # limitations under the License. import logging from io import BytesIO +from typing import Tuple from PIL import Image @@ -39,7 +41,7 @@ class Thumbnailer: FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} - def __init__(self, input_path): + def __init__(self, input_path: str): try: self.image = Image.open(input_path) except OSError as e: @@ -59,11 +61,11 @@ class Thumbnailer: # A lot of parsing errors can happen when parsing EXIF logger.info("Error parsing image EXIF information: %s", e) - def transpose(self): + def transpose(self) -> Tuple[int, int]: """Transpose the image using its EXIF Orientation tag Returns: - Tuple[int, int]: (width, height) containing the new image size in pixels. + A tuple containing the new image size in pixels as (width, height). """ if self.transpose_method is not None: self.image = self.image.transpose(self.transpose_method) @@ -73,7 +75,7 @@ class Thumbnailer: self.image.info["exif"] = None return self.image.size - def aspect(self, max_width, max_height): + def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]: """Calculate the largest size that preserves aspect ratio which fits within the given rectangle:: @@ -91,7 +93,7 @@ class Thumbnailer: else: return (max_height * self.width) // self.height, max_height - def _resize(self, width, height): + def _resize(self, width: int, height: int) -> Image: # 1-bit or 8-bit color palette images need converting to RGB # otherwise they will be scaled using nearest neighbour which # looks awful @@ -99,7 +101,7 @@ class Thumbnailer: self.image = self.image.convert("RGB") return self.image.resize((width, height), Image.ANTIALIAS) - def scale(self, width, height, output_type): + def scale(self, width: int, height: int, output_type: str) -> BytesIO: """Rescales the image to the given dimensions. Returns: @@ -108,7 +110,7 @@ class Thumbnailer: scaled = self._resize(width, height) return self._encode_image(scaled, output_type) - def crop(self, width, height, output_type): + def crop(self, width: int, height: int, output_type: str) -> BytesIO: """Rescales and crops the image to the given dimensions preserving aspect:: (w_in / h_in) = (w_scaled / h_scaled) @@ -136,7 +138,7 @@ class Thumbnailer: cropped = scaled_image.crop((crop_left, 0, crop_right, height)) return self._encode_image(cropped, output_type) - def _encode_image(self, output_image, output_type): + def _encode_image(self, output_image: Image, output_type: str) -> BytesIO: output_bytes_io = BytesIO() fmt = self.FORMATS[output_type] if fmt == "JPEG": diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index d76f7389e1..6da76ae994 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,18 +15,25 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_string +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.rest.media.v1.media_repository import MediaRepository + logger = logging.getLogger(__name__) class UploadResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs, media_repo): + def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): super().__init__() self.media_repo = media_repo @@ -37,14 +45,14 @@ class UploadResource(DirectServeJsonResource): self.max_upload_size = hs.config.max_upload_size self.clock = hs.get_clock() - async def _async_render_OPTIONS(self, request): + async def _async_render_OPTIONS(self, request: Request) -> None: respond_with_json(request, 200, {}, send_cors=True) - async def _async_render_POST(self, request): + async def _async_render_POST(self, request: Request) -> None: requester = await self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point - content_length = request.getHeader(b"Content-Length").decode("ascii") + content_length = request.getHeader("Content-Length") if content_length is None: raise SynapseError(msg="Request must specify a Content-Length", code=400) if int(content_length) > self.max_upload_size: diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index c0b733488b..e5ef515090 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2020 The Matrix.org Foundation C.I.C. +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,3 +12,55 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from typing import TYPE_CHECKING, Mapping + +from twisted.web.resource import Resource + +from synapse.rest.synapse.client.new_user_consent import NewUserConsentResource +from synapse.rest.synapse.client.pick_idp import PickIdpResource +from synapse.rest.synapse.client.pick_username import pick_username_resource +from synapse.rest.synapse.client.sso_register import SsoRegisterResource + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resource]: + """Builds a resource tree to include synapse-specific client resources + + These are resources which should be loaded on all workers which expose a C-S API: + ie, the main process, and any generic workers so configured. + + Returns: + map from path to Resource. + """ + resources = { + # SSO bits. These are always loaded, whether or not SSO login is actually + # enabled (they just won't work very well if it's not) + "/_synapse/client/pick_idp": PickIdpResource(hs), + "/_synapse/client/pick_username": pick_username_resource(hs), + "/_synapse/client/new_user_consent": NewUserConsentResource(hs), + "/_synapse/client/sso_register": SsoRegisterResource(hs), + } + + # provider-specific SSO bits. Only load these if they are enabled, since they + # rely on optional dependencies. + if hs.config.oidc_enabled: + from synapse.rest.synapse.client.oidc import OIDCResource + + resources["/_synapse/client/oidc"] = OIDCResource(hs) + + if hs.config.saml2_enabled: + from synapse.rest.synapse.client.saml2 import SAML2Resource + + res = SAML2Resource(hs) + resources["/_synapse/client/saml2"] = res + + # This is also mounted under '/_matrix' for backwards-compatibility. + resources["/_matrix/saml2"] = res + + return resources + + +__all__ = ["build_synapse_client_resource_tree"] diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py new file mode 100644 index 0000000000..b2e0f93810 --- /dev/null +++ b/synapse/rest/synapse/client/new_user_consent.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request + +from synapse.api.errors import SynapseError +from synapse.handlers.sso import get_username_mapping_session_cookie_from_request +from synapse.http.server import DirectServeHtmlResource, respond_with_html +from synapse.http.servlet import parse_string +from synapse.types import UserID +from synapse.util.templates import build_jinja_env + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class NewUserConsentResource(DirectServeHtmlResource): + """A resource which collects consent to the server's terms from a new user + + This resource gets mounted at /_synapse/client/new_user_consent, and is shown + when we are automatically creating a new user due to an SSO login. + + It shows a template which prompts the user to go and read the Ts and Cs, and click + a clickybox if they have done so. + """ + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + self._server_name = hs.hostname + self._consent_version = hs.config.consent.user_consent_version + + def template_search_dirs(): + if hs.config.sso.sso_template_dir: + yield hs.config.sso.sso_template_dir + yield hs.config.sso.default_template_dir + + self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) + + async def _async_render_GET(self, request: Request) -> None: + try: + session_id = get_username_mapping_session_cookie_from_request(request) + session = self._sso_handler.get_mapping_session(session_id) + except SynapseError as e: + logger.warning("Error fetching session: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + + user_id = UserID(session.chosen_localpart, self._server_name) + user_profile = { + "display_name": session.display_name, + } + + template_params = { + "user_id": user_id.to_string(), + "user_profile": user_profile, + "consent_version": self._consent_version, + "terms_url": "/_matrix/consent?v=%s" % (self._consent_version,), + } + + template = self._jinja_env.get_template("sso_new_user_consent.html") + html = template.render(template_params) + respond_with_html(request, 200, html) + + async def _async_render_POST(self, request: Request): + try: + session_id = get_username_mapping_session_cookie_from_request(request) + except SynapseError as e: + logger.warning("Error fetching session cookie: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + + try: + accepted_version = parse_string(request, "accepted_version", required=True) + except SynapseError as e: + self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code) + return + + await self._sso_handler.handle_terms_accepted( + request, session_id, accepted_version + ) diff --git a/synapse/rest/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py index d958dd65bb..64c0deb75d 100644 --- a/synapse/rest/oidc/__init__.py +++ b/synapse/rest/synapse/client/oidc/__init__.py @@ -12,11 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging from twisted.web.resource import Resource -from synapse.rest.oidc.callback_resource import OIDCCallbackResource +from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource logger = logging.getLogger(__name__) @@ -25,3 +26,6 @@ class OIDCResource(Resource): def __init__(self, hs): Resource.__init__(self) self.putChild(b"callback", OIDCCallbackResource(hs)) + + +__all__ = ["OIDCResource"] diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/synapse/client/oidc/callback_resource.py index f7a0bc4bdb..f7a0bc4bdb 100644 --- a/synapse/rest/oidc/callback_resource.py +++ b/synapse/rest/synapse/client/oidc/callback_resource.py diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py new file mode 100644 index 0000000000..9550b82998 --- /dev/null +++ b/synapse/rest/synapse/client/pick_idp.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING + +from synapse.http.server import ( + DirectServeHtmlResource, + finish_request, + respond_with_html, +) +from synapse.http.servlet import parse_string +from synapse.http.site import SynapseRequest + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class PickIdpResource(DirectServeHtmlResource): + """IdP picker resource. + + This resource gets mounted under /_synapse/client/pick_idp. It serves an HTML page + which prompts the user to choose an Identity Provider from the list. + """ + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + self._sso_login_idp_picker_template = ( + hs.config.sso.sso_login_idp_picker_template + ) + self._server_name = hs.hostname + + async def _async_render_GET(self, request: SynapseRequest) -> None: + client_redirect_url = parse_string( + request, "redirectUrl", required=True, encoding="utf-8" + ) + idp = parse_string(request, "idp", required=False) + + # if we need to pick an IdP, do so + if not idp: + return await self._serve_id_picker(request, client_redirect_url) + + # otherwise, redirect to the IdP's redirect URI + providers = self._sso_handler.get_identity_providers() + auth_provider = providers.get(idp) + if not auth_provider: + logger.info("Unknown idp %r", idp) + self._sso_handler.render_error( + request, "unknown_idp", "Unknown identity provider ID" + ) + return + + sso_url = await auth_provider.handle_redirect_request( + request, client_redirect_url.encode("utf8") + ) + logger.info("Redirecting to %s", sso_url) + request.redirect(sso_url) + finish_request(request) + + async def _serve_id_picker( + self, request: SynapseRequest, client_redirect_url: str + ) -> None: + # otherwise, serve up the IdP picker + providers = self._sso_handler.get_identity_providers() + html = self._sso_login_idp_picker_template.render( + redirect_url=client_redirect_url, + server_name=self._server_name, + providers=providers.values(), + ) + respond_with_html(request, 200, html) diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py new file mode 100644 index 0000000000..96077cfcd1 --- /dev/null +++ b/synapse/rest/synapse/client/pick_username.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, List + +from twisted.web.http import Request +from twisted.web.resource import Resource + +from synapse.api.errors import SynapseError +from synapse.handlers.sso import get_username_mapping_session_cookie_from_request +from synapse.http.server import ( + DirectServeHtmlResource, + DirectServeJsonResource, + respond_with_html, +) +from synapse.http.servlet import parse_boolean, parse_string +from synapse.http.site import SynapseRequest +from synapse.util.templates import build_jinja_env + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +def pick_username_resource(hs: "HomeServer") -> Resource: + """Factory method to generate the username picker resource. + + This resource gets mounted under /_synapse/client/pick_username and has two + children: + + * "account_details": renders the form and handles the POSTed response + * "check": a JSON endpoint which checks if a userid is free. + """ + + res = Resource() + res.putChild(b"account_details", AccountDetailsResource(hs)) + res.putChild(b"check", AvailabilityCheckResource(hs)) + + return res + + +class AvailabilityCheckResource(DirectServeJsonResource): + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + + async def _async_render_GET(self, request: Request): + localpart = parse_string(request, "username", required=True) + + session_id = get_username_mapping_session_cookie_from_request(request) + + is_available = await self._sso_handler.check_username_availability( + localpart, session_id + ) + return 200, {"available": is_available} + + +class AccountDetailsResource(DirectServeHtmlResource): + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + + def template_search_dirs(): + if hs.config.sso.sso_template_dir: + yield hs.config.sso.sso_template_dir + yield hs.config.sso.default_template_dir + + self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) + + async def _async_render_GET(self, request: Request) -> None: + try: + session_id = get_username_mapping_session_cookie_from_request(request) + session = self._sso_handler.get_mapping_session(session_id) + except SynapseError as e: + logger.warning("Error fetching session: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + + idp_id = session.auth_provider_id + template_params = { + "idp": self._sso_handler.get_identity_providers()[idp_id], + "user_attributes": { + "display_name": session.display_name, + "emails": session.emails, + }, + } + + template = self._jinja_env.get_template("sso_auth_account_details.html") + html = template.render(template_params) + respond_with_html(request, 200, html) + + async def _async_render_POST(self, request: SynapseRequest): + try: + session_id = get_username_mapping_session_cookie_from_request(request) + except SynapseError as e: + logger.warning("Error fetching session cookie: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + + try: + localpart = parse_string(request, "username", required=True) + use_display_name = parse_boolean(request, "use_display_name", default=False) + + try: + emails_to_use = [ + val.decode("utf-8") for val in request.args.get(b"use_email", []) + ] # type: List[str] + except ValueError: + raise SynapseError(400, "Query parameter use_email must be utf-8") + except SynapseError as e: + logger.warning("[session %s] bad param: %s", session_id, e) + self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code) + return + + await self._sso_handler.handle_submit_username_request( + request, session_id, localpart, use_display_name, emails_to_use + ) diff --git a/synapse/rest/saml2/__init__.py b/synapse/rest/synapse/client/saml2/__init__.py index 68da37ca6a..3e8235ee1e 100644 --- a/synapse/rest/saml2/__init__.py +++ b/synapse/rest/synapse/client/saml2/__init__.py @@ -12,12 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging from twisted.web.resource import Resource -from synapse.rest.saml2.metadata_resource import SAML2MetadataResource -from synapse.rest.saml2.response_resource import SAML2ResponseResource +from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource +from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource logger = logging.getLogger(__name__) @@ -27,3 +28,6 @@ class SAML2Resource(Resource): Resource.__init__(self) self.putChild(b"metadata.xml", SAML2MetadataResource(hs)) self.putChild(b"authn_response", SAML2ResponseResource(hs)) + + +__all__ = ["SAML2Resource"] diff --git a/synapse/rest/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py index 1e8526e22e..1e8526e22e 100644 --- a/synapse/rest/saml2/metadata_resource.py +++ b/synapse/rest/synapse/client/saml2/metadata_resource.py diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py index f6668fb5e3..f6668fb5e3 100644 --- a/synapse/rest/saml2/response_resource.py +++ b/synapse/rest/synapse/client/saml2/response_resource.py diff --git a/synapse/rest/synapse/client/sso_register.py b/synapse/rest/synapse/client/sso_register.py new file mode 100644 index 0000000000..dfefeb7796 --- /dev/null +++ b/synapse/rest/synapse/client/sso_register.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request + +from synapse.api.errors import SynapseError +from synapse.handlers.sso import get_username_mapping_session_cookie_from_request +from synapse.http.server import DirectServeHtmlResource + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class SsoRegisterResource(DirectServeHtmlResource): + """A resource which completes SSO registration + + This resource gets mounted at /_synapse/client/sso_register, and is shown + after we collect username and/or consent for a new SSO user. It (finally) registers + the user, and confirms redirect to the client + """ + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + + async def _async_render_GET(self, request: Request) -> None: + try: + session_id = get_username_mapping_session_cookie_from_request(request) + except SynapseError as e: + logger.warning("Error fetching session cookie: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + await self._sso_handler.register_sso_user(request, session_id) |