diff options
Diffstat (limited to 'synapse/rest/media/upload_resource.py')
-rw-r--r-- | synapse/rest/media/upload_resource.py | 75 |
1 files changed, 68 insertions, 7 deletions
diff --git a/synapse/rest/media/upload_resource.py b/synapse/rest/media/upload_resource.py index 949326d85d..62d3e228a8 100644 --- a/synapse/rest/media/upload_resource.py +++ b/synapse/rest/media/upload_resource.py @@ -15,7 +15,7 @@ import logging import re -from typing import IO, TYPE_CHECKING, Dict, List, Optional +from typing import IO, TYPE_CHECKING, Dict, List, Optional, Tuple from synapse.api.errors import Codes, SynapseError from synapse.http.server import respond_with_json @@ -29,23 +29,24 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# The name of the lock to use when uploading media. +_UPLOAD_MEDIA_LOCK_NAME = "upload_media" -class UploadResource(RestServlet): - PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload")] +class BaseUploadServlet(RestServlet): def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): super().__init__() self.media_repo = media_repo self.filepaths = media_repo.filepaths self.store = hs.get_datastores().main - self.clock = hs.get_clock() + self.server_name = hs.hostname self.auth = hs.get_auth() self.max_upload_size = hs.config.media.max_upload_size - self.clock = hs.get_clock() - async def on_POST(self, request: SynapseRequest) -> None: - requester = await self.auth.get_user_by_req(request) + def _get_file_metadata( + self, request: SynapseRequest + ) -> Tuple[int, Optional[str], str]: raw_content_length = request.getHeader("Content-Length") if raw_content_length is None: raise SynapseError(msg="Request must specify a Content-Length", code=400) @@ -88,6 +89,16 @@ class UploadResource(RestServlet): # disposition = headers.getRawHeaders(b"Content-Disposition")[0] # TODO(markjh): parse content-dispostion + return content_length, upload_name, media_type + + +class UploadServlet(BaseUploadServlet): + PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload$")] + + async def on_POST(self, request: SynapseRequest) -> None: + requester = await self.auth.get_user_by_req(request) + content_length, upload_name, media_type = self._get_file_metadata(request) + try: content: IO = request.content # type: ignore content_uri = await self.media_repo.create_content( @@ -103,3 +114,53 @@ class UploadResource(RestServlet): respond_with_json( request, 200, {"content_uri": str(content_uri)}, send_cors=True ) + + +class AsyncUploadServlet(BaseUploadServlet): + PATTERNS = [ + re.compile( + "/_matrix/media/v3/upload/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$" + ) + ] + + async def on_PUT( + self, request: SynapseRequest, server_name: str, media_id: str + ) -> None: + requester = await self.auth.get_user_by_req(request) + + if server_name != self.server_name: + raise SynapseError( + 404, + "Non-local server name specified", + errcode=Codes.NOT_FOUND, + ) + + lock = await self.store.try_acquire_lock(_UPLOAD_MEDIA_LOCK_NAME, media_id) + if not lock: + raise SynapseError( + 409, + "Media ID cannot be overwritten", + errcode=Codes.CANNOT_OVERWRITE_MEDIA, + ) + + async with lock: + await self.media_repo.verify_can_upload(media_id, requester.user) + content_length, upload_name, media_type = self._get_file_metadata(request) + + try: + content: IO = request.content # type: ignore + await self.media_repo.update_content( + media_id, + media_type, + upload_name, + content, + content_length, + requester.user, + ) + except SpamMediaException: + # For uploading of media we want to respond with a 400, instead of + # the default 404, as that would just be confusing. + raise SynapseError(400, "Bad content") + + logger.info("Uploaded content for media ID %r", media_id) + respond_with_json(request, 200, {}, send_cors=True) |