diff options
Diffstat (limited to 'synapse/rest')
-rw-r--r-- | synapse/rest/__init__.py | 2 | ||||
-rw-r--r-- | synapse/rest/admin/__init__.py | 7 | ||||
-rw-r--r-- | synapse/rest/admin/rooms.py | 94 | ||||
-rw-r--r-- | synapse/rest/admin/users.py | 38 | ||||
-rw-r--r-- | synapse/rest/client/v1/login.py | 178 | ||||
-rw-r--r-- | synapse/rest/client/v1/room.py | 33 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/account.py | 57 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/auth.py | 90 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/devices.py | 12 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/keys.py | 6 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/password_policy.py | 58 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/register.py | 44 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/room_keys.py | 2 | ||||
-rw-r--r-- | synapse/rest/media/v1/download_resource.py | 3 | ||||
-rw-r--r-- | synapse/rest/media/v1/media_repository.py | 110 | ||||
-rw-r--r-- | synapse/rest/media/v1/preview_url_resource.py | 54 | ||||
-rw-r--r-- | synapse/rest/media/v1/thumbnail_resource.py | 54 |
17 files changed, 481 insertions, 361 deletions
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 4a1fc2ec2b..46e458e95b 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -41,6 +41,7 @@ from synapse.rest.client.v2_alpha import ( keys, notifications, openid, + password_policy, read_marker, receipts, register, @@ -118,6 +119,7 @@ class ClientRestResource(JsonResource): capabilities.register_servlets(hs, client_resource) account_validity.register_servlets(hs, client_resource) relations.register_servlets(hs, client_resource) + password_policy.register_servlets(hs, client_resource) # moving to /_synapse/admin synapse.rest.admin.register_servlets_for_client_rest_resource( diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 42cc2b062a..ed70d448a1 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -29,7 +29,11 @@ from synapse.rest.admin._base import ( from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet -from synapse.rest.admin.rooms import ListRoomRestServlet, ShutdownRoomRestServlet +from synapse.rest.admin.rooms import ( + JoinRoomAliasServlet, + ListRoomRestServlet, + ShutdownRoomRestServlet, +) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.users import ( AccountValidityRenewServlet, @@ -189,6 +193,7 @@ def register_servlets(hs, http_server): """ register_servlets_for_client_rest_resource(hs, http_server) ListRoomRestServlet(hs).register(http_server) + JoinRoomAliasServlet(hs).register(http_server) PurgeRoomServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f9b8c0a4f0..d1bdb64111 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List, Optional -from synapse.api.constants import Membership -from synapse.api.errors import Codes, SynapseError +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -29,7 +30,7 @@ from synapse.rest.admin._base import ( historical_admin_path_patterns, ) from synapse.storage.data_stores.main.room import RoomSortOrder -from synapse.types import create_requester +from synapse.types import RoomAlias, RoomID, UserID, create_requester from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -182,10 +183,23 @@ class ListRoomRestServlet(RestServlet): # Extract query parameters start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) - order_by = parse_string(request, "order_by", default="alphabetical") + order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value) if order_by not in ( RoomSortOrder.ALPHABETICAL.value, RoomSortOrder.SIZE.value, + RoomSortOrder.NAME.value, + RoomSortOrder.CANONICAL_ALIAS.value, + RoomSortOrder.JOINED_MEMBERS.value, + RoomSortOrder.JOINED_LOCAL_MEMBERS.value, + RoomSortOrder.VERSION.value, + RoomSortOrder.CREATOR.value, + RoomSortOrder.ENCRYPTION.value, + RoomSortOrder.FEDERATABLE.value, + RoomSortOrder.PUBLIC.value, + RoomSortOrder.JOIN_RULES.value, + RoomSortOrder.GUEST_ACCESS.value, + RoomSortOrder.HISTORY_VISIBILITY.value, + RoomSortOrder.STATE_EVENTS.value, ): raise SynapseError( 400, @@ -237,3 +251,75 @@ class ListRoomRestServlet(RestServlet): response["prev_batch"] = 0 return 200, response + + +class JoinRoomAliasServlet(RestServlet): + + PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.room_member_handler = hs.get_room_member_handler() + self.admin_handler = hs.get_handlers().admin_handler + self.state_handler = hs.get_state_handler() + + async def on_POST(self, request, room_identifier): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + content = parse_json_object_from_request(request) + + assert_params_in_dict(content, ["user_id"]) + target_user = UserID.from_string(content["user_id"]) + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "This endpoint can only be used with local users") + + if not await self.admin_handler.get_user(target_user): + raise NotFoundError("User not found") + + if RoomID.is_valid(room_identifier): + room_id = room_identifier + try: + remote_room_hosts = [ + x.decode("ascii") for x in request.args[b"server_name"] + ] # type: Optional[List[str]] + except Exception: + remote_room_hosts = None + elif RoomAlias.is_valid(room_identifier): + handler = self.room_member_handler + room_alias = RoomAlias.from_string(room_identifier) + room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias) + room_id = room_id.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + + fake_requester = create_requester(target_user) + + # send invite if room has "JoinRules.INVITE" + room_state = await self.state_handler.get_current_state(room_id) + join_rules_event = room_state.get((EventTypes.JoinRules, "")) + if join_rules_event: + if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC): + await self.room_member_handler.update_membership( + requester=requester, + target=fake_requester.user, + room_id=room_id, + action="invite", + remote_room_hosts=remote_room_hosts, + ratelimit=False, + ) + + await self.room_member_handler.update_membership( + requester=fake_requester, + target=fake_requester.user, + room_id=room_id, + action="join", + remote_room_hosts=remote_room_hosts, + ratelimit=False, + ) + + return 200, {"room_id": room_id} diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 8551ac19b8..326682fbdb 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -94,10 +94,10 @@ class UsersRestServletV2(RestServlet): guests = parse_boolean(request, "guests", default=True) deactivated = parse_boolean(request, "deactivated", default=False) - users = await self.store.get_users_paginate( + users, total = await self.store.get_users_paginate( start, limit, user_id, guests, deactivated ) - ret = {"users": users} + ret = {"users": users, "total": total} if len(users) >= limit: ret["next_token"] = str(start + len(users)) @@ -199,7 +199,7 @@ class UserRestServletV2(RestServlet): user_id, threepid["medium"], threepid["address"], current_time ) - if "avatar_url" in body: + if "avatar_url" in body and type(body["avatar_url"]) == str: await self.profile_handler.set_avatar_url( target_user, requester, body["avatar_url"], True ) @@ -243,11 +243,11 @@ class UserRestServletV2(RestServlet): else: # create user password = body.get("password") - if password is not None and ( - not isinstance(body["password"], text_type) - or len(body["password"]) > 512 - ): - raise SynapseError(400, "Invalid password") + password_hash = None + if password is not None: + if not isinstance(password, text_type) or len(password) > 512: + raise SynapseError(400, "Invalid password") + password_hash = await self.auth_handler.hash(password) admin = body.get("admin", None) user_type = body.get("user_type", None) @@ -259,7 +259,7 @@ class UserRestServletV2(RestServlet): user_id = await self.registration_handler.register_user( localpart=target_user.localpart, - password=password, + password_hash=password_hash, admin=bool(admin), default_display_name=displayname, user_type=user_type, @@ -276,7 +276,7 @@ class UserRestServletV2(RestServlet): user_id, threepid["medium"], threepid["address"], current_time ) - if "avatar_url" in body: + if "avatar_url" in body and type(body["avatar_url"]) == str: await self.profile_handler.set_avatar_url( user_id, requester, body["avatar_url"], True ) @@ -298,7 +298,7 @@ class UserRegisterServlet(RestServlet): NONCE_TIMEOUT = 60 def __init__(self, hs): - self.handlers = hs.get_handlers() + self.auth_handler = hs.get_auth_handler() self.reactor = hs.get_reactor() self.nonces = {} self.hs = hs @@ -362,16 +362,16 @@ class UserRegisterServlet(RestServlet): 400, "password must be specified", errcode=Codes.BAD_JSON ) else: - if ( - not isinstance(body["password"], text_type) - or len(body["password"]) > 512 - ): + password = body["password"] + if not isinstance(password, text_type) or len(password) > 512: raise SynapseError(400, "Invalid password") - password = body["password"].encode("utf-8") - if b"\x00" in password: + password_bytes = password.encode("utf-8") + if b"\x00" in password_bytes: raise SynapseError(400, "Invalid password") + password_hash = await self.auth_handler.hash(password) + admin = body.get("admin", None) user_type = body.get("user_type", None) @@ -388,7 +388,7 @@ class UserRegisterServlet(RestServlet): want_mac_builder.update(b"\x00") want_mac_builder.update(username) want_mac_builder.update(b"\x00") - want_mac_builder.update(password) + want_mac_builder.update(password_bytes) want_mac_builder.update(b"\x00") want_mac_builder.update(b"admin" if admin else b"notadmin") if user_type: @@ -407,7 +407,7 @@ class UserRegisterServlet(RestServlet): user_id = await register.registration_handler.register_user( localpart=body["username"].lower(), - password=body["password"], + password_hash=password_hash, admin=bool(admin), user_type=user_type, ) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index d0d4999795..4de2f97d06 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -14,11 +14,6 @@ # limitations under the License. import logging -import xml.etree.ElementTree as ET - -from six.moves import urllib - -from twisted.web.client import PartialDownloadError from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter @@ -28,10 +23,10 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) -from synapse.push.mailer import load_jinja2_templates +from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder -from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.types import UserID from synapse.util.msisdn import phone_number_to_msisdn logger = logging.getLogger(__name__) @@ -402,7 +397,7 @@ class BaseSSORedirectServlet(RestServlet): PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) - def on_GET(self, request): + def on_GET(self, request: SynapseRequest): args = request.args if b"redirectUrl" not in args: return 400, "Redirect URL not specified for SSO auth" @@ -411,15 +406,15 @@ class BaseSSORedirectServlet(RestServlet): request.redirect(sso_url) finish_request(request) - def get_sso_url(self, client_redirect_url): + def get_sso_url(self, client_redirect_url: bytes) -> bytes: """Get the URL to redirect to, to perform SSO auth Args: - client_redirect_url (bytes): the URL that we should redirect the + client_redirect_url: the URL that we should redirect the client to when everything is done Returns: - bytes: URL to redirect to + URL to redirect to """ # to be implemented by subclasses raise NotImplementedError() @@ -427,19 +422,12 @@ class BaseSSORedirectServlet(RestServlet): class CasRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): - super(CasRedirectServlet, self).__init__() - self.cas_server_url = hs.config.cas_server_url.encode("ascii") - self.cas_service_url = hs.config.cas_service_url.encode("ascii") + self._cas_handler = hs.get_cas_handler() - def get_sso_url(self, client_redirect_url): - client_redirect_url_param = urllib.parse.urlencode( - {b"redirectUrl": client_redirect_url} + def get_sso_url(self, client_redirect_url: bytes) -> bytes: + return self._cas_handler.get_redirect_url( + {"redirectUrl": client_redirect_url} ).encode("ascii") - hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket" - service_param = urllib.parse.urlencode( - {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)} - ).encode("ascii") - return b"%s/login?%s" % (self.cas_server_url, service_param) class CasTicketServlet(RestServlet): @@ -447,81 +435,25 @@ class CasTicketServlet(RestServlet): def __init__(self, hs): super(CasTicketServlet, self).__init__() - self.cas_server_url = hs.config.cas_server_url - self.cas_service_url = hs.config.cas_service_url - self.cas_displayname_attribute = hs.config.cas_displayname_attribute - self.cas_required_attributes = hs.config.cas_required_attributes - self._sso_auth_handler = SSOAuthHandler(hs) - self._http_client = hs.get_proxied_http_client() - - async def on_GET(self, request): - client_redirect_url = parse_string(request, "redirectUrl", required=True) - uri = self.cas_server_url + "/proxyValidate" - args = { - "ticket": parse_string(request, "ticket", required=True), - "service": self.cas_service_url, - } - try: - body = await self._http_client.get_raw(uri, args) - except PartialDownloadError as pde: - # Twisted raises this error if the connection is closed, - # even if that's being used old-http style to signal end-of-data - body = pde.response - result = await self.handle_cas_response(request, body, client_redirect_url) - return result + self._cas_handler = hs.get_cas_handler() - def handle_cas_response(self, request, cas_response_body, client_redirect_url): - user, attributes = self.parse_cas_response(cas_response_body) - displayname = attributes.pop(self.cas_displayname_attribute, None) + async def on_GET(self, request: SynapseRequest) -> None: + client_redirect_url = parse_string(request, "redirectUrl") + ticket = parse_string(request, "ticket", required=True) - for required_attribute, required_value in self.cas_required_attributes.items(): - # If required attribute was not in CAS Response - Forbidden - if required_attribute not in attributes: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + # Maybe get a session ID (if this ticket is from user interactive + # authentication). + session = parse_string(request, "session") - # Also need to check value - if required_value is not None: - actual_value = attributes[required_attribute] - # If required attribute value does not match expected - Forbidden - if required_value != actual_value: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) + # Either client_redirect_url or session must be provided. + if not client_redirect_url and not session: + message = "Missing string query parameter redirectUrl or session" + raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) - return self._sso_auth_handler.on_successful_auth( - user, request, client_redirect_url, displayname + await self._cas_handler.handle_ticket( + request, ticket, client_redirect_url, session ) - def parse_cas_response(self, cas_response_body): - user = None - attributes = {} - try: - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise Exception("root of CAS response is not serviceResponse") - success = root[0].tag.endswith("authenticationSuccess") - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - for attribute in child: - # ElementTree library expands the namespace in - # attribute tags to the full URL of the namespace. - # We don't care about namespace here and it will always - # be encased in curly braces, so we remove them. - tag = attribute.tag - if "}" in tag: - tag = tag.split("}")[1] - attributes[tag] = attribute.text - if user is None: - raise Exception("CAS response does not contain user") - except Exception: - logger.exception("Error parsing CAS response") - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not success: - raise LoginError( - 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED - ) - return user, attributes - class SAMLRedirectServlet(BaseSSORedirectServlet): PATTERNS = client_patterns("/login/sso/redirect", v1=True) @@ -529,72 +461,10 @@ class SAMLRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): self._saml_handler = hs.get_saml_handler() - def get_sso_url(self, client_redirect_url): + def get_sso_url(self, client_redirect_url: bytes) -> bytes: return self._saml_handler.handle_redirect_request(client_redirect_url) -class SSOAuthHandler(object): - """ - Utility class for Resources and Servlets which handle the response from a SSO - service - - Args: - hs (synapse.server.HomeServer) - """ - - def __init__(self, hs): - self._hostname = hs.hostname - self._auth_handler = hs.get_auth_handler() - self._registration_handler = hs.get_registration_handler() - self._macaroon_gen = hs.get_macaroon_generator() - - # Load the redirect page HTML template - self._template = load_jinja2_templates( - hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"], - )[0] - - self._server_name = hs.config.server_name - - # cast to tuple for use with str.startswith - self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) - - async def on_successful_auth( - self, username, request, client_redirect_url, user_display_name=None - ): - """Called once the user has successfully authenticated with the SSO. - - Registers the user if necessary, and then returns a redirect (with - a login token) to the client. - - Args: - username (unicode|bytes): the remote user id. We'll map this onto - something sane for a MXID localpath. - - request (SynapseRequest): the incoming request from the browser. We'll - respond to it with a redirect. - - client_redirect_url (unicode): the redirect_url the client gave us when - it first started the process. - - user_display_name (unicode|None): if set, and we have to register a new user, - we will set their displayname to this. - - Returns: - Deferred[none]: Completes once we have handled the request. - """ - localpart = map_username_to_mxid_localpart(username) - user_id = UserID(localpart, self._hostname).to_string() - registered_user_id = await self._auth_handler.check_user_exists(user_id) - if not registered_user_id: - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=user_display_name - ) - - self._auth_handler.complete_sso_login( - registered_user_id, request, client_redirect_url - ) - - def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) if hs.config.cas_enabled: diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index bffd43de5f..6b5830cc3f 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -27,6 +27,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, + HttpResponseException, InvalidClientCredentialsError, SynapseError, ) @@ -364,10 +365,13 @@ class PublicRoomListRestServlet(TransactionRestServlet): limit = None handler = self.hs.get_room_list_handler() - if server: - data = await handler.get_remote_public_room_list( - server, limit=limit, since_token=since_token - ) + if server and server != self.hs.config.server_name: + try: + data = await handler.get_remote_public_room_list( + server, limit=limit, since_token=since_token + ) + except HttpResponseException as e: + raise e.to_synapse_error() else: data = await handler.get_local_public_room_list( limit=limit, since_token=since_token @@ -404,15 +408,18 @@ class PublicRoomListRestServlet(TransactionRestServlet): limit = None handler = self.hs.get_room_list_handler() - if server: - data = await handler.get_remote_public_room_list( - server, - limit=limit, - since_token=since_token, - search_filter=search_filter, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) + if server and server != self.hs.config.server_name: + try: + data = await handler.get_remote_public_room_list( + server, + limit=limit, + since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) + except HttpResponseException as e: + raise e.to_synapse_error() else: data = await handler.get_local_public_room_list( limit=limit, diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 631cc74cb4..1bd0234779 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -30,7 +30,7 @@ from synapse.http.servlet import ( ) from synapse.push.mailer import Mailer, load_jinja2_templates from synapse.util.msisdn import phone_number_to_msisdn -from synapse.util.stringutils import assert_valid_client_secret +from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import check_3pid_allowed from ._base import client_patterns, interactive_auth_handler @@ -100,6 +100,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): ) if existing_user_id is None: + if self.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: @@ -234,13 +239,21 @@ class PasswordRestServlet(RestServlet): if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req(request) params = await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, + request, + body, + self.hs.get_ip_from_request(request), + "modify your account password", ) user_id = requester.user.to_string() else: requester = None result, params, _ = await self.auth_handler.check_auth( - [[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request) + [[LoginType.EMAIL_IDENTITY]], + request, + body, + self.hs.get_ip_from_request(request), + "modify your account password", ) if LoginType.EMAIL_IDENTITY in result: @@ -308,7 +321,11 @@ class DeactivateAccountRestServlet(RestServlet): return 200, {} await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, + request, + body, + self.hs.get_ip_from_request(request), + "deactivate your account", ) result = await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase, id_server=body.get("id_server") @@ -378,6 +395,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): ) if existing_user_id is not None: + if self.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: @@ -441,6 +463,11 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) if existing_user_id is not None: + if self.hs.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) if not self.hs.config.account_threepid_delegate_msisdn: @@ -602,6 +629,11 @@ class ThreepidRestServlet(RestServlet): return 200, {"threepids": threepids} async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -646,6 +678,11 @@ class ThreepidAddRestServlet(RestServlet): @interactive_auth_handler async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -656,7 +693,11 @@ class ThreepidAddRestServlet(RestServlet): assert_valid_client_secret(client_secret) await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, + request, + body, + self.hs.get_ip_from_request(request), + "add a third-party identifier to your account", ) validation_session = await self.identity_handler.validate_threepid_session( @@ -741,10 +782,16 @@ class ThreepidDeleteRestServlet(RestServlet): def __init__(self, hs): super(ThreepidDeleteRestServlet, self).__init__() + self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 50e080673b..24dd3d3e96 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -130,7 +130,17 @@ class AuthRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() - def on_GET(self, request, stagetype): + # SSO configuration. + self._saml_enabled = hs.config.saml2_enabled + if self._saml_enabled: + self._saml_handler = hs.get_saml_handler() + self._cas_enabled = hs.config.cas_enabled + if self._cas_enabled: + self._cas_handler = hs.get_cas_handler() + self._cas_server_url = hs.config.cas_server_url + self._cas_service_url = hs.config.cas_service_url + + async def on_GET(self, request, stagetype): session = parse_string(request, "session") if not session: raise SynapseError(400, "No session supplied") @@ -142,14 +152,6 @@ class AuthRestServlet(RestServlet): % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), "sitekey": self.hs.config.recaptcha_public_key, } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - return None elif stagetype == LoginType.TERMS: html = TERMS_TEMPLATE % { "session": session, @@ -158,17 +160,41 @@ class AuthRestServlet(RestServlet): "myurl": "%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - return None + + elif stagetype == LoginType.SSO: + # Display a confirmation page which prompts the user to + # re-authenticate with their SSO provider. + if self._cas_enabled: + # Generate a request to CAS that redirects back to an endpoint + # to verify the successful authentication. + sso_redirect_url = self._cas_handler.get_redirect_url( + {"session": session}, + ) + + elif self._saml_enabled: + client_redirect_url = "" + sso_redirect_url = self._saml_handler.handle_redirect_request( + client_redirect_url, session + ) + + else: + raise SynapseError(400, "Homeserver not configured for SSO.") + + html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) + else: raise SynapseError(404, "Unknown auth stage type") + # Render the HTML and return. + html_bytes = html.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + finish_request(request) + return None + async def on_POST(self, request, stagetype): session = parse_string(request, "session") @@ -196,15 +222,6 @@ class AuthRestServlet(RestServlet): % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), "sitekey": self.hs.config.recaptcha_public_key, } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - - return None elif stagetype == LoginType.TERMS: authdict = {"session": session} @@ -225,17 +242,22 @@ class AuthRestServlet(RestServlet): "myurl": "%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - return None + elif stagetype == LoginType.SSO: + # The SSO fallback workflow should not post here, + raise SynapseError(404, "Fallback SSO auth does not support POST requests.") else: raise SynapseError(404, "Unknown auth stage type") + # Render the HTML and return. + html_bytes = html.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + finish_request(request) + return None + def on_OPTIONS(self, _): return 200, {} diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 94ff73f384..c0714fcfb1 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -81,7 +81,11 @@ class DeleteDevicesRestServlet(RestServlet): assert_params_in_dict(body, ["devices"]) await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, + request, + body, + self.hs.get_ip_from_request(request), + "remove device(s) from your account", ) await self.device_handler.delete_devices( @@ -127,7 +131,11 @@ class DeviceRestServlet(RestServlet): raise await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, + request, + body, + self.hs.get_ip_from_request(request), + "remove a device from your account", ) await self.device_handler.delete_device(requester.user.to_string(), device_id) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index f7ed4daf90..8f41a3edbf 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -263,7 +263,11 @@ class SigningKeyUploadServlet(RestServlet): body = parse_json_object_from_request(request) await self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request) + requester, + request, + body, + self.hs.get_ip_from_request(request), + "add a device signing key to your account", ) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py new file mode 100644 index 0000000000..968403cca4 --- /dev/null +++ b/synapse/rest/client/v2_alpha/password_policy.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.http.servlet import RestServlet + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class PasswordPolicyServlet(RestServlet): + PATTERNS = client_patterns("/password_policy$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(PasswordPolicyServlet, self).__init__() + + self.policy = hs.config.password_policy + self.enabled = hs.config.password_policy_enabled + + def on_GET(self, request): + if not self.enabled or not self.policy: + return (200, {}) + + policy = {} + + for param in [ + "minimum_length", + "require_digit", + "require_symbol", + "require_lowercase", + "require_uppercase", + ]: + if param in self.policy: + policy["m.%s" % param] = self.policy[param] + + return (200, policy) + + +def register_servlets(hs, http_server): + PasswordPolicyServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index a09189b1b4..c26927f27b 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -49,7 +49,7 @@ from synapse.http.servlet import ( from synapse.push.mailer import load_jinja2_templates from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter -from synapse.util.stringutils import assert_valid_client_secret +from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import check_3pid_allowed from ._base import client_patterns, interactive_auth_handler @@ -135,6 +135,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): ) if existing_user_id is not None: + if self.hs.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: @@ -202,6 +207,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) if existing_user_id is not None: + if self.hs.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + return 200, {"sid": random_string(16)} + raise SynapseError( 400, "Phone number is already in use", Codes.THREEPID_IN_USE ) @@ -373,6 +383,7 @@ class RegisterRestServlet(RestServlet): self.room_member_handler = hs.get_room_member_handler() self.macaroon_gen = hs.get_macaroon_generator() self.ratelimiter = hs.get_registration_ratelimiter() + self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() self._registration_flows = _calculate_registration_flows( @@ -415,11 +426,16 @@ class RegisterRestServlet(RestServlet): # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the username/password provided to us. if "password" in body: - if ( - not isinstance(body["password"], string_types) - or len(body["password"]) > 512 - ): + password = body.pop("password") + if not isinstance(password, string_types) or len(password) > 512: raise SynapseError(400, "Invalid password") + self.password_policy_handler.validate_password(password) + + # If the password is valid, hash it and store it back on the request. + # This ensures the hashed password is handled everywhere. + if "password_hash" in body: + raise SynapseError(400, "Unexpected property: password_hash") + body["password_hash"] = await self.auth_handler.hash(password) desired_username = None if "username" in body: @@ -472,7 +488,7 @@ class RegisterRestServlet(RestServlet): guest_access_token = body.get("guest_access_token", None) - if "initial_device_display_name" in body and "password" not in body: + if "initial_device_display_name" in body and "password_hash" not in body: # ignore 'initial_device_display_name' if sent without # a password to work around a client bug where it sent # the 'initial_device_display_name' param alone, wiping out @@ -487,7 +503,7 @@ class RegisterRestServlet(RestServlet): # registered a user for this session, so we could just return the # user here. We carry on and go through the auth checks though, # for paranoia. - registered_user_id = self.auth_handler.get_session_data( + registered_user_id = await self.auth_handler.get_session_data( session_id, "registered_user_id", None ) @@ -499,7 +515,11 @@ class RegisterRestServlet(RestServlet): ) auth_result, params, session_id = await self.auth_handler.check_auth( - self._registration_flows, body, self.hs.get_ip_from_request(request) + self._registration_flows, + request, + body, + self.hs.get_ip_from_request(request), + "register a new account", ) # Check that we're not trying to register a denied 3pid. @@ -530,11 +550,11 @@ class RegisterRestServlet(RestServlet): registered = False else: # NB: This may be from the auth handler and NOT from the POST - assert_params_in_dict(params, ["password"]) + assert_params_in_dict(params, ["password_hash"]) desired_username = params.get("username", None) guest_access_token = params.get("guest_access_token", None) - new_password = params.get("password", None) + new_password_hash = params.get("password_hash", None) if desired_username is not None: desired_username = desired_username.lower() @@ -567,7 +587,7 @@ class RegisterRestServlet(RestServlet): registered_user_id = await self.registration_handler.register_user( localpart=desired_username, - password=new_password, + password_hash=new_password_hash, guest_access_token=guest_access_token, threepid=threepid, address=client_addr, @@ -582,7 +602,7 @@ class RegisterRestServlet(RestServlet): # remember that we've now registered that user account, and with # what user ID (since the user may not have specified) - self.auth_handler.set_session_data( + await self.auth_handler.set_session_data( session_id, "registered_user_id", registered_user_id ) diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 38952a1d27..59529707df 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -188,7 +188,7 @@ class RoomKeysServlet(RestServlet): """ requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - version = parse_string(request, "version") + version = parse_string(request, "version", required=True) room_keys = await self.e2e_room_keys_handler.get_room_keys( user_id, version, room_id, session_id diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index 66a01559e1..24d3ae5bbc 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -50,6 +50,9 @@ class DownloadResource(DirectServeResource): b" media-src 'self';" b" object-src 'self';", ) + request.setHeader( + b"Referrer-Policy", b"no-referrer", + ) server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: await self.media_repo.get_local_media(request, media_id, name) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 490b1b45a8..fd10d42f2f 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -24,7 +24,6 @@ from six import iteritems import twisted.internet.error import twisted.web.http -from twisted.internet import defer from twisted.web.resource import Resource from synapse.api.errors import ( @@ -114,15 +113,14 @@ class MediaRepository(object): "update_recently_accessed_media", self._update_recently_accessed ) - @defer.inlineCallbacks - def _update_recently_accessed(self): + async def _update_recently_accessed(self): remote_media = self.recently_accessed_remotes self.recently_accessed_remotes = set() local_media = self.recently_accessed_locals self.recently_accessed_locals = set() - yield self.store.update_cached_last_access_time( + await self.store.update_cached_last_access_time( local_media, remote_media, self.clock.time_msec() ) @@ -138,8 +136,7 @@ class MediaRepository(object): else: self.recently_accessed_locals.add(media_id) - @defer.inlineCallbacks - def create_content( + async def create_content( self, media_type, upload_name, content, content_length, auth_user ): """Store uploaded content for a local user and return the mxc URL @@ -158,11 +155,11 @@ class MediaRepository(object): file_info = FileInfo(server_name=None, file_id=media_id) - fname = yield self.media_storage.store_file(content, file_info) + fname = await self.media_storage.store_file(content, file_info) logger.info("Stored local media in file %r", fname) - yield self.store.store_local_media( + await self.store.store_local_media( media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), @@ -171,12 +168,11 @@ class MediaRepository(object): user_id=auth_user, ) - yield self._generate_thumbnails(None, media_id, media_id, media_type) + await self._generate_thumbnails(None, media_id, media_id, media_type) return "mxc://%s/%s" % (self.server_name, media_id) - @defer.inlineCallbacks - def get_local_media(self, request, media_id, name): + async def get_local_media(self, request, media_id, name): """Responds to reqests for local media, if exists, or returns 404. Args: @@ -190,7 +186,7 @@ class MediaRepository(object): Deferred: Resolves once a response has successfully been written to request """ - media_info = yield self.store.get_local_media(media_id) + media_info = await self.store.get_local_media(media_id) if not media_info or media_info["quarantined_by"]: respond_404(request) return @@ -204,13 +200,12 @@ class MediaRepository(object): file_info = FileInfo(None, media_id, url_cache=url_cache) - responder = yield self.media_storage.fetch_media(file_info) - yield respond_with_responder( + responder = await self.media_storage.fetch_media(file_info) + await respond_with_responder( request, responder, media_type, media_length, upload_name ) - @defer.inlineCallbacks - def get_remote_media(self, request, server_name, media_id, name): + async def get_remote_media(self, request, server_name, media_id, name): """Respond to requests for remote media. Args: @@ -236,8 +231,8 @@ class MediaRepository(object): # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) - with (yield self.remote_media_linearizer.queue(key)): - responder, media_info = yield self._get_remote_media_impl( + with (await self.remote_media_linearizer.queue(key)): + responder, media_info = await self._get_remote_media_impl( server_name, media_id ) @@ -246,14 +241,13 @@ class MediaRepository(object): media_type = media_info["media_type"] media_length = media_info["media_length"] upload_name = name if name else media_info["upload_name"] - yield respond_with_responder( + await respond_with_responder( request, responder, media_type, media_length, upload_name ) else: respond_404(request) - @defer.inlineCallbacks - def get_remote_media_info(self, server_name, media_id): + async def get_remote_media_info(self, server_name, media_id): """Gets the media info associated with the remote file, downloading if necessary. @@ -274,8 +268,8 @@ class MediaRepository(object): # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) - with (yield self.remote_media_linearizer.queue(key)): - responder, media_info = yield self._get_remote_media_impl( + with (await self.remote_media_linearizer.queue(key)): + responder, media_info = await self._get_remote_media_impl( server_name, media_id ) @@ -286,8 +280,7 @@ class MediaRepository(object): return media_info - @defer.inlineCallbacks - def _get_remote_media_impl(self, server_name, media_id): + async def _get_remote_media_impl(self, server_name, media_id): """Looks for media in local cache, if not there then attempt to download from remote server. @@ -299,7 +292,7 @@ class MediaRepository(object): Returns: Deferred[(Responder, media_info)] """ - media_info = yield self.store.get_cached_remote_media(server_name, media_id) + media_info = await self.store.get_cached_remote_media(server_name, media_id) # file_id is the ID we use to track the file locally. If we've already # seen the file then reuse the existing ID, otherwise genereate a new @@ -317,19 +310,18 @@ class MediaRepository(object): logger.info("Media is quarantined") raise NotFoundError() - responder = yield self.media_storage.fetch_media(file_info) + responder = await self.media_storage.fetch_media(file_info) if responder: return responder, media_info # Failed to find the file anywhere, lets download it. - media_info = yield self._download_remote_file(server_name, media_id, file_id) + media_info = await self._download_remote_file(server_name, media_id, file_id) - responder = yield self.media_storage.fetch_media(file_info) + responder = await self.media_storage.fetch_media(file_info) return responder, media_info - @defer.inlineCallbacks - def _download_remote_file(self, server_name, media_id, file_id): + async def _download_remote_file(self, server_name, media_id, file_id): """Attempt to download the remote file from the given server name, using the given file_id as the local id. @@ -351,7 +343,7 @@ class MediaRepository(object): ("/_matrix/media/v1/download", server_name, media_id) ) try: - length, headers = yield self.client.get_file( + length, headers = await self.client.get_file( server_name, request_path, output_stream=f, @@ -397,7 +389,7 @@ class MediaRepository(object): ) raise SynapseError(502, "Failed to fetch remote media") - yield finish() + await finish() media_type = headers[b"Content-Type"][0].decode("ascii") upload_name = get_filename_from_headers(headers) @@ -405,7 +397,7 @@ class MediaRepository(object): logger.info("Stored remote media in file %r", fname) - yield self.store.store_cached_remote_media( + await self.store.store_cached_remote_media( origin=server_name, media_id=media_id, media_type=media_type, @@ -423,7 +415,7 @@ class MediaRepository(object): "filesystem_id": file_id, } - yield self._generate_thumbnails(server_name, media_id, file_id, media_type) + await self._generate_thumbnails(server_name, media_id, file_id, media_type) return media_info @@ -458,16 +450,15 @@ class MediaRepository(object): return t_byte_source - @defer.inlineCallbacks - def generate_local_exact_thumbnail( + async def generate_local_exact_thumbnail( self, media_id, t_width, t_height, t_method, t_type, url_cache ): - input_path = yield self.media_storage.ensure_media_is_in_local_cache( + input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(None, media_id, url_cache=url_cache) ) thumbnailer = Thumbnailer(input_path) - t_byte_source = yield defer_to_thread( + t_byte_source = await defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, thumbnailer, @@ -490,7 +481,7 @@ class MediaRepository(object): thumbnail_type=t_type, ) - output_path = yield self.media_storage.store_file( + output_path = await self.media_storage.store_file( t_byte_source, file_info ) finally: @@ -500,22 +491,21 @@ class MediaRepository(object): t_len = os.path.getsize(output_path) - yield self.store.store_local_thumbnail( + await self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) return output_path - @defer.inlineCallbacks - def generate_remote_exact_thumbnail( + async def generate_remote_exact_thumbnail( self, server_name, file_id, media_id, t_width, t_height, t_method, t_type ): - input_path = yield self.media_storage.ensure_media_is_in_local_cache( + input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=False) ) thumbnailer = Thumbnailer(input_path) - t_byte_source = yield defer_to_thread( + t_byte_source = await defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, thumbnailer, @@ -537,7 +527,7 @@ class MediaRepository(object): thumbnail_type=t_type, ) - output_path = yield self.media_storage.store_file( + output_path = await self.media_storage.store_file( t_byte_source, file_info ) finally: @@ -547,7 +537,7 @@ class MediaRepository(object): t_len = os.path.getsize(output_path) - yield self.store.store_remote_media_thumbnail( + await self.store.store_remote_media_thumbnail( server_name, media_id, file_id, @@ -560,8 +550,7 @@ class MediaRepository(object): return output_path - @defer.inlineCallbacks - def _generate_thumbnails( + async def _generate_thumbnails( self, server_name, media_id, file_id, media_type, url_cache=False ): """Generate and store thumbnails for an image. @@ -582,7 +571,7 @@ class MediaRepository(object): if not requirements: return - input_path = yield self.media_storage.ensure_media_is_in_local_cache( + input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=url_cache) ) @@ -600,7 +589,7 @@ class MediaRepository(object): return if thumbnailer.transpose_method is not None: - m_width, m_height = yield defer_to_thread( + m_width, m_height = await defer_to_thread( self.hs.get_reactor(), thumbnailer.transpose ) @@ -620,11 +609,11 @@ class MediaRepository(object): for (t_width, t_height, t_type), t_method in iteritems(thumbnails): # Generate the thumbnail if t_method == "crop": - t_byte_source = yield defer_to_thread( + t_byte_source = await defer_to_thread( self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type ) elif t_method == "scale": - t_byte_source = yield defer_to_thread( + t_byte_source = await defer_to_thread( self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type ) else: @@ -646,7 +635,7 @@ class MediaRepository(object): url_cache=url_cache, ) - output_path = yield self.media_storage.store_file( + output_path = await self.media_storage.store_file( t_byte_source, file_info ) finally: @@ -656,7 +645,7 @@ class MediaRepository(object): # Write to database if server_name: - yield self.store.store_remote_media_thumbnail( + await self.store.store_remote_media_thumbnail( server_name, media_id, file_id, @@ -667,15 +656,14 @@ class MediaRepository(object): t_len, ) else: - yield self.store.store_local_thumbnail( + await self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) return {"width": m_width, "height": m_height} - @defer.inlineCallbacks - def delete_old_remote_media(self, before_ts): - old_media = yield self.store.get_remote_media_before(before_ts) + async def delete_old_remote_media(self, before_ts): + old_media = await self.store.get_remote_media_before(before_ts) deleted = 0 @@ -689,7 +677,7 @@ class MediaRepository(object): # TODO: Should we delete from the backup store - with (yield self.remote_media_linearizer.queue(key)): + with (await self.remote_media_linearizer.queue(key)): full_path = self.filepaths.remote_media_filepath(origin, file_id) try: os.remove(full_path) @@ -705,7 +693,7 @@ class MediaRepository(object): ) shutil.rmtree(thumbnail_dir, ignore_errors=True) - yield self.store.delete_remote_media(origin, media_id) + await self.store.delete_remote_media(origin, media_id) deleted += 1 return {"deleted": deleted} diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 07e395cfd1..f206605727 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -86,6 +86,7 @@ class PreviewUrlResource(DirectServeResource): self.media_storage = media_storage self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist + self.url_preview_accept_language = hs.config.url_preview_accept_language # memory cache mapping urls to an ObservableDeferred returning # JSON-encoded OG metadata @@ -165,8 +166,7 @@ class PreviewUrlResource(DirectServeResource): og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe)) respond_with_json_bytes(request, 200, og, send_cors=True) - @defer.inlineCallbacks - def _do_preview(self, url, user, ts): + async def _do_preview(self, url, user, ts): """Check the db, and download the URL and build a preview Args: @@ -179,7 +179,7 @@ class PreviewUrlResource(DirectServeResource): """ # check the URL cache in the DB (which will also provide us with # historical previews, if we have any) - cache_result = yield self.store.get_url_cache(url, ts) + cache_result = await self.store.get_url_cache(url, ts) if ( cache_result and cache_result["expires_ts"] > ts @@ -192,13 +192,13 @@ class PreviewUrlResource(DirectServeResource): og = og.encode("utf8") return og - media_info = yield self._download_url(url, user) + media_info = await self._download_url(url, user) logger.debug("got media_info of '%s'", media_info) if _is_media(media_info["media_type"]): file_id = media_info["filesystem_id"] - dims = yield self.media_repo._generate_thumbnails( + dims = await self.media_repo._generate_thumbnails( None, file_id, file_id, media_info["media_type"], url_cache=True ) @@ -248,14 +248,14 @@ class PreviewUrlResource(DirectServeResource): # request itself and benefit from the same caching etc. But for now we # just rely on the caching on the master request to speed things up. if "og:image" in og and og["og:image"]: - image_info = yield self._download_url( + image_info = await self._download_url( _rebase_url(og["og:image"], media_info["uri"]), user ) if _is_media(image_info["media_type"]): # TODO: make sure we don't choke on white-on-transparent images file_id = image_info["filesystem_id"] - dims = yield self.media_repo._generate_thumbnails( + dims = await self.media_repo._generate_thumbnails( None, file_id, file_id, image_info["media_type"], url_cache=True ) if dims: @@ -293,7 +293,7 @@ class PreviewUrlResource(DirectServeResource): jsonog = json.dumps(og) # store OG in history-aware DB cache - yield self.store.store_url_cache( + await self.store.store_url_cache( url, media_info["response_code"], media_info["etag"], @@ -305,8 +305,7 @@ class PreviewUrlResource(DirectServeResource): return jsonog.encode("utf8") - @defer.inlineCallbacks - def _download_url(self, url, user): + async def _download_url(self, url, user): # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -317,9 +316,12 @@ class PreviewUrlResource(DirectServeResource): with self.media_storage.store_into_file(file_info) as (f, fname, finish): try: - logger.debug("Trying to get url '%s'", url) - length, headers, uri, code = yield self.client.get_file( - url, output_stream=f, max_size=self.max_spider_size + logger.debug("Trying to get preview for url '%s'", url) + length, headers, uri, code = await self.client.get_file( + url, + output_stream=f, + max_size=self.max_spider_size, + headers={"Accept-Language": self.url_preview_accept_language}, ) except SynapseError: # Pass SynapseErrors through directly, so that the servlet @@ -345,7 +347,7 @@ class PreviewUrlResource(DirectServeResource): % (traceback.format_exception_only(sys.exc_info()[0], e),), Codes.UNKNOWN, ) - yield finish() + await finish() try: if b"Content-Type" in headers: @@ -356,7 +358,7 @@ class PreviewUrlResource(DirectServeResource): download_name = get_filename_from_headers(headers) - yield self.store.store_local_media( + await self.store.store_local_media( media_id=file_id, media_type=media_type, time_now_ms=self.clock.time_msec(), @@ -393,22 +395,21 @@ class PreviewUrlResource(DirectServeResource): "expire_url_cache_data", self._expire_url_cache_data ) - @defer.inlineCallbacks - def _expire_url_cache_data(self): + async def _expire_url_cache_data(self): """Clean up expired url cache content, media and thumbnails. """ # TODO: Delete from backup media store now = self.clock.time_msec() - logger.info("Running url preview cache expiry") + logger.debug("Running url preview cache expiry") - if not (yield self.store.db.updates.has_completed_background_updates()): + if not (await self.store.db.updates.has_completed_background_updates()): logger.info("Still running DB updates; skipping expiry") return # First we delete expired url cache entries - media_ids = yield self.store.get_expired_url_cache(now) + media_ids = await self.store.get_expired_url_cache(now) removed_media = [] for media_id in media_ids: @@ -430,17 +431,19 @@ class PreviewUrlResource(DirectServeResource): except Exception: pass - yield self.store.delete_url_cache(removed_media) + await self.store.delete_url_cache(removed_media) if removed_media: logger.info("Deleted %d entries from url cache", len(removed_media)) + else: + logger.debug("No entries removed from url cache") # Now we delete old images associated with the url cache. # These may be cached for a bit on the client (i.e., they # may have a room open with a preview url thing open). # So we wait a couple of days before deleting, just in case. expire_before = now - 2 * 24 * 60 * 60 * 1000 - media_ids = yield self.store.get_url_cache_media_before(expire_before) + media_ids = await self.store.get_url_cache_media_before(expire_before) removed_media = [] for media_id in media_ids: @@ -478,9 +481,12 @@ class PreviewUrlResource(DirectServeResource): except Exception: pass - yield self.store.delete_url_cache_media(removed_media) + await self.store.delete_url_cache_media(removed_media) - logger.info("Deleted %d media from url cache", len(removed_media)) + if removed_media: + logger.info("Deleted %d media from url cache", len(removed_media)) + else: + logger.debug("No media removed from url cache") def decode_and_calc_og(body, media_uri, request_encoding=None): diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index d57480f761..0b87220234 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.http.server import ( DirectServeResource, set_cors_headers, @@ -79,11 +77,10 @@ class ThumbnailResource(DirectServeResource): ) self.media_repo.mark_recently_accessed(server_name, media_id) - @defer.inlineCallbacks - def _respond_local_thumbnail( + async def _respond_local_thumbnail( self, request, media_id, width, height, method, m_type ): - media_info = yield self.store.get_local_media(media_id) + media_info = await self.store.get_local_media(media_id) if not media_info: respond_404(request) @@ -93,7 +90,7 @@ class ThumbnailResource(DirectServeResource): respond_404(request) return - thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) + thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) if thumbnail_infos: thumbnail_info = self._select_thumbnail( @@ -114,14 +111,13 @@ class ThumbnailResource(DirectServeResource): t_type = file_info.thumbnail_type t_length = thumbnail_info["thumbnail_length"] - responder = yield self.media_storage.fetch_media(file_info) - yield respond_with_responder(request, responder, t_type, t_length) + responder = await self.media_storage.fetch_media(file_info) + await respond_with_responder(request, responder, t_type, t_length) else: logger.info("Couldn't find any generated thumbnails") respond_404(request) - @defer.inlineCallbacks - def _select_or_generate_local_thumbnail( + async def _select_or_generate_local_thumbnail( self, request, media_id, @@ -130,7 +126,7 @@ class ThumbnailResource(DirectServeResource): desired_method, desired_type, ): - media_info = yield self.store.get_local_media(media_id) + media_info = await self.store.get_local_media(media_id) if not media_info: respond_404(request) @@ -140,7 +136,7 @@ class ThumbnailResource(DirectServeResource): respond_404(request) return - thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) + thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) for info in thumbnail_infos: t_w = info["thumbnail_width"] == desired_width t_h = info["thumbnail_height"] == desired_height @@ -162,15 +158,15 @@ class ThumbnailResource(DirectServeResource): t_type = file_info.thumbnail_type t_length = info["thumbnail_length"] - responder = yield self.media_storage.fetch_media(file_info) + responder = await self.media_storage.fetch_media(file_info) if responder: - yield respond_with_responder(request, responder, t_type, t_length) + await respond_with_responder(request, responder, t_type, t_length) return logger.debug("We don't have a thumbnail of that size. Generating") # Okay, so we generate one. - file_path = yield self.media_repo.generate_local_exact_thumbnail( + file_path = await self.media_repo.generate_local_exact_thumbnail( media_id, desired_width, desired_height, @@ -180,13 +176,12 @@ class ThumbnailResource(DirectServeResource): ) if file_path: - yield respond_with_file(request, desired_type, file_path) + await respond_with_file(request, desired_type, file_path) else: logger.warning("Failed to generate thumbnail") respond_404(request) - @defer.inlineCallbacks - def _select_or_generate_remote_thumbnail( + async def _select_or_generate_remote_thumbnail( self, request, server_name, @@ -196,9 +191,9 @@ class ThumbnailResource(DirectServeResource): desired_method, desired_type, ): - media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) + media_info = await self.media_repo.get_remote_media_info(server_name, media_id) - thumbnail_infos = yield self.store.get_remote_media_thumbnails( + thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) @@ -224,15 +219,15 @@ class ThumbnailResource(DirectServeResource): t_type = file_info.thumbnail_type t_length = info["thumbnail_length"] - responder = yield self.media_storage.fetch_media(file_info) + responder = await self.media_storage.fetch_media(file_info) if responder: - yield respond_with_responder(request, responder, t_type, t_length) + await respond_with_responder(request, responder, t_type, t_length) return logger.debug("We don't have a thumbnail of that size. Generating") # Okay, so we generate one. - file_path = yield self.media_repo.generate_remote_exact_thumbnail( + file_path = await self.media_repo.generate_remote_exact_thumbnail( server_name, file_id, media_id, @@ -243,21 +238,20 @@ class ThumbnailResource(DirectServeResource): ) if file_path: - yield respond_with_file(request, desired_type, file_path) + await respond_with_file(request, desired_type, file_path) else: logger.warning("Failed to generate thumbnail") respond_404(request) - @defer.inlineCallbacks - def _respond_remote_thumbnail( + async def _respond_remote_thumbnail( self, request, server_name, media_id, width, height, method, m_type ): # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead of # downloading the remote file and generating our own thumbnails. - media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) + media_info = await self.media_repo.get_remote_media_info(server_name, media_id) - thumbnail_infos = yield self.store.get_remote_media_thumbnails( + thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) @@ -278,8 +272,8 @@ class ThumbnailResource(DirectServeResource): t_type = file_info.thumbnail_type t_length = thumbnail_info["thumbnail_length"] - responder = yield self.media_storage.fetch_media(file_info) - yield respond_with_responder(request, responder, t_type, t_length) + responder = await self.media_storage.fetch_media(file_info) + await respond_with_responder(request, responder, t_type, t_length) else: logger.info("Failed to find any generated thumbnails") respond_404(request) |