summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
authorMathieu Velten <mathieuv@matrix.org>2023-03-13 13:59:16 +0100
committerMathieu Velten <mathieuv@matrix.org>2023-03-13 13:59:16 +0100
commit5980756e0900d23f83926a7223644586183fd0b5 (patch)
treea3513d14a88d1bed4d649c4945535888c2d91ef2 /synapse/rest
parentUpdate changelog.d/14519.misc (diff)
parentBump hiredis from 2.2.1 to 2.2.2 (#15252) (diff)
downloadsynapse-github/mv/mypy-unused-awaitable.tar.xz
Merge remote-tracking branch 'origin/develop' into mv/mypy-unused-awaitable github/mv/mypy-unused-awaitable mv/mypy-unused-awaitable
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/__init__.py5
-rw-r--r--synapse/rest/admin/event_reports.py41
-rw-r--r--synapse/rest/admin/media.py18
-rw-r--r--synapse/rest/admin/rooms.py4
-rw-r--r--synapse/rest/admin/server_notice_servlet.py34
-rw-r--r--synapse/rest/admin/users.py22
-rw-r--r--synapse/rest/client/account.py19
-rw-r--r--synapse/rest/client/auth.py1
-rw-r--r--synapse/rest/client/devices.py2
-rw-r--r--synapse/rest/client/events.py16
-rw-r--r--synapse/rest/client/filter.py1
-rw-r--r--synapse/rest/client/keys.py32
-rw-r--r--synapse/rest/client/knock.py17
-rw-r--r--synapse/rest/client/notifications.py12
-rw-r--r--synapse/rest/client/register.py18
-rw-r--r--synapse/rest/client/report_event.py11
-rw-r--r--synapse/rest/client/room.py198
-rw-r--r--synapse/rest/client/room_batch.py16
-rw-r--r--synapse/rest/client/room_keys.py48
-rw-r--r--synapse/rest/client/sendtodevice.py25
-rw-r--r--synapse/rest/client/sync.py35
-rw-r--r--synapse/rest/client/tags.py4
-rw-r--r--synapse/rest/client/transactions.py55
-rw-r--r--synapse/rest/media/config_resource.py (renamed from synapse/rest/media/v1/config_resource.py)0
-rw-r--r--synapse/rest/media/download_resource.py (renamed from synapse/rest/media/v1/download_resource.py)5
-rw-r--r--synapse/rest/media/media_repository_resource.py93
-rw-r--r--synapse/rest/media/preview_url_resource.py (renamed from synapse/rest/media/v1/preview_url_resource.py)45
-rw-r--r--synapse/rest/media/thumbnail_resource.py (renamed from synapse/rest/media/v1/thumbnail_resource.py)10
-rw-r--r--synapse/rest/media/upload_resource.py (renamed from synapse/rest/media/v1/upload_resource.py)4
-rw-r--r--synapse/rest/media/v1/_base.py468
-rw-r--r--synapse/rest/media/v1/filepath.py410
-rw-r--r--synapse/rest/media/v1/media_repository.py1112
-rw-r--r--synapse/rest/media/v1/media_storage.py366
-rw-r--r--synapse/rest/media/v1/oembed.py265
-rw-r--r--synapse/rest/media/v1/preview_html.py501
-rw-r--r--synapse/rest/media/v1/storage_provider.py172
-rw-r--r--synapse/rest/media/v1/thumbnailer.py222
37 files changed, 534 insertions, 3773 deletions
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py

index 14c4e6ebbb..2e19e055d3 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py
@@ -108,8 +108,7 @@ class ClientRestResource(JsonResource): if is_main_process: logout.register_servlets(hs, client_resource) sync.register_servlets(hs, client_resource) - if is_main_process: - filter.register_servlets(hs, client_resource) + filter.register_servlets(hs, client_resource) account.register_servlets(hs, client_resource) register.register_servlets(hs, client_resource) if is_main_process: @@ -140,7 +139,7 @@ class ClientRestResource(JsonResource): relations.register_servlets(hs, client_resource) if is_main_process: password_policy.register_servlets(hs, client_resource) - knock.register_servlets(hs, client_resource) + knock.register_servlets(hs, client_resource) # moving to /_synapse/admin if is_main_process: diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index a3beb74e2c..c546ef7e23 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py
@@ -53,11 +53,11 @@ class EventReportsRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports$") def __init__(self, hs: "HomeServer"): - self.auth = hs.get_auth() - self.store = hs.get_datastores().main + self._auth = hs.get_auth() + self._store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self.auth, request) + await assert_requester_is_admin(self._auth, request) start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) @@ -79,7 +79,7 @@ class EventReportsRestServlet(RestServlet): errcode=Codes.INVALID_PARAM, ) - event_reports, total = await self.store.get_event_reports_paginate( + event_reports, total = await self._store.get_event_reports_paginate( start, limit, direction, user_id, room_id ) ret = {"event_reports": event_reports, "total": total} @@ -108,13 +108,13 @@ class EventReportDetailRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$") def __init__(self, hs: "HomeServer"): - self.auth = hs.get_auth() - self.store = hs.get_datastores().main + self._auth = hs.get_auth() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, report_id: str ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self.auth, request) + await assert_requester_is_admin(self._auth, request) message = ( "The report_id parameter must be a string representing a positive integer." @@ -131,8 +131,33 @@ class EventReportDetailRestServlet(RestServlet): HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM ) - ret = await self.store.get_event_report(resolved_report_id) + ret = await self._store.get_event_report(resolved_report_id) if not ret: raise NotFoundError("Event report not found") return HTTPStatus.OK, ret + + async def on_DELETE( + self, request: SynapseRequest, report_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + message = ( + "The report_id parameter must be a string representing a positive integer." + ) + try: + resolved_report_id = int(report_id) + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) + + if resolved_report_id < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) + + if await self._store.delete_event_report(resolved_report_id): + return HTTPStatus.OK, {} + + raise NotFoundError("Event report not found") diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 0d072c42a7..c134ccfb3d 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py
@@ -15,7 +15,7 @@ import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from synapse.api.constants import Direction from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -285,7 +285,12 @@ class DeleteMediaByDateSize(RestServlet): timestamp and size. """ - PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/delete$") + PATTERNS = [ + *admin_patterns("/media/delete$"), + # This URL kept around for legacy reasons, it is undesirable since it + # overlaps with the DeleteMediaByID servlet. + *admin_patterns("/media/(?P<server_name>[^/]*)/delete$"), + ] def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main @@ -294,7 +299,7 @@ class DeleteMediaByDateSize(RestServlet): self.media_repository = hs.get_media_repository() async def on_POST( - self, request: SynapseRequest, server_name: str + self, request: SynapseRequest, server_name: Optional[str] = None ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -322,7 +327,8 @@ class DeleteMediaByDateSize(RestServlet): errcode=Codes.INVALID_PARAM, ) - if self.server_name != server_name: + # This check is useless, we keep it for the legacy endpoint only. + if server_name is not None and self.server_name != server_name: raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media") logging.info( @@ -489,6 +495,8 @@ def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer) ProtectMediaByID(hs).register(http_server) UnprotectMediaByID(hs).register(http_server) ListMediaInRoom(hs).register(http_server) - DeleteMediaByID(hs).register(http_server) + # XXX DeleteMediaByDateSize must be registered before DeleteMediaByID as + # their URL routes overlap. DeleteMediaByDateSize(hs).register(http_server) + DeleteMediaByID(hs).register(http_server) UserMediaRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 1d6e4982d7..4de56bf13f 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py
@@ -75,7 +75,6 @@ class RoomRestV2Servlet(RestServlet): async def on_DELETE( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) await assert_user_is_admin(self._auth, requester) @@ -144,7 +143,6 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self._auth, request) if not RoomID.is_valid(room_id): @@ -181,7 +179,6 @@ class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, delete_id: str ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self._auth, request) delete_status = self._pagination_handler.get_delete_status(delete_id) @@ -438,7 +435,6 @@ class RoomStateRestServlet(RestServlet): class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): - PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)$") def __init__(self, hs: "HomeServer"): diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 15da9cd881..7dd1c10b91 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from http import HTTPStatus -from typing import TYPE_CHECKING, Awaitable, Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from synapse.api.constants import EventTypes from synapse.api.errors import NotFoundError, SynapseError @@ -23,10 +23,10 @@ from synapse.http.servlet import ( parse_json_object_from_request, ) from synapse.http.site import SynapseRequest -from synapse.rest.admin import assert_requester_is_admin -from synapse.rest.admin._base import admin_patterns +from synapse.logging.opentracing import set_tag +from synapse.rest.admin._base import admin_patterns, assert_user_is_admin from synapse.rest.client.transactions import HttpTransactionCache -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, Requester, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -70,10 +70,13 @@ class SendServerNoticeServlet(RestServlet): self.__class__.__name__, ) - async def on_POST( - self, request: SynapseRequest, txn_id: Optional[str] = None + async def _do( + self, + request: SynapseRequest, + requester: Requester, + txn_id: Optional[str], ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self.auth, request) + await assert_user_is_admin(self.auth, requester) body = parse_json_object_from_request(request) assert_params_in_dict(body, ("user_id", "content")) event_type = body.get("type", EventTypes.Message) @@ -106,9 +109,18 @@ class SendServerNoticeServlet(RestServlet): return HTTPStatus.OK, {"event_id": event.event_id} - def on_PUT( + async def on_POST( + self, + request: SynapseRequest, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + return await self._do(request, requester, None) + + async def on_PUT( self, request: SynapseRequest, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, txn_id + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + set_tag("txn_id", txn_id) + return await self.txns.fetch_or_execute_request( + request, requester, self._do, request, requester, txn_id ) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index b9dca8ef3a..357e9a574d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py
@@ -304,13 +304,20 @@ class UserRestServletV2(RestServlet): # remove old threepids for medium, address in del_threepids: try: - await self.auth_handler.delete_threepid( - user_id, medium, address, None + # Attempt to remove any known bindings of this third-party ID + # and user ID from identity servers. + await self.hs.get_identity_handler().try_unbind_threepid( + user_id, medium, address, id_server=None ) except Exception: logger.exception("Failed to remove threepids") raise SynapseError(500, "Failed to remove threepids") + # Delete the local association of this user ID and third-party ID. + await self.auth_handler.delete_local_threepid( + user_id, medium, address + ) + # add new threepids current_time = self.hs.get_clock().time_msec() for medium, address in add_threepids: @@ -683,8 +690,12 @@ class AccountValidityRenewServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if self.account_activity_handler.on_legacy_admin_request_callback: - expiration_ts = await ( - self.account_activity_handler.on_legacy_admin_request_callback(request) + expiration_ts = ( + await ( + self.account_activity_handler.on_legacy_admin_request_callback( + request + ) + ) ) else: body = parse_json_object_from_request(request) @@ -1192,7 +1203,8 @@ class AccountDataRestServlet(RestServlet): if not await self._store.get_user_by_id(user_id): raise NotFoundError("User not found") - global_data, by_room_data = await self._store.get_account_data_for_user(user_id) + global_data = await self._store.get_global_account_data_for_user(user_id) + by_room_data = await self._store.get_room_account_data_for_user(user_id) return HTTPStatus.OK, { "account_data": { "global": global_data, diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 4373c73662..484d7440a4 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py
@@ -415,6 +415,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): request, MsisdnRequestTokenBody ) msisdn = phone_number_to_msisdn(body.country, body.phone_number) + logger.info("Request #%s to verify ownership of %s", body.send_attempt, msisdn) if not await check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( @@ -444,6 +445,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} + logger.info("MSISDN %s is already in use by %s", msisdn, existing_user_id) raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) if not self.hs.config.registration.account_threepid_delegate_msisdn: @@ -468,6 +470,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe( body.send_attempt ) + logger.info("MSISDN %s: got response from identity server: %s", msisdn, ret) return 200, ret @@ -734,12 +737,7 @@ class ThreepidUnbindRestServlet(RestServlet): # Attempt to unbind the threepid from an identity server. If id_server is None, try to # unbind from all identity servers this threepid has been added to in the past result = await self.identity_handler.try_unbind_threepid( - requester.user.to_string(), - { - "address": body.address, - "medium": body.medium, - "id_server": body.id_server, - }, + requester.user.to_string(), body.medium, body.address, body.id_server ) return 200, {"id_server_unbind_result": "success" if result else "no-support"} @@ -770,7 +768,9 @@ class ThreepidDeleteRestServlet(RestServlet): user_id = requester.user.to_string() try: - ret = await self.auth_handler.delete_threepid( + # Attempt to remove any known bindings of this third-party ID + # and user ID from identity servers. + ret = await self.hs.get_identity_handler().try_unbind_threepid( user_id, body.medium, body.address, body.id_server ) except Exception: @@ -785,6 +785,11 @@ class ThreepidDeleteRestServlet(RestServlet): else: id_server_unbind_result = "no-support" + # Delete the local association of this user ID and third-party ID. + await self.auth_handler.delete_local_threepid( + user_id, body.medium, body.address + ) + return 200, {"id_server_unbind_result": id_server_unbind_result} diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index eb77337044..276a1b405d 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py
@@ -97,7 +97,6 @@ class AuthRestServlet(RestServlet): return None async def on_POST(self, request: Request, stagetype: str) -> None: - session = parse_string(request, "session") if not session: raise SynapseError(400, "No session supplied") diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 486c6dbbc5..dab4a77f7e 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py
@@ -255,7 +255,7 @@ class DehydratedDeviceServlet(RestServlet): """ - PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=()) + PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device$", releases=()) def __init__(self, hs: "HomeServer"): super().__init__() diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py
index 782e7d14e8..694d77d287 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py
@@ -17,6 +17,7 @@ import logging from typing import TYPE_CHECKING, Dict, List, Tuple, Union from synapse.api.errors import SynapseError +from synapse.events.utils import SerializeEventConfig from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_string from synapse.http.site import SynapseRequest @@ -43,9 +44,8 @@ class EventStreamRestServlet(RestServlet): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - is_guest = requester.is_guest args: Dict[bytes, List[bytes]] = request.args # type: ignore - if is_guest: + if requester.is_guest: if b"room_id" not in args: raise SynapseError(400, "Guest users must specify room_id param") room_id = parse_string(request, "room_id") @@ -63,13 +63,12 @@ class EventStreamRestServlet(RestServlet): as_client_event = b"raw" not in args chunk = await self.event_stream_handler.get_stream( - requester.user.to_string(), + requester, pagin_config, timeout=timeout, as_client_event=as_client_event, - affect_presence=(not is_guest), + affect_presence=(not requester.is_guest), room_id=room_id, - is_guest=is_guest, ) return 200, chunk @@ -91,9 +90,12 @@ class EventRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) event = await self.event_handler.get_event(requester.user, None, event_id) - time_now = self.clock.time_msec() if event: - result = self._event_serializer.serialize_event(event, time_now) + result = self._event_serializer.serialize_event( + event, + self.clock.time_msec(), + config=SerializeEventConfig(requester=requester), + ) return 200, result else: return 404, "Event not found." diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py
index cc1c2f9731..236199897c 100644 --- a/synapse/rest/client/filter.py +++ b/synapse/rest/client/filter.py
@@ -79,7 +79,6 @@ class CreateFilterRestServlet(RestServlet): async def on_POST( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: - target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 7873b363c0..32bb8b9a91 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py
@@ -312,15 +312,29 @@ class SigningKeyUploadServlet(RestServlet): user_id = requester.user.to_string() body = parse_json_object_from_request(request) - await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - "add a device signing key to your account", - # Allow skipping of UI auth since this is frequently called directly - # after login and it is silly to ask users to re-auth immediately. - can_skip_ui_auth=True, - ) + if self.hs.config.experimental.msc3967_enabled: + if await self.e2e_keys_handler.is_cross_signing_set_up_for_user(user_id): + # If we already have a master key then cross signing is set up and we require UIA to reset + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "reset the device signing key on your account", + # Do not allow skipping of UIA auth. + can_skip_ui_auth=False, + ) + # Otherwise we don't require UIA since we are setting up cross signing for first time + else: + # Previous behaviour is to always require UIA but allow it to be skipped + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "add a device signing key to your account", + # Allow skipping of UI auth since this is frequently called directly + # after login and it is silly to ask users to re-auth immediately. + can_skip_ui_auth=True, + ) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) return 200, result diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py
index ad025c8a45..4fa66904ba 100644 --- a/synapse/rest/client/knock.py +++ b/synapse/rest/client/knock.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple from synapse.api.constants import Membership from synapse.api.errors import SynapseError @@ -24,8 +24,6 @@ from synapse.http.servlet import ( parse_strings_from_args, ) from synapse.http.site import SynapseRequest -from synapse.logging.opentracing import set_tag -from synapse.rest.client.transactions import HttpTransactionCache from synapse.types import JsonDict, RoomAlias, RoomID if TYPE_CHECKING: @@ -45,7 +43,6 @@ class KnockRoomAliasServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.txns = HttpTransactionCache(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -53,7 +50,6 @@ class KnockRoomAliasServlet(RestServlet): self, request: SynapseRequest, room_identifier: str, - txn_id: Optional[str] = None, ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -67,7 +63,6 @@ class KnockRoomAliasServlet(RestServlet): # twisted.web.server.Request.args is incorrectly defined as Optional[Any] args: Dict[bytes, List[bytes]] = request.args # type: ignore - remote_room_hosts = parse_strings_from_args( args, "server_name", required=False ) @@ -86,7 +81,6 @@ class KnockRoomAliasServlet(RestServlet): target=requester.user, room_id=room_id, action=Membership.KNOCK, - txn_id=txn_id, third_party_signed=None, remote_room_hosts=remote_room_hosts, content=event_content, @@ -94,15 +88,6 @@ class KnockRoomAliasServlet(RestServlet): return 200, {"room_id": room_id} - def on_PUT( - self, request: SynapseRequest, room_identifier: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: - set_tag("txn_id", txn_id) - - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_identifier, txn_id - ) - def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: KnockRoomAliasServlet(hs).register(http_server) diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index 61268e3af1..ea10042569 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py
@@ -72,6 +72,12 @@ class NotificationsServlet(RestServlet): next_token = None + serialize_options = SerializeEventConfig( + event_format=format_event_for_client_v2_without_room_id, + requester=requester, + ) + now = self.clock.time_msec() + for pa in push_actions: returned_pa = { "room_id": pa.room_id, @@ -81,10 +87,8 @@ class NotificationsServlet(RestServlet): "event": ( self._event_serializer.serialize_event( notif_events[pa.event_id], - self.clock.time_msec(), - config=SerializeEventConfig( - event_format=format_event_for_client_v2_without_room_id - ), + now, + config=serialize_options, ) ), } diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 3cb1e7e375..bce806f2bb 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py
@@ -628,10 +628,12 @@ class RegisterRestServlet(RestServlet): if not password_hash: raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) - desired_username = await ( - self.password_auth_provider.get_username_for_registration( - auth_result, - params, + desired_username = ( + await ( + self.password_auth_provider.get_username_for_registration( + auth_result, + params, + ) ) ) @@ -682,9 +684,11 @@ class RegisterRestServlet(RestServlet): session_id ) - display_name = await ( - self.password_auth_provider.get_displayname_for_registration( - auth_result, params + display_name = ( + await ( + self.password_auth_provider.get_displayname_for_registration( + auth_result, params + ) ) ) diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py
index e2b410cf32..9be5860221 100644 --- a/synapse/rest/client/report_event.py +++ b/synapse/rest/client/report_event.py
@@ -16,7 +16,7 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, Tuple -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest @@ -39,6 +39,7 @@ class ReportEventRestServlet(RestServlet): self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastores().main + self._event_handler = self.hs.get_event_handler() async def on_POST( self, request: SynapseRequest, room_id: str, event_id: str @@ -61,6 +62,14 @@ class ReportEventRestServlet(RestServlet): Codes.BAD_JSON, ) + event = await self._event_handler.get_event( + requester.user, room_id, event_id, show_redacted=False + ) + if event is None: + raise NotFoundError( + "Unable to report event: it does not exist or you aren't able to see it." + ) + await self.store.add_event_report( room_id=room_id, event_id=event_id, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index e7bec9f02e..3a3abed56a 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py
@@ -37,7 +37,7 @@ from synapse.api.errors import ( UnredactedContentDeletedError, ) from synapse.api.filtering import Filter -from synapse.events.utils import format_event_for_client_v2 +from synapse.events.utils import SerializeEventConfig, format_event_for_client_v2 from synapse.http.server import HttpServer from synapse.http.servlet import ( ResolveRoomIdMixin, @@ -57,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID +from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID from synapse.types.state import StateFilter from synapse.util import json_decoder from synapse.util.cancellation import cancellable @@ -151,20 +151,27 @@ class RoomCreateRestServlet(TransactionRestServlet): PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) - def on_PUT( + async def on_PUT( self, request: SynapseRequest, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request(request, self.on_POST, request) + return await self.txns.fetch_or_execute_request( + request, requester, self._do, request, requester + ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) + return await self._do(request, requester) - info, _ = await self._room_creation_handler.create_room( + async def _do( + self, request: SynapseRequest, requester: Requester + ) -> Tuple[int, JsonDict]: + room_id, _, _ = await self._room_creation_handler.create_room( requester, self.get_room_config(request) ) - return 200, info + return 200, {"room_id": room_id} def get_room_config(self, request: Request) -> JsonDict: user_supplied_config = parse_json_object_from_request(request) @@ -172,9 +179,9 @@ class RoomCreateRestServlet(TransactionRestServlet): # TODO: Needs unit testing for generic events -class RoomStateEventRestServlet(TransactionRestServlet): +class RoomStateEventRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): - super().__init__(hs) + super().__init__() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.message_handler = hs.get_message_handler() @@ -324,16 +331,16 @@ class RoomSendEventRestServlet(TransactionRestServlet): def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/send/$event_type[/$txn_id] PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)" - register_txn_path(self, PATTERNS, http_server, with_get=True) + register_txn_path(self, PATTERNS, http_server) - async def on_POST( + async def _do( self, request: SynapseRequest, + requester: Requester, room_id: str, event_type: str, - txn_id: Optional[str] = None, + txn_id: Optional[str], ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) event_dict: JsonDict = { @@ -362,18 +369,30 @@ class RoomSendEventRestServlet(TransactionRestServlet): set_tag("event_id", event_id) return 200, {"event_id": event_id} - def on_GET( - self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str - ) -> Tuple[int, str]: - return 200, "Not implemented" + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + event_type: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + return await self._do(request, requester, room_id, event_type, None) - def on_PUT( + async def on_PUT( self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, event_type, txn_id + return await self.txns.fetch_or_execute_request( + request, + requester, + self._do, + request, + requester, + room_id, + event_type, + txn_id, ) @@ -389,14 +408,13 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): PATTERNS = "/join/(?P<room_identifier>[^/]*)" register_txn_path(self, PATTERNS, http_server) - async def on_POST( + async def _do( self, request: SynapseRequest, + requester: Requester, room_identifier: str, - txn_id: Optional[str] = None, + txn_id: Optional[str], ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - content = parse_json_object_from_request(request, allow_empty_body=True) # twisted.web.server.Request.args is incorrectly defined as Optional[Any] @@ -420,22 +438,31 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): return 200, {"room_id": room_id} - def on_PUT( + async def on_POST( + self, + request: SynapseRequest, + room_identifier: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + return await self._do(request, requester, room_identifier, None) + + async def on_PUT( self, request: SynapseRequest, room_identifier: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_identifier, txn_id + return await self.txns.fetch_or_execute_request( + request, requester, self._do, request, requester, room_identifier, txn_id ) # TODO: Needs unit testing -class PublicRoomListRestServlet(TransactionRestServlet): +class PublicRoomListRestServlet(RestServlet): PATTERNS = client_patterns("/publicRooms$", v1=True) def __init__(self, hs: "HomeServer"): - super().__init__(hs) + super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -814,11 +841,13 @@ class RoomEventServlet(RestServlet): [event], requester.user.to_string() ) - time_now = self.clock.time_msec() # per MSC2676, /rooms/{roomId}/event/{eventId}, should return the # *original* event, rather than the edited version event_dict = self._event_serializer.serialize_event( - event, time_now, bundle_aggregations=aggregations, apply_edits=False + event, + self.clock.time_msec(), + bundle_aggregations=aggregations, + config=SerializeEventConfig(requester=requester), ) return 200, event_dict @@ -863,24 +892,30 @@ class RoomEventContextServlet(RestServlet): raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() + serializer_options = SerializeEventConfig(requester=requester) results = { "events_before": self._event_serializer.serialize_events( event_context.events_before, time_now, bundle_aggregations=event_context.aggregations, + config=serializer_options, ), "event": self._event_serializer.serialize_event( event_context.event, time_now, bundle_aggregations=event_context.aggregations, + config=serializer_options, ), "events_after": self._event_serializer.serialize_events( event_context.events_after, time_now, bundle_aggregations=event_context.aggregations, + config=serializer_options, ), "state": self._event_serializer.serialize_events( - event_context.state, time_now + event_context.state, + time_now, + config=serializer_options, ), "start": event_context.start, "end": event_context.end, @@ -899,22 +934,25 @@ class RoomForgetRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget" register_txn_path(self, PATTERNS, http_server) - async def on_POST( - self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=False) - + async def _do(self, requester: Requester, room_id: str) -> Tuple[int, JsonDict]: await self.room_member_handler.forget(user=requester.user, room_id=room_id) return 200, {} - def on_PUT( + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=False) + return await self._do(requester, room_id) + + async def on_PUT( self, request: SynapseRequest, room_id: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=False) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, txn_id + return await self.txns.fetch_or_execute_request( + request, requester, self._do, requester, room_id ) @@ -926,22 +964,21 @@ class RoomMembershipRestServlet(TransactionRestServlet): self.auth = hs.get_auth() def register(self, http_server: HttpServer) -> None: - # /rooms/$roomid/[invite|join|leave] + # /rooms/$roomid/[join|invite|leave|ban|unban|kick] PATTERNS = ( "/rooms/(?P<room_id>[^/]*)/" "(?P<membership_action>join|invite|leave|ban|unban|kick)" ) register_txn_path(self, PATTERNS, http_server) - async def on_POST( + async def _do( self, request: SynapseRequest, + requester: Requester, room_id: str, membership_action: str, - txn_id: Optional[str] = None, + txn_id: Optional[str], ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - if requester.is_guest and membership_action not in { Membership.JOIN, Membership.LEAVE, @@ -1006,13 +1043,30 @@ class RoomMembershipRestServlet(TransactionRestServlet): return 200, return_value - def on_PUT( + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + membership_action: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + return await self._do(request, requester, room_id, membership_action, None) + + async def on_PUT( self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, membership_action, txn_id + return await self.txns.fetch_or_execute_request( + request, + requester, + self._do, + request, + requester, + room_id, + membership_action, + txn_id, ) @@ -1028,14 +1082,14 @@ class RoomRedactEventRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)" register_txn_path(self, PATTERNS, http_server) - async def on_POST( + async def _do( self, request: SynapseRequest, + requester: Requester, room_id: str, event_id: str, - txn_id: Optional[str] = None, + txn_id: Optional[str], ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) try: @@ -1086,13 +1140,23 @@ class RoomRedactEventRestServlet(TransactionRestServlet): set_tag("event_id", event_id) return 200, {"event_id": event_id} - def on_PUT( + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + return await self._do(request, requester, room_id, event_id, None) + + async def on_PUT( self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, event_id, txn_id + return await self.txns.fetch_or_execute_request( + request, requester, self._do, request, requester, room_id, event_id, txn_id ) @@ -1192,7 +1256,7 @@ class SearchRestServlet(RestServlet): content = parse_json_object_from_request(request) batch = parse_string(request, "next_batch") - results = await self.search_handler.search(requester.user, content, batch) + results = await self.search_handler.search(requester, content, batch) return 200, results @@ -1216,7 +1280,6 @@ def register_txn_path( servlet: RestServlet, regex_string: str, http_server: HttpServer, - with_get: bool = False, ) -> None: """Registers a transaction-based path. @@ -1228,7 +1291,6 @@ def register_txn_path( regex_string: The regex string to register. Must NOT have a trailing $ as this string will be appended to. http_server: The http_server to register paths with. - with_get: True to also register respective GET paths for the PUTs. """ on_POST = getattr(servlet, "on_POST", None) on_PUT = getattr(servlet, "on_PUT", None) @@ -1246,18 +1308,6 @@ def register_txn_path( on_PUT, servlet.__class__.__name__, ) - on_GET = getattr(servlet, "on_GET", None) - if with_get: - if on_GET is None: - raise RuntimeError( - "register_txn_path called with with_get = True, but no on_GET method exists" - ) - http_server.register_paths( - "GET", - client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True), - on_GET, - servlet.__class__.__name__, - ) class TimestampLookupRestServlet(RestServlet): diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index 10be4a781b..ef284ecc11 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py
@@ -15,9 +15,7 @@ import logging import re from http import HTTPStatus -from typing import TYPE_CHECKING, Awaitable, Tuple - -from twisted.web.server import Request +from typing import TYPE_CHECKING, Tuple from synapse.api.constants import EventContentFields from synapse.api.errors import AuthError, Codes, SynapseError @@ -30,7 +28,6 @@ from synapse.http.servlet import ( parse_strings_from_args, ) from synapse.http.site import SynapseRequest -from synapse.rest.client.transactions import HttpTransactionCache from synapse.types import JsonDict if TYPE_CHECKING: @@ -79,7 +76,6 @@ class RoomBatchSendEventRestServlet(RestServlet): self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() self.room_batch_handler = hs.get_room_batch_handler() - self.txns = HttpTransactionCache(hs) async def on_POST( self, request: SynapseRequest, room_id: str @@ -249,16 +245,6 @@ class RoomBatchSendEventRestServlet(RestServlet): return HTTPStatus.OK, response_dict - def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]: - return HTTPStatus.NOT_IMPLEMENTED, "Not implemented" - - def on_PUT( - self, request: SynapseRequest, room_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id - ) - def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: msc2716_enabled = hs.config.experimental.msc2716_enabled diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py
index f7081f638e..4e7ffdb555 100644 --- a/synapse/rest/client/room_keys.py +++ b/synapse/rest/client/room_keys.py
@@ -259,6 +259,32 @@ class RoomKeysNewVersionServlet(RestServlet): self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + """ + Retrieve the version information about the most current backup version (if any) + + It takes out an exclusive lock on this user's room_key backups, to ensure + clients only upload to the current backup. + + Returns 404 if the given version does not exist. + + GET /room_keys/version HTTP/1.1 + { + "version": "12345", + "algorithm": "m.megolm_backup.v1", + "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" + } + """ + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + + try: + info = await self.e2e_room_keys_handler.get_version_info(user_id) + except SynapseError as e: + if e.code == 404: + raise SynapseError(404, "No backup found", Codes.NOT_FOUND) + return 200, info + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """ Create a new backup version for this user's room_keys with the given @@ -301,7 +327,7 @@ class RoomKeysNewVersionServlet(RestServlet): class RoomKeysVersionServlet(RestServlet): - PATTERNS = client_patterns("/room_keys/version(/(?P<version>[^/]+))?$") + PATTERNS = client_patterns("/room_keys/version/(?P<version>[^/]+)$") def __init__(self, hs: "HomeServer"): super().__init__() @@ -309,12 +335,11 @@ class RoomKeysVersionServlet(RestServlet): self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() async def on_GET( - self, request: SynapseRequest, version: Optional[str] + self, request: SynapseRequest, version: str ) -> Tuple[int, JsonDict]: """ Retrieve the version information about a given version of the user's - room_keys backup. If the version part is missing, returns info about the - most current backup version (if any) + room_keys backup. It takes out an exclusive lock on this user's room_key backups, to ensure clients only upload to the current backup. @@ -339,20 +364,16 @@ class RoomKeysVersionServlet(RestServlet): return 200, info async def on_DELETE( - self, request: SynapseRequest, version: Optional[str] + self, request: SynapseRequest, version: str ) -> Tuple[int, JsonDict]: """ Delete the information about a given version of the user's - room_keys backup. If the version part is missing, deletes the most - current backup version (if any). Doesn't delete the actual room data. + room_keys backup. Doesn't delete the actual room data. DELETE /room_keys/version/12345 HTTP/1.1 HTTP/1.1 200 OK {} """ - if version is None: - raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND) - requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() @@ -360,7 +381,7 @@ class RoomKeysVersionServlet(RestServlet): return 200, {} async def on_PUT( - self, request: SynapseRequest, version: Optional[str] + self, request: SynapseRequest, version: str ) -> Tuple[int, JsonDict]: """ Update the information about a given version of the user's room_keys backup. @@ -386,11 +407,6 @@ class RoomKeysVersionServlet(RestServlet): user_id = requester.user.to_string() info = parse_json_object_from_request(request) - if version is None: - raise SynapseError( - 400, "No version specified to update", Codes.MISSING_PARAM - ) - await self.e2e_room_keys_handler.update_version(user_id, version, info) return 200, {} diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py
index 55d52f0b28..110af6df47 100644 --- a/synapse/rest/client/sendtodevice.py +++ b/synapse/rest/client/sendtodevice.py
@@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Awaitable, Tuple +from typing import TYPE_CHECKING, Tuple from synapse.http import servlet from synapse.http.server import HttpServer @@ -21,7 +21,7 @@ from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_r from synapse.http.site import SynapseRequest from synapse.logging.opentracing import set_tag from synapse.rest.client.transactions import HttpTransactionCache -from synapse.types import JsonDict +from synapse.types import JsonDict, Requester from ._base import client_patterns @@ -43,19 +43,26 @@ class SendToDeviceRestServlet(servlet.RestServlet): self.txns = HttpTransactionCache(hs) self.device_message_handler = hs.get_device_message_handler() - def on_PUT( + async def on_PUT( self, request: SynapseRequest, message_type: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self._put, request, message_type, txn_id + return await self.txns.fetch_or_execute_request( + request, + requester, + self._put, + request, + requester, + message_type, ) async def _put( - self, request: SynapseRequest, message_type: str, txn_id: str + self, + request: SynapseRequest, + requester: Requester, + message_type: str, ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - content = parse_json_object_from_request(request) assert_params_in_dict(content, ("messages",)) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index f2013faeb2..e578b26fa3 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py
@@ -16,7 +16,7 @@ import logging from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union -from synapse.api.constants import EduTypes, Membership, PresenceState +from synapse.api.constants import AccountDataTypes, EduTypes, Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState @@ -38,7 +38,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.logging.opentracing import trace_with_opname -from synapse.types import JsonDict, StreamToken +from synapse.types import JsonDict, Requester, StreamToken from synapse.util import json_decoder from ._base import client_patterns, set_timeline_upper_limit @@ -139,7 +139,28 @@ class SyncRestServlet(RestServlet): device_id, ) - request_key = (user, timeout, since, filter_id, full_state, device_id) + # Stream position of the last ignored users account data event for this user, + # if we're initial syncing. + # We include this in the request key to invalidate an initial sync + # in the response cache once the set of ignored users has changed. + # (We filter out ignored users from timeline events, so our sync response + # is invalid once the set of ignored users changes.) + last_ignore_accdata_streampos: Optional[int] = None + if not since: + # No `since`, so this is an initial sync. + last_ignore_accdata_streampos = await self.store.get_latest_stream_id_for_global_account_data_by_type_for_user( + user.to_string(), AccountDataTypes.IGNORED_USER_LIST + ) + + request_key = ( + user, + timeout, + since, + filter_id, + full_state, + device_id, + last_ignore_accdata_streampos, + ) if filter_id is None: filter_collection = self.filtering.DEFAULT_FILTER_COLLECTION @@ -205,7 +226,7 @@ class SyncRestServlet(RestServlet): # We know that the the requester has an access token since appservices # cannot use sync. response_content = await self.encode_response( - time_now, sync_result, requester.access_token_id, filter_collection + time_now, sync_result, requester, filter_collection ) logger.debug("Event formatting complete") @@ -216,7 +237,7 @@ class SyncRestServlet(RestServlet): self, time_now: int, sync_result: SyncResult, - access_token_id: Optional[int], + requester: Requester, filter: FilterCollection, ) -> JsonDict: logger.debug("Formatting events in sync response") @@ -229,12 +250,12 @@ class SyncRestServlet(RestServlet): serialize_options = SerializeEventConfig( event_format=event_formatter, - token_id=access_token_id, + requester=requester, only_event_fields=filter.event_fields, ) stripped_serialize_options = SerializeEventConfig( event_format=event_formatter, - token_id=access_token_id, + requester=requester, include_stripped_room_state=True, ) diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py
index ca638755c7..dde08417a4 100644 --- a/synapse/rest/client/tags.py +++ b/synapse/rest/client/tags.py
@@ -34,7 +34,9 @@ class TagListServlet(RestServlet): GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 """ - PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags") + PATTERNS = client_patterns( + "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags$" + ) def __init__(self, hs: "HomeServer"): super().__init__() diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 3a2c6bd36d..df194f8439 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py
@@ -15,16 +15,16 @@ """This module contains logic for storing HTTP PUT transactions. This is used to ensure idempotency when performing PUTs using the REST API.""" import logging -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Hashable, Tuple from typing_extensions import ParamSpec from twisted.internet.defer import Deferred from twisted.python.failure import Failure -from twisted.web.server import Request +from twisted.web.iweb import IRequest from synapse.logging.context import make_deferred_yieldable, run_in_background -from synapse.types import JsonDict +from synapse.types import JsonDict, Requester from synapse.util.async_helpers import ObservableDeferred if TYPE_CHECKING: @@ -41,53 +41,47 @@ P = ParamSpec("P") class HttpTransactionCache: def __init__(self, hs: "HomeServer"): self.hs = hs - self.auth = self.hs.get_auth() self.clock = self.hs.get_clock() # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp) self.transactions: Dict[ - str, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int] + Hashable, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int] ] = {} # Try to clean entries every 30 mins. This means entries will exist # for at *LEAST* 30 mins, and at *MOST* 60 mins. self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) - def _get_transaction_key(self, request: Request) -> str: + def _get_transaction_key(self, request: IRequest, requester: Requester) -> Hashable: """A helper function which returns a transaction key that can be used with TransactionCache for idempotent requests. Idempotency is based on the returned key being the same for separate requests to the same endpoint. The key is formed from the HTTP request - path and the access_token for the requesting user. + path and attributes from the requester: the access_token_id for regular users, + the user ID for guest users, and the appservice ID for appservice users. Args: - request: The incoming request. Must contain an access_token. + request: The incoming request. + requester: The requester doing the request. Returns: A transaction key """ assert request.path is not None - token = self.auth.get_access_token_from_request(request) - return request.path.decode("utf8") + "/" + token + path: str = request.path.decode("utf8") + if requester.is_guest: + assert requester.user is not None, "Guest requester must have a user ID set" + return (path, "guest", requester.user) + elif requester.app_service is not None: + return (path, "appservice", requester.app_service.id) + else: + assert ( + requester.access_token_id is not None + ), "Requester must have an access_token_id" + return (path, "user", requester.access_token_id) def fetch_or_execute_request( self, - request: Request, - fn: Callable[P, Awaitable[Tuple[int, JsonDict]]], - *args: P.args, - **kwargs: P.kwargs, - ) -> Awaitable[Tuple[int, JsonDict]]: - """A helper function for fetch_or_execute which extracts - a transaction key from the given request. - - See: - fetch_or_execute - """ - return self.fetch_or_execute( - self._get_transaction_key(request), fn, *args, **kwargs - ) - - def fetch_or_execute( - self, - txn_key: str, + request: IRequest, + requester: Requester, fn: Callable[P, Awaitable[Tuple[int, JsonDict]]], *args: P.args, **kwargs: P.kwargs, @@ -96,14 +90,15 @@ class HttpTransactionCache: to produce a response for this transaction. Args: - txn_key: A key to ensure idempotency should fetch_or_execute be - called again at a later point in time. + request: + requester: fn: A function which returns a tuple of (response_code, response_dict). *args: Arguments to pass to fn. **kwargs: Keyword arguments to pass to fn. Returns: Deferred which resolves to a tuple of (response_code, response_dict). """ + txn_key = self._get_transaction_key(request, requester) if txn_key in self.transactions: observable = self.transactions[txn_key][0] else: diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/config_resource.py
index a95804d327..a95804d327 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/config_resource.py
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/download_resource.py
index 048a042692..8f270cf4cc 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/download_resource.py
@@ -22,11 +22,10 @@ from synapse.http.server import ( ) from synapse.http.servlet import parse_boolean from synapse.http.site import SynapseRequest - -from ._base import parse_media_id, respond_404 +from synapse.media._base import parse_media_id, respond_404 if TYPE_CHECKING: - from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/media_repository_resource.py b/synapse/rest/media/media_repository_resource.py new file mode 100644
index 0000000000..5ebaa3b032 --- /dev/null +++ b/synapse/rest/media/media_repository_resource.py
@@ -0,0 +1,93 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from synapse.config._base import ConfigError +from synapse.http.server import UnrecognizedRequestResource + +from .config_resource import MediaConfigResource +from .download_resource import DownloadResource +from .preview_url_resource import PreviewUrlResource +from .thumbnail_resource import ThumbnailResource +from .upload_resource import UploadResource + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class MediaRepositoryResource(UnrecognizedRequestResource): + """File uploading and downloading. + + Uploads are POSTed to a resource which returns a token which is used to GET + the download:: + + => POST /_matrix/media/r0/upload HTTP/1.1 + Content-Type: <media-type> + Content-Length: <content-length> + + <media> + + <= HTTP/1.1 200 OK + Content-Type: application/json + + { "content_uri": "mxc://<server-name>/<media-id>" } + + => GET /_matrix/media/r0/download/<server-name>/<media-id> HTTP/1.1 + + <= HTTP/1.1 200 OK + Content-Type: <media-type> + Content-Disposition: attachment;filename=<upload-filename> + + <media> + + Clients can get thumbnails by supplying a desired width and height and + thumbnailing method:: + + => GET /_matrix/media/r0/thumbnail/<server_name> + /<media-id>?width=<w>&height=<h>&method=<m> HTTP/1.1 + + <= HTTP/1.1 200 OK + Content-Type: image/jpeg or image/png + + <thumbnail> + + The thumbnail methods are "crop" and "scale". "scale" tries to return an + image where either the width or the height is smaller than the requested + size. The client should then scale and letterbox the image if it needs to + fit within a given rectangle. "crop" tries to return an image where the + width and height are close to the requested size and the aspect matches + the requested size. The client should scale the image if it needs to fit + within a given rectangle. + """ + + def __init__(self, hs: "HomeServer"): + # If we're not configured to use it, raise if we somehow got here. + if not hs.config.media.can_load_media_repo: + raise ConfigError("Synapse is not configured to use a media repo.") + + super().__init__() + media_repo = hs.get_media_repository() + + self.putChild(b"upload", UploadResource(hs, media_repo)) + self.putChild(b"download", DownloadResource(hs, media_repo)) + self.putChild( + b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage) + ) + if hs.config.media.url_preview_enabled: + self.putChild( + b"preview_url", + PreviewUrlResource(hs, media_repo, media_repo.media_storage), + ) + self.putChild(b"config", MediaConfigResource(hs)) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/preview_url_resource.py
index a8f6fd6b35..7ada728757 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/preview_url_resource.py
@@ -40,21 +40,19 @@ from synapse.http.server import ( from synapse.http.servlet import parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.media._base import FileInfo, get_filename_from_headers +from synapse.media.media_storage import MediaStorage +from synapse.media.oembed import OEmbedProvider +from synapse.media.preview_html import decode_body, parse_html_to_open_graph from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.rest.media.v1._base import get_filename_from_headers -from synapse.rest.media.v1.media_storage import MediaStorage -from synapse.rest.media.v1.oembed import OEmbedProvider -from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph from synapse.types import JsonDict, UserID from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.stringutils import random_string -from ._base import FileInfo - if TYPE_CHECKING: - from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -163,6 +161,10 @@ class PreviewUrlResource(DirectServeJsonResource): 7. Stores the result in the database cache. 4. Returns the result. + If any additional requests (e.g. from oEmbed autodiscovery, step 5.3 or + image thumbnailing, step 5.4 or 6.4) fails then the URL preview as a whole + does not fail. As much information as possible is returned. + The in-memory cache expires after 1 hour. Expired entries in the database cache (and their associated media files) are @@ -364,16 +366,25 @@ class PreviewUrlResource(DirectServeJsonResource): oembed_url = self._oembed.autodiscover_from_html(tree) og_from_oembed: JsonDict = {} if oembed_url: - oembed_info = await self._handle_url( - oembed_url, user, allow_data_urls=True - ) - ( - og_from_oembed, - author_name, - expiration_ms, - ) = await self._handle_oembed_response( - url, oembed_info, expiration_ms - ) + try: + oembed_info = await self._handle_url( + oembed_url, user, allow_data_urls=True + ) + except Exception as e: + # Fetching the oEmbed info failed, don't block the entire URL preview. + logger.warning( + "oEmbed fetch failed during URL preview: %s errored with %s", + oembed_url, + e, + ) + else: + ( + og_from_oembed, + author_name, + expiration_ms, + ) = await self._handle_oembed_response( + url, oembed_info, expiration_ms + ) # Parse Open Graph information from the HTML in case the oEmbed # response failed or is incomplete. diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index 5f725c7600..4ee2a0dbda 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/thumbnail_resource.py
@@ -27,9 +27,7 @@ from synapse.http.server import ( ) from synapse.http.servlet import parse_integer, parse_string from synapse.http.site import SynapseRequest -from synapse.rest.media.v1.media_storage import MediaStorage - -from ._base import ( +from synapse.media._base import ( FileInfo, ThumbnailInfo, parse_media_id, @@ -37,9 +35,10 @@ from ._base import ( respond_with_file, respond_with_responder, ) +from synapse.media.media_storage import MediaStorage if TYPE_CHECKING: - from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -69,7 +68,8 @@ class ThumbnailResource(DirectServeJsonResource): width = parse_integer(request, "width", required=True) height = parse_integer(request, "height", required=True) method = parse_string(request, "method", "scale") - m_type = parse_string(request, "type", "image/png") + # TODO Parse the Accept header to get an prioritised list of thumbnail types. + m_type = "image/png" if server_name == self.server_name: if self.dynamic_thumbnails: diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/upload_resource.py
index 97548b54e5..697348613b 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/upload_resource.py
@@ -20,10 +20,10 @@ from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_bytes_from_args from synapse.http.site import SynapseRequest -from synapse.rest.media.v1.media_storage import SpamMediaException +from synapse.media.media_storage import SpamMediaException if TYPE_CHECKING: - from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.media.media_repository import MediaRepository from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index d30878f704..88427a5737 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py
@@ -1,5 +1,4 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# Copyright 2023 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,466 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# -import logging -import os -import urllib -from types import TracebackType -from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type - -import attr - -from twisted.internet.interfaces import IConsumer -from twisted.protocols.basic import FileSender -from twisted.web.server import Request - -from synapse.api.errors import Codes, SynapseError, cs_error -from synapse.http.server import finish_request, respond_with_json -from synapse.http.site import SynapseRequest -from synapse.logging.context import make_deferred_yieldable -from synapse.util.stringutils import is_ascii, parse_and_validate_server_name - -logger = logging.getLogger(__name__) - -# list all text content types that will have the charset default to UTF-8 when -# none is given -TEXT_CONTENT_TYPES = [ - "text/css", - "text/csv", - "text/html", - "text/calendar", - "text/plain", - "text/javascript", - "application/json", - "application/ld+json", - "application/rtf", - "image/svg+xml", - "text/xml", -] - - -def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: - """Parses the server name, media ID and optional file name from the request URI - - Also performs some rough validation on the server name. - - Args: - request: The `Request`. - - Returns: - A tuple containing the parsed server name, media ID and optional file name. - - Raises: - SynapseError(404): if parsing or validation fail for any reason - """ - try: - # The type on postpath seems incorrect in Twisted 21.2.0. - postpath: List[bytes] = request.postpath # type: ignore - assert postpath - - # This allows users to append e.g. /test.png to the URL. Useful for - # clients that parse the URL to see content type. - server_name_bytes, media_id_bytes = postpath[:2] - server_name = server_name_bytes.decode("utf-8") - media_id = media_id_bytes.decode("utf8") - - # Validate the server name, raising if invalid - parse_and_validate_server_name(server_name) - - file_name = None - if len(postpath) > 2: - try: - file_name = urllib.parse.unquote(postpath[-1].decode("utf-8")) - except UnicodeDecodeError: - pass - return server_name, media_id, file_name - except Exception: - raise SynapseError( - 404, "Invalid media id token %r" % (request.postpath,), Codes.UNKNOWN - ) - - -def respond_404(request: SynapseRequest) -> None: - respond_with_json( - request, - 404, - cs_error("Not found %r" % (request.postpath,), code=Codes.NOT_FOUND), - send_cors=True, - ) - - -async def respond_with_file( - request: SynapseRequest, - media_type: str, - file_path: str, - file_size: Optional[int] = None, - upload_name: Optional[str] = None, -) -> None: - logger.debug("Responding with %r", file_path) - - if os.path.isfile(file_path): - if file_size is None: - stat = os.stat(file_path) - file_size = stat.st_size - - add_file_headers(request, media_type, file_size, upload_name) - - with open(file_path, "rb") as f: - await make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) - - finish_request(request) - else: - respond_404(request) - - -def add_file_headers( - request: Request, - media_type: str, - file_size: Optional[int], - upload_name: Optional[str], -) -> None: - """Adds the correct response headers in preparation for responding with the - media. - - Args: - request - media_type: The media/content type. - file_size: Size in bytes of the media, if known. - upload_name: The name of the requested file, if any. - """ - - def _quote(x: str) -> str: - return urllib.parse.quote(x.encode("utf-8")) - - # Default to a UTF-8 charset for text content types. - # ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16' - if media_type.lower() in TEXT_CONTENT_TYPES: - content_type = media_type + "; charset=UTF-8" - else: - content_type = media_type - - request.setHeader(b"Content-Type", content_type.encode("UTF-8")) - if upload_name: - # RFC6266 section 4.1 [1] defines both `filename` and `filename*`. - # - # `filename` is defined to be a `value`, which is defined by RFC2616 - # section 3.6 [2] to be a `token` or a `quoted-string`, where a `token` - # is (essentially) a single US-ASCII word, and a `quoted-string` is a - # US-ASCII string surrounded by double-quotes, using backslash as an - # escape character. Note that %-encoding is *not* permitted. - # - # `filename*` is defined to be an `ext-value`, which is defined in - # RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`, - # where `value-chars` is essentially a %-encoded string in the given charset. - # - # [1]: https://tools.ietf.org/html/rfc6266#section-4.1 - # [2]: https://tools.ietf.org/html/rfc2616#section-3.6 - # [3]: https://tools.ietf.org/html/rfc5987#section-3.2.1 - - # We avoid the quoted-string version of `filename`, because (a) synapse didn't - # correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we - # may as well just do the filename* version. - if _can_encode_filename_as_token(upload_name): - disposition = "inline; filename=%s" % (upload_name,) - else: - disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),) - - request.setHeader(b"Content-Disposition", disposition.encode("ascii")) - - # cache for at least a day. - # XXX: we might want to turn this off for data we don't want to - # recommend caching as it's sensitive or private - or at least - # select private. don't bother setting Expires as all our - # clients are smart enough to be happy with Cache-Control - request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") - if file_size is not None: - request.setHeader(b"Content-Length", b"%d" % (file_size,)) - - # Tell web crawlers to not index, archive, or follow links in media. This - # should help to prevent things in the media repo from showing up in web - # search results. - request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex") - - -# separators as defined in RFC2616. SP and HT are handled separately. -# see _can_encode_filename_as_token. -_FILENAME_SEPARATOR_CHARS = { - "(", - ")", - "<", - ">", - "@", - ",", - ";", - ":", - "\\", - '"', - "/", - "[", - "]", - "?", - "=", - "{", - "}", -} - - -def _can_encode_filename_as_token(x: str) -> bool: - for c in x: - # from RFC2616: - # - # token = 1*<any CHAR except CTLs or separators> - # - # separators = "(" | ")" | "<" | ">" | "@" - # | "," | ";" | ":" | "\" | <"> - # | "/" | "[" | "]" | "?" | "=" - # | "{" | "}" | SP | HT - # - # CHAR = <any US-ASCII character (octets 0 - 127)> - # - # CTL = <any US-ASCII control character - # (octets 0 - 31) and DEL (127)> - # - if ord(c) >= 127 or ord(c) <= 32 or c in _FILENAME_SEPARATOR_CHARS: - return False - return True - - -async def respond_with_responder( - request: SynapseRequest, - responder: "Optional[Responder]", - media_type: str, - file_size: Optional[int], - upload_name: Optional[str] = None, -) -> None: - """Responds to the request with given responder. If responder is None then - returns 404. - - Args: - request - responder - media_type: The media/content type. - file_size: Size in bytes of the media. If not known it should be None - upload_name: The name of the requested file, if any. - """ - if not responder: - respond_404(request) - return - - # If we have a responder we *must* use it as a context manager. - with responder: - if request._disconnected: - logger.warning( - "Not sending response to request %s, already disconnected.", request - ) - return - - logger.debug("Responding to media request with responder %s", responder) - add_file_headers(request, media_type, file_size, upload_name) - try: - - await responder.write_to_consumer(request) - except Exception as e: - # The majority of the time this will be due to the client having gone - # away. Unfortunately, Twisted simply throws a generic exception at us - # in that case. - logger.warning("Failed to write to consumer: %s %s", type(e), e) - - # Unregister the producer, if it has one, so Twisted doesn't complain - if request.producer: - request.unregisterProducer() - - finish_request(request) - - -class Responder: - """Represents a response that can be streamed to the requester. - - Responder is a context manager which *must* be used, so that any resources - held can be cleaned up. - """ - - def write_to_consumer(self, consumer: IConsumer) -> Awaitable: - """Stream response into consumer - - Args: - consumer: The consumer to stream into. - - Returns: - Resolves once the response has finished being written - """ - - def __enter__(self) -> None: - pass - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - pass - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class ThumbnailInfo: - """Details about a generated thumbnail.""" - - width: int - height: int - method: str - # Content type of thumbnail, e.g. image/png - type: str - # The size of the media file, in bytes. - length: Optional[int] = None - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class FileInfo: - """Details about a requested/uploaded file.""" - - # The server name where the media originated from, or None if local. - server_name: Optional[str] - # The local ID of the file. For local files this is the same as the media_id - file_id: str - # If the file is for the url preview cache - url_cache: bool = False - # Whether the file is a thumbnail or not. - thumbnail: Optional[ThumbnailInfo] = None - - # The below properties exist to maintain compatibility with third-party modules. - @property - def thumbnail_width(self) -> Optional[int]: - if not self.thumbnail: - return None - return self.thumbnail.width - - @property - def thumbnail_height(self) -> Optional[int]: - if not self.thumbnail: - return None - return self.thumbnail.height - - @property - def thumbnail_method(self) -> Optional[str]: - if not self.thumbnail: - return None - return self.thumbnail.method - - @property - def thumbnail_type(self) -> Optional[str]: - if not self.thumbnail: - return None - return self.thumbnail.type - - @property - def thumbnail_length(self) -> Optional[int]: - if not self.thumbnail: - return None - return self.thumbnail.length - - -def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: - """ - Get the filename of the downloaded file by inspecting the - Content-Disposition HTTP header. - - Args: - headers: The HTTP request headers. - - Returns: - The filename, or None. - """ - content_disposition = headers.get(b"Content-Disposition", [b""]) - - # No header, bail out. - if not content_disposition[0]: - return None - - _, params = _parse_header(content_disposition[0]) - - upload_name = None - - # First check if there is a valid UTF-8 filename - upload_name_utf8 = params.get(b"filename*", None) - if upload_name_utf8: - if upload_name_utf8.lower().startswith(b"utf-8''"): - upload_name_utf8 = upload_name_utf8[7:] - # We have a filename*= section. This MUST be ASCII, and any UTF-8 - # bytes are %-quoted. - try: - # Once it is decoded, we can then unquote the %-encoded - # parts strictly into a unicode string. - upload_name = urllib.parse.unquote( - upload_name_utf8.decode("ascii"), errors="strict" - ) - except UnicodeDecodeError: - # Incorrect UTF-8. - pass - - # If there isn't check for an ascii name. - if not upload_name: - upload_name_ascii = params.get(b"filename", None) - if upload_name_ascii and is_ascii(upload_name_ascii): - upload_name = upload_name_ascii.decode("ascii") - - # This may be None here, indicating we did not find a matching name. - return upload_name - - -def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]: - """Parse a Content-type like header. - - Cargo-culted from `cgi`, but works on bytes rather than strings. - - Args: - line: header to be parsed - - Returns: - The main content-type, followed by the parameter dictionary - """ - parts = _parseparam(b";" + line) - key = next(parts) - pdict = {} - for p in parts: - i = p.find(b"=") - if i >= 0: - name = p[:i].strip().lower() - value = p[i + 1 :].strip() - - # strip double-quotes - if len(value) >= 2 and value[0:1] == value[-1:] == b'"': - value = value[1:-1] - value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"') - pdict[name] = value - - return key, pdict - - -def _parseparam(s: bytes) -> Generator[bytes, None, None]: - """Generator which splits the input on ;, respecting double-quoted sequences - - Cargo-culted from `cgi`, but works on bytes rather than strings. - - Args: - s: header to be parsed - - Returns: - The split input - """ - while s[:1] == b";": - s = s[1:] - - # look for the next ; - end = s.find(b";") - - # if there is an odd number of " marks between here and the next ;, skip to the - # next ; instead - while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2: - end = s.find(b";", end + 1) - - if end < 0: - end = len(s) - f = s[:end] - yield f.strip() - s = s[end:] +# This exists purely for backwards compatibility with media providers and spam checkers. +from synapse.media._base import FileInfo, Responder # noqa: F401 diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py deleted file mode 100644
index 1f6441c412..0000000000 --- a/synapse/rest/media/v1/filepath.py +++ /dev/null
@@ -1,410 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2020-2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -import os -import re -import string -from typing import Any, Callable, List, TypeVar, Union, cast - -NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") - - -F = TypeVar("F", bound=Callable[..., str]) - - -def _wrap_in_base_path(func: F) -> F: - """Takes a function that returns a relative path and turns it into an - absolute path based on the location of the primary media store - """ - - @functools.wraps(func) - def _wrapped(self: "MediaFilePaths", *args: Any, **kwargs: Any) -> str: - path = func(self, *args, **kwargs) - return os.path.join(self.base_path, path) - - return cast(F, _wrapped) - - -GetPathMethod = TypeVar( - "GetPathMethod", bound=Union[Callable[..., str], Callable[..., List[str]]] -) - - -def _wrap_with_jail_check(relative: bool) -> Callable[[GetPathMethod], GetPathMethod]: - """Wraps a path-returning method to check that the returned path(s) do not escape - the media store directory. - - The path-returning method may return either a single path, or a list of paths. - - The check is not expected to ever fail, unless `func` is missing a call to - `_validate_path_component`, or `_validate_path_component` is buggy. - - Args: - relative: A boolean indicating whether the wrapped method returns paths relative - to the media store directory. - - Returns: - A method which will wrap a path-returning method, adding a check to ensure that - the returned path(s) lie within the media store directory. The check will raise - a `ValueError` if it fails. - """ - - def _wrap_with_jail_check_inner(func: GetPathMethod) -> GetPathMethod: - @functools.wraps(func) - def _wrapped( - self: "MediaFilePaths", *args: Any, **kwargs: Any - ) -> Union[str, List[str]]: - path_or_paths = func(self, *args, **kwargs) - - if isinstance(path_or_paths, list): - paths_to_check = path_or_paths - else: - paths_to_check = [path_or_paths] - - for path in paths_to_check: - # Construct the path that will ultimately be used. - # We cannot guess whether `path` is relative to the media store - # directory, since the media store directory may itself be a relative - # path. - if relative: - path = os.path.join(self.base_path, path) - normalized_path = os.path.normpath(path) - - # Now that `normpath` has eliminated `../`s and `./`s from the path, - # `os.path.commonpath` can be used to check whether it lies within the - # media store directory. - if ( - os.path.commonpath([normalized_path, self.normalized_base_path]) - != self.normalized_base_path - ): - # The path resolves to outside the media store directory, - # or `self.base_path` is `.`, which is an unlikely configuration. - raise ValueError(f"Invalid media store path: {path!r}") - - # Note that `os.path.normpath`/`abspath` has a subtle caveat: - # `a/b/c/../c` will normalize to `a/b/c`, but the former refers to a - # different path if `a/b/c` is a symlink. That is, the check above is - # not perfect and may allow a certain restricted subset of untrustworthy - # paths through. Since the check above is secondary to the main - # `_validate_path_component` checks, it's less important for it to be - # perfect. - # - # As an alternative, `os.path.realpath` will resolve symlinks, but - # proves problematic if there are symlinks inside the media store. - # eg. if `url_store/` is symlinked to elsewhere, its canonical path - # won't match that of the main media store directory. - - return path_or_paths - - return cast(GetPathMethod, _wrapped) - - return _wrap_with_jail_check_inner - - -ALLOWED_CHARACTERS = set( - string.ascii_letters - + string.digits - + "_-" - + ".[]:" # Domain names, IPv6 addresses and ports in server names -) -FORBIDDEN_NAMES = { - "", - os.path.curdir, # "." for the current platform - os.path.pardir, # ".." for the current platform -} - - -def _validate_path_component(name: str) -> str: - """Checks that the given string can be safely used as a path component - - Args: - name: The path component to check. - - Returns: - The path component if valid. - - Raises: - ValueError: If `name` cannot be safely used as a path component. - """ - if not ALLOWED_CHARACTERS.issuperset(name) or name in FORBIDDEN_NAMES: - raise ValueError(f"Invalid path component: {name!r}") - - return name - - -class MediaFilePaths: - """Describes where files are stored on disk. - - Most of the functions have a `*_rel` variant which returns a file path that - is relative to the base media store path. This is mainly used when we want - to write to the backup media store (when one is configured) - """ - - def __init__(self, primary_base_path: str): - self.base_path = primary_base_path - self.normalized_base_path = os.path.normpath(self.base_path) - - # Refuse to initialize if paths cannot be validated correctly for the current - # platform. - assert os.path.sep not in ALLOWED_CHARACTERS - assert os.path.altsep not in ALLOWED_CHARACTERS - # On Windows, paths have all sorts of weirdness which `_validate_path_component` - # does not consider. In any case, the remote media store can't work correctly - # for certain homeservers there, since ":"s aren't allowed in paths. - assert os.name == "posix" - - @_wrap_with_jail_check(relative=True) - def local_media_filepath_rel(self, media_id: str) -> str: - return os.path.join( - "local_content", - _validate_path_component(media_id[0:2]), - _validate_path_component(media_id[2:4]), - _validate_path_component(media_id[4:]), - ) - - local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) - - @_wrap_with_jail_check(relative=True) - def local_media_thumbnail_rel( - self, media_id: str, width: int, height: int, content_type: str, method: str - ) -> str: - top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) - return os.path.join( - "local_thumbnails", - _validate_path_component(media_id[0:2]), - _validate_path_component(media_id[2:4]), - _validate_path_component(media_id[4:]), - _validate_path_component(file_name), - ) - - local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) - - @_wrap_with_jail_check(relative=False) - def local_media_thumbnail_dir(self, media_id: str) -> str: - """ - Retrieve the local store path of thumbnails of a given media_id - - Args: - media_id: The media ID to query. - Returns: - Path of local_thumbnails from media_id - """ - return os.path.join( - self.base_path, - "local_thumbnails", - _validate_path_component(media_id[0:2]), - _validate_path_component(media_id[2:4]), - _validate_path_component(media_id[4:]), - ) - - @_wrap_with_jail_check(relative=True) - def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: - return os.path.join( - "remote_content", - _validate_path_component(server_name), - _validate_path_component(file_id[0:2]), - _validate_path_component(file_id[2:4]), - _validate_path_component(file_id[4:]), - ) - - remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) - - @_wrap_with_jail_check(relative=True) - def remote_media_thumbnail_rel( - self, - server_name: str, - file_id: str, - width: int, - height: int, - content_type: str, - method: str, - ) -> str: - top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) - return os.path.join( - "remote_thumbnail", - _validate_path_component(server_name), - _validate_path_component(file_id[0:2]), - _validate_path_component(file_id[2:4]), - _validate_path_component(file_id[4:]), - _validate_path_component(file_name), - ) - - remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) - - # Legacy path that was used to store thumbnails previously. - # Should be removed after some time, when most of the thumbnails are stored - # using the new path. - @_wrap_with_jail_check(relative=True) - def remote_media_thumbnail_rel_legacy( - self, server_name: str, file_id: str, width: int, height: int, content_type: str - ) -> str: - top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) - return os.path.join( - "remote_thumbnail", - _validate_path_component(server_name), - _validate_path_component(file_id[0:2]), - _validate_path_component(file_id[2:4]), - _validate_path_component(file_id[4:]), - _validate_path_component(file_name), - ) - - @_wrap_with_jail_check(relative=False) - def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: - return os.path.join( - self.base_path, - "remote_thumbnail", - _validate_path_component(server_name), - _validate_path_component(file_id[0:2]), - _validate_path_component(file_id[2:4]), - _validate_path_component(file_id[4:]), - ) - - @_wrap_with_jail_check(relative=True) - def url_cache_filepath_rel(self, media_id: str) -> str: - if NEW_FORMAT_ID_RE.match(media_id): - # Media id is of the form <DATE><RANDOM_STRING> - # E.g.: 2017-09-28-fsdRDt24DS234dsf - return os.path.join( - "url_cache", - _validate_path_component(media_id[:10]), - _validate_path_component(media_id[11:]), - ) - else: - return os.path.join( - "url_cache", - _validate_path_component(media_id[0:2]), - _validate_path_component(media_id[2:4]), - _validate_path_component(media_id[4:]), - ) - - url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) - - @_wrap_with_jail_check(relative=False) - def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: - "The dirs to try and remove if we delete the media_id file" - if NEW_FORMAT_ID_RE.match(media_id): - return [ - os.path.join( - self.base_path, "url_cache", _validate_path_component(media_id[:10]) - ) - ] - else: - return [ - os.path.join( - self.base_path, - "url_cache", - _validate_path_component(media_id[0:2]), - _validate_path_component(media_id[2:4]), - ), - os.path.join( - self.base_path, "url_cache", _validate_path_component(media_id[0:2]) - ), - ] - - @_wrap_with_jail_check(relative=True) - def url_cache_thumbnail_rel( - self, media_id: str, width: int, height: int, content_type: str, method: str - ) -> str: - # Media id is of the form <DATE><RANDOM_STRING> - # E.g.: 2017-09-28-fsdRDt24DS234dsf - - top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) - - if NEW_FORMAT_ID_RE.match(media_id): - return os.path.join( - "url_cache_thumbnails", - _validate_path_component(media_id[:10]), - _validate_path_component(media_id[11:]), - _validate_path_component(file_name), - ) - else: - return os.path.join( - "url_cache_thumbnails", - _validate_path_component(media_id[0:2]), - _validate_path_component(media_id[2:4]), - _validate_path_component(media_id[4:]), - _validate_path_component(file_name), - ) - - url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) - - @_wrap_with_jail_check(relative=True) - def url_cache_thumbnail_directory_rel(self, media_id: str) -> str: - # Media id is of the form <DATE><RANDOM_STRING> - # E.g.: 2017-09-28-fsdRDt24DS234dsf - - if NEW_FORMAT_ID_RE.match(media_id): - return os.path.join( - "url_cache_thumbnails", - _validate_path_component(media_id[:10]), - _validate_path_component(media_id[11:]), - ) - else: - return os.path.join( - "url_cache_thumbnails", - _validate_path_component(media_id[0:2]), - _validate_path_component(media_id[2:4]), - _validate_path_component(media_id[4:]), - ) - - url_cache_thumbnail_directory = _wrap_in_base_path( - url_cache_thumbnail_directory_rel - ) - - @_wrap_with_jail_check(relative=False) - def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: - "The dirs to try and remove if we delete the media_id thumbnails" - # Media id is of the form <DATE><RANDOM_STRING> - # E.g.: 2017-09-28-fsdRDt24DS234dsf - if NEW_FORMAT_ID_RE.match(media_id): - return [ - os.path.join( - self.base_path, - "url_cache_thumbnails", - _validate_path_component(media_id[:10]), - _validate_path_component(media_id[11:]), - ), - os.path.join( - self.base_path, - "url_cache_thumbnails", - _validate_path_component(media_id[:10]), - ), - ] - else: - return [ - os.path.join( - self.base_path, - "url_cache_thumbnails", - _validate_path_component(media_id[0:2]), - _validate_path_component(media_id[2:4]), - _validate_path_component(media_id[4:]), - ), - os.path.join( - self.base_path, - "url_cache_thumbnails", - _validate_path_component(media_id[0:2]), - _validate_path_component(media_id[2:4]), - ), - os.path.join( - self.base_path, - "url_cache_thumbnails", - _validate_path_component(media_id[0:2]), - ), - ] diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py deleted file mode 100644
index c70e1837af..0000000000 --- a/synapse/rest/media/v1/media_repository.py +++ /dev/null
@@ -1,1112 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import errno -import logging -import os -import shutil -from io import BytesIO -from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple - -from matrix_common.types.mxc_uri import MXCUri - -import twisted.internet.error -import twisted.web.http -from twisted.internet.defer import Deferred - -from synapse.api.errors import ( - FederationDeniedError, - HttpResponseException, - NotFoundError, - RequestSendFailed, - SynapseError, -) -from synapse.config._base import ConfigError -from synapse.config.repository import ThumbnailRequirement -from synapse.http.server import UnrecognizedRequestResource -from synapse.http.site import SynapseRequest -from synapse.logging.context import defer_to_thread -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import UserID -from synapse.util.async_helpers import Linearizer -from synapse.util.retryutils import NotRetryingDestination -from synapse.util.stringutils import random_string - -from ._base import ( - FileInfo, - Responder, - ThumbnailInfo, - get_filename_from_headers, - respond_404, - respond_with_responder, -) -from .config_resource import MediaConfigResource -from .download_resource import DownloadResource -from .filepath import MediaFilePaths -from .media_storage import MediaStorage -from .preview_url_resource import PreviewUrlResource -from .storage_provider import StorageProviderWrapper -from .thumbnail_resource import ThumbnailResource -from .thumbnailer import Thumbnailer, ThumbnailError -from .upload_resource import UploadResource - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - -# How often to run the background job to update the "recently accessed" -# attribute of local and remote media. -UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 # 1 minute -# How often to run the background job to check for local and remote media -# that should be purged according to the configured media retention settings. -MEDIA_RETENTION_CHECK_PERIOD_MS = 60 * 60 * 1000 # 1 hour - - -class MediaRepository: - def __init__(self, hs: "HomeServer"): - self.hs = hs - self.auth = hs.get_auth() - self.client = hs.get_federation_http_client() - self.clock = hs.get_clock() - self.server_name = hs.hostname - self.store = hs.get_datastores().main - self.max_upload_size = hs.config.media.max_upload_size - self.max_image_pixels = hs.config.media.max_image_pixels - - Thumbnailer.set_limits(self.max_image_pixels) - - self.primary_base_path: str = hs.config.media.media_store_path - self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path) - - self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails - self.thumbnail_requirements = hs.config.media.thumbnail_requirements - - self.remote_media_linearizer = Linearizer(name="media_remote") - - self.recently_accessed_remotes: Set[Tuple[str, str]] = set() - self.recently_accessed_locals: Set[str] = set() - - self.federation_domain_whitelist = ( - hs.config.federation.federation_domain_whitelist - ) - - # List of StorageProviders where we should search for media and - # potentially upload to. - storage_providers = [] - - for ( - clz, - provider_config, - wrapper_config, - ) in hs.config.media.media_storage_providers: - backend = clz(hs, provider_config) - provider = StorageProviderWrapper( - backend, - store_local=wrapper_config.store_local, - store_remote=wrapper_config.store_remote, - store_synchronous=wrapper_config.store_synchronous, - ) - storage_providers.append(provider) - - self.media_storage = MediaStorage( - self.hs, self.primary_base_path, self.filepaths, storage_providers - ) - - self.clock.looping_call( - self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS - ) - - # Media retention configuration options - self._media_retention_local_media_lifetime_ms = ( - hs.config.media.media_retention_local_media_lifetime_ms - ) - self._media_retention_remote_media_lifetime_ms = ( - hs.config.media.media_retention_remote_media_lifetime_ms - ) - - # Check whether local or remote media retention is configured - if ( - hs.config.media.media_retention_local_media_lifetime_ms is not None - or hs.config.media.media_retention_remote_media_lifetime_ms is not None - ): - # Run the background job to apply media retention rules routinely, - # with the duration between runs dictated by the homeserver config. - self.clock.looping_call( - self._start_apply_media_retention_rules, - MEDIA_RETENTION_CHECK_PERIOD_MS, - ) - - def _start_update_recently_accessed(self) -> Deferred: - return run_as_background_process( - "update_recently_accessed_media", self._update_recently_accessed - ) - - def _start_apply_media_retention_rules(self) -> Deferred: - return run_as_background_process( - "apply_media_retention_rules", self._apply_media_retention_rules - ) - - async def _update_recently_accessed(self) -> None: - remote_media = self.recently_accessed_remotes - self.recently_accessed_remotes = set() - - local_media = self.recently_accessed_locals - self.recently_accessed_locals = set() - - await self.store.update_cached_last_access_time( - local_media, remote_media, self.clock.time_msec() - ) - - def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None: - """Mark the given media as recently accessed. - - Args: - server_name: Origin server of media, or None if local - media_id: The media ID of the content - """ - if server_name: - self.recently_accessed_remotes.add((server_name, media_id)) - else: - self.recently_accessed_locals.add(media_id) - - async def create_content( - self, - media_type: str, - upload_name: Optional[str], - content: IO, - content_length: int, - auth_user: UserID, - ) -> MXCUri: - """Store uploaded content for a local user and return the mxc URL - - Args: - media_type: The content type of the file. - upload_name: The name of the file, if provided. - content: A file like object that is the content to store - content_length: The length of the content - auth_user: The user_id of the uploader - - Returns: - The mxc url of the stored content - """ - - media_id = random_string(24) - - file_info = FileInfo(server_name=None, file_id=media_id) - - fname = await self.media_storage.store_file(content, file_info) - - logger.info("Stored local media in file %r", fname) - - await self.store.store_local_media( - media_id=media_id, - media_type=media_type, - time_now_ms=self.clock.time_msec(), - upload_name=upload_name, - media_length=content_length, - user_id=auth_user, - ) - - await self._generate_thumbnails(None, media_id, media_id, media_type) - - return MXCUri(self.server_name, media_id) - - async def get_local_media( - self, request: SynapseRequest, media_id: str, name: Optional[str] - ) -> None: - """Responds to requests for local media, if exists, or returns 404. - - Args: - request: The incoming request. - media_id: The media ID of the content. (This is the same as - the file_id for local content.) - name: Optional name that, if specified, will be used as - the filename in the Content-Disposition header of the response. - - Returns: - Resolves once a response has successfully been written to request - """ - media_info = await self.store.get_local_media(media_id) - if not media_info or media_info["quarantined_by"]: - respond_404(request) - return - - self.mark_recently_accessed(None, media_id) - - media_type = media_info["media_type"] - if not media_type: - media_type = "application/octet-stream" - media_length = media_info["media_length"] - upload_name = name if name else media_info["upload_name"] - url_cache = media_info["url_cache"] - - file_info = FileInfo(None, media_id, url_cache=bool(url_cache)) - - responder = await self.media_storage.fetch_media(file_info) - await respond_with_responder( - request, responder, media_type, media_length, upload_name - ) - - async def get_remote_media( - self, - request: SynapseRequest, - server_name: str, - media_id: str, - name: Optional[str], - ) -> None: - """Respond to requests for remote media. - - Args: - request: The incoming request. - server_name: Remote server_name where the media originated. - media_id: The media ID of the content (as defined by the remote server). - name: Optional name that, if specified, will be used as - the filename in the Content-Disposition header of the response. - - Returns: - Resolves once a response has successfully been written to request - """ - if ( - self.federation_domain_whitelist is not None - and server_name not in self.federation_domain_whitelist - ): - raise FederationDeniedError(server_name) - - self.mark_recently_accessed(server_name, media_id) - - # We linearize here to ensure that we don't try and download remote - # media multiple times concurrently - key = (server_name, media_id) - async with self.remote_media_linearizer.queue(key): - responder, media_info = await self._get_remote_media_impl( - server_name, media_id - ) - - # We deliberately stream the file outside the lock - if responder: - media_type = media_info["media_type"] - media_length = media_info["media_length"] - upload_name = name if name else media_info["upload_name"] - await respond_with_responder( - request, responder, media_type, media_length, upload_name - ) - else: - respond_404(request) - - async def get_remote_media_info(self, server_name: str, media_id: str) -> dict: - """Gets the media info associated with the remote file, downloading - if necessary. - - Args: - server_name: Remote server_name where the media originated. - media_id: The media ID of the content (as defined by the remote server). - - Returns: - The media info of the file - """ - if ( - self.federation_domain_whitelist is not None - and server_name not in self.federation_domain_whitelist - ): - raise FederationDeniedError(server_name) - - # We linearize here to ensure that we don't try and download remote - # media multiple times concurrently - key = (server_name, media_id) - async with self.remote_media_linearizer.queue(key): - responder, media_info = await self._get_remote_media_impl( - server_name, media_id - ) - - # Ensure we actually use the responder so that it releases resources - if responder: - with responder: - pass - - return media_info - - async def _get_remote_media_impl( - self, server_name: str, media_id: str - ) -> Tuple[Optional[Responder], dict]: - """Looks for media in local cache, if not there then attempt to - download from remote server. - - Args: - server_name: Remote server_name where the media originated. - media_id: The media ID of the content (as defined by the - remote server). - - Returns: - A tuple of responder and the media info of the file. - """ - media_info = await self.store.get_cached_remote_media(server_name, media_id) - - # file_id is the ID we use to track the file locally. If we've already - # seen the file then reuse the existing ID, otherwise generate a new - # one. - - # If we have an entry in the DB, try and look for it - if media_info: - file_id = media_info["filesystem_id"] - file_info = FileInfo(server_name, file_id) - - if media_info["quarantined_by"]: - logger.info("Media is quarantined") - raise NotFoundError() - - if not media_info["media_type"]: - media_info["media_type"] = "application/octet-stream" - - responder = await self.media_storage.fetch_media(file_info) - if responder: - return responder, media_info - - # Failed to find the file anywhere, lets download it. - - try: - media_info = await self._download_remote_file( - server_name, - media_id, - ) - except SynapseError: - raise - except Exception as e: - # An exception may be because we downloaded media in another - # process, so let's check if we magically have the media. - media_info = await self.store.get_cached_remote_media(server_name, media_id) - if not media_info: - raise e - - file_id = media_info["filesystem_id"] - if not media_info["media_type"]: - media_info["media_type"] = "application/octet-stream" - file_info = FileInfo(server_name, file_id) - - # We generate thumbnails even if another process downloaded the media - # as a) it's conceivable that the other download request dies before it - # generates thumbnails, but mainly b) we want to be sure the thumbnails - # have finished being generated before responding to the client, - # otherwise they'll request thumbnails and get a 404 if they're not - # ready yet. - await self._generate_thumbnails( - server_name, media_id, file_id, media_info["media_type"] - ) - - responder = await self.media_storage.fetch_media(file_info) - return responder, media_info - - async def _download_remote_file( - self, - server_name: str, - media_id: str, - ) -> dict: - """Attempt to download the remote file from the given server name, - using the given file_id as the local id. - - Args: - server_name: Originating server - media_id: The media ID of the content (as defined by the - remote server). This is different than the file_id, which is - locally generated. - file_id: Local file ID - - Returns: - The media info of the file. - """ - - file_id = random_string(24) - - file_info = FileInfo(server_name=server_name, file_id=file_id) - - with self.media_storage.store_into_file(file_info) as (f, fname, finish): - request_path = "/".join( - ("/_matrix/media/r0/download", server_name, media_id) - ) - try: - length, headers = await self.client.get_file( - server_name, - request_path, - output_stream=f, - max_size=self.max_upload_size, - args={ - # tell the remote server to 404 if it doesn't - # recognise the server_name, to make sure we don't - # end up with a routing loop. - "allow_remote": "false" - }, - ) - except RequestSendFailed as e: - logger.warning( - "Request failed fetching remote media %s/%s: %r", - server_name, - media_id, - e, - ) - raise SynapseError(502, "Failed to fetch remote media") - - except HttpResponseException as e: - logger.warning( - "HTTP error fetching remote media %s/%s: %s", - server_name, - media_id, - e.response, - ) - if e.code == twisted.web.http.NOT_FOUND: - raise e.to_synapse_error() - raise SynapseError(502, "Failed to fetch remote media") - - except SynapseError: - logger.warning( - "Failed to fetch remote media %s/%s", server_name, media_id - ) - raise - except NotRetryingDestination: - logger.warning("Not retrying destination %r", server_name) - raise SynapseError(502, "Failed to fetch remote media") - except Exception: - logger.exception( - "Failed to fetch remote media %s/%s", server_name, media_id - ) - raise SynapseError(502, "Failed to fetch remote media") - - await finish() - - if b"Content-Type" in headers: - media_type = headers[b"Content-Type"][0].decode("ascii") - else: - media_type = "application/octet-stream" - upload_name = get_filename_from_headers(headers) - time_now_ms = self.clock.time_msec() - - # Multiple remote media download requests can race (when using - # multiple media repos), so this may throw a violation constraint - # exception. If it does we'll delete the newly downloaded file from - # disk (as we're in the ctx manager). - # - # However: we've already called `finish()` so we may have also - # written to the storage providers. This is preferable to the - # alternative where we call `finish()` *after* this, where we could - # end up having an entry in the DB but fail to write the files to - # the storage providers. - await self.store.store_cached_remote_media( - origin=server_name, - media_id=media_id, - media_type=media_type, - time_now_ms=self.clock.time_msec(), - upload_name=upload_name, - media_length=length, - filesystem_id=file_id, - ) - - logger.info("Stored remote media in file %r", fname) - - media_info = { - "media_type": media_type, - "media_length": length, - "upload_name": upload_name, - "created_ts": time_now_ms, - "filesystem_id": file_id, - } - - return media_info - - def _get_thumbnail_requirements( - self, media_type: str - ) -> Tuple[ThumbnailRequirement, ...]: - scpos = media_type.find(";") - if scpos > 0: - media_type = media_type[:scpos] - return self.thumbnail_requirements.get(media_type, ()) - - def _generate_thumbnail( - self, - thumbnailer: Thumbnailer, - t_width: int, - t_height: int, - t_method: str, - t_type: str, - ) -> Optional[BytesIO]: - m_width = thumbnailer.width - m_height = thumbnailer.height - - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, - m_height, - self.max_image_pixels, - ) - return None - - if thumbnailer.transpose_method is not None: - m_width, m_height = thumbnailer.transpose() - - if t_method == "crop": - return thumbnailer.crop(t_width, t_height, t_type) - elif t_method == "scale": - t_width, t_height = thumbnailer.aspect(t_width, t_height) - t_width = min(m_width, t_width) - t_height = min(m_height, t_height) - return thumbnailer.scale(t_width, t_height, t_type) - - return None - - async def generate_local_exact_thumbnail( - self, - media_id: str, - t_width: int, - t_height: int, - t_method: str, - t_type: str, - url_cache: bool, - ) -> Optional[str]: - input_path = await self.media_storage.ensure_media_is_in_local_cache( - FileInfo(None, media_id, url_cache=url_cache) - ) - - try: - thumbnailer = Thumbnailer(input_path) - except ThumbnailError as e: - logger.warning( - "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s", - media_id, - t_method, - t_type, - e, - ) - return None - - with thumbnailer: - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - self._generate_thumbnail, - thumbnailer, - t_width, - t_height, - t_method, - t_type, - ) - - if t_byte_source: - try: - file_info = FileInfo( - server_name=None, - file_id=media_id, - url_cache=url_cache, - thumbnail=ThumbnailInfo( - width=t_width, - height=t_height, - method=t_method, - type=t_type, - ), - ) - - output_path = await self.media_storage.store_file( - t_byte_source, file_info - ) - finally: - t_byte_source.close() - - logger.info("Stored thumbnail in file %r", output_path) - - t_len = os.path.getsize(output_path) - - await self.store.store_local_thumbnail( - media_id, t_width, t_height, t_type, t_method, t_len - ) - - return output_path - - # Could not generate thumbnail. - return None - - async def generate_remote_exact_thumbnail( - self, - server_name: str, - file_id: str, - media_id: str, - t_width: int, - t_height: int, - t_method: str, - t_type: str, - ) -> Optional[str]: - input_path = await self.media_storage.ensure_media_is_in_local_cache( - FileInfo(server_name, file_id) - ) - - try: - thumbnailer = Thumbnailer(input_path) - except ThumbnailError as e: - logger.warning( - "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s", - media_id, - server_name, - t_method, - t_type, - e, - ) - return None - - with thumbnailer: - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - self._generate_thumbnail, - thumbnailer, - t_width, - t_height, - t_method, - t_type, - ) - - if t_byte_source: - try: - file_info = FileInfo( - server_name=server_name, - file_id=file_id, - thumbnail=ThumbnailInfo( - width=t_width, - height=t_height, - method=t_method, - type=t_type, - ), - ) - - output_path = await self.media_storage.store_file( - t_byte_source, file_info - ) - finally: - t_byte_source.close() - - logger.info("Stored thumbnail in file %r", output_path) - - t_len = os.path.getsize(output_path) - - await self.store.store_remote_media_thumbnail( - server_name, - media_id, - file_id, - t_width, - t_height, - t_type, - t_method, - t_len, - ) - - return output_path - - # Could not generate thumbnail. - return None - - async def _generate_thumbnails( - self, - server_name: Optional[str], - media_id: str, - file_id: str, - media_type: str, - url_cache: bool = False, - ) -> Optional[dict]: - """Generate and store thumbnails for an image. - - Args: - server_name: The server name if remote media, else None if local - media_id: The media ID of the content. (This is the same as - the file_id for local content) - file_id: Local file ID - media_type: The content type of the file - url_cache: If we are thumbnailing images downloaded for the URL cache, - used exclusively by the url previewer - - Returns: - Dict with "width" and "height" keys of original image or None if the - media cannot be thumbnailed. - """ - requirements = self._get_thumbnail_requirements(media_type) - if not requirements: - return None - - input_path = await self.media_storage.ensure_media_is_in_local_cache( - FileInfo(server_name, file_id, url_cache=url_cache) - ) - - try: - thumbnailer = Thumbnailer(input_path) - except ThumbnailError as e: - logger.warning( - "Unable to generate thumbnails for remote media %s from %s of type %s: %s", - media_id, - server_name, - media_type, - e, - ) - return None - - with thumbnailer: - m_width = thumbnailer.width - m_height = thumbnailer.height - - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, - m_height, - self.max_image_pixels, - ) - return None - - if thumbnailer.transpose_method is not None: - m_width, m_height = await defer_to_thread( - self.hs.get_reactor(), thumbnailer.transpose - ) - - # We deduplicate the thumbnail sizes by ignoring the cropped versions if - # they have the same dimensions of a scaled one. - thumbnails: Dict[Tuple[int, int, str], str] = {} - for requirement in requirements: - if requirement.method == "crop": - thumbnails.setdefault( - (requirement.width, requirement.height, requirement.media_type), - requirement.method, - ) - elif requirement.method == "scale": - t_width, t_height = thumbnailer.aspect( - requirement.width, requirement.height - ) - t_width = min(m_width, t_width) - t_height = min(m_height, t_height) - thumbnails[ - (t_width, t_height, requirement.media_type) - ] = requirement.method - - # Now we generate the thumbnails for each dimension, store it - for (t_width, t_height, t_type), t_method in thumbnails.items(): - # Generate the thumbnail - if t_method == "crop": - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - thumbnailer.crop, - t_width, - t_height, - t_type, - ) - elif t_method == "scale": - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - thumbnailer.scale, - t_width, - t_height, - t_type, - ) - else: - logger.error("Unrecognized method: %r", t_method) - continue - - if not t_byte_source: - continue - - file_info = FileInfo( - server_name=server_name, - file_id=file_id, - url_cache=url_cache, - thumbnail=ThumbnailInfo( - width=t_width, - height=t_height, - method=t_method, - type=t_type, - ), - ) - - with self.media_storage.store_into_file(file_info) as ( - f, - fname, - finish, - ): - try: - await self.media_storage.write_to_file(t_byte_source, f) - await finish() - finally: - t_byte_source.close() - - t_len = os.path.getsize(fname) - - # Write to database - if server_name: - # Multiple remote media download requests can race (when - # using multiple media repos), so this may throw a violation - # constraint exception. If it does we'll delete the newly - # generated thumbnail from disk (as we're in the ctx - # manager). - # - # However: we've already called `finish()` so we may have - # also written to the storage providers. This is preferable - # to the alternative where we call `finish()` *after* this, - # where we could end up having an entry in the DB but fail - # to write the files to the storage providers. - try: - await self.store.store_remote_media_thumbnail( - server_name, - media_id, - file_id, - t_width, - t_height, - t_type, - t_method, - t_len, - ) - except Exception as e: - thumbnail_exists = ( - await self.store.get_remote_media_thumbnail( - server_name, - media_id, - t_width, - t_height, - t_type, - ) - ) - if not thumbnail_exists: - raise e - else: - await self.store.store_local_thumbnail( - media_id, t_width, t_height, t_type, t_method, t_len - ) - - return {"width": m_width, "height": m_height} - - async def _apply_media_retention_rules(self) -> None: - """ - Purge old local and remote media according to the media retention rules - defined in the homeserver config. - """ - # Purge remote media - if self._media_retention_remote_media_lifetime_ms is not None: - # Calculate a threshold timestamp derived from the configured lifetime. Any - # media that has not been accessed since this timestamp will be removed. - remote_media_threshold_timestamp_ms = ( - self.clock.time_msec() - self._media_retention_remote_media_lifetime_ms - ) - - logger.info( - "Purging remote media last accessed before" - f" {remote_media_threshold_timestamp_ms}" - ) - - await self.delete_old_remote_media( - before_ts=remote_media_threshold_timestamp_ms - ) - - # And now do the same for local media - if self._media_retention_local_media_lifetime_ms is not None: - # This works the same as the remote media threshold - local_media_threshold_timestamp_ms = ( - self.clock.time_msec() - self._media_retention_local_media_lifetime_ms - ) - - logger.info( - "Purging local media last accessed before" - f" {local_media_threshold_timestamp_ms}" - ) - - await self.delete_old_local_media( - before_ts=local_media_threshold_timestamp_ms, - keep_profiles=True, - delete_quarantined_media=False, - delete_protected_media=False, - ) - - async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]: - old_media = await self.store.get_remote_media_ids( - before_ts, include_quarantined_media=False - ) - - deleted = 0 - - for media in old_media: - origin = media["media_origin"] - media_id = media["media_id"] - file_id = media["filesystem_id"] - key = (origin, media_id) - - logger.info("Deleting: %r", key) - - # TODO: Should we delete from the backup store - - async with self.remote_media_linearizer.queue(key): - full_path = self.filepaths.remote_media_filepath(origin, file_id) - try: - os.remove(full_path) - except OSError as e: - logger.warning("Failed to remove file: %r", full_path) - if e.errno == errno.ENOENT: - pass - else: - continue - - thumbnail_dir = self.filepaths.remote_media_thumbnail_dir( - origin, file_id - ) - shutil.rmtree(thumbnail_dir, ignore_errors=True) - - await self.store.delete_remote_media(origin, media_id) - deleted += 1 - - return {"deleted": deleted} - - async def delete_local_media_ids( - self, media_ids: List[str] - ) -> Tuple[List[str], int]: - """ - Delete the given local or remote media ID from this server - - Args: - media_id: The media ID to delete. - Returns: - A tuple of (list of deleted media IDs, total deleted media IDs). - """ - return await self._remove_local_media_from_disk(media_ids) - - async def delete_old_local_media( - self, - before_ts: int, - size_gt: int = 0, - keep_profiles: bool = True, - delete_quarantined_media: bool = False, - delete_protected_media: bool = False, - ) -> Tuple[List[str], int]: - """ - Delete local or remote media from this server by size and timestamp. Removes - media files, any thumbnails and cached URLs. - - Args: - before_ts: Unix timestamp in ms. - Files that were last used before this timestamp will be deleted. - size_gt: Size of the media in bytes. Files that are larger will be deleted. - keep_profiles: Switch to delete also files that are still used in image data - (e.g user profile, room avatar). If false these files will be deleted. - delete_quarantined_media: If True, media marked as quarantined will be deleted. - delete_protected_media: If True, media marked as protected will be deleted. - - Returns: - A tuple of (list of deleted media IDs, total deleted media IDs). - """ - old_media = await self.store.get_local_media_ids( - before_ts, - size_gt, - keep_profiles, - include_quarantined_media=delete_quarantined_media, - include_protected_media=delete_protected_media, - ) - return await self._remove_local_media_from_disk(old_media) - - async def _remove_local_media_from_disk( - self, media_ids: List[str] - ) -> Tuple[List[str], int]: - """ - Delete local or remote media from this server. Removes media files, - any thumbnails and cached URLs. - - Args: - media_ids: List of media_id to delete - Returns: - A tuple of (list of deleted media IDs, total deleted media IDs). - """ - removed_media = [] - for media_id in media_ids: - logger.info("Deleting media with ID '%s'", media_id) - full_path = self.filepaths.local_media_filepath(media_id) - try: - os.remove(full_path) - except OSError as e: - logger.warning("Failed to remove file: %r: %s", full_path, e) - if e.errno == errno.ENOENT: - pass - else: - continue - - thumbnail_dir = self.filepaths.local_media_thumbnail_dir(media_id) - shutil.rmtree(thumbnail_dir, ignore_errors=True) - - await self.store.delete_remote_media(self.server_name, media_id) - - await self.store.delete_url_cache((media_id,)) - await self.store.delete_url_cache_media((media_id,)) - - removed_media.append(media_id) - - return removed_media, len(removed_media) - - -class MediaRepositoryResource(UnrecognizedRequestResource): - """File uploading and downloading. - - Uploads are POSTed to a resource which returns a token which is used to GET - the download:: - - => POST /_matrix/media/r0/upload HTTP/1.1 - Content-Type: <media-type> - Content-Length: <content-length> - - <media> - - <= HTTP/1.1 200 OK - Content-Type: application/json - - { "content_uri": "mxc://<server-name>/<media-id>" } - - => GET /_matrix/media/r0/download/<server-name>/<media-id> HTTP/1.1 - - <= HTTP/1.1 200 OK - Content-Type: <media-type> - Content-Disposition: attachment;filename=<upload-filename> - - <media> - - Clients can get thumbnails by supplying a desired width and height and - thumbnailing method:: - - => GET /_matrix/media/r0/thumbnail/<server_name> - /<media-id>?width=<w>&height=<h>&method=<m> HTTP/1.1 - - <= HTTP/1.1 200 OK - Content-Type: image/jpeg or image/png - - <thumbnail> - - The thumbnail methods are "crop" and "scale". "scale" tries to return an - image where either the width or the height is smaller than the requested - size. The client should then scale and letterbox the image if it needs to - fit within a given rectangle. "crop" tries to return an image where the - width and height are close to the requested size and the aspect matches - the requested size. The client should scale the image if it needs to fit - within a given rectangle. - """ - - def __init__(self, hs: "HomeServer"): - # If we're not configured to use it, raise if we somehow got here. - if not hs.config.media.can_load_media_repo: - raise ConfigError("Synapse is not configured to use a media repo.") - - super().__init__() - media_repo = hs.get_media_repository() - - self.putChild(b"upload", UploadResource(hs, media_repo)) - self.putChild(b"download", DownloadResource(hs, media_repo)) - self.putChild( - b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage) - ) - if hs.config.media.url_preview_enabled: - self.putChild( - b"preview_url", - PreviewUrlResource(hs, media_repo, media_repo.media_storage), - ) - self.putChild(b"config", MediaConfigResource(hs)) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index a5c3de192f..11b0e8e231 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py
@@ -1,4 +1,4 @@ -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# Copyright 2023 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,365 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -import logging -import os -import shutil -from types import TracebackType -from typing import ( - IO, - TYPE_CHECKING, - Any, - Awaitable, - BinaryIO, - Callable, - Generator, - Optional, - Sequence, - Tuple, - Type, -) - -import attr - -from twisted.internet.defer import Deferred -from twisted.internet.interfaces import IConsumer -from twisted.protocols.basic import FileSender - -import synapse -from synapse.api.errors import NotFoundError -from synapse.logging.context import defer_to_thread, make_deferred_yieldable -from synapse.util import Clock -from synapse.util.file_consumer import BackgroundFileConsumer - -from ._base import FileInfo, Responder -from .filepath import MediaFilePaths - -if TYPE_CHECKING: - from synapse.server import HomeServer - - from .storage_provider import StorageProviderWrapper - -logger = logging.getLogger(__name__) - - -class MediaStorage: - """Responsible for storing/fetching files from local sources. - - Args: - hs - local_media_directory: Base path where we store media on disk - filepaths - storage_providers: List of StorageProvider that are used to fetch and store files. - """ - - def __init__( - self, - hs: "HomeServer", - local_media_directory: str, - filepaths: MediaFilePaths, - storage_providers: Sequence["StorageProviderWrapper"], - ): - self.hs = hs - self.reactor = hs.get_reactor() - self.local_media_directory = local_media_directory - self.filepaths = filepaths - self.storage_providers = storage_providers - self.spam_checker = hs.get_spam_checker() - self.clock = hs.get_clock() - - async def store_file(self, source: IO, file_info: FileInfo) -> str: - """Write `source` to the on disk media store, and also any other - configured storage providers - - Args: - source: A file like object that should be written - file_info: Info about the file to store - - Returns: - the file path written to in the primary media store - """ - - with self.store_into_file(file_info) as (f, fname, finish_cb): - # Write to the main repository - await self.write_to_file(source, f) - await finish_cb() - - return fname - - async def write_to_file(self, source: IO, output: IO) -> None: - """Asynchronously write the `source` to `output`.""" - await defer_to_thread(self.reactor, _write_file_synchronously, source, output) - - @contextlib.contextmanager - def store_into_file( - self, file_info: FileInfo - ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]: - """Context manager used to get a file like object to write into, as - described by file_info. - - Actually yields a 3-tuple (file, fname, finish_cb), where file is a file - like object that can be written to, fname is the absolute path of file - on disk, and finish_cb is a function that returns an awaitable. - - fname can be used to read the contents from after upload, e.g. to - generate thumbnails. - - finish_cb must be called and waited on after the file has been - successfully been written to. Should not be called if there was an - error. - - Args: - file_info: Info about the file to store - - Example: - - with media_storage.store_into_file(info) as (f, fname, finish_cb): - # .. write into f ... - await finish_cb() - """ - - path = self._file_info_to_path(file_info) - fname = os.path.join(self.local_media_directory, path) - - dirname = os.path.dirname(fname) - os.makedirs(dirname, exist_ok=True) - - finished_called = [False] - - try: - with open(fname, "wb") as f: - - async def finish() -> None: - # Ensure that all writes have been flushed and close the - # file. - f.flush() - f.close() - - spam_check = await self.spam_checker.check_media_file_for_spam( - ReadableFileWrapper(self.clock, fname), file_info - ) - if spam_check != synapse.module_api.NOT_SPAM: - logger.info("Blocking media due to spam checker") - # Note that we'll delete the stored media, due to the - # try/except below. The media also won't be stored in - # the DB. - # We currently ignore any additional field returned by - # the spam-check API. - raise SpamMediaException(errcode=spam_check[0]) - - for provider in self.storage_providers: - await provider.store_file(path, file_info) - - finished_called[0] = True - - yield f, fname, finish - except Exception as e: - try: - os.remove(fname) - except Exception: - pass - - raise e from None - - if not finished_called: - raise Exception("Finished callback not called") - - async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: - """Attempts to fetch media described by file_info from the local cache - and configured storage providers. - - Args: - file_info - - Returns: - Returns a Responder if the file was found, otherwise None. - """ - paths = [self._file_info_to_path(file_info)] - - # fallback for remote thumbnails with no method in the filename - if file_info.thumbnail and file_info.server_name: - paths.append( - self.filepaths.remote_media_thumbnail_rel_legacy( - server_name=file_info.server_name, - file_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - ) - ) - - for path in paths: - local_path = os.path.join(self.local_media_directory, path) - if os.path.exists(local_path): - logger.debug("responding with local file %s", local_path) - return FileResponder(open(local_path, "rb")) - logger.debug("local file %s did not exist", local_path) - - for provider in self.storage_providers: - for path in paths: - res: Any = await provider.fetch(path, file_info) - if res: - logger.debug("Streaming %s from %s", path, provider) - return res - logger.debug("%s not found on %s", path, provider) - - return None - - async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: - """Ensures that the given file is in the local cache. Attempts to - download it from storage providers if it isn't. - - Args: - file_info - - Returns: - Full path to local file - """ - path = self._file_info_to_path(file_info) - local_path = os.path.join(self.local_media_directory, path) - if os.path.exists(local_path): - return local_path - - # Fallback for paths without method names - # Should be removed in the future - if file_info.thumbnail and file_info.server_name: - legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( - server_name=file_info.server_name, - file_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - ) - legacy_local_path = os.path.join(self.local_media_directory, legacy_path) - if os.path.exists(legacy_local_path): - return legacy_local_path - - dirname = os.path.dirname(local_path) - os.makedirs(dirname, exist_ok=True) - - for provider in self.storage_providers: - res: Any = await provider.fetch(path, file_info) - if res: - with res: - consumer = BackgroundFileConsumer( - open(local_path, "wb"), self.reactor - ) - await res.write_to_consumer(consumer) - await consumer.wait() - return local_path - - raise NotFoundError() - - def _file_info_to_path(self, file_info: FileInfo) -> str: - """Converts file_info into a relative path. - - The path is suitable for storing files under a directory, e.g. used to - store files on local FS under the base media repository directory. - """ - if file_info.url_cache: - if file_info.thumbnail: - return self.filepaths.url_cache_thumbnail_rel( - media_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - method=file_info.thumbnail.method, - ) - return self.filepaths.url_cache_filepath_rel(file_info.file_id) - - if file_info.server_name: - if file_info.thumbnail: - return self.filepaths.remote_media_thumbnail_rel( - server_name=file_info.server_name, - file_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - method=file_info.thumbnail.method, - ) - return self.filepaths.remote_media_filepath_rel( - file_info.server_name, file_info.file_id - ) - - if file_info.thumbnail: - return self.filepaths.local_media_thumbnail_rel( - media_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - method=file_info.thumbnail.method, - ) - return self.filepaths.local_media_filepath_rel(file_info.file_id) - - -def _write_file_synchronously(source: IO, dest: IO) -> None: - """Write `source` to the file like `dest` synchronously. Should be called - from a thread. - - Args: - source: A file like object that's to be written - dest: A file like object to be written to - """ - source.seek(0) # Ensure we read from the start of the file - shutil.copyfileobj(source, dest) - - -class FileResponder(Responder): - """Wraps an open file that can be sent to a request. - - Args: - open_file: A file like object to be streamed ot the client, - is closed when finished streaming. - """ - - def __init__(self, open_file: IO): - self.open_file = open_file - - def write_to_consumer(self, consumer: IConsumer) -> Deferred: - return make_deferred_yieldable( - FileSender().beginFileTransfer(self.open_file, consumer) - ) - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> None: - self.open_file.close() - - -class SpamMediaException(NotFoundError): - """The media was blocked by a spam checker, so we simply 404 the request (in - the same way as if it was quarantined). - """ - - -@attr.s(slots=True, auto_attribs=True) -class ReadableFileWrapper: - """Wrapper that allows reading a file in chunks, yielding to the reactor, - and writing to a callback. - - This is simplified `FileSender` that takes an IO object rather than an - `IConsumer`. - """ - - CHUNK_SIZE = 2**14 - - clock: Clock - path: str - - async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None: - """Reads the file in chunks and calls the callback with each chunk.""" - - with open(self.path, "rb") as file: - while True: - chunk = file.read(self.CHUNK_SIZE) - if not chunk: - break - - callback(chunk) +# - # We yield to the reactor by sleeping for 0 seconds. - await self.clock.sleep(0) +# This exists purely for backwards compatibility with spam checkers. +from synapse.media.media_storage import ReadableFileWrapper # noqa: F401 diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py deleted file mode 100644
index 7592aa5d47..0000000000 --- a/synapse/rest/media/v1/oembed.py +++ /dev/null
@@ -1,265 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import html -import logging -import urllib.parse -from typing import TYPE_CHECKING, List, Optional - -import attr - -from synapse.rest.media.v1.preview_html import parse_html_description -from synapse.types import JsonDict -from synapse.util import json_decoder - -if TYPE_CHECKING: - from lxml import etree - - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class OEmbedResult: - # The Open Graph result (converted from the oEmbed result). - open_graph_result: JsonDict - # The author_name of the oEmbed result - author_name: Optional[str] - # Number of milliseconds to cache the content, according to the oEmbed response. - # - # This will be None if no cache-age is provided in the oEmbed response (or - # if the oEmbed response cannot be turned into an Open Graph response). - cache_age: Optional[int] - - -class OEmbedProvider: - """ - A helper for accessing oEmbed content. - - It can be used to check if a URL should be accessed via oEmbed and for - requesting/parsing oEmbed content. - """ - - def __init__(self, hs: "HomeServer"): - self._oembed_patterns = {} - for oembed_endpoint in hs.config.oembed.oembed_patterns: - api_endpoint = oembed_endpoint.api_endpoint - - # Only JSON is supported at the moment. This could be declared in - # the formats field. Otherwise, if the endpoint ends in .xml assume - # it doesn't support JSON. - if ( - oembed_endpoint.formats is not None - and "json" not in oembed_endpoint.formats - ) or api_endpoint.endswith(".xml"): - logger.info( - "Ignoring oEmbed endpoint due to not supporting JSON: %s", - api_endpoint, - ) - continue - - # Iterate through each URL pattern and point it to the endpoint. - for pattern in oembed_endpoint.url_patterns: - self._oembed_patterns[pattern] = api_endpoint - - def get_oembed_url(self, url: str) -> Optional[str]: - """ - Check whether the URL should be downloaded as oEmbed content instead. - - Args: - url: The URL to check. - - Returns: - A URL to use instead or None if the original URL should be used. - """ - for url_pattern, endpoint in self._oembed_patterns.items(): - if url_pattern.fullmatch(url): - # TODO Specify max height / width. - - # Note that only the JSON format is supported, some endpoints want - # this in the URL, others want it as an argument. - endpoint = endpoint.replace("{format}", "json") - - args = {"url": url, "format": "json"} - query_str = urllib.parse.urlencode(args, True) - return f"{endpoint}?{query_str}" - - # No match. - return None - - def autodiscover_from_html(self, tree: "etree.Element") -> Optional[str]: - """ - Search an HTML document for oEmbed autodiscovery information. - - Args: - tree: The parsed HTML body. - - Returns: - The URL to use for oEmbed information, or None if no URL was found. - """ - # Search for link elements with the proper rel and type attributes. - for tag in tree.xpath( - "//link[@rel='alternate'][@type='application/json+oembed']" - ): - if "href" in tag.attrib: - return tag.attrib["href"] - - # Some providers (e.g. Flickr) use alternative instead of alternate. - for tag in tree.xpath( - "//link[@rel='alternative'][@type='application/json+oembed']" - ): - if "href" in tag.attrib: - return tag.attrib["href"] - - return None - - def parse_oembed_response(self, url: str, raw_body: bytes) -> OEmbedResult: - """ - Parse the oEmbed response into an Open Graph response. - - Args: - url: The URL which is being previewed (not the one which was - requested). - raw_body: The oEmbed response as JSON encoded as bytes. - - Returns: - json-encoded Open Graph data - """ - - try: - # oEmbed responses *must* be UTF-8 according to the spec. - oembed = json_decoder.decode(raw_body.decode("utf-8")) - except ValueError: - return OEmbedResult({}, None, None) - - # The version is a required string field, but not always provided, - # or sometimes provided as a float. Be lenient. - oembed_version = oembed.get("version", "1.0") - if oembed_version != "1.0" and oembed_version != 1: - return OEmbedResult({}, None, None) - - # Attempt to parse the cache age, if possible. - try: - cache_age = int(oembed.get("cache_age")) * 1000 - except (TypeError, ValueError): - # If the cache age cannot be parsed (e.g. wrong type or invalid - # string), ignore it. - cache_age = None - - # The oEmbed response converted to Open Graph. - open_graph_response: JsonDict = {"og:url": url} - - title = oembed.get("title") - if title and isinstance(title, str): - # A common WordPress plug-in seems to incorrectly escape entities - # in the oEmbed response. - open_graph_response["og:title"] = html.unescape(title) - - author_name = oembed.get("author_name") - if not isinstance(author_name, str): - author_name = None - - # Use the provider name and as the site. - provider_name = oembed.get("provider_name") - if provider_name and isinstance(provider_name, str): - open_graph_response["og:site_name"] = provider_name - - # If a thumbnail exists, use it. Note that dimensions will be calculated later. - thumbnail_url = oembed.get("thumbnail_url") - if thumbnail_url and isinstance(thumbnail_url, str): - open_graph_response["og:image"] = thumbnail_url - - # Process each type separately. - oembed_type = oembed.get("type") - if oembed_type == "rich": - html_str = oembed.get("html") - if isinstance(html_str, str): - calc_description_and_urls(open_graph_response, html_str) - - elif oembed_type == "photo": - # If this is a photo, use the full image, not the thumbnail. - url = oembed.get("url") - if url and isinstance(url, str): - open_graph_response["og:image"] = url - - elif oembed_type == "video": - open_graph_response["og:type"] = "video.other" - html_str = oembed.get("html") - if html_str and isinstance(html_str, str): - calc_description_and_urls(open_graph_response, oembed["html"]) - for size in ("width", "height"): - val = oembed.get(size) - if type(val) is int: - open_graph_response[f"og:video:{size}"] = val - - elif oembed_type == "link": - open_graph_response["og:type"] = "website" - - else: - logger.warning("Unknown oEmbed type: %s", oembed_type) - - return OEmbedResult(open_graph_response, author_name, cache_age) - - -def _fetch_urls(tree: "etree.Element", tag_name: str) -> List[str]: - results = [] - for tag in tree.xpath("//*/" + tag_name): - if "src" in tag.attrib: - results.append(tag.attrib["src"]) - return results - - -def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> None: - """ - Calculate description for an HTML document. - - This uses lxml to convert the HTML document into plaintext. If errors - occur during processing of the document, an empty response is returned. - - Args: - open_graph_response: The current Open Graph summary. This is updated with additional fields. - html_body: The HTML document, as bytes. - - Returns: - The summary - """ - # If there's no body, nothing useful is going to be found. - if not html_body: - return - - from lxml import etree - - # Create an HTML parser. If this fails, log and return no metadata. - parser = etree.HTMLParser(recover=True, encoding="utf-8") - - # Attempt to parse the body. If this fails, log and return no metadata. - tree = etree.fromstring(html_body, parser) - - # The data was successfully parsed, but no tree was found. - if tree is None: - return - - # Attempt to find interesting URLs (images, videos, embeds). - if "og:image" not in open_graph_response: - image_urls = _fetch_urls(tree, "img") - if image_urls: - open_graph_response["og:image"] = image_urls[0] - - video_urls = _fetch_urls(tree, "video") + _fetch_urls(tree, "embed") - if video_urls: - open_graph_response["og:video"] = video_urls[0] - - description = parse_html_description(tree) - if description: - open_graph_response["og:description"] = description diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py deleted file mode 100644
index 516d0434f0..0000000000 --- a/synapse/rest/media/v1/preview_html.py +++ /dev/null
@@ -1,501 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import codecs -import logging -import re -from typing import ( - TYPE_CHECKING, - Callable, - Dict, - Generator, - Iterable, - List, - Optional, - Set, - Union, -) - -if TYPE_CHECKING: - from lxml import etree - -logger = logging.getLogger(__name__) - -_charset_match = re.compile( - rb'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I -) -_xml_encoding_match = re.compile( - rb'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I -) -_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) - -# Certain elements aren't meant for display. -ARIA_ROLES_TO_IGNORE = {"directory", "menu", "menubar", "toolbar"} - - -def _normalise_encoding(encoding: str) -> Optional[str]: - """Use the Python codec's name as the normalised entry.""" - try: - return codecs.lookup(encoding).name - except LookupError: - return None - - -def _get_html_media_encodings( - body: bytes, content_type: Optional[str] -) -> Iterable[str]: - """ - Get potential encoding of the body based on the (presumably) HTML body or the content-type header. - - The precedence used for finding a character encoding is: - - 1. <meta> tag with a charset declared. - 2. The XML document's character encoding attribute. - 3. The Content-Type header. - 4. Fallback to utf-8. - 5. Fallback to windows-1252. - - This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector. - - Args: - body: The HTML document, as bytes. - content_type: The Content-Type header. - - Returns: - The character encoding of the body, as a string. - """ - # There's no point in returning an encoding more than once. - attempted_encodings: Set[str] = set() - - # Limit searches to the first 1kb, since it ought to be at the top. - body_start = body[:1024] - - # Check if it has an encoding set in a meta tag. - match = _charset_match.search(body_start) - if match: - encoding = _normalise_encoding(match.group(1).decode("ascii")) - if encoding: - attempted_encodings.add(encoding) - yield encoding - - # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/> - - # Check if it has an XML document with an encoding. - match = _xml_encoding_match.match(body_start) - if match: - encoding = _normalise_encoding(match.group(1).decode("ascii")) - if encoding and encoding not in attempted_encodings: - attempted_encodings.add(encoding) - yield encoding - - # Check the HTTP Content-Type header for a character set. - if content_type: - content_match = _content_type_match.match(content_type) - if content_match: - encoding = _normalise_encoding(content_match.group(1)) - if encoding and encoding not in attempted_encodings: - attempted_encodings.add(encoding) - yield encoding - - # Finally, fallback to UTF-8, then windows-1252. - for fallback in ("utf-8", "cp1252"): - if fallback not in attempted_encodings: - yield fallback - - -def decode_body( - body: bytes, uri: str, content_type: Optional[str] = None -) -> Optional["etree.Element"]: - """ - This uses lxml to parse the HTML document. - - Args: - body: The HTML document, as bytes. - uri: The URI used to download the body. - content_type: The Content-Type header. - - Returns: - The parsed HTML body, or None if an error occurred during processed. - """ - # If there's no body, nothing useful is going to be found. - if not body: - return None - - # The idea here is that multiple encodings are tried until one works. - # Unfortunately the result is never used and then LXML will decode the string - # again with the found encoding. - for encoding in _get_html_media_encodings(body, content_type): - try: - body.decode(encoding) - except Exception: - pass - else: - break - else: - logger.warning("Unable to decode HTML body for %s", uri) - return None - - from lxml import etree - - # Create an HTML parser. - parser = etree.HTMLParser(recover=True, encoding=encoding) - - # Attempt to parse the body. Returns None if the body was successfully - # parsed, but no tree was found. - return etree.fromstring(body, parser) - - -def _get_meta_tags( - tree: "etree.Element", - property: str, - prefix: str, - property_mapper: Optional[Callable[[str], Optional[str]]] = None, -) -> Dict[str, Optional[str]]: - """ - Search for meta tags prefixed with a particular string. - - Args: - tree: The parsed HTML document. - property: The name of the property which contains the tag name, e.g. - "property" for Open Graph. - prefix: The prefix on the property to search for, e.g. "og" for Open Graph. - property_mapper: An optional callable to map the property to the Open Graph - form. Can return None for a key to ignore that key. - - Returns: - A map of tag name to value. - """ - results: Dict[str, Optional[str]] = {} - for tag in tree.xpath( - f"//*/meta[starts-with(@{property}, '{prefix}:')][@content][not(@content='')]" - ): - # if we've got more than 50 tags, someone is taking the piss - if len(results) >= 50: - logger.warning( - "Skipping parsing of Open Graph for page with too many '%s:' tags", - prefix, - ) - return {} - - key = tag.attrib[property] - if property_mapper: - key = property_mapper(key) - # None is a special value used to ignore a value. - if key is None: - continue - - results[key] = tag.attrib["content"] - - return results - - -def _map_twitter_to_open_graph(key: str) -> Optional[str]: - """ - Map a Twitter card property to the analogous Open Graph property. - - Args: - key: The Twitter card property (starts with "twitter:"). - - Returns: - The Open Graph property (starts with "og:") or None to have this property - be ignored. - """ - # Twitter card properties with no analogous Open Graph property. - if key == "twitter:card" or key == "twitter:creator": - return None - if key == "twitter:site": - return "og:site_name" - # Otherwise, swap twitter to og. - return "og" + key[7:] - - -def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: - """ - Parse the HTML document into an Open Graph response. - - This uses lxml to search the HTML document for Open Graph data (or - synthesizes it from the document). - - Args: - tree: The parsed HTML document. - - Returns: - The Open Graph response as a dictionary. - """ - - # Search for Open Graph (og:) meta tags, e.g.: - # - # "og:type" : "video", - # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", - # "og:site_name" : "YouTube", - # "og:video:type" : "application/x-shockwave-flash", - # "og:description" : "Fun stuff happening here", - # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", - # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", - # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", - # "og:video:width" : "1280" - # "og:video:height" : "720", - # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", - - og = _get_meta_tags(tree, "property", "og") - - # TODO: Search for properties specific to the different Open Graph types, - # such as article: meta tags, e.g.: - # - # "article:publisher" : "https://www.facebook.com/thethudonline" /> - # "article:author" content="https://www.facebook.com/thethudonline" /> - # "article:tag" content="baby" /> - # "article:section" content="Breaking News" /> - # "article:published_time" content="2016-03-31T19:58:24+00:00" /> - # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> - - # Search for Twitter Card (twitter:) meta tags, e.g.: - # - # "twitter:site" : "@matrixdotorg" - # "twitter:creator" : "@matrixdotorg" - # - # Twitter cards tags also duplicate Open Graph tags. - # - # See https://developer.twitter.com/en/docs/twitter-for-websites/cards/guides/getting-started - twitter = _get_meta_tags(tree, "name", "twitter", _map_twitter_to_open_graph) - # Merge the Twitter values with the Open Graph values, but do not overwrite - # information from Open Graph tags. - for key, value in twitter.items(): - if key not in og: - og[key] = value - - if "og:title" not in og: - # Attempt to find a title from the title tag, or the biggest header on the page. - title = tree.xpath("((//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1])/text()") - if title: - og["og:title"] = title[0].strip() - else: - og["og:title"] = None - - if "og:image" not in og: - meta_image = tree.xpath( - "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image'][not(@content='')]/@content[1]" - ) - # If a meta image is found, use it. - if meta_image: - og["og:image"] = meta_image[0] - else: - # Try to find images which are larger than 10px by 10px. - # - # TODO: consider inlined CSS styles as well as width & height attribs - images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") - images = sorted( - images, - key=lambda i: ( - -1 * float(i.attrib["width"]) * float(i.attrib["height"]) - ), - ) - # If no images were found, try to find *any* images. - if not images: - images = tree.xpath("//img[@src][1]") - if images: - og["og:image"] = images[0].attrib["src"] - - # Finally, fallback to the favicon if nothing else. - else: - favicons = tree.xpath("//link[@href][contains(@rel, 'icon')]/@href[1]") - if favicons: - og["og:image"] = favicons[0] - - if "og:description" not in og: - # Check the first meta description tag for content. - meta_description = tree.xpath( - "//*/meta[translate(@name, 'DESCRIPTION', 'description')='description'][not(@content='')]/@content[1]" - ) - # If a meta description is found with content, use it. - if meta_description: - og["og:description"] = meta_description[0] - else: - og["og:description"] = parse_html_description(tree) - elif og["og:description"]: - # This must be a non-empty string at this point. - assert isinstance(og["og:description"], str) - og["og:description"] = summarize_paragraphs([og["og:description"]]) - - # TODO: delete the url downloads to stop diskfilling, - # as we only ever cared about its OG - return og - - -def parse_html_description(tree: "etree.Element") -> Optional[str]: - """ - Calculate a text description based on an HTML document. - - Grabs any text nodes which are inside the <body/> tag, unless they are within - an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or - if they are within a <script/>, <svg/> or <style/> tag, or if they are within - a tag whose content is usually only shown to old browsers - (<iframe/>, <video/>, <canvas/>, <picture/>). - - This is a very very very coarse approximation to a plain text render of the page. - - Args: - tree: The parsed HTML document. - - Returns: - The plain text description, or None if one cannot be generated. - """ - # We don't just use XPATH here as that is slow on some machines. - - from lxml import etree - - TAGS_TO_REMOVE = { - "header", - "nav", - "aside", - "footer", - "script", - "noscript", - "style", - "svg", - "iframe", - "video", - "canvas", - "img", - "picture", - etree.Comment, - } - - # Split all the text nodes into paragraphs (by splitting on new - # lines) - text_nodes = ( - re.sub(r"\s+", "\n", el).strip() - for el in _iterate_over_text(tree.find("body"), TAGS_TO_REMOVE) - ) - return summarize_paragraphs(text_nodes) - - -def _iterate_over_text( - tree: Optional["etree.Element"], - tags_to_ignore: Set[Union[str, "etree.Comment"]], - stack_limit: int = 1024, -) -> Generator[str, None, None]: - """Iterate over the tree returning text nodes in a depth first fashion, - skipping text nodes inside certain tags. - - Args: - tree: The parent element to iterate. Can be None if there isn't one. - tags_to_ignore: Set of tags to ignore - stack_limit: Maximum stack size limit for depth-first traversal. - Nodes will be dropped if this limit is hit, which may truncate the - textual result. - Intended to limit the maximum working memory when generating a preview. - """ - - if tree is None: - return - - # This is a stack whose items are elements to iterate over *or* strings - # to be returned. - elements: List[Union[str, "etree.Element"]] = [tree] - while elements: - el = elements.pop() - - if isinstance(el, str): - yield el - elif el.tag not in tags_to_ignore: - # If the element isn't meant for display, ignore it. - if el.get("role") in ARIA_ROLES_TO_IGNORE: - continue - - # el.text is the text before the first child, so we can immediately - # return it if the text exists. - if el.text: - yield el.text - - # We add to the stack all the element's children, interspersed with - # each child's tail text (if it exists). - # - # We iterate in reverse order so that earlier pieces of text appear - # closer to the top of the stack. - for child in el.iterchildren(reversed=True): - if len(elements) > stack_limit: - # We've hit our limit for working memory - break - - if child.tail: - # The tail text of a node is text that comes *after* the node, - # so we always include it even if we ignore the child node. - elements.append(child.tail) - - elements.append(child) - - -def summarize_paragraphs( - text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500 -) -> Optional[str]: - """ - Try to get a summary respecting first paragraph and then word boundaries. - - Args: - text_nodes: The paragraphs to summarize. - min_size: The minimum number of words to include. - max_size: The maximum number of words to include. - - Returns: - A summary of the text nodes, or None if that was not possible. - """ - - # TODO: Respect sentences? - - description = "" - - # Keep adding paragraphs until we get to the MIN_SIZE. - for text_node in text_nodes: - if len(description) < min_size: - text_node = re.sub(r"[\t \r\n]+", " ", text_node) - description += text_node + "\n\n" - else: - break - - description = description.strip() - description = re.sub(r"[\t ]+", " ", description) - description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description) - - # If the concatenation of paragraphs to get above MIN_SIZE - # took us over MAX_SIZE, then we need to truncate mid paragraph - if len(description) > max_size: - new_desc = "" - - # This splits the paragraph into words, but keeping the - # (preceding) whitespace intact so we can easily concat - # words back together. - for match in re.finditer(r"\s*\S+", description): - word = match.group() - - # Keep adding words while the total length is less than - # MAX_SIZE. - if len(word) + len(new_desc) < max_size: - new_desc += word - else: - # At this point the next word *will* take us over - # MAX_SIZE, but we also want to ensure that its not - # a huge word. If it is add it anyway and we'll - # truncate later. - if len(new_desc) < min_size: - new_desc += word - break - - # Double check that we're not over the limit - if len(new_desc) > max_size: - new_desc = new_desc[:max_size] - - # We always add an ellipsis because at the very least - # we chopped mid paragraph. - description = new_desc.strip() + "…" - return description if description else None diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 3f48439af1..d7653f30ae 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py
@@ -1,4 +1,4 @@ -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# Copyright 2023 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,171 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# -import abc -import logging -import os -import shutil -from typing import TYPE_CHECKING, Callable, Optional - -from synapse.config._base import Config -from synapse.logging.context import defer_to_thread, run_in_background -from synapse.util.async_helpers import maybe_awaitable - -from ._base import FileInfo, Responder -from .media_storage import FileResponder - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class StorageProvider(metaclass=abc.ABCMeta): - """A storage provider is a service that can store uploaded media and - retrieve them. - """ - - @abc.abstractmethod - async def store_file(self, path: str, file_info: FileInfo) -> None: - """Store the file described by file_info. The actual contents can be - retrieved by reading the file in file_info.upload_path. - - Args: - path: Relative path of file in local cache - file_info: The metadata of the file. - """ - - @abc.abstractmethod - async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: - """Attempt to fetch the file described by file_info and stream it - into writer. - - Args: - path: Relative path of file in local cache - file_info: The metadata of the file. - - Returns: - Returns a Responder if the provider has the file, otherwise returns None. - """ - - -class StorageProviderWrapper(StorageProvider): - """Wraps a storage provider and provides various config options - - Args: - backend: The storage provider to wrap. - store_local: Whether to store new local files or not. - store_synchronous: Whether to wait for file to be successfully - uploaded, or todo the upload in the background. - store_remote: Whether remote media should be uploaded - """ - - def __init__( - self, - backend: StorageProvider, - store_local: bool, - store_synchronous: bool, - store_remote: bool, - ): - self.backend = backend - self.store_local = store_local - self.store_synchronous = store_synchronous - self.store_remote = store_remote - - def __str__(self) -> str: - return "StorageProviderWrapper[%s]" % (self.backend,) - - async def store_file(self, path: str, file_info: FileInfo) -> None: - if not file_info.server_name and not self.store_local: - return None - - if file_info.server_name and not self.store_remote: - return None - - if file_info.url_cache: - # The URL preview cache is short lived and not worth offloading or - # backing up. - return None - - if self.store_synchronous: - # store_file is supposed to return an Awaitable, but guard - # against improper implementations. - await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore - else: - # TODO: Handle errors. - async def store() -> None: - try: - return await maybe_awaitable( - self.backend.store_file(path, file_info) - ) - except Exception: - logger.exception("Error storing file") - - run_in_background(store) # type: ignore[unused-awaitable] - - async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: - if file_info.url_cache: - # Files in the URL preview cache definitely aren't stored here, - # so avoid any potentially slow I/O or network access. - return None - - # store_file is supposed to return an Awaitable, but guard - # against improper implementations. - return await maybe_awaitable(self.backend.fetch(path, file_info)) - - -class FileStorageProviderBackend(StorageProvider): - """A storage provider that stores files in a directory on a filesystem. - - Args: - hs - config: The config returned by `parse_config`. - """ - - def __init__(self, hs: "HomeServer", config: str): - self.hs = hs - self.cache_directory = hs.config.media.media_store_path - self.base_directory = config - - def __str__(self) -> str: - return "FileStorageProviderBackend[%s]" % (self.base_directory,) - - async def store_file(self, path: str, file_info: FileInfo) -> None: - """See StorageProvider.store_file""" - - primary_fname = os.path.join(self.cache_directory, path) - backup_fname = os.path.join(self.base_directory, path) - - dirname = os.path.dirname(backup_fname) - os.makedirs(dirname, exist_ok=True) - - # mypy needs help inferring the type of the second parameter, which is generic - shutil_copyfile: Callable[[str, str], str] = shutil.copyfile - await defer_to_thread( - self.hs.get_reactor(), - shutil_copyfile, - primary_fname, - backup_fname, - ) - - async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: - """See StorageProvider.fetch""" - - backup_fname = os.path.join(self.base_directory, path) - if os.path.isfile(backup_fname): - return FileResponder(open(backup_fname, "rb")) - - return None - - @staticmethod - def parse_config(config: dict) -> str: - """Called on startup to parse config supplied. This should parse - the config and raise if there is a problem. - - The returned value is passed into the constructor. - - In this case we only care about a single param, the directory, so let's - just pull that out. - """ - return Config.ensure_directory(config["directory"]) +# This exists purely for backwards compatibility with media providers. +from synapse.media.storage_provider import StorageProvider # noqa: F401 diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py deleted file mode 100644
index 9480cc5763..0000000000 --- a/synapse/rest/media/v1/thumbnailer.py +++ /dev/null
@@ -1,222 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2020-2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from io import BytesIO -from types import TracebackType -from typing import Optional, Tuple, Type - -from PIL import Image - -logger = logging.getLogger(__name__) - -EXIF_ORIENTATION_TAG = 0x0112 -EXIF_TRANSPOSE_MAPPINGS = { - 2: Image.FLIP_LEFT_RIGHT, - 3: Image.ROTATE_180, - 4: Image.FLIP_TOP_BOTTOM, - 5: Image.TRANSPOSE, - 6: Image.ROTATE_270, - 7: Image.TRANSVERSE, - 8: Image.ROTATE_90, -} - - -class ThumbnailError(Exception): - """An error occurred generating a thumbnail.""" - - -class Thumbnailer: - - FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} - - @staticmethod - def set_limits(max_image_pixels: int) -> None: - Image.MAX_IMAGE_PIXELS = max_image_pixels - - def __init__(self, input_path: str): - # Have we closed the image? - self._closed = False - - try: - self.image = Image.open(input_path) - except OSError as e: - # If an error occurs opening the image, a thumbnail won't be able to - # be generated. - raise ThumbnailError from e - except Image.DecompressionBombError as e: - # If an image decompression bomb error occurs opening the image, - # then the image exceeds the pixel limit and a thumbnail won't - # be able to be generated. - raise ThumbnailError from e - - self.width, self.height = self.image.size - self.transpose_method = None - try: - # We don't use ImageOps.exif_transpose since it crashes with big EXIF - # - # Ignore safety: Pillow seems to acknowledge that this method is - # "private, experimental, but generally widely used". Pillow 6 - # includes a public getexif() method (no underscore) that we might - # consider using instead when we can bump that dependency. - # - # At the time of writing, Debian buster (currently oldstable) - # provides version 5.4.1. It's expected to EOL in mid-2022, see - # https://wiki.debian.org/DebianReleases#Production_Releases - image_exif = self.image._getexif() # type: ignore - if image_exif is not None: - image_orientation = image_exif.get(EXIF_ORIENTATION_TAG) - assert type(image_orientation) is int - self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation) - except Exception as e: - # A lot of parsing errors can happen when parsing EXIF - logger.info("Error parsing image EXIF information: %s", e) - - def transpose(self) -> Tuple[int, int]: - """Transpose the image using its EXIF Orientation tag - - Returns: - A tuple containing the new image size in pixels as (width, height). - """ - if self.transpose_method is not None: - # Safety: `transpose` takes an int rather than e.g. an IntEnum. - # self.transpose_method is set above to be a value in - # EXIF_TRANSPOSE_MAPPINGS, and that only contains correct values. - with self.image: - self.image = self.image.transpose(self.transpose_method) # type: ignore[arg-type] - self.width, self.height = self.image.size - self.transpose_method = None - # We don't need EXIF any more - self.image.info["exif"] = None - return self.image.size - - def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]: - """Calculate the largest size that preserves aspect ratio which - fits within the given rectangle:: - - (w_in / h_in) = (w_out / h_out) - w_out = max(min(w_max, h_max * (w_in / h_in)), 1) - h_out = max(min(h_max, w_max * (h_in / w_in)), 1) - - Args: - max_width: The largest possible width. - max_height: The largest possible height. - """ - - if max_width * self.height < max_height * self.width: - return max_width, max((max_width * self.height) // self.width, 1) - else: - return max((max_height * self.width) // self.height, 1), max_height - - def _resize(self, width: int, height: int) -> Image.Image: - # 1-bit or 8-bit color palette images need converting to RGB - # otherwise they will be scaled using nearest neighbour which - # looks awful. - # - # If the image has transparency, use RGBA instead. - if self.image.mode in ["1", "L", "P"]: - if self.image.info.get("transparency", None) is not None: - with self.image: - self.image = self.image.convert("RGBA") - else: - with self.image: - self.image = self.image.convert("RGB") - return self.image.resize((width, height), Image.ANTIALIAS) - - def scale(self, width: int, height: int, output_type: str) -> BytesIO: - """Rescales the image to the given dimensions. - - Returns: - The bytes of the encoded image ready to be written to disk - """ - with self._resize(width, height) as scaled: - return self._encode_image(scaled, output_type) - - def crop(self, width: int, height: int, output_type: str) -> BytesIO: - """Rescales and crops the image to the given dimensions preserving - aspect:: - (w_in / h_in) = (w_scaled / h_scaled) - w_scaled = max(w_out, h_out * (w_in / h_in)) - h_scaled = max(h_out, w_out * (h_in / w_in)) - - Args: - max_width: The largest possible width. - max_height: The largest possible height. - - Returns: - The bytes of the encoded image ready to be written to disk - """ - if width * self.height > height * self.width: - scaled_width = width - scaled_height = (width * self.height) // self.width - crop_top = (scaled_height - height) // 2 - crop_bottom = height + crop_top - crop = (0, crop_top, width, crop_bottom) - else: - scaled_width = (height * self.width) // self.height - scaled_height = height - crop_left = (scaled_width - width) // 2 - crop_right = width + crop_left - crop = (crop_left, 0, crop_right, height) - - with self._resize(scaled_width, scaled_height) as scaled_image: - with scaled_image.crop(crop) as cropped: - return self._encode_image(cropped, output_type) - - def _encode_image(self, output_image: Image.Image, output_type: str) -> BytesIO: - output_bytes_io = BytesIO() - fmt = self.FORMATS[output_type] - if fmt == "JPEG": - output_image = output_image.convert("RGB") - output_image.save(output_bytes_io, fmt, quality=80) - return output_bytes_io - - def close(self) -> None: - """Closes the underlying image file. - - Once closed no other functions can be called. - - Can be called multiple times. - """ - - if self._closed: - return - - self._closed = True - - # Since we run this on the finalizer then we need to handle `__init__` - # raising an exception before it can define `self.image`. - image = getattr(self, "image", None) - if image is None: - return - - image.close() - - def __enter__(self) -> "Thumbnailer": - """Make `Thumbnailer` a context manager that calls `close` on - `__exit__`. - """ - return self - - def __exit__( - self, - type: Optional[Type[BaseException]], - value: Optional[BaseException], - traceback: Optional[TracebackType], - ) -> None: - self.close() - - def __del__(self) -> None: - # Make sure we actually do close the image, rather than leak data. - self.close()