diff options
Diffstat (limited to 'synapse/rest')
30 files changed, 298 insertions, 3626 deletions
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 14c4e6ebbb..2e19e055d3 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -108,8 +108,7 @@ class ClientRestResource(JsonResource): if is_main_process: logout.register_servlets(hs, client_resource) sync.register_servlets(hs, client_resource) - if is_main_process: - filter.register_servlets(hs, client_resource) + filter.register_servlets(hs, client_resource) account.register_servlets(hs, client_resource) register.register_servlets(hs, client_resource) if is_main_process: @@ -140,7 +139,7 @@ class ClientRestResource(JsonResource): relations.register_servlets(hs, client_resource) if is_main_process: password_policy.register_servlets(hs, client_resource) - knock.register_servlets(hs, client_resource) + knock.register_servlets(hs, client_resource) # moving to /_synapse/admin if is_main_process: diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index a3beb74e2c..c546ef7e23 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -53,11 +53,11 @@ class EventReportsRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports$") def __init__(self, hs: "HomeServer"): - self.auth = hs.get_auth() - self.store = hs.get_datastores().main + self._auth = hs.get_auth() + self._store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self.auth, request) + await assert_requester_is_admin(self._auth, request) start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) @@ -79,7 +79,7 @@ class EventReportsRestServlet(RestServlet): errcode=Codes.INVALID_PARAM, ) - event_reports, total = await self.store.get_event_reports_paginate( + event_reports, total = await self._store.get_event_reports_paginate( start, limit, direction, user_id, room_id ) ret = {"event_reports": event_reports, "total": total} @@ -108,13 +108,13 @@ class EventReportDetailRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$") def __init__(self, hs: "HomeServer"): - self.auth = hs.get_auth() - self.store = hs.get_datastores().main + self._auth = hs.get_auth() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, report_id: str ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self.auth, request) + await assert_requester_is_admin(self._auth, request) message = ( "The report_id parameter must be a string representing a positive integer." @@ -131,8 +131,33 @@ class EventReportDetailRestServlet(RestServlet): HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM ) - ret = await self.store.get_event_report(resolved_report_id) + ret = await self._store.get_event_report(resolved_report_id) if not ret: raise NotFoundError("Event report not found") return HTTPStatus.OK, ret + + async def on_DELETE( + self, request: SynapseRequest, report_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + message = ( + "The report_id parameter must be a string representing a positive integer." + ) + try: + resolved_report_id = int(report_id) + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) + + if resolved_report_id < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) + + if await self._store.delete_event_report(resolved_report_id): + return HTTPStatus.OK, {} + + raise NotFoundError("Event report not found") diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 1d6e4982d7..4de56bf13f 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -75,7 +75,6 @@ class RoomRestV2Servlet(RestServlet): async def on_DELETE( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) await assert_user_is_admin(self._auth, requester) @@ -144,7 +143,6 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self._auth, request) if not RoomID.is_valid(room_id): @@ -181,7 +179,6 @@ class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, delete_id: str ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self._auth, request) delete_status = self._pagination_handler.get_delete_status(delete_id) @@ -438,7 +435,6 @@ class RoomStateRestServlet(RestServlet): class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): - PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)$") def __init__(self, hs: "HomeServer"): diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 0c0bf540b9..357e9a574d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -304,13 +304,20 @@ class UserRestServletV2(RestServlet): # remove old threepids for medium, address in del_threepids: try: - await self.auth_handler.delete_threepid( - user_id, medium, address, None + # Attempt to remove any known bindings of this third-party ID + # and user ID from identity servers. + await self.hs.get_identity_handler().try_unbind_threepid( + user_id, medium, address, id_server=None ) except Exception: logger.exception("Failed to remove threepids") raise SynapseError(500, "Failed to remove threepids") + # Delete the local association of this user ID and third-party ID. + await self.auth_handler.delete_local_threepid( + user_id, medium, address + ) + # add new threepids current_time = self.hs.get_clock().time_msec() for medium, address in add_threepids: @@ -683,8 +690,12 @@ class AccountValidityRenewServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if self.account_activity_handler.on_legacy_admin_request_callback: - expiration_ts = await ( - self.account_activity_handler.on_legacy_admin_request_callback(request) + expiration_ts = ( + await ( + self.account_activity_handler.on_legacy_admin_request_callback( + request + ) + ) ) else: body = parse_json_object_from_request(request) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 662f5bf762..484d7440a4 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -768,7 +768,9 @@ class ThreepidDeleteRestServlet(RestServlet): user_id = requester.user.to_string() try: - ret = await self.auth_handler.delete_threepid( + # Attempt to remove any known bindings of this third-party ID + # and user ID from identity servers. + ret = await self.hs.get_identity_handler().try_unbind_threepid( user_id, body.medium, body.address, body.id_server ) except Exception: @@ -783,6 +785,11 @@ class ThreepidDeleteRestServlet(RestServlet): else: id_server_unbind_result = "no-support" + # Delete the local association of this user ID and third-party ID. + await self.auth_handler.delete_local_threepid( + user_id, body.medium, body.address + ) + return 200, {"id_server_unbind_result": id_server_unbind_result} diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index eb77337044..276a1b405d 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -97,7 +97,6 @@ class AuthRestServlet(RestServlet): return None async def on_POST(self, request: Request, stagetype: str) -> None: - session = parse_string(request, "session") if not session: raise SynapseError(400, "No session supplied") diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 486c6dbbc5..dab4a77f7e 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -255,7 +255,7 @@ class DehydratedDeviceServlet(RestServlet): """ - PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=()) + PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device$", releases=()) def __init__(self, hs: "HomeServer"): super().__init__() diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 782e7d14e8..694d77d287 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -17,6 +17,7 @@ import logging from typing import TYPE_CHECKING, Dict, List, Tuple, Union from synapse.api.errors import SynapseError +from synapse.events.utils import SerializeEventConfig from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_string from synapse.http.site import SynapseRequest @@ -43,9 +44,8 @@ class EventStreamRestServlet(RestServlet): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - is_guest = requester.is_guest args: Dict[bytes, List[bytes]] = request.args # type: ignore - if is_guest: + if requester.is_guest: if b"room_id" not in args: raise SynapseError(400, "Guest users must specify room_id param") room_id = parse_string(request, "room_id") @@ -63,13 +63,12 @@ class EventStreamRestServlet(RestServlet): as_client_event = b"raw" not in args chunk = await self.event_stream_handler.get_stream( - requester.user.to_string(), + requester, pagin_config, timeout=timeout, as_client_event=as_client_event, - affect_presence=(not is_guest), + affect_presence=(not requester.is_guest), room_id=room_id, - is_guest=is_guest, ) return 200, chunk @@ -91,9 +90,12 @@ class EventRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) event = await self.event_handler.get_event(requester.user, None, event_id) - time_now = self.clock.time_msec() if event: - result = self._event_serializer.serialize_event(event, time_now) + result = self._event_serializer.serialize_event( + event, + self.clock.time_msec(), + config=SerializeEventConfig(requester=requester), + ) return 200, result else: return 404, "Event not found." diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py index cc1c2f9731..236199897c 100644 --- a/synapse/rest/client/filter.py +++ b/synapse/rest/client/filter.py @@ -79,7 +79,6 @@ class CreateFilterRestServlet(RestServlet): async def on_POST( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: - target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 7873b363c0..32bb8b9a91 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -312,15 +312,29 @@ class SigningKeyUploadServlet(RestServlet): user_id = requester.user.to_string() body = parse_json_object_from_request(request) - await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - "add a device signing key to your account", - # Allow skipping of UI auth since this is frequently called directly - # after login and it is silly to ask users to re-auth immediately. - can_skip_ui_auth=True, - ) + if self.hs.config.experimental.msc3967_enabled: + if await self.e2e_keys_handler.is_cross_signing_set_up_for_user(user_id): + # If we already have a master key then cross signing is set up and we require UIA to reset + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "reset the device signing key on your account", + # Do not allow skipping of UIA auth. + can_skip_ui_auth=False, + ) + # Otherwise we don't require UIA since we are setting up cross signing for first time + else: + # Previous behaviour is to always require UIA but allow it to be skipped + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "add a device signing key to your account", + # Allow skipping of UI auth since this is frequently called directly + # after login and it is silly to ask users to re-auth immediately. + can_skip_ui_auth=True, + ) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) return 200, result diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py index ad025c8a45..4fa66904ba 100644 --- a/synapse/rest/client/knock.py +++ b/synapse/rest/client/knock.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple from synapse.api.constants import Membership from synapse.api.errors import SynapseError @@ -24,8 +24,6 @@ from synapse.http.servlet import ( parse_strings_from_args, ) from synapse.http.site import SynapseRequest -from synapse.logging.opentracing import set_tag -from synapse.rest.client.transactions import HttpTransactionCache from synapse.types import JsonDict, RoomAlias, RoomID if TYPE_CHECKING: @@ -45,7 +43,6 @@ class KnockRoomAliasServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.txns = HttpTransactionCache(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -53,7 +50,6 @@ class KnockRoomAliasServlet(RestServlet): self, request: SynapseRequest, room_identifier: str, - txn_id: Optional[str] = None, ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -67,7 +63,6 @@ class KnockRoomAliasServlet(RestServlet): # twisted.web.server.Request.args is incorrectly defined as Optional[Any] args: Dict[bytes, List[bytes]] = request.args # type: ignore - remote_room_hosts = parse_strings_from_args( args, "server_name", required=False ) @@ -86,7 +81,6 @@ class KnockRoomAliasServlet(RestServlet): target=requester.user, room_id=room_id, action=Membership.KNOCK, - txn_id=txn_id, third_party_signed=None, remote_room_hosts=remote_room_hosts, content=event_content, @@ -94,15 +88,6 @@ class KnockRoomAliasServlet(RestServlet): return 200, {"room_id": room_id} - def on_PUT( - self, request: SynapseRequest, room_identifier: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: - set_tag("txn_id", txn_id) - - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_identifier, txn_id - ) - def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: KnockRoomAliasServlet(hs).register(http_server) diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index 61268e3af1..ea10042569 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -72,6 +72,12 @@ class NotificationsServlet(RestServlet): next_token = None + serialize_options = SerializeEventConfig( + event_format=format_event_for_client_v2_without_room_id, + requester=requester, + ) + now = self.clock.time_msec() + for pa in push_actions: returned_pa = { "room_id": pa.room_id, @@ -81,10 +87,8 @@ class NotificationsServlet(RestServlet): "event": ( self._event_serializer.serialize_event( notif_events[pa.event_id], - self.clock.time_msec(), - config=SerializeEventConfig( - event_format=format_event_for_client_v2_without_room_id - ), + now, + config=serialize_options, ) ), } diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 3cb1e7e375..bce806f2bb 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -628,10 +628,12 @@ class RegisterRestServlet(RestServlet): if not password_hash: raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) - desired_username = await ( - self.password_auth_provider.get_username_for_registration( - auth_result, - params, + desired_username = ( + await ( + self.password_auth_provider.get_username_for_registration( + auth_result, + params, + ) ) ) @@ -682,9 +684,11 @@ class RegisterRestServlet(RestServlet): session_id ) - display_name = await ( - self.password_auth_provider.get_displayname_for_registration( - auth_result, params + display_name = ( + await ( + self.password_auth_provider.get_displayname_for_registration( + auth_result, params + ) ) ) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index d0db85cca7..61e4cf0213 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -37,7 +37,7 @@ from synapse.api.errors import ( UnredactedContentDeletedError, ) from synapse.api.filtering import Filter -from synapse.events.utils import format_event_for_client_v2 +from synapse.events.utils import SerializeEventConfig, format_event_for_client_v2 from synapse.http.server import HttpServer from synapse.http.servlet import ( ResolveRoomIdMixin, @@ -160,11 +160,11 @@ class RoomCreateRestServlet(TransactionRestServlet): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - info, _ = await self._room_creation_handler.create_room( + room_id, _, _ = await self._room_creation_handler.create_room( requester, self.get_room_config(request) ) - return 200, info + return 200, {"room_id": room_id} def get_room_config(self, request: Request) -> JsonDict: user_supplied_config = parse_json_object_from_request(request) @@ -814,11 +814,13 @@ class RoomEventServlet(RestServlet): [event], requester.user.to_string() ) - time_now = self.clock.time_msec() # per MSC2676, /rooms/{roomId}/event/{eventId}, should return the # *original* event, rather than the edited version event_dict = self._event_serializer.serialize_event( - event, time_now, bundle_aggregations=aggregations, apply_edits=False + event, + self.clock.time_msec(), + bundle_aggregations=aggregations, + config=SerializeEventConfig(requester=requester), ) return 200, event_dict @@ -863,24 +865,30 @@ class RoomEventContextServlet(RestServlet): raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() + serializer_options = SerializeEventConfig(requester=requester) results = { "events_before": self._event_serializer.serialize_events( event_context.events_before, time_now, bundle_aggregations=event_context.aggregations, + config=serializer_options, ), "event": self._event_serializer.serialize_event( event_context.event, time_now, bundle_aggregations=event_context.aggregations, + config=serializer_options, ), "events_after": self._event_serializer.serialize_events( event_context.events_after, time_now, bundle_aggregations=event_context.aggregations, + config=serializer_options, ), "state": self._event_serializer.serialize_events( - event_context.state, time_now + event_context.state, + time_now, + config=serializer_options, ), "start": event_context.start, "end": event_context.end, @@ -926,7 +934,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): self.auth = hs.get_auth() def register(self, http_server: HttpServer) -> None: - # /rooms/$roomid/[invite|join|leave] + # /rooms/$roomid/[join|invite|leave|ban|unban|kick] PATTERNS = ( "/rooms/(?P<room_id>[^/]*)/" "(?P<membership_action>join|invite|leave|ban|unban|kick)" @@ -1192,7 +1200,7 @@ class SearchRestServlet(RestServlet): content = parse_json_object_from_request(request) batch = parse_string(request, "next_batch") - results = await self.search_handler.search(requester.user, content, batch) + results = await self.search_handler.search(requester, content, batch) return 200, results diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 10be4a781b..ef284ecc11 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -15,9 +15,7 @@ import logging import re from http import HTTPStatus -from typing import TYPE_CHECKING, Awaitable, Tuple - -from twisted.web.server import Request +from typing import TYPE_CHECKING, Tuple from synapse.api.constants import EventContentFields from synapse.api.errors import AuthError, Codes, SynapseError @@ -30,7 +28,6 @@ from synapse.http.servlet import ( parse_strings_from_args, ) from synapse.http.site import SynapseRequest -from synapse.rest.client.transactions import HttpTransactionCache from synapse.types import JsonDict if TYPE_CHECKING: @@ -79,7 +76,6 @@ class RoomBatchSendEventRestServlet(RestServlet): self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() self.room_batch_handler = hs.get_room_batch_handler() - self.txns = HttpTransactionCache(hs) async def on_POST( self, request: SynapseRequest, room_id: str @@ -249,16 +245,6 @@ class RoomBatchSendEventRestServlet(RestServlet): return HTTPStatus.OK, response_dict - def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]: - return HTTPStatus.NOT_IMPLEMENTED, "Not implemented" - - def on_PUT( - self, request: SynapseRequest, room_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id - ) - def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: msc2716_enabled = hs.config.experimental.msc2716_enabled diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index f2013faeb2..e578b26fa3 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -16,7 +16,7 @@ import logging from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union -from synapse.api.constants import EduTypes, Membership, PresenceState +from synapse.api.constants import AccountDataTypes, EduTypes, Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState @@ -38,7 +38,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.logging.opentracing import trace_with_opname -from synapse.types import JsonDict, StreamToken +from synapse.types import JsonDict, Requester, StreamToken from synapse.util import json_decoder from ._base import client_patterns, set_timeline_upper_limit @@ -139,7 +139,28 @@ class SyncRestServlet(RestServlet): device_id, ) - request_key = (user, timeout, since, filter_id, full_state, device_id) + # Stream position of the last ignored users account data event for this user, + # if we're initial syncing. + # We include this in the request key to invalidate an initial sync + # in the response cache once the set of ignored users has changed. + # (We filter out ignored users from timeline events, so our sync response + # is invalid once the set of ignored users changes.) + last_ignore_accdata_streampos: Optional[int] = None + if not since: + # No `since`, so this is an initial sync. + last_ignore_accdata_streampos = await self.store.get_latest_stream_id_for_global_account_data_by_type_for_user( + user.to_string(), AccountDataTypes.IGNORED_USER_LIST + ) + + request_key = ( + user, + timeout, + since, + filter_id, + full_state, + device_id, + last_ignore_accdata_streampos, + ) if filter_id is None: filter_collection = self.filtering.DEFAULT_FILTER_COLLECTION @@ -205,7 +226,7 @@ class SyncRestServlet(RestServlet): # We know that the the requester has an access token since appservices # cannot use sync. response_content = await self.encode_response( - time_now, sync_result, requester.access_token_id, filter_collection + time_now, sync_result, requester, filter_collection ) logger.debug("Event formatting complete") @@ -216,7 +237,7 @@ class SyncRestServlet(RestServlet): self, time_now: int, sync_result: SyncResult, - access_token_id: Optional[int], + requester: Requester, filter: FilterCollection, ) -> JsonDict: logger.debug("Formatting events in sync response") @@ -229,12 +250,12 @@ class SyncRestServlet(RestServlet): serialize_options = SerializeEventConfig( event_format=event_formatter, - token_id=access_token_id, + requester=requester, only_event_fields=filter.event_fields, ) stripped_serialize_options = SerializeEventConfig( event_format=event_formatter, - token_id=access_token_id, + requester=requester, include_stripped_room_state=True, ) diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/config_resource.py index a95804d327..a95804d327 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/config_resource.py diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/download_resource.py index 048a042692..8f270cf4cc 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/download_resource.py @@ -22,11 +22,10 @@ from synapse.http.server import ( ) from synapse.http.servlet import parse_boolean from synapse.http.site import SynapseRequest - -from ._base import parse_media_id, respond_404 +from synapse.media._base import parse_media_id, respond_404 if TYPE_CHECKING: - from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/media_repository_resource.py b/synapse/rest/media/media_repository_resource.py new file mode 100644 index 0000000000..5ebaa3b032 --- /dev/null +++ b/synapse/rest/media/media_repository_resource.py @@ -0,0 +1,93 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from synapse.config._base import ConfigError +from synapse.http.server import UnrecognizedRequestResource + +from .config_resource import MediaConfigResource +from .download_resource import DownloadResource +from .preview_url_resource import PreviewUrlResource +from .thumbnail_resource import ThumbnailResource +from .upload_resource import UploadResource + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class MediaRepositoryResource(UnrecognizedRequestResource): + """File uploading and downloading. + + Uploads are POSTed to a resource which returns a token which is used to GET + the download:: + + => POST /_matrix/media/r0/upload HTTP/1.1 + Content-Type: <media-type> + Content-Length: <content-length> + + <media> + + <= HTTP/1.1 200 OK + Content-Type: application/json + + { "content_uri": "mxc://<server-name>/<media-id>" } + + => GET /_matrix/media/r0/download/<server-name>/<media-id> HTTP/1.1 + + <= HTTP/1.1 200 OK + Content-Type: <media-type> + Content-Disposition: attachment;filename=<upload-filename> + + <media> + + Clients can get thumbnails by supplying a desired width and height and + thumbnailing method:: + + => GET /_matrix/media/r0/thumbnail/<server_name> + /<media-id>?width=<w>&height=<h>&method=<m> HTTP/1.1 + + <= HTTP/1.1 200 OK + Content-Type: image/jpeg or image/png + + <thumbnail> + + The thumbnail methods are "crop" and "scale". "scale" tries to return an + image where either the width or the height is smaller than the requested + size. The client should then scale and letterbox the image if it needs to + fit within a given rectangle. "crop" tries to return an image where the + width and height are close to the requested size and the aspect matches + the requested size. The client should scale the image if it needs to fit + within a given rectangle. + """ + + def __init__(self, hs: "HomeServer"): + # If we're not configured to use it, raise if we somehow got here. + if not hs.config.media.can_load_media_repo: + raise ConfigError("Synapse is not configured to use a media repo.") + + super().__init__() + media_repo = hs.get_media_repository() + + self.putChild(b"upload", UploadResource(hs, media_repo)) + self.putChild(b"download", DownloadResource(hs, media_repo)) + self.putChild( + b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage) + ) + if hs.config.media.url_preview_enabled: + self.putChild( + b"preview_url", + PreviewUrlResource(hs, media_repo, media_repo.media_storage), + ) + self.putChild(b"config", MediaConfigResource(hs)) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/preview_url_resource.py index a8f6fd6b35..7ada728757 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/preview_url_resource.py @@ -40,21 +40,19 @@ from synapse.http.server import ( from synapse.http.servlet import parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.media._base import FileInfo, get_filename_from_headers +from synapse.media.media_storage import MediaStorage +from synapse.media.oembed import OEmbedProvider +from synapse.media.preview_html import decode_body, parse_html_to_open_graph from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.rest.media.v1._base import get_filename_from_headers -from synapse.rest.media.v1.media_storage import MediaStorage -from synapse.rest.media.v1.oembed import OEmbedProvider -from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph from synapse.types import JsonDict, UserID from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.stringutils import random_string -from ._base import FileInfo - if TYPE_CHECKING: - from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -163,6 +161,10 @@ class PreviewUrlResource(DirectServeJsonResource): 7. Stores the result in the database cache. 4. Returns the result. + If any additional requests (e.g. from oEmbed autodiscovery, step 5.3 or + image thumbnailing, step 5.4 or 6.4) fails then the URL preview as a whole + does not fail. As much information as possible is returned. + The in-memory cache expires after 1 hour. Expired entries in the database cache (and their associated media files) are @@ -364,16 +366,25 @@ class PreviewUrlResource(DirectServeJsonResource): oembed_url = self._oembed.autodiscover_from_html(tree) og_from_oembed: JsonDict = {} if oembed_url: - oembed_info = await self._handle_url( - oembed_url, user, allow_data_urls=True - ) - ( - og_from_oembed, - author_name, - expiration_ms, - ) = await self._handle_oembed_response( - url, oembed_info, expiration_ms - ) + try: + oembed_info = await self._handle_url( + oembed_url, user, allow_data_urls=True + ) + except Exception as e: + # Fetching the oEmbed info failed, don't block the entire URL preview. + logger.warning( + "oEmbed fetch failed during URL preview: %s errored with %s", + oembed_url, + e, + ) + else: + ( + og_from_oembed, + author_name, + expiration_ms, + ) = await self._handle_oembed_response( + url, oembed_info, expiration_ms + ) # Parse Open Graph information from the HTML in case the oEmbed # response failed or is incomplete. diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py index 5f725c7600..4ee2a0dbda 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/thumbnail_resource.py @@ -27,9 +27,7 @@ from synapse.http.server import ( ) from synapse.http.servlet import parse_integer, parse_string from synapse.http.site import SynapseRequest -from synapse.rest.media.v1.media_storage import MediaStorage - -from ._base import ( +from synapse.media._base import ( FileInfo, ThumbnailInfo, parse_media_id, @@ -37,9 +35,10 @@ from ._base import ( respond_with_file, respond_with_responder, ) +from synapse.media.media_storage import MediaStorage if TYPE_CHECKING: - from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -69,7 +68,8 @@ class ThumbnailResource(DirectServeJsonResource): width = parse_integer(request, "width", required=True) height = parse_integer(request, "height", required=True) method = parse_string(request, "method", "scale") - m_type = parse_string(request, "type", "image/png") + # TODO Parse the Accept header to get an prioritised list of thumbnail types. + m_type = "image/png" if server_name == self.server_name: if self.dynamic_thumbnails: diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/upload_resource.py index 97548b54e5..697348613b 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/upload_resource.py @@ -20,10 +20,10 @@ from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_bytes_from_args from synapse.http.site import SynapseRequest -from synapse.rest.media.v1.media_storage import SpamMediaException +from synapse.media.media_storage import SpamMediaException if TYPE_CHECKING: - from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index d30878f704..88427a5737 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -1,5 +1,4 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# Copyright 2023 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,466 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# -import logging -import os -import urllib -from types import TracebackType -from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type - -import attr - -from twisted.internet.interfaces import IConsumer -from twisted.protocols.basic import FileSender -from twisted.web.server import Request - -from synapse.api.errors import Codes, SynapseError, cs_error -from synapse.http.server import finish_request, respond_with_json -from synapse.http.site import SynapseRequest -from synapse.logging.context import make_deferred_yieldable -from synapse.util.stringutils import is_ascii, parse_and_validate_server_name - -logger = logging.getLogger(__name__) - -# list all text content types that will have the charset default to UTF-8 when -# none is given -TEXT_CONTENT_TYPES = [ - "text/css", - "text/csv", - "text/html", - "text/calendar", - "text/plain", - "text/javascript", - "application/json", - "application/ld+json", - "application/rtf", - "image/svg+xml", - "text/xml", -] - - -def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: - """Parses the server name, media ID and optional file name from the request URI - - Also performs some rough validation on the server name. - - Args: - request: The `Request`. - - Returns: - A tuple containing the parsed server name, media ID and optional file name. - - Raises: - SynapseError(404): if parsing or validation fail for any reason - """ - try: - # The type on postpath seems incorrect in Twisted 21.2.0. - postpath: List[bytes] = request.postpath # type: ignore - assert postpath - - # This allows users to append e.g. /test.png to the URL. Useful for - # clients that parse the URL to see content type. - server_name_bytes, media_id_bytes = postpath[:2] - server_name = server_name_bytes.decode("utf-8") - media_id = media_id_bytes.decode("utf8") - - # Validate the server name, raising if invalid - parse_and_validate_server_name(server_name) - - file_name = None - if len(postpath) > 2: - try: - file_name = urllib.parse.unquote(postpath[-1].decode("utf-8")) - except UnicodeDecodeError: - pass - return server_name, media_id, file_name - except Exception: - raise SynapseError( - 404, "Invalid media id token %r" % (request.postpath,), Codes.UNKNOWN - ) - - -def respond_404(request: SynapseRequest) -> None: - respond_with_json( - request, - 404, - cs_error("Not found %r" % (request.postpath,), code=Codes.NOT_FOUND), - send_cors=True, - ) - - -async def respond_with_file( - request: SynapseRequest, - media_type: str, - file_path: str, - file_size: Optional[int] = None, - upload_name: Optional[str] = None, -) -> None: - logger.debug("Responding with %r", file_path) - - if os.path.isfile(file_path): - if file_size is None: - stat = os.stat(file_path) - file_size = stat.st_size - - add_file_headers(request, media_type, file_size, upload_name) - - with open(file_path, "rb") as f: - await make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) - - finish_request(request) - else: - respond_404(request) - - -def add_file_headers( - request: Request, - media_type: str, - file_size: Optional[int], - upload_name: Optional[str], -) -> None: - """Adds the correct response headers in preparation for responding with the - media. - - Args: - request - media_type: The media/content type. - file_size: Size in bytes of the media, if known. - upload_name: The name of the requested file, if any. - """ - - def _quote(x: str) -> str: - return urllib.parse.quote(x.encode("utf-8")) - - # Default to a UTF-8 charset for text content types. - # ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16' - if media_type.lower() in TEXT_CONTENT_TYPES: - content_type = media_type + "; charset=UTF-8" - else: - content_type = media_type - - request.setHeader(b"Content-Type", content_type.encode("UTF-8")) - if upload_name: - # RFC6266 section 4.1 [1] defines both `filename` and `filename*`. - # - # `filename` is defined to be a `value`, which is defined by RFC2616 - # section 3.6 [2] to be a `token` or a `quoted-string`, where a `token` - # is (essentially) a single US-ASCII word, and a `quoted-string` is a - # US-ASCII string surrounded by double-quotes, using backslash as an - # escape character. Note that %-encoding is *not* permitted. - # - # `filename*` is defined to be an `ext-value`, which is defined in - # RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`, - # where `value-chars` is essentially a %-encoded string in the given charset. - # - # [1]: https://tools.ietf.org/html/rfc6266#section-4.1 - # [2]: https://tools.ietf.org/html/rfc2616#section-3.6 - # [3]: https://tools.ietf.org/html/rfc5987#section-3.2.1 - - # We avoid the quoted-string version of `filename`, because (a) synapse didn't - # correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we - # may as well just do the filename* version. - if _can_encode_filename_as_token(upload_name): - disposition = "inline; filename=%s" % (upload_name,) - else: - disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),) - - request.setHeader(b"Content-Disposition", disposition.encode("ascii")) - - # cache for at least a day. - # XXX: we might want to turn this off for data we don't want to - # recommend caching as it's sensitive or private - or at least - # select private. don't bother setting Expires as all our - # clients are smart enough to be happy with Cache-Control - request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") - if file_size is not None: - request.setHeader(b"Content-Length", b"%d" % (file_size,)) - - # Tell web crawlers to not index, archive, or follow links in media. This - # should help to prevent things in the media repo from showing up in web - # search results. - request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex") - - -# separators as defined in RFC2616. SP and HT are handled separately. -# see _can_encode_filename_as_token. -_FILENAME_SEPARATOR_CHARS = { - "(", - ")", - "<", - ">", - "@", - ",", - ";", - ":", - "\\", - '"', - "/", - "[", - "]", - "?", - "=", - "{", - "}", -} - - -def _can_encode_filename_as_token(x: str) -> bool: - for c in x: - # from RFC2616: - # - # token = 1*<any CHAR except CTLs or separators> - # - # separators = "(" | ")" | "<" | ">" | "@" - # | "," | ";" | ":" | "\" | <"> - # | "/" | "[" | "]" | "?" | "=" - # | "{" | "}" | SP | HT - # - # CHAR = <any US-ASCII character (octets 0 - 127)> - # - # CTL = <any US-ASCII control character - # (octets 0 - 31) and DEL (127)> - # - if ord(c) >= 127 or ord(c) <= 32 or c in _FILENAME_SEPARATOR_CHARS: - return False - return True - - -async def respond_with_responder( - request: SynapseRequest, - responder: "Optional[Responder]", - media_type: str, - file_size: Optional[int], - upload_name: Optional[str] = None, -) -> None: - """Responds to the request with given responder. If responder is None then - returns 404. - - Args: - request - responder - media_type: The media/content type. - file_size: Size in bytes of the media. If not known it should be None - upload_name: The name of the requested file, if any. - """ - if not responder: - respond_404(request) - return - - # If we have a responder we *must* use it as a context manager. - with responder: - if request._disconnected: - logger.warning( - "Not sending response to request %s, already disconnected.", request - ) - return - - logger.debug("Responding to media request with responder %s", responder) - add_file_headers(request, media_type, file_size, upload_name) - try: - - await responder.write_to_consumer(request) - except Exception as e: - # The majority of the time this will be due to the client having gone - # away. Unfortunately, Twisted simply throws a generic exception at us - # in that case. - logger.warning("Failed to write to consumer: %s %s", type(e), e) - - # Unregister the producer, if it has one, so Twisted doesn't complain - if request.producer: - request.unregisterProducer() - - finish_request(request) - - -class Responder: - """Represents a response that can be streamed to the requester. - - Responder is a context manager which *must* be used, so that any resources - held can be cleaned up. - """ - - def write_to_consumer(self, consumer: IConsumer) -> Awaitable: - """Stream response into consumer - - Args: - consumer: The consumer to stream into. - - Returns: - Resolves once the response has finished being written - """ - - def __enter__(self) -> None: - pass - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - pass - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class ThumbnailInfo: - """Details about a generated thumbnail.""" - - width: int - height: int - method: str - # Content type of thumbnail, e.g. image/png - type: str - # The size of the media file, in bytes. - length: Optional[int] = None - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class FileInfo: - """Details about a requested/uploaded file.""" - - # The server name where the media originated from, or None if local. - server_name: Optional[str] - # The local ID of the file. For local files this is the same as the media_id - file_id: str - # If the file is for the url preview cache - url_cache: bool = False - # Whether the file is a thumbnail or not. - thumbnail: Optional[ThumbnailInfo] = None - - # The below properties exist to maintain compatibility with third-party modules. - @property - def thumbnail_width(self) -> Optional[int]: - if not self.thumbnail: - return None - return self.thumbnail.width - - @property - def thumbnail_height(self) -> Optional[int]: - if not self.thumbnail: - return None - return self.thumbnail.height - - @property - def thumbnail_method(self) -> Optional[str]: - if not self.thumbnail: - return None - return self.thumbnail.method - - @property - def thumbnail_type(self) -> Optional[str]: - if not self.thumbnail: - return None - return self.thumbnail.type - - @property - def thumbnail_length(self) -> Optional[int]: - if not self.thumbnail: - return None - return self.thumbnail.length - - -def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: - """ - Get the filename of the downloaded file by inspecting the - Content-Disposition HTTP header. - - Args: - headers: The HTTP request headers. - - Returns: - The filename, or None. - """ - content_disposition = headers.get(b"Content-Disposition", [b""]) - - # No header, bail out. - if not content_disposition[0]: - return None - - _, params = _parse_header(content_disposition[0]) - - upload_name = None - - # First check if there is a valid UTF-8 filename - upload_name_utf8 = params.get(b"filename*", None) - if upload_name_utf8: - if upload_name_utf8.lower().startswith(b"utf-8''"): - upload_name_utf8 = upload_name_utf8[7:] - # We have a filename*= section. This MUST be ASCII, and any UTF-8 - # bytes are %-quoted. - try: - # Once it is decoded, we can then unquote the %-encoded - # parts strictly into a unicode string. - upload_name = urllib.parse.unquote( - upload_name_utf8.decode("ascii"), errors="strict" - ) - except UnicodeDecodeError: - # Incorrect UTF-8. - pass - - # If there isn't check for an ascii name. - if not upload_name: - upload_name_ascii = params.get(b"filename", None) - if upload_name_ascii and is_ascii(upload_name_ascii): - upload_name = upload_name_ascii.decode("ascii") - - # This may be None here, indicating we did not find a matching name. - return upload_name - - -def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]: - """Parse a Content-type like header. - - Cargo-culted from `cgi`, but works on bytes rather than strings. - - Args: - line: header to be parsed - - Returns: - The main content-type, followed by the parameter dictionary - """ - parts = _parseparam(b";" + line) - key = next(parts) - pdict = {} - for p in parts: - i = p.find(b"=") - if i >= 0: - name = p[:i].strip().lower() - value = p[i + 1 :].strip() - - # strip double-quotes - if len(value) >= 2 and value[0:1] == value[-1:] == b'"': - value = value[1:-1] - value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"') - pdict[name] = value - - return key, pdict - - -def _parseparam(s: bytes) -> Generator[bytes, None, None]: - """Generator which splits the input on ;, respecting double-quoted sequences - - Cargo-culted from `cgi`, but works on bytes rather than strings. - - Args: - s: header to be parsed - - Returns: - The split input - """ - while s[:1] == b";": - s = s[1:] - - # look for the next ; - end = s.find(b";") - - # if there is an odd number of " marks between here and the next ;, skip to the - # next ; instead - while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2: - end = s.find(b";", end + 1) - - if end < 0: - end = len(s) - f = s[:end] - yield f.strip() - s = s[end:] +# This exists purely for backwards compatibility with media providers and spam checkers. +from synapse.media._base import FileInfo, Responder # noqa: F401 diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py deleted file mode 100644 index 1f6441c412..0000000000 --- a/synapse/rest/media/v1/filepath.py +++ /dev/null @@ -1,410 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2020-2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -import os -import re -import string -from typing import Any, Callable, List, TypeVar, Union, cast - -NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") - - -F = TypeVar("F", bound=Callable[..., str]) - - -def _wrap_in_base_path(func: F) -> F: - """Takes a function that returns a relative path and turns it into an - absolute path based on the location of the primary media store - """ - - @functools.wraps(func) - def _wrapped(self: "MediaFilePaths", *args: Any, **kwargs: Any) -> str: - path = func(self, *args, **kwargs) - return os.path.join(self.base_path, path) - - return cast(F, _wrapped) - - -GetPathMethod = TypeVar( - "GetPathMethod", bound=Union[Callable[..., str], Callable[..., List[str]]] -) - - -def _wrap_with_jail_check(relative: bool) -> Callable[[GetPathMethod], GetPathMethod]: - """Wraps a path-returning method to check that the returned path(s) do not escape - the media store directory. - - The path-returning method may return either a single path, or a list of paths. - - The check is not expected to ever fail, unless `func` is missing a call to - `_validate_path_component`, or `_validate_path_component` is buggy. - - Args: - relative: A boolean indicating whether the wrapped method returns paths relative - to the media store directory. - - Returns: - A method which will wrap a path-returning method, adding a check to ensure that - the returned path(s) lie within the media store directory. The check will raise - a `ValueError` if it fails. - """ - - def _wrap_with_jail_check_inner(func: GetPathMethod) -> GetPathMethod: - @functools.wraps(func) - def _wrapped( - self: "MediaFilePaths", *args: Any, **kwargs: Any - ) -> Union[str, List[str]]: - path_or_paths = func(self, *args, **kwargs) - - if isinstance(path_or_paths, list): - paths_to_check = path_or_paths - else: - paths_to_check = [path_or_paths] - - for path in paths_to_check: - # Construct the path that will ultimately be used. - # We cannot guess whether `path` is relative to the media store - # directory, since the media store directory may itself be a relative - # path. - if relative: - path = os.path.join(self.base_path, path) - normalized_path = os.path.normpath(path) - - # Now that `normpath` has eliminated `../`s and `./`s from the path, - # `os.path.commonpath` can be used to check whether it lies within the - # media store directory. - if ( - os.path.commonpath([normalized_path, self.normalized_base_path]) - != self.normalized_base_path - ): - # The path resolves to outside the media store directory, - # or `self.base_path` is `.`, which is an unlikely configuration. - raise ValueError(f"Invalid media store path: {path!r}") - - # Note that `os.path.normpath`/`abspath` has a subtle caveat: - # `a/b/c/../c` will normalize to `a/b/c`, but the former refers to a - # different path if `a/b/c` is a symlink. That is, the check above is - # not perfect and may allow a certain restricted subset of untrustworthy - # paths through. Since the check above is secondary to the main - # `_validate_path_component` checks, it's less important for it to be - # perfect. - # - # As an alternative, `os.path.realpath` will resolve symlinks, but - # proves problematic if there are symlinks inside the media store. - # eg. if `url_store/` is symlinked to elsewhere, its canonical path - # won't match that of the main media store directory. - - return path_or_paths - - return cast(GetPathMethod, _wrapped) - - return _wrap_with_jail_check_inner - - -ALLOWED_CHARACTERS = set( - string.ascii_letters - + string.digits - + "_-" - + ".[]:" # Domain names, IPv6 addresses and ports in server names -) -FORBIDDEN_NAMES = { - "", - os.path.curdir, # "." for the current platform - os.path.pardir, # ".." for the current platform -} - - -def _validate_path_component(name: str) -> str: - """Checks that the given string can be safely used as a path component - - Args: - name: The path component to check. - - Returns: - The path component if valid. - - Raises: - ValueError: If `name` cannot be safely used as a path component. - """ - if not ALLOWED_CHARACTERS.issuperset(name) or name in FORBIDDEN_NAMES: - raise ValueError(f"Invalid path component: {name!r}") - - return name - - -class MediaFilePaths: - """Describes where files are stored on disk. - - Most of the functions have a `*_rel` variant which returns a file path that - is relative to the base media store path. This is mainly used when we want - to write to the backup media store (when one is configured) - """ - - def __init__(self, primary_base_path: str): - self.base_path = primary_base_path - self.normalized_base_path = os.path.normpath(self.base_path) - - # Refuse to initialize if paths cannot be validated correctly for the current - # platform. - assert os.path.sep not in ALLOWED_CHARACTERS - assert os.path.altsep not in ALLOWED_CHARACTERS - # On Windows, paths have all sorts of weirdness which `_validate_path_component` - # does not consider. In any case, the remote media store can't work correctly - # for certain homeservers there, since ":"s aren't allowed in paths. - assert os.name == "posix" - - @_wrap_with_jail_check(relative=True) - def local_media_filepath_rel(self, media_id: str) -> str: - return os.path.join( - "local_content", - _validate_path_component(media_id[0:2]), - _validate_path_component(media_id[2:4]), - _validate_path_component(media_id[4:]), - ) - - local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) - - @_wrap_with_jail_check(relative=True) - def local_media_thumbnail_rel( - self, media_id: str, width: int, height: int, content_type: str, method: str - ) -> str: - top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) - return os.path.join( - "local_thumbnails", - _validate_path_component(media_id[0:2]), - _validate_path_component(media_id[2:4]), - _validate_path_component(media_id[4:]), - _validate_path_component(file_name), - ) - - local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) - - @_wrap_with_jail_check(relative=False) - def local_media_thumbnail_dir(self, media_id: str) -> str: - """ - Retrieve the local store path of thumbnails of a given media_id - - Args: - media_id: The media ID to query. - Returns: - Path of local_thumbnails from media_id - """ - return os.path.join( - self.base_path, - "local_thumbnails", - _validate_path_component(media_id[0:2]), - _validate_path_component(media_id[2:4]), - _validate_path_component(media_id[4:]), - ) - - @_wrap_with_jail_check(relative=True) - def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: - return os.path.join( - "remote_content", - _validate_path_component(server_name), - _validate_path_component(file_id[0:2]), - _validate_path_component(file_id[2:4]), - _validate_path_component(file_id[4:]), - ) - - remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) - - @_wrap_with_jail_check(relative=True) - def remote_media_thumbnail_rel( - self, - server_name: str, - file_id: str, - width: int, - height: int, - content_type: str, - method: str, - ) -> str: - top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) - return os.path.join( - "remote_thumbnail", - _validate_path_component(server_name), - _validate_path_component(file_id[0:2]), - _validate_path_component(file_id[2:4]), - _validate_path_component(file_id[4:]), - _validate_path_component(file_name), - ) - - remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) - - # Legacy path that was used to store thumbnails previously. - # Should be removed after some time, when most of the thumbnails are stored - # using the new path. - @_wrap_with_jail_check(relative=True) - def remote_media_thumbnail_rel_legacy( - self, server_name: str, file_id: str, width: int, height: int, content_type: str - ) -> str: - top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) - return os.path.join( - "remote_thumbnail", - _validate_path_component(server_name), - _validate_path_component(file_id[0:2]), - _validate_path_component(file_id[2:4]), - _validate_path_component(file_id[4:]), - _validate_path_component(file_name), - ) - - @_wrap_with_jail_check(relative=False) - def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: - return os.path.join( - self.base_path, - "remote_thumbnail", - _validate_path_component(server_name), - _validate_path_component(file_id[0:2]), - _validate_path_component(file_id[2:4]), - _validate_path_component(file_id[4:]), - ) - - @_wrap_with_jail_check(relative=True) - def url_cache_filepath_rel(self, media_id: str) -> str: - if NEW_FORMAT_ID_RE.match(media_id): - # Media id is of the form <DATE><RANDOM_STRING> - # E.g.: 2017-09-28-fsdRDt24DS234dsf - return os.path.join( - "url_cache", - _validate_path_component(media_id[:10]), - _validate_path_component(media_id[11:]), - ) - else: - return os.path.join( - "url_cache", - _validate_path_component(media_id[0:2]), - _validate_path_component(media_id[2:4]), - _validate_path_component(media_id[4:]), - ) - - url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) - - @_wrap_with_jail_check(relative=False) - def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: - "The dirs to try and remove if we delete the media_id file" - if NEW_FORMAT_ID_RE.match(media_id): - return [ - os.path.join( - self.base_path, "url_cache", _validate_path_component(media_id[:10]) - ) - ] - else: - return [ - os.path.join( - self.base_path, - "url_cache", - _validate_path_component(media_id[0:2]), - _validate_path_component(media_id[2:4]), - ), - os.path.join( - self.base_path, "url_cache", _validate_path_component(media_id[0:2]) - ), - ] - - @_wrap_with_jail_check(relative=True) - def url_cache_thumbnail_rel( - self, media_id: str, width: int, height: int, content_type: str, method: str - ) -> str: - # Media id is of the form <DATE><RANDOM_STRING> - # E.g.: 2017-09-28-fsdRDt24DS234dsf - - top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) - - if NEW_FORMAT_ID_RE.match(media_id): - return os.path.join( - "url_cache_thumbnails", - _validate_path_component(media_id[:10]), - _validate_path_component(media_id[11:]), - _validate_path_component(file_name), - ) - else: - return os.path.join( - "url_cache_thumbnails", - _validate_path_component(media_id[0:2]), - _validate_path_component(media_id[2:4]), - _validate_path_component(media_id[4:]), - _validate_path_component(file_name), - ) - - url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) - - @_wrap_with_jail_check(relative=True) - def url_cache_thumbnail_directory_rel(self, media_id: str) -> str: - # Media id is of the form <DATE><RANDOM_STRING> - # E.g.: 2017-09-28-fsdRDt24DS234dsf - - if NEW_FORMAT_ID_RE.match(media_id): - return os.path.join( - "url_cache_thumbnails", - _validate_path_component(media_id[:10]), - _validate_path_component(media_id[11:]), - ) - else: - return os.path.join( - "url_cache_thumbnails", - _validate_path_component(media_id[0:2]), - _validate_path_component(media_id[2:4]), - _validate_path_component(media_id[4:]), - ) - - url_cache_thumbnail_directory = _wrap_in_base_path( - url_cache_thumbnail_directory_rel - ) - - @_wrap_with_jail_check(relative=False) - def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: - "The dirs to try and remove if we delete the media_id thumbnails" - # Media id is of the form <DATE><RANDOM_STRING> - # E.g.: 2017-09-28-fsdRDt24DS234dsf - if NEW_FORMAT_ID_RE.match(media_id): - return [ - os.path.join( - self.base_path, - "url_cache_thumbnails", - _validate_path_component(media_id[:10]), - _validate_path_component(media_id[11:]), - ), - os.path.join( - self.base_path, - "url_cache_thumbnails", - _validate_path_component(media_id[:10]), - ), - ] - else: - return [ - os.path.join( - self.base_path, - "url_cache_thumbnails", - _validate_path_component(media_id[0:2]), - _validate_path_component(media_id[2:4]), - _validate_path_component(media_id[4:]), - ), - os.path.join( - self.base_path, - "url_cache_thumbnails", - _validate_path_component(media_id[0:2]), - _validate_path_component(media_id[2:4]), - ), - os.path.join( - self.base_path, - "url_cache_thumbnails", - _validate_path_component(media_id[0:2]), - ), - ] diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py deleted file mode 100644 index c70e1837af..0000000000 --- a/synapse/rest/media/v1/media_repository.py +++ /dev/null @@ -1,1112 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import errno -import logging -import os -import shutil -from io import BytesIO -from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple - -from matrix_common.types.mxc_uri import MXCUri - -import twisted.internet.error -import twisted.web.http -from twisted.internet.defer import Deferred - -from synapse.api.errors import ( - FederationDeniedError, - HttpResponseException, - NotFoundError, - RequestSendFailed, - SynapseError, -) -from synapse.config._base import ConfigError -from synapse.config.repository import ThumbnailRequirement -from synapse.http.server import UnrecognizedRequestResource -from synapse.http.site import SynapseRequest -from synapse.logging.context import defer_to_thread -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import UserID -from synapse.util.async_helpers import Linearizer -from synapse.util.retryutils import NotRetryingDestination -from synapse.util.stringutils import random_string - -from ._base import ( - FileInfo, - Responder, - ThumbnailInfo, - get_filename_from_headers, - respond_404, - respond_with_responder, -) -from .config_resource import MediaConfigResource -from .download_resource import DownloadResource -from .filepath import MediaFilePaths -from .media_storage import MediaStorage -from .preview_url_resource import PreviewUrlResource -from .storage_provider import StorageProviderWrapper -from .thumbnail_resource import ThumbnailResource -from .thumbnailer import Thumbnailer, ThumbnailError -from .upload_resource import UploadResource - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - -# How often to run the background job to update the "recently accessed" -# attribute of local and remote media. -UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 # 1 minute -# How often to run the background job to check for local and remote media -# that should be purged according to the configured media retention settings. -MEDIA_RETENTION_CHECK_PERIOD_MS = 60 * 60 * 1000 # 1 hour - - -class MediaRepository: - def __init__(self, hs: "HomeServer"): - self.hs = hs - self.auth = hs.get_auth() - self.client = hs.get_federation_http_client() - self.clock = hs.get_clock() - self.server_name = hs.hostname - self.store = hs.get_datastores().main - self.max_upload_size = hs.config.media.max_upload_size - self.max_image_pixels = hs.config.media.max_image_pixels - - Thumbnailer.set_limits(self.max_image_pixels) - - self.primary_base_path: str = hs.config.media.media_store_path - self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path) - - self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails - self.thumbnail_requirements = hs.config.media.thumbnail_requirements - - self.remote_media_linearizer = Linearizer(name="media_remote") - - self.recently_accessed_remotes: Set[Tuple[str, str]] = set() - self.recently_accessed_locals: Set[str] = set() - - self.federation_domain_whitelist = ( - hs.config.federation.federation_domain_whitelist - ) - - # List of StorageProviders where we should search for media and - # potentially upload to. - storage_providers = [] - - for ( - clz, - provider_config, - wrapper_config, - ) in hs.config.media.media_storage_providers: - backend = clz(hs, provider_config) - provider = StorageProviderWrapper( - backend, - store_local=wrapper_config.store_local, - store_remote=wrapper_config.store_remote, - store_synchronous=wrapper_config.store_synchronous, - ) - storage_providers.append(provider) - - self.media_storage = MediaStorage( - self.hs, self.primary_base_path, self.filepaths, storage_providers - ) - - self.clock.looping_call( - self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS - ) - - # Media retention configuration options - self._media_retention_local_media_lifetime_ms = ( - hs.config.media.media_retention_local_media_lifetime_ms - ) - self._media_retention_remote_media_lifetime_ms = ( - hs.config.media.media_retention_remote_media_lifetime_ms - ) - - # Check whether local or remote media retention is configured - if ( - hs.config.media.media_retention_local_media_lifetime_ms is not None - or hs.config.media.media_retention_remote_media_lifetime_ms is not None - ): - # Run the background job to apply media retention rules routinely, - # with the duration between runs dictated by the homeserver config. - self.clock.looping_call( - self._start_apply_media_retention_rules, - MEDIA_RETENTION_CHECK_PERIOD_MS, - ) - - def _start_update_recently_accessed(self) -> Deferred: - return run_as_background_process( - "update_recently_accessed_media", self._update_recently_accessed - ) - - def _start_apply_media_retention_rules(self) -> Deferred: - return run_as_background_process( - "apply_media_retention_rules", self._apply_media_retention_rules - ) - - async def _update_recently_accessed(self) -> None: - remote_media = self.recently_accessed_remotes - self.recently_accessed_remotes = set() - - local_media = self.recently_accessed_locals - self.recently_accessed_locals = set() - - await self.store.update_cached_last_access_time( - local_media, remote_media, self.clock.time_msec() - ) - - def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None: - """Mark the given media as recently accessed. - - Args: - server_name: Origin server of media, or None if local - media_id: The media ID of the content - """ - if server_name: - self.recently_accessed_remotes.add((server_name, media_id)) - else: - self.recently_accessed_locals.add(media_id) - - async def create_content( - self, - media_type: str, - upload_name: Optional[str], - content: IO, - content_length: int, - auth_user: UserID, - ) -> MXCUri: - """Store uploaded content for a local user and return the mxc URL - - Args: - media_type: The content type of the file. - upload_name: The name of the file, if provided. - content: A file like object that is the content to store - content_length: The length of the content - auth_user: The user_id of the uploader - - Returns: - The mxc url of the stored content - """ - - media_id = random_string(24) - - file_info = FileInfo(server_name=None, file_id=media_id) - - fname = await self.media_storage.store_file(content, file_info) - - logger.info("Stored local media in file %r", fname) - - await self.store.store_local_media( - media_id=media_id, - media_type=media_type, - time_now_ms=self.clock.time_msec(), - upload_name=upload_name, - media_length=content_length, - user_id=auth_user, - ) - - await self._generate_thumbnails(None, media_id, media_id, media_type) - - return MXCUri(self.server_name, media_id) - - async def get_local_media( - self, request: SynapseRequest, media_id: str, name: Optional[str] - ) -> None: - """Responds to requests for local media, if exists, or returns 404. - - Args: - request: The incoming request. - media_id: The media ID of the content. (This is the same as - the file_id for local content.) - name: Optional name that, if specified, will be used as - the filename in the Content-Disposition header of the response. - - Returns: - Resolves once a response has successfully been written to request - """ - media_info = await self.store.get_local_media(media_id) - if not media_info or media_info["quarantined_by"]: - respond_404(request) - return - - self.mark_recently_accessed(None, media_id) - - media_type = media_info["media_type"] - if not media_type: - media_type = "application/octet-stream" - media_length = media_info["media_length"] - upload_name = name if name else media_info["upload_name"] - url_cache = media_info["url_cache"] - - file_info = FileInfo(None, media_id, url_cache=bool(url_cache)) - - responder = await self.media_storage.fetch_media(file_info) - await respond_with_responder( - request, responder, media_type, media_length, upload_name - ) - - async def get_remote_media( - self, - request: SynapseRequest, - server_name: str, - media_id: str, - name: Optional[str], - ) -> None: - """Respond to requests for remote media. - - Args: - request: The incoming request. - server_name: Remote server_name where the media originated. - media_id: The media ID of the content (as defined by the remote server). - name: Optional name that, if specified, will be used as - the filename in the Content-Disposition header of the response. - - Returns: - Resolves once a response has successfully been written to request - """ - if ( - self.federation_domain_whitelist is not None - and server_name not in self.federation_domain_whitelist - ): - raise FederationDeniedError(server_name) - - self.mark_recently_accessed(server_name, media_id) - - # We linearize here to ensure that we don't try and download remote - # media multiple times concurrently - key = (server_name, media_id) - async with self.remote_media_linearizer.queue(key): - responder, media_info = await self._get_remote_media_impl( - server_name, media_id - ) - - # We deliberately stream the file outside the lock - if responder: - media_type = media_info["media_type"] - media_length = media_info["media_length"] - upload_name = name if name else media_info["upload_name"] - await respond_with_responder( - request, responder, media_type, media_length, upload_name - ) - else: - respond_404(request) - - async def get_remote_media_info(self, server_name: str, media_id: str) -> dict: - """Gets the media info associated with the remote file, downloading - if necessary. - - Args: - server_name: Remote server_name where the media originated. - media_id: The media ID of the content (as defined by the remote server). - - Returns: - The media info of the file - """ - if ( - self.federation_domain_whitelist is not None - and server_name not in self.federation_domain_whitelist - ): - raise FederationDeniedError(server_name) - - # We linearize here to ensure that we don't try and download remote - # media multiple times concurrently - key = (server_name, media_id) - async with self.remote_media_linearizer.queue(key): - responder, media_info = await self._get_remote_media_impl( - server_name, media_id - ) - - # Ensure we actually use the responder so that it releases resources - if responder: - with responder: - pass - - return media_info - - async def _get_remote_media_impl( - self, server_name: str, media_id: str - ) -> Tuple[Optional[Responder], dict]: - """Looks for media in local cache, if not there then attempt to - download from remote server. - - Args: - server_name: Remote server_name where the media originated. - media_id: The media ID of the content (as defined by the - remote server). - - Returns: - A tuple of responder and the media info of the file. - """ - media_info = await self.store.get_cached_remote_media(server_name, media_id) - - # file_id is the ID we use to track the file locally. If we've already - # seen the file then reuse the existing ID, otherwise generate a new - # one. - - # If we have an entry in the DB, try and look for it - if media_info: - file_id = media_info["filesystem_id"] - file_info = FileInfo(server_name, file_id) - - if media_info["quarantined_by"]: - logger.info("Media is quarantined") - raise NotFoundError() - - if not media_info["media_type"]: - media_info["media_type"] = "application/octet-stream" - - responder = await self.media_storage.fetch_media(file_info) - if responder: - return responder, media_info - - # Failed to find the file anywhere, lets download it. - - try: - media_info = await self._download_remote_file( - server_name, - media_id, - ) - except SynapseError: - raise - except Exception as e: - # An exception may be because we downloaded media in another - # process, so let's check if we magically have the media. - media_info = await self.store.get_cached_remote_media(server_name, media_id) - if not media_info: - raise e - - file_id = media_info["filesystem_id"] - if not media_info["media_type"]: - media_info["media_type"] = "application/octet-stream" - file_info = FileInfo(server_name, file_id) - - # We generate thumbnails even if another process downloaded the media - # as a) it's conceivable that the other download request dies before it - # generates thumbnails, but mainly b) we want to be sure the thumbnails - # have finished being generated before responding to the client, - # otherwise they'll request thumbnails and get a 404 if they're not - # ready yet. - await self._generate_thumbnails( - server_name, media_id, file_id, media_info["media_type"] - ) - - responder = await self.media_storage.fetch_media(file_info) - return responder, media_info - - async def _download_remote_file( - self, - server_name: str, - media_id: str, - ) -> dict: - """Attempt to download the remote file from the given server name, - using the given file_id as the local id. - - Args: - server_name: Originating server - media_id: The media ID of the content (as defined by the - remote server). This is different than the file_id, which is - locally generated. - file_id: Local file ID - - Returns: - The media info of the file. - """ - - file_id = random_string(24) - - file_info = FileInfo(server_name=server_name, file_id=file_id) - - with self.media_storage.store_into_file(file_info) as (f, fname, finish): - request_path = "/".join( - ("/_matrix/media/r0/download", server_name, media_id) - ) - try: - length, headers = await self.client.get_file( - server_name, - request_path, - output_stream=f, - max_size=self.max_upload_size, - args={ - # tell the remote server to 404 if it doesn't - # recognise the server_name, to make sure we don't - # end up with a routing loop. - "allow_remote": "false" - }, - ) - except RequestSendFailed as e: - logger.warning( - "Request failed fetching remote media %s/%s: %r", - server_name, - media_id, - e, - ) - raise SynapseError(502, "Failed to fetch remote media") - - except HttpResponseException as e: - logger.warning( - "HTTP error fetching remote media %s/%s: %s", - server_name, - media_id, - e.response, - ) - if e.code == twisted.web.http.NOT_FOUND: - raise e.to_synapse_error() - raise SynapseError(502, "Failed to fetch remote media") - - except SynapseError: - logger.warning( - "Failed to fetch remote media %s/%s", server_name, media_id - ) - raise - except NotRetryingDestination: - logger.warning("Not retrying destination %r", server_name) - raise SynapseError(502, "Failed to fetch remote media") - except Exception: - logger.exception( - "Failed to fetch remote media %s/%s", server_name, media_id - ) - raise SynapseError(502, "Failed to fetch remote media") - - await finish() - - if b"Content-Type" in headers: - media_type = headers[b"Content-Type"][0].decode("ascii") - else: - media_type = "application/octet-stream" - upload_name = get_filename_from_headers(headers) - time_now_ms = self.clock.time_msec() - - # Multiple remote media download requests can race (when using - # multiple media repos), so this may throw a violation constraint - # exception. If it does we'll delete the newly downloaded file from - # disk (as we're in the ctx manager). - # - # However: we've already called `finish()` so we may have also - # written to the storage providers. This is preferable to the - # alternative where we call `finish()` *after* this, where we could - # end up having an entry in the DB but fail to write the files to - # the storage providers. - await self.store.store_cached_remote_media( - origin=server_name, - media_id=media_id, - media_type=media_type, - time_now_ms=self.clock.time_msec(), - upload_name=upload_name, - media_length=length, - filesystem_id=file_id, - ) - - logger.info("Stored remote media in file %r", fname) - - media_info = { - "media_type": media_type, - "media_length": length, - "upload_name": upload_name, - "created_ts": time_now_ms, - "filesystem_id": file_id, - } - - return media_info - - def _get_thumbnail_requirements( - self, media_type: str - ) -> Tuple[ThumbnailRequirement, ...]: - scpos = media_type.find(";") - if scpos > 0: - media_type = media_type[:scpos] - return self.thumbnail_requirements.get(media_type, ()) - - def _generate_thumbnail( - self, - thumbnailer: Thumbnailer, - t_width: int, - t_height: int, - t_method: str, - t_type: str, - ) -> Optional[BytesIO]: - m_width = thumbnailer.width - m_height = thumbnailer.height - - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, - m_height, - self.max_image_pixels, - ) - return None - - if thumbnailer.transpose_method is not None: - m_width, m_height = thumbnailer.transpose() - - if t_method == "crop": - return thumbnailer.crop(t_width, t_height, t_type) - elif t_method == "scale": - t_width, t_height = thumbnailer.aspect(t_width, t_height) - t_width = min(m_width, t_width) - t_height = min(m_height, t_height) - return thumbnailer.scale(t_width, t_height, t_type) - - return None - - async def generate_local_exact_thumbnail( - self, - media_id: str, - t_width: int, - t_height: int, - t_method: str, - t_type: str, - url_cache: bool, - ) -> Optional[str]: - input_path = await self.media_storage.ensure_media_is_in_local_cache( - FileInfo(None, media_id, url_cache=url_cache) - ) - - try: - thumbnailer = Thumbnailer(input_path) - except ThumbnailError as e: - logger.warning( - "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s", - media_id, - t_method, - t_type, - e, - ) - return None - - with thumbnailer: - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - self._generate_thumbnail, - thumbnailer, - t_width, - t_height, - t_method, - t_type, - ) - - if t_byte_source: - try: - file_info = FileInfo( - server_name=None, - file_id=media_id, - url_cache=url_cache, - thumbnail=ThumbnailInfo( - width=t_width, - height=t_height, - method=t_method, - type=t_type, - ), - ) - - output_path = await self.media_storage.store_file( - t_byte_source, file_info - ) - finally: - t_byte_source.close() - - logger.info("Stored thumbnail in file %r", output_path) - - t_len = os.path.getsize(output_path) - - await self.store.store_local_thumbnail( - media_id, t_width, t_height, t_type, t_method, t_len - ) - - return output_path - - # Could not generate thumbnail. - return None - - async def generate_remote_exact_thumbnail( - self, - server_name: str, - file_id: str, - media_id: str, - t_width: int, - t_height: int, - t_method: str, - t_type: str, - ) -> Optional[str]: - input_path = await self.media_storage.ensure_media_is_in_local_cache( - FileInfo(server_name, file_id) - ) - - try: - thumbnailer = Thumbnailer(input_path) - except ThumbnailError as e: - logger.warning( - "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s", - media_id, - server_name, - t_method, - t_type, - e, - ) - return None - - with thumbnailer: - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - self._generate_thumbnail, - thumbnailer, - t_width, - t_height, - t_method, - t_type, - ) - - if t_byte_source: - try: - file_info = FileInfo( - server_name=server_name, - file_id=file_id, - thumbnail=ThumbnailInfo( - width=t_width, - height=t_height, - method=t_method, - type=t_type, - ), - ) - - output_path = await self.media_storage.store_file( - t_byte_source, file_info - ) - finally: - t_byte_source.close() - - logger.info("Stored thumbnail in file %r", output_path) - - t_len = os.path.getsize(output_path) - - await self.store.store_remote_media_thumbnail( - server_name, - media_id, - file_id, - t_width, - t_height, - t_type, - t_method, - t_len, - ) - - return output_path - - # Could not generate thumbnail. - return None - - async def _generate_thumbnails( - self, - server_name: Optional[str], - media_id: str, - file_id: str, - media_type: str, - url_cache: bool = False, - ) -> Optional[dict]: - """Generate and store thumbnails for an image. - - Args: - server_name: The server name if remote media, else None if local - media_id: The media ID of the content. (This is the same as - the file_id for local content) - file_id: Local file ID - media_type: The content type of the file - url_cache: If we are thumbnailing images downloaded for the URL cache, - used exclusively by the url previewer - - Returns: - Dict with "width" and "height" keys of original image or None if the - media cannot be thumbnailed. - """ - requirements = self._get_thumbnail_requirements(media_type) - if not requirements: - return None - - input_path = await self.media_storage.ensure_media_is_in_local_cache( - FileInfo(server_name, file_id, url_cache=url_cache) - ) - - try: - thumbnailer = Thumbnailer(input_path) - except ThumbnailError as e: - logger.warning( - "Unable to generate thumbnails for remote media %s from %s of type %s: %s", - media_id, - server_name, - media_type, - e, - ) - return None - - with thumbnailer: - m_width = thumbnailer.width - m_height = thumbnailer.height - - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, - m_height, - self.max_image_pixels, - ) - return None - - if thumbnailer.transpose_method is not None: - m_width, m_height = await defer_to_thread( - self.hs.get_reactor(), thumbnailer.transpose - ) - - # We deduplicate the thumbnail sizes by ignoring the cropped versions if - # they have the same dimensions of a scaled one. - thumbnails: Dict[Tuple[int, int, str], str] = {} - for requirement in requirements: - if requirement.method == "crop": - thumbnails.setdefault( - (requirement.width, requirement.height, requirement.media_type), - requirement.method, - ) - elif requirement.method == "scale": - t_width, t_height = thumbnailer.aspect( - requirement.width, requirement.height - ) - t_width = min(m_width, t_width) - t_height = min(m_height, t_height) - thumbnails[ - (t_width, t_height, requirement.media_type) - ] = requirement.method - - # Now we generate the thumbnails for each dimension, store it - for (t_width, t_height, t_type), t_method in thumbnails.items(): - # Generate the thumbnail - if t_method == "crop": - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - thumbnailer.crop, - t_width, - t_height, - t_type, - ) - elif t_method == "scale": - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - thumbnailer.scale, - t_width, - t_height, - t_type, - ) - else: - logger.error("Unrecognized method: %r", t_method) - continue - - if not t_byte_source: - continue - - file_info = FileInfo( - server_name=server_name, - file_id=file_id, - url_cache=url_cache, - thumbnail=ThumbnailInfo( - width=t_width, - height=t_height, - method=t_method, - type=t_type, - ), - ) - - with self.media_storage.store_into_file(file_info) as ( - f, - fname, - finish, - ): - try: - await self.media_storage.write_to_file(t_byte_source, f) - await finish() - finally: - t_byte_source.close() - - t_len = os.path.getsize(fname) - - # Write to database - if server_name: - # Multiple remote media download requests can race (when - # using multiple media repos), so this may throw a violation - # constraint exception. If it does we'll delete the newly - # generated thumbnail from disk (as we're in the ctx - # manager). - # - # However: we've already called `finish()` so we may have - # also written to the storage providers. This is preferable - # to the alternative where we call `finish()` *after* this, - # where we could end up having an entry in the DB but fail - # to write the files to the storage providers. - try: - await self.store.store_remote_media_thumbnail( - server_name, - media_id, - file_id, - t_width, - t_height, - t_type, - t_method, - t_len, - ) - except Exception as e: - thumbnail_exists = ( - await self.store.get_remote_media_thumbnail( - server_name, - media_id, - t_width, - t_height, - t_type, - ) - ) - if not thumbnail_exists: - raise e - else: - await self.store.store_local_thumbnail( - media_id, t_width, t_height, t_type, t_method, t_len - ) - - return {"width": m_width, "height": m_height} - - async def _apply_media_retention_rules(self) -> None: - """ - Purge old local and remote media according to the media retention rules - defined in the homeserver config. - """ - # Purge remote media - if self._media_retention_remote_media_lifetime_ms is not None: - # Calculate a threshold timestamp derived from the configured lifetime. Any - # media that has not been accessed since this timestamp will be removed. - remote_media_threshold_timestamp_ms = ( - self.clock.time_msec() - self._media_retention_remote_media_lifetime_ms - ) - - logger.info( - "Purging remote media last accessed before" - f" {remote_media_threshold_timestamp_ms}" - ) - - await self.delete_old_remote_media( - before_ts=remote_media_threshold_timestamp_ms - ) - - # And now do the same for local media - if self._media_retention_local_media_lifetime_ms is not None: - # This works the same as the remote media threshold - local_media_threshold_timestamp_ms = ( - self.clock.time_msec() - self._media_retention_local_media_lifetime_ms - ) - - logger.info( - "Purging local media last accessed before" - f" {local_media_threshold_timestamp_ms}" - ) - - await self.delete_old_local_media( - before_ts=local_media_threshold_timestamp_ms, - keep_profiles=True, - delete_quarantined_media=False, - delete_protected_media=False, - ) - - async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]: - old_media = await self.store.get_remote_media_ids( - before_ts, include_quarantined_media=False - ) - - deleted = 0 - - for media in old_media: - origin = media["media_origin"] - media_id = media["media_id"] - file_id = media["filesystem_id"] - key = (origin, media_id) - - logger.info("Deleting: %r", key) - - # TODO: Should we delete from the backup store - - async with self.remote_media_linearizer.queue(key): - full_path = self.filepaths.remote_media_filepath(origin, file_id) - try: - os.remove(full_path) - except OSError as e: - logger.warning("Failed to remove file: %r", full_path) - if e.errno == errno.ENOENT: - pass - else: - continue - - thumbnail_dir = self.filepaths.remote_media_thumbnail_dir( - origin, file_id - ) - shutil.rmtree(thumbnail_dir, ignore_errors=True) - - await self.store.delete_remote_media(origin, media_id) - deleted += 1 - - return {"deleted": deleted} - - async def delete_local_media_ids( - self, media_ids: List[str] - ) -> Tuple[List[str], int]: - """ - Delete the given local or remote media ID from this server - - Args: - media_id: The media ID to delete. - Returns: - A tuple of (list of deleted media IDs, total deleted media IDs). - """ - return await self._remove_local_media_from_disk(media_ids) - - async def delete_old_local_media( - self, - before_ts: int, - size_gt: int = 0, - keep_profiles: bool = True, - delete_quarantined_media: bool = False, - delete_protected_media: bool = False, - ) -> Tuple[List[str], int]: - """ - Delete local or remote media from this server by size and timestamp. Removes - media files, any thumbnails and cached URLs. - - Args: - before_ts: Unix timestamp in ms. - Files that were last used before this timestamp will be deleted. - size_gt: Size of the media in bytes. Files that are larger will be deleted. - keep_profiles: Switch to delete also files that are still used in image data - (e.g user profile, room avatar). If false these files will be deleted. - delete_quarantined_media: If True, media marked as quarantined will be deleted. - delete_protected_media: If True, media marked as protected will be deleted. - - Returns: - A tuple of (list of deleted media IDs, total deleted media IDs). - """ - old_media = await self.store.get_local_media_ids( - before_ts, - size_gt, - keep_profiles, - include_quarantined_media=delete_quarantined_media, - include_protected_media=delete_protected_media, - ) - return await self._remove_local_media_from_disk(old_media) - - async def _remove_local_media_from_disk( - self, media_ids: List[str] - ) -> Tuple[List[str], int]: - """ - Delete local or remote media from this server. Removes media files, - any thumbnails and cached URLs. - - Args: - media_ids: List of media_id to delete - Returns: - A tuple of (list of deleted media IDs, total deleted media IDs). - """ - removed_media = [] - for media_id in media_ids: - logger.info("Deleting media with ID '%s'", media_id) - full_path = self.filepaths.local_media_filepath(media_id) - try: - os.remove(full_path) - except OSError as e: - logger.warning("Failed to remove file: %r: %s", full_path, e) - if e.errno == errno.ENOENT: - pass - else: - continue - - thumbnail_dir = self.filepaths.local_media_thumbnail_dir(media_id) - shutil.rmtree(thumbnail_dir, ignore_errors=True) - - await self.store.delete_remote_media(self.server_name, media_id) - - await self.store.delete_url_cache((media_id,)) - await self.store.delete_url_cache_media((media_id,)) - - removed_media.append(media_id) - - return removed_media, len(removed_media) - - -class MediaRepositoryResource(UnrecognizedRequestResource): - """File uploading and downloading. - - Uploads are POSTed to a resource which returns a token which is used to GET - the download:: - - => POST /_matrix/media/r0/upload HTTP/1.1 - Content-Type: <media-type> - Content-Length: <content-length> - - <media> - - <= HTTP/1.1 200 OK - Content-Type: application/json - - { "content_uri": "mxc://<server-name>/<media-id>" } - - => GET /_matrix/media/r0/download/<server-name>/<media-id> HTTP/1.1 - - <= HTTP/1.1 200 OK - Content-Type: <media-type> - Content-Disposition: attachment;filename=<upload-filename> - - <media> - - Clients can get thumbnails by supplying a desired width and height and - thumbnailing method:: - - => GET /_matrix/media/r0/thumbnail/<server_name> - /<media-id>?width=<w>&height=<h>&method=<m> HTTP/1.1 - - <= HTTP/1.1 200 OK - Content-Type: image/jpeg or image/png - - <thumbnail> - - The thumbnail methods are "crop" and "scale". "scale" tries to return an - image where either the width or the height is smaller than the requested - size. The client should then scale and letterbox the image if it needs to - fit within a given rectangle. "crop" tries to return an image where the - width and height are close to the requested size and the aspect matches - the requested size. The client should scale the image if it needs to fit - within a given rectangle. - """ - - def __init__(self, hs: "HomeServer"): - # If we're not configured to use it, raise if we somehow got here. - if not hs.config.media.can_load_media_repo: - raise ConfigError("Synapse is not configured to use a media repo.") - - super().__init__() - media_repo = hs.get_media_repository() - - self.putChild(b"upload", UploadResource(hs, media_repo)) - self.putChild(b"download", DownloadResource(hs, media_repo)) - self.putChild( - b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage) - ) - if hs.config.media.url_preview_enabled: - self.putChild( - b"preview_url", - PreviewUrlResource(hs, media_repo, media_repo.media_storage), - ) - self.putChild(b"config", MediaConfigResource(hs)) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index db25848744..11b0e8e231 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -1,4 +1,4 @@ -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# Copyright 2023 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,364 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -import logging -import os -import shutil -from types import TracebackType -from typing import ( - IO, - TYPE_CHECKING, - Any, - Awaitable, - BinaryIO, - Callable, - Generator, - Optional, - Sequence, - Tuple, - Type, -) - -import attr - -from twisted.internet.defer import Deferred -from twisted.internet.interfaces import IConsumer -from twisted.protocols.basic import FileSender - -import synapse -from synapse.api.errors import NotFoundError -from synapse.logging.context import defer_to_thread, make_deferred_yieldable -from synapse.util import Clock -from synapse.util.file_consumer import BackgroundFileConsumer - -from ._base import FileInfo, Responder -from .filepath import MediaFilePaths - -if TYPE_CHECKING: - from synapse.rest.media.v1.storage_provider import StorageProvider - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class MediaStorage: - """Responsible for storing/fetching files from local sources. - - Args: - hs - local_media_directory: Base path where we store media on disk - filepaths - storage_providers: List of StorageProvider that are used to fetch and store files. - """ - - def __init__( - self, - hs: "HomeServer", - local_media_directory: str, - filepaths: MediaFilePaths, - storage_providers: Sequence["StorageProvider"], - ): - self.hs = hs - self.reactor = hs.get_reactor() - self.local_media_directory = local_media_directory - self.filepaths = filepaths - self.storage_providers = storage_providers - self.spam_checker = hs.get_spam_checker() - self.clock = hs.get_clock() - - async def store_file(self, source: IO, file_info: FileInfo) -> str: - """Write `source` to the on disk media store, and also any other - configured storage providers - - Args: - source: A file like object that should be written - file_info: Info about the file to store - - Returns: - the file path written to in the primary media store - """ - - with self.store_into_file(file_info) as (f, fname, finish_cb): - # Write to the main repository - await self.write_to_file(source, f) - await finish_cb() - - return fname - - async def write_to_file(self, source: IO, output: IO) -> None: - """Asynchronously write the `source` to `output`.""" - await defer_to_thread(self.reactor, _write_file_synchronously, source, output) - - @contextlib.contextmanager - def store_into_file( - self, file_info: FileInfo - ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]: - """Context manager used to get a file like object to write into, as - described by file_info. - - Actually yields a 3-tuple (file, fname, finish_cb), where file is a file - like object that can be written to, fname is the absolute path of file - on disk, and finish_cb is a function that returns an awaitable. - - fname can be used to read the contents from after upload, e.g. to - generate thumbnails. - - finish_cb must be called and waited on after the file has been - successfully been written to. Should not be called if there was an - error. - - Args: - file_info: Info about the file to store - - Example: - - with media_storage.store_into_file(info) as (f, fname, finish_cb): - # .. write into f ... - await finish_cb() - """ - - path = self._file_info_to_path(file_info) - fname = os.path.join(self.local_media_directory, path) - - dirname = os.path.dirname(fname) - os.makedirs(dirname, exist_ok=True) - - finished_called = [False] - - try: - with open(fname, "wb") as f: - - async def finish() -> None: - # Ensure that all writes have been flushed and close the - # file. - f.flush() - f.close() - - spam_check = await self.spam_checker.check_media_file_for_spam( - ReadableFileWrapper(self.clock, fname), file_info - ) - if spam_check != synapse.module_api.NOT_SPAM: - logger.info("Blocking media due to spam checker") - # Note that we'll delete the stored media, due to the - # try/except below. The media also won't be stored in - # the DB. - # We currently ignore any additional field returned by - # the spam-check API. - raise SpamMediaException(errcode=spam_check[0]) - - for provider in self.storage_providers: - await provider.store_file(path, file_info) - - finished_called[0] = True - - yield f, fname, finish - except Exception as e: - try: - os.remove(fname) - except Exception: - pass - - raise e from None - - if not finished_called: - raise Exception("Finished callback not called") - - async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: - """Attempts to fetch media described by file_info from the local cache - and configured storage providers. - - Args: - file_info - - Returns: - Returns a Responder if the file was found, otherwise None. - """ - paths = [self._file_info_to_path(file_info)] - - # fallback for remote thumbnails with no method in the filename - if file_info.thumbnail and file_info.server_name: - paths.append( - self.filepaths.remote_media_thumbnail_rel_legacy( - server_name=file_info.server_name, - file_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - ) - ) - - for path in paths: - local_path = os.path.join(self.local_media_directory, path) - if os.path.exists(local_path): - logger.debug("responding with local file %s", local_path) - return FileResponder(open(local_path, "rb")) - logger.debug("local file %s did not exist", local_path) - - for provider in self.storage_providers: - for path in paths: - res: Any = await provider.fetch(path, file_info) - if res: - logger.debug("Streaming %s from %s", path, provider) - return res - logger.debug("%s not found on %s", path, provider) - - return None - - async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: - """Ensures that the given file is in the local cache. Attempts to - download it from storage providers if it isn't. - - Args: - file_info - - Returns: - Full path to local file - """ - path = self._file_info_to_path(file_info) - local_path = os.path.join(self.local_media_directory, path) - if os.path.exists(local_path): - return local_path - - # Fallback for paths without method names - # Should be removed in the future - if file_info.thumbnail and file_info.server_name: - legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( - server_name=file_info.server_name, - file_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - ) - legacy_local_path = os.path.join(self.local_media_directory, legacy_path) - if os.path.exists(legacy_local_path): - return legacy_local_path - - dirname = os.path.dirname(local_path) - os.makedirs(dirname, exist_ok=True) - - for provider in self.storage_providers: - res: Any = await provider.fetch(path, file_info) - if res: - with res: - consumer = BackgroundFileConsumer( - open(local_path, "wb"), self.reactor - ) - await res.write_to_consumer(consumer) - await consumer.wait() - return local_path - - raise NotFoundError() - - def _file_info_to_path(self, file_info: FileInfo) -> str: - """Converts file_info into a relative path. - - The path is suitable for storing files under a directory, e.g. used to - store files on local FS under the base media repository directory. - """ - if file_info.url_cache: - if file_info.thumbnail: - return self.filepaths.url_cache_thumbnail_rel( - media_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - method=file_info.thumbnail.method, - ) - return self.filepaths.url_cache_filepath_rel(file_info.file_id) - - if file_info.server_name: - if file_info.thumbnail: - return self.filepaths.remote_media_thumbnail_rel( - server_name=file_info.server_name, - file_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - method=file_info.thumbnail.method, - ) - return self.filepaths.remote_media_filepath_rel( - file_info.server_name, file_info.file_id - ) - - if file_info.thumbnail: - return self.filepaths.local_media_thumbnail_rel( - media_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - method=file_info.thumbnail.method, - ) - return self.filepaths.local_media_filepath_rel(file_info.file_id) - - -def _write_file_synchronously(source: IO, dest: IO) -> None: - """Write `source` to the file like `dest` synchronously. Should be called - from a thread. - - Args: - source: A file like object that's to be written - dest: A file like object to be written to - """ - source.seek(0) # Ensure we read from the start of the file - shutil.copyfileobj(source, dest) - - -class FileResponder(Responder): - """Wraps an open file that can be sent to a request. - - Args: - open_file: A file like object to be streamed ot the client, - is closed when finished streaming. - """ - - def __init__(self, open_file: IO): - self.open_file = open_file - - def write_to_consumer(self, consumer: IConsumer) -> Deferred: - return make_deferred_yieldable( - FileSender().beginFileTransfer(self.open_file, consumer) - ) - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - self.open_file.close() - - -class SpamMediaException(NotFoundError): - """The media was blocked by a spam checker, so we simply 404 the request (in - the same way as if it was quarantined). - """ - - -@attr.s(slots=True, auto_attribs=True) -class ReadableFileWrapper: - """Wrapper that allows reading a file in chunks, yielding to the reactor, - and writing to a callback. - - This is simplified `FileSender` that takes an IO object rather than an - `IConsumer`. - """ - - CHUNK_SIZE = 2**14 - - clock: Clock - path: str - - async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None: - """Reads the file in chunks and calls the callback with each chunk.""" - - with open(self.path, "rb") as file: - while True: - chunk = file.read(self.CHUNK_SIZE) - if not chunk: - break - - callback(chunk) +# - # We yield to the reactor by sleeping for 0 seconds. - await self.clock.sleep(0) +# This exists purely for backwards compatibility with spam checkers. +from synapse.media.media_storage import ReadableFileWrapper # noqa: F401 diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py deleted file mode 100644 index 7592aa5d47..0000000000 --- a/synapse/rest/media/v1/oembed.py +++ /dev/null @@ -1,265 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import html -import logging -import urllib.parse -from typing import TYPE_CHECKING, List, Optional - -import attr - -from synapse.rest.media.v1.preview_html import parse_html_description -from synapse.types import JsonDict -from synapse.util import json_decoder - -if TYPE_CHECKING: - from lxml import etree - - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class OEmbedResult: - # The Open Graph result (converted from the oEmbed result). - open_graph_result: JsonDict - # The author_name of the oEmbed result - author_name: Optional[str] - # Number of milliseconds to cache the content, according to the oEmbed response. - # - # This will be None if no cache-age is provided in the oEmbed response (or - # if the oEmbed response cannot be turned into an Open Graph response). - cache_age: Optional[int] - - -class OEmbedProvider: - """ - A helper for accessing oEmbed content. - - It can be used to check if a URL should be accessed via oEmbed and for - requesting/parsing oEmbed content. - """ - - def __init__(self, hs: "HomeServer"): - self._oembed_patterns = {} - for oembed_endpoint in hs.config.oembed.oembed_patterns: - api_endpoint = oembed_endpoint.api_endpoint - - # Only JSON is supported at the moment. This could be declared in - # the formats field. Otherwise, if the endpoint ends in .xml assume - # it doesn't support JSON. - if ( - oembed_endpoint.formats is not None - and "json" not in oembed_endpoint.formats - ) or api_endpoint.endswith(".xml"): - logger.info( - "Ignoring oEmbed endpoint due to not supporting JSON: %s", - api_endpoint, - ) - continue - - # Iterate through each URL pattern and point it to the endpoint. - for pattern in oembed_endpoint.url_patterns: - self._oembed_patterns[pattern] = api_endpoint - - def get_oembed_url(self, url: str) -> Optional[str]: - """ - Check whether the URL should be downloaded as oEmbed content instead. - - Args: - url: The URL to check. - - Returns: - A URL to use instead or None if the original URL should be used. - """ - for url_pattern, endpoint in self._oembed_patterns.items(): - if url_pattern.fullmatch(url): - # TODO Specify max height / width. - - # Note that only the JSON format is supported, some endpoints want - # this in the URL, others want it as an argument. - endpoint = endpoint.replace("{format}", "json") - - args = {"url": url, "format": "json"} - query_str = urllib.parse.urlencode(args, True) - return f"{endpoint}?{query_str}" - - # No match. - return None - - def autodiscover_from_html(self, tree: "etree.Element") -> Optional[str]: - """ - Search an HTML document for oEmbed autodiscovery information. - - Args: - tree: The parsed HTML body. - - Returns: - The URL to use for oEmbed information, or None if no URL was found. - """ - # Search for link elements with the proper rel and type attributes. - for tag in tree.xpath( - "//link[@rel='alternate'][@type='application/json+oembed']" - ): - if "href" in tag.attrib: - return tag.attrib["href"] - - # Some providers (e.g. Flickr) use alternative instead of alternate. - for tag in tree.xpath( - "//link[@rel='alternative'][@type='application/json+oembed']" - ): - if "href" in tag.attrib: - return tag.attrib["href"] - - return None - - def parse_oembed_response(self, url: str, raw_body: bytes) -> OEmbedResult: - """ - Parse the oEmbed response into an Open Graph response. - - Args: - url: The URL which is being previewed (not the one which was - requested). - raw_body: The oEmbed response as JSON encoded as bytes. - - Returns: - json-encoded Open Graph data - """ - - try: - # oEmbed responses *must* be UTF-8 according to the spec. - oembed = json_decoder.decode(raw_body.decode("utf-8")) - except ValueError: - return OEmbedResult({}, None, None) - - # The version is a required string field, but not always provided, - # or sometimes provided as a float. Be lenient. - oembed_version = oembed.get("version", "1.0") - if oembed_version != "1.0" and oembed_version != 1: - return OEmbedResult({}, None, None) - - # Attempt to parse the cache age, if possible. - try: - cache_age = int(oembed.get("cache_age")) * 1000 - except (TypeError, ValueError): - # If the cache age cannot be parsed (e.g. wrong type or invalid - # string), ignore it. - cache_age = None - - # The oEmbed response converted to Open Graph. - open_graph_response: JsonDict = {"og:url": url} - - title = oembed.get("title") - if title and isinstance(title, str): - # A common WordPress plug-in seems to incorrectly escape entities - # in the oEmbed response. - open_graph_response["og:title"] = html.unescape(title) - - author_name = oembed.get("author_name") - if not isinstance(author_name, str): - author_name = None - - # Use the provider name and as the site. - provider_name = oembed.get("provider_name") - if provider_name and isinstance(provider_name, str): - open_graph_response["og:site_name"] = provider_name - - # If a thumbnail exists, use it. Note that dimensions will be calculated later. - thumbnail_url = oembed.get("thumbnail_url") - if thumbnail_url and isinstance(thumbnail_url, str): - open_graph_response["og:image"] = thumbnail_url - - # Process each type separately. - oembed_type = oembed.get("type") - if oembed_type == "rich": - html_str = oembed.get("html") - if isinstance(html_str, str): - calc_description_and_urls(open_graph_response, html_str) - - elif oembed_type == "photo": - # If this is a photo, use the full image, not the thumbnail. - url = oembed.get("url") - if url and isinstance(url, str): - open_graph_response["og:image"] = url - - elif oembed_type == "video": - open_graph_response["og:type"] = "video.other" - html_str = oembed.get("html") - if html_str and isinstance(html_str, str): - calc_description_and_urls(open_graph_response, oembed["html"]) - for size in ("width", "height"): - val = oembed.get(size) - if type(val) is int: - open_graph_response[f"og:video:{size}"] = val - - elif oembed_type == "link": - open_graph_response["og:type"] = "website" - - else: - logger.warning("Unknown oEmbed type: %s", oembed_type) - - return OEmbedResult(open_graph_response, author_name, cache_age) - - -def _fetch_urls(tree: "etree.Element", tag_name: str) -> List[str]: - results = [] - for tag in tree.xpath("//*/" + tag_name): - if "src" in tag.attrib: - results.append(tag.attrib["src"]) - return results - - -def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> None: - """ - Calculate description for an HTML document. - - This uses lxml to convert the HTML document into plaintext. If errors - occur during processing of the document, an empty response is returned. - - Args: - open_graph_response: The current Open Graph summary. This is updated with additional fields. - html_body: The HTML document, as bytes. - - Returns: - The summary - """ - # If there's no body, nothing useful is going to be found. - if not html_body: - return - - from lxml import etree - - # Create an HTML parser. If this fails, log and return no metadata. - parser = etree.HTMLParser(recover=True, encoding="utf-8") - - # Attempt to parse the body. If this fails, log and return no metadata. - tree = etree.fromstring(html_body, parser) - - # The data was successfully parsed, but no tree was found. - if tree is None: - return - - # Attempt to find interesting URLs (images, videos, embeds). - if "og:image" not in open_graph_response: - image_urls = _fetch_urls(tree, "img") - if image_urls: - open_graph_response["og:image"] = image_urls[0] - - video_urls = _fetch_urls(tree, "video") + _fetch_urls(tree, "embed") - if video_urls: - open_graph_response["og:video"] = video_urls[0] - - description = parse_html_description(tree) - if description: - open_graph_response["og:description"] = description diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py deleted file mode 100644 index 516d0434f0..0000000000 --- a/synapse/rest/media/v1/preview_html.py +++ /dev/null @@ -1,501 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import codecs -import logging -import re -from typing import ( - TYPE_CHECKING, - Callable, - Dict, - Generator, - Iterable, - List, - Optional, - Set, - Union, -) - -if TYPE_CHECKING: - from lxml import etree - -logger = logging.getLogger(__name__) - -_charset_match = re.compile( - rb'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I -) -_xml_encoding_match = re.compile( - rb'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I -) -_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) - -# Certain elements aren't meant for display. -ARIA_ROLES_TO_IGNORE = {"directory", "menu", "menubar", "toolbar"} - - -def _normalise_encoding(encoding: str) -> Optional[str]: - """Use the Python codec's name as the normalised entry.""" - try: - return codecs.lookup(encoding).name - except LookupError: - return None - - -def _get_html_media_encodings( - body: bytes, content_type: Optional[str] -) -> Iterable[str]: - """ - Get potential encoding of the body based on the (presumably) HTML body or the content-type header. - - The precedence used for finding a character encoding is: - - 1. <meta> tag with a charset declared. - 2. The XML document's character encoding attribute. - 3. The Content-Type header. - 4. Fallback to utf-8. - 5. Fallback to windows-1252. - - This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector. - - Args: - body: The HTML document, as bytes. - content_type: The Content-Type header. - - Returns: - The character encoding of the body, as a string. - """ - # There's no point in returning an encoding more than once. - attempted_encodings: Set[str] = set() - - # Limit searches to the first 1kb, since it ought to be at the top. - body_start = body[:1024] - - # Check if it has an encoding set in a meta tag. - match = _charset_match.search(body_start) - if match: - encoding = _normalise_encoding(match.group(1).decode("ascii")) - if encoding: - attempted_encodings.add(encoding) - yield encoding - - # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/> - - # Check if it has an XML document with an encoding. - match = _xml_encoding_match.match(body_start) - if match: - encoding = _normalise_encoding(match.group(1).decode("ascii")) - if encoding and encoding not in attempted_encodings: - attempted_encodings.add(encoding) - yield encoding - - # Check the HTTP Content-Type header for a character set. - if content_type: - content_match = _content_type_match.match(content_type) - if content_match: - encoding = _normalise_encoding(content_match.group(1)) - if encoding and encoding not in attempted_encodings: - attempted_encodings.add(encoding) - yield encoding - - # Finally, fallback to UTF-8, then windows-1252. - for fallback in ("utf-8", "cp1252"): - if fallback not in attempted_encodings: - yield fallback - - -def decode_body( - body: bytes, uri: str, content_type: Optional[str] = None -) -> Optional["etree.Element"]: - """ - This uses lxml to parse the HTML document. - - Args: - body: The HTML document, as bytes. - uri: The URI used to download the body. - content_type: The Content-Type header. - - Returns: - The parsed HTML body, or None if an error occurred during processed. - """ - # If there's no body, nothing useful is going to be found. - if not body: - return None - - # The idea here is that multiple encodings are tried until one works. - # Unfortunately the result is never used and then LXML will decode the string - # again with the found encoding. - for encoding in _get_html_media_encodings(body, content_type): - try: - body.decode(encoding) - except Exception: - pass - else: - break - else: - logger.warning("Unable to decode HTML body for %s", uri) - return None - - from lxml import etree - - # Create an HTML parser. - parser = etree.HTMLParser(recover=True, encoding=encoding) - - # Attempt to parse the body. Returns None if the body was successfully - # parsed, but no tree was found. - return etree.fromstring(body, parser) - - -def _get_meta_tags( - tree: "etree.Element", - property: str, - prefix: str, - property_mapper: Optional[Callable[[str], Optional[str]]] = None, -) -> Dict[str, Optional[str]]: - """ - Search for meta tags prefixed with a particular string. - - Args: - tree: The parsed HTML document. - property: The name of the property which contains the tag name, e.g. - "property" for Open Graph. - prefix: The prefix on the property to search for, e.g. "og" for Open Graph. - property_mapper: An optional callable to map the property to the Open Graph - form. Can return None for a key to ignore that key. - - Returns: - A map of tag name to value. - """ - results: Dict[str, Optional[str]] = {} - for tag in tree.xpath( - f"//*/meta[starts-with(@{property}, '{prefix}:')][@content][not(@content='')]" - ): - # if we've got more than 50 tags, someone is taking the piss - if len(results) >= 50: - logger.warning( - "Skipping parsing of Open Graph for page with too many '%s:' tags", - prefix, - ) - return {} - - key = tag.attrib[property] - if property_mapper: - key = property_mapper(key) - # None is a special value used to ignore a value. - if key is None: - continue - - results[key] = tag.attrib["content"] - - return results - - -def _map_twitter_to_open_graph(key: str) -> Optional[str]: - """ - Map a Twitter card property to the analogous Open Graph property. - - Args: - key: The Twitter card property (starts with "twitter:"). - - Returns: - The Open Graph property (starts with "og:") or None to have this property - be ignored. - """ - # Twitter card properties with no analogous Open Graph property. - if key == "twitter:card" or key == "twitter:creator": - return None - if key == "twitter:site": - return "og:site_name" - # Otherwise, swap twitter to og. - return "og" + key[7:] - - -def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: - """ - Parse the HTML document into an Open Graph response. - - This uses lxml to search the HTML document for Open Graph data (or - synthesizes it from the document). - - Args: - tree: The parsed HTML document. - - Returns: - The Open Graph response as a dictionary. - """ - - # Search for Open Graph (og:) meta tags, e.g.: - # - # "og:type" : "video", - # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", - # "og:site_name" : "YouTube", - # "og:video:type" : "application/x-shockwave-flash", - # "og:description" : "Fun stuff happening here", - # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", - # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", - # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", - # "og:video:width" : "1280" - # "og:video:height" : "720", - # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", - - og = _get_meta_tags(tree, "property", "og") - - # TODO: Search for properties specific to the different Open Graph types, - # such as article: meta tags, e.g.: - # - # "article:publisher" : "https://www.facebook.com/thethudonline" /> - # "article:author" content="https://www.facebook.com/thethudonline" /> - # "article:tag" content="baby" /> - # "article:section" content="Breaking News" /> - # "article:published_time" content="2016-03-31T19:58:24+00:00" /> - # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> - - # Search for Twitter Card (twitter:) meta tags, e.g.: - # - # "twitter:site" : "@matrixdotorg" - # "twitter:creator" : "@matrixdotorg" - # - # Twitter cards tags also duplicate Open Graph tags. - # - # See https://developer.twitter.com/en/docs/twitter-for-websites/cards/guides/getting-started - twitter = _get_meta_tags(tree, "name", "twitter", _map_twitter_to_open_graph) - # Merge the Twitter values with the Open Graph values, but do not overwrite - # information from Open Graph tags. - for key, value in twitter.items(): - if key not in og: - og[key] = value - - if "og:title" not in og: - # Attempt to find a title from the title tag, or the biggest header on the page. - title = tree.xpath("((//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1])/text()") - if title: - og["og:title"] = title[0].strip() - else: - og["og:title"] = None - - if "og:image" not in og: - meta_image = tree.xpath( - "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image'][not(@content='')]/@content[1]" - ) - # If a meta image is found, use it. - if meta_image: - og["og:image"] = meta_image[0] - else: - # Try to find images which are larger than 10px by 10px. - # - # TODO: consider inlined CSS styles as well as width & height attribs - images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") - images = sorted( - images, - key=lambda i: ( - -1 * float(i.attrib["width"]) * float(i.attrib["height"]) - ), - ) - # If no images were found, try to find *any* images. - if not images: - images = tree.xpath("//img[@src][1]") - if images: - og["og:image"] = images[0].attrib["src"] - - # Finally, fallback to the favicon if nothing else. - else: - favicons = tree.xpath("//link[@href][contains(@rel, 'icon')]/@href[1]") - if favicons: - og["og:image"] = favicons[0] - - if "og:description" not in og: - # Check the first meta description tag for content. - meta_description = tree.xpath( - "//*/meta[translate(@name, 'DESCRIPTION', 'description')='description'][not(@content='')]/@content[1]" - ) - # If a meta description is found with content, use it. - if meta_description: - og["og:description"] = meta_description[0] - else: - og["og:description"] = parse_html_description(tree) - elif og["og:description"]: - # This must be a non-empty string at this point. - assert isinstance(og["og:description"], str) - og["og:description"] = summarize_paragraphs([og["og:description"]]) - - # TODO: delete the url downloads to stop diskfilling, - # as we only ever cared about its OG - return og - - -def parse_html_description(tree: "etree.Element") -> Optional[str]: - """ - Calculate a text description based on an HTML document. - - Grabs any text nodes which are inside the <body/> tag, unless they are within - an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or - if they are within a <script/>, <svg/> or <style/> tag, or if they are within - a tag whose content is usually only shown to old browsers - (<iframe/>, <video/>, <canvas/>, <picture/>). - - This is a very very very coarse approximation to a plain text render of the page. - - Args: - tree: The parsed HTML document. - - Returns: - The plain text description, or None if one cannot be generated. - """ - # We don't just use XPATH here as that is slow on some machines. - - from lxml import etree - - TAGS_TO_REMOVE = { - "header", - "nav", - "aside", - "footer", - "script", - "noscript", - "style", - "svg", - "iframe", - "video", - "canvas", - "img", - "picture", - etree.Comment, - } - - # Split all the text nodes into paragraphs (by splitting on new - # lines) - text_nodes = ( - re.sub(r"\s+", "\n", el).strip() - for el in _iterate_over_text(tree.find("body"), TAGS_TO_REMOVE) - ) - return summarize_paragraphs(text_nodes) - - -def _iterate_over_text( - tree: Optional["etree.Element"], - tags_to_ignore: Set[Union[str, "etree.Comment"]], - stack_limit: int = 1024, -) -> Generator[str, None, None]: - """Iterate over the tree returning text nodes in a depth first fashion, - skipping text nodes inside certain tags. - - Args: - tree: The parent element to iterate. Can be None if there isn't one. - tags_to_ignore: Set of tags to ignore - stack_limit: Maximum stack size limit for depth-first traversal. - Nodes will be dropped if this limit is hit, which may truncate the - textual result. - Intended to limit the maximum working memory when generating a preview. - """ - - if tree is None: - return - - # This is a stack whose items are elements to iterate over *or* strings - # to be returned. - elements: List[Union[str, "etree.Element"]] = [tree] - while elements: - el = elements.pop() - - if isinstance(el, str): - yield el - elif el.tag not in tags_to_ignore: - # If the element isn't meant for display, ignore it. - if el.get("role") in ARIA_ROLES_TO_IGNORE: - continue - - # el.text is the text before the first child, so we can immediately - # return it if the text exists. - if el.text: - yield el.text - - # We add to the stack all the element's children, interspersed with - # each child's tail text (if it exists). - # - # We iterate in reverse order so that earlier pieces of text appear - # closer to the top of the stack. - for child in el.iterchildren(reversed=True): - if len(elements) > stack_limit: - # We've hit our limit for working memory - break - - if child.tail: - # The tail text of a node is text that comes *after* the node, - # so we always include it even if we ignore the child node. - elements.append(child.tail) - - elements.append(child) - - -def summarize_paragraphs( - text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500 -) -> Optional[str]: - """ - Try to get a summary respecting first paragraph and then word boundaries. - - Args: - text_nodes: The paragraphs to summarize. - min_size: The minimum number of words to include. - max_size: The maximum number of words to include. - - Returns: - A summary of the text nodes, or None if that was not possible. - """ - - # TODO: Respect sentences? - - description = "" - - # Keep adding paragraphs until we get to the MIN_SIZE. - for text_node in text_nodes: - if len(description) < min_size: - text_node = re.sub(r"[\t \r\n]+", " ", text_node) - description += text_node + "\n\n" - else: - break - - description = description.strip() - description = re.sub(r"[\t ]+", " ", description) - description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description) - - # If the concatenation of paragraphs to get above MIN_SIZE - # took us over MAX_SIZE, then we need to truncate mid paragraph - if len(description) > max_size: - new_desc = "" - - # This splits the paragraph into words, but keeping the - # (preceding) whitespace intact so we can easily concat - # words back together. - for match in re.finditer(r"\s*\S+", description): - word = match.group() - - # Keep adding words while the total length is less than - # MAX_SIZE. - if len(word) + len(new_desc) < max_size: - new_desc += word - else: - # At this point the next word *will* take us over - # MAX_SIZE, but we also want to ensure that its not - # a huge word. If it is add it anyway and we'll - # truncate later. - if len(new_desc) < min_size: - new_desc += word - break - - # Double check that we're not over the limit - if len(new_desc) > max_size: - new_desc = new_desc[:max_size] - - # We always add an ellipsis because at the very least - # we chopped mid paragraph. - description = new_desc.strip() + "…" - return description if description else None diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 1c9b71d69c..d7653f30ae 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -1,4 +1,4 @@ -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# Copyright 2023 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,171 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# -import abc -import logging -import os -import shutil -from typing import TYPE_CHECKING, Callable, Optional - -from synapse.config._base import Config -from synapse.logging.context import defer_to_thread, run_in_background -from synapse.util.async_helpers import maybe_awaitable - -from ._base import FileInfo, Responder -from .media_storage import FileResponder - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class StorageProvider(metaclass=abc.ABCMeta): - """A storage provider is a service that can store uploaded media and - retrieve them. - """ - - @abc.abstractmethod - async def store_file(self, path: str, file_info: FileInfo) -> None: - """Store the file described by file_info. The actual contents can be - retrieved by reading the file in file_info.upload_path. - - Args: - path: Relative path of file in local cache - file_info: The metadata of the file. - """ - - @abc.abstractmethod - async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: - """Attempt to fetch the file described by file_info and stream it - into writer. - - Args: - path: Relative path of file in local cache - file_info: The metadata of the file. - - Returns: - Returns a Responder if the provider has the file, otherwise returns None. - """ - - -class StorageProviderWrapper(StorageProvider): - """Wraps a storage provider and provides various config options - - Args: - backend: The storage provider to wrap. - store_local: Whether to store new local files or not. - store_synchronous: Whether to wait for file to be successfully - uploaded, or todo the upload in the background. - store_remote: Whether remote media should be uploaded - """ - - def __init__( - self, - backend: StorageProvider, - store_local: bool, - store_synchronous: bool, - store_remote: bool, - ): - self.backend = backend - self.store_local = store_local - self.store_synchronous = store_synchronous - self.store_remote = store_remote - - def __str__(self) -> str: - return "StorageProviderWrapper[%s]" % (self.backend,) - - async def store_file(self, path: str, file_info: FileInfo) -> None: - if not file_info.server_name and not self.store_local: - return None - - if file_info.server_name and not self.store_remote: - return None - - if file_info.url_cache: - # The URL preview cache is short lived and not worth offloading or - # backing up. - return None - - if self.store_synchronous: - # store_file is supposed to return an Awaitable, but guard - # against improper implementations. - await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore - else: - # TODO: Handle errors. - async def store() -> None: - try: - return await maybe_awaitable( - self.backend.store_file(path, file_info) - ) - except Exception: - logger.exception("Error storing file") - - run_in_background(store) - - async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: - if file_info.url_cache: - # Files in the URL preview cache definitely aren't stored here, - # so avoid any potentially slow I/O or network access. - return None - - # store_file is supposed to return an Awaitable, but guard - # against improper implementations. - return await maybe_awaitable(self.backend.fetch(path, file_info)) - - -class FileStorageProviderBackend(StorageProvider): - """A storage provider that stores files in a directory on a filesystem. - - Args: - hs - config: The config returned by `parse_config`. - """ - - def __init__(self, hs: "HomeServer", config: str): - self.hs = hs - self.cache_directory = hs.config.media.media_store_path - self.base_directory = config - - def __str__(self) -> str: - return "FileStorageProviderBackend[%s]" % (self.base_directory,) - - async def store_file(self, path: str, file_info: FileInfo) -> None: - """See StorageProvider.store_file""" - - primary_fname = os.path.join(self.cache_directory, path) - backup_fname = os.path.join(self.base_directory, path) - - dirname = os.path.dirname(backup_fname) - os.makedirs(dirname, exist_ok=True) - - # mypy needs help inferring the type of the second parameter, which is generic - shutil_copyfile: Callable[[str, str], str] = shutil.copyfile - await defer_to_thread( - self.hs.get_reactor(), - shutil_copyfile, - primary_fname, - backup_fname, - ) - - async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: - """See StorageProvider.fetch""" - - backup_fname = os.path.join(self.base_directory, path) - if os.path.isfile(backup_fname): - return FileResponder(open(backup_fname, "rb")) - - return None - - @staticmethod - def parse_config(config: dict) -> str: - """Called on startup to parse config supplied. This should parse - the config and raise if there is a problem. - - The returned value is passed into the constructor. - - In this case we only care about a single param, the directory, so let's - just pull that out. - """ - return Config.ensure_directory(config["directory"]) +# This exists purely for backwards compatibility with media providers. +from synapse.media.storage_provider import StorageProvider # noqa: F401 diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py deleted file mode 100644 index 9480cc5763..0000000000 --- a/synapse/rest/media/v1/thumbnailer.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2020-2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from io import BytesIO -from types import TracebackType -from typing import Optional, Tuple, Type - -from PIL import Image - -logger = logging.getLogger(__name__) - -EXIF_ORIENTATION_TAG = 0x0112 -EXIF_TRANSPOSE_MAPPINGS = { - 2: Image.FLIP_LEFT_RIGHT, - 3: Image.ROTATE_180, - 4: Image.FLIP_TOP_BOTTOM, - 5: Image.TRANSPOSE, - 6: Image.ROTATE_270, - 7: Image.TRANSVERSE, - 8: Image.ROTATE_90, -} - - -class ThumbnailError(Exception): - """An error occurred generating a thumbnail.""" - - -class Thumbnailer: - - FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} - - @staticmethod - def set_limits(max_image_pixels: int) -> None: - Image.MAX_IMAGE_PIXELS = max_image_pixels - - def __init__(self, input_path: str): - # Have we closed the image? - self._closed = False - - try: - self.image = Image.open(input_path) - except OSError as e: - # If an error occurs opening the image, a thumbnail won't be able to - # be generated. - raise ThumbnailError from e - except Image.DecompressionBombError as e: - # If an image decompression bomb error occurs opening the image, - # then the image exceeds the pixel limit and a thumbnail won't - # be able to be generated. - raise ThumbnailError from e - - self.width, self.height = self.image.size - self.transpose_method = None - try: - # We don't use ImageOps.exif_transpose since it crashes with big EXIF - # - # Ignore safety: Pillow seems to acknowledge that this method is - # "private, experimental, but generally widely used". Pillow 6 - # includes a public getexif() method (no underscore) that we might - # consider using instead when we can bump that dependency. - # - # At the time of writing, Debian buster (currently oldstable) - # provides version 5.4.1. It's expected to EOL in mid-2022, see - # https://wiki.debian.org/DebianReleases#Production_Releases - image_exif = self.image._getexif() # type: ignore - if image_exif is not None: - image_orientation = image_exif.get(EXIF_ORIENTATION_TAG) - assert type(image_orientation) is int - self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation) - except Exception as e: - # A lot of parsing errors can happen when parsing EXIF - logger.info("Error parsing image EXIF information: %s", e) - - def transpose(self) -> Tuple[int, int]: - """Transpose the image using its EXIF Orientation tag - - Returns: - A tuple containing the new image size in pixels as (width, height). - """ - if self.transpose_method is not None: - # Safety: `transpose` takes an int rather than e.g. an IntEnum. - # self.transpose_method is set above to be a value in - # EXIF_TRANSPOSE_MAPPINGS, and that only contains correct values. - with self.image: - self.image = self.image.transpose(self.transpose_method) # type: ignore[arg-type] - self.width, self.height = self.image.size - self.transpose_method = None - # We don't need EXIF any more - self.image.info["exif"] = None - return self.image.size - - def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]: - """Calculate the largest size that preserves aspect ratio which - fits within the given rectangle:: - - (w_in / h_in) = (w_out / h_out) - w_out = max(min(w_max, h_max * (w_in / h_in)), 1) - h_out = max(min(h_max, w_max * (h_in / w_in)), 1) - - Args: - max_width: The largest possible width. - max_height: The largest possible height. - """ - - if max_width * self.height < max_height * self.width: - return max_width, max((max_width * self.height) // self.width, 1) - else: - return max((max_height * self.width) // self.height, 1), max_height - - def _resize(self, width: int, height: int) -> Image.Image: - # 1-bit or 8-bit color palette images need converting to RGB - # otherwise they will be scaled using nearest neighbour which - # looks awful. - # - # If the image has transparency, use RGBA instead. - if self.image.mode in ["1", "L", "P"]: - if self.image.info.get("transparency", None) is not None: - with self.image: - self.image = self.image.convert("RGBA") - else: - with self.image: - self.image = self.image.convert("RGB") - return self.image.resize((width, height), Image.ANTIALIAS) - - def scale(self, width: int, height: int, output_type: str) -> BytesIO: - """Rescales the image to the given dimensions. - - Returns: - The bytes of the encoded image ready to be written to disk - """ - with self._resize(width, height) as scaled: - return self._encode_image(scaled, output_type) - - def crop(self, width: int, height: int, output_type: str) -> BytesIO: - """Rescales and crops the image to the given dimensions preserving - aspect:: - (w_in / h_in) = (w_scaled / h_scaled) - w_scaled = max(w_out, h_out * (w_in / h_in)) - h_scaled = max(h_out, w_out * (h_in / w_in)) - - Args: - max_width: The largest possible width. - max_height: The largest possible height. - - Returns: - The bytes of the encoded image ready to be written to disk - """ - if width * self.height > height * self.width: - scaled_width = width - scaled_height = (width * self.height) // self.width - crop_top = (scaled_height - height) // 2 - crop_bottom = height + crop_top - crop = (0, crop_top, width, crop_bottom) - else: - scaled_width = (height * self.width) // self.height - scaled_height = height - crop_left = (scaled_width - width) // 2 - crop_right = width + crop_left - crop = (crop_left, 0, crop_right, height) - - with self._resize(scaled_width, scaled_height) as scaled_image: - with scaled_image.crop(crop) as cropped: - return self._encode_image(cropped, output_type) - - def _encode_image(self, output_image: Image.Image, output_type: str) -> BytesIO: - output_bytes_io = BytesIO() - fmt = self.FORMATS[output_type] - if fmt == "JPEG": - output_image = output_image.convert("RGB") - output_image.save(output_bytes_io, fmt, quality=80) - return output_bytes_io - - def close(self) -> None: - """Closes the underlying image file. - - Once closed no other functions can be called. - - Can be called multiple times. - """ - - if self._closed: - return - - self._closed = True - - # Since we run this on the finalizer then we need to handle `__init__` - # raising an exception before it can define `self.image`. - image = getattr(self, "image", None) - if image is None: - return - - image.close() - - def __enter__(self) -> "Thumbnailer": - """Make `Thumbnailer` a context manager that calls `close` on - `__exit__`. - """ - return self - - def __exit__( - self, - type: Optional[Type[BaseException]], - value: Optional[BaseException], - traceback: Optional[TracebackType], - ) -> None: - self.close() - - def __del__(self) -> None: - # Make sure we actually do close the image, rather than leak data. - self.close() |