diff options
author | Sumner Evans <me@sumnerevans.com> | 2023-11-15 07:19:24 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-15 09:19:24 -0500 |
commit | 999bd77d3abb7b0a4430f31f5912956c3bc100ee (patch) | |
tree | 28c85ef5ced509995ed15a2eee3467313a3db69e /synapse/media/media_repository.py | |
parent | Add links to pre-1.0 changelog issue/PR references. (#16638) (diff) | |
download | synapse-999bd77d3abb7b0a4430f31f5912956c3bc100ee.tar.xz |
Asynchronous Uploads (#15503)
Support asynchronous uploads as defined in MSC2246.
Diffstat (limited to 'synapse/media/media_repository.py')
-rw-r--r-- | synapse/media/media_repository.py | 220 |
1 files changed, 206 insertions, 14 deletions
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 1957426c6a..bf976b9e7c 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -27,13 +27,16 @@ import twisted.web.http from twisted.internet.defer import Deferred from synapse.api.errors import ( + Codes, FederationDeniedError, HttpResponseException, NotFoundError, RequestSendFailed, SynapseError, + cs_error, ) from synapse.config.repository import ThumbnailRequirement +from synapse.http.server import respond_with_json from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.logging.opentracing import trace @@ -51,7 +54,7 @@ from synapse.media.storage_provider import StorageProviderWrapper from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.url_previewer import UrlPreviewer from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.databases.main.media_repository import RemoteMedia +from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia from synapse.types import UserID from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination @@ -80,6 +83,8 @@ class MediaRepository: self.store = hs.get_datastores().main self.max_upload_size = hs.config.media.max_upload_size self.max_image_pixels = hs.config.media.max_image_pixels + self.unused_expiration_time = hs.config.media.unused_expiration_time + self.max_pending_media_uploads = hs.config.media.max_pending_media_uploads Thumbnailer.set_limits(self.max_image_pixels) @@ -186,6 +191,117 @@ class MediaRepository: self.recently_accessed_locals.add(media_id) @trace + async def create_media_id(self, auth_user: UserID) -> Tuple[str, int]: + """Create and store a media ID for a local user and return the MXC URI and its + expiration. + + Args: + auth_user: The user_id of the uploader + + Returns: + A tuple containing the MXC URI of the stored content and the timestamp at + which the MXC URI expires. + """ + media_id = random_string(24) + now = self.clock.time_msec() + await self.store.store_local_media_id( + media_id=media_id, + time_now_ms=now, + user_id=auth_user, + ) + return f"mxc://{self.server_name}/{media_id}", now + self.unused_expiration_time + + @trace + async def reached_pending_media_limit(self, auth_user: UserID) -> Tuple[bool, int]: + """Check if the user is over the limit for pending media uploads. + + Args: + auth_user: The user_id of the uploader + + Returns: + A tuple with a boolean and an integer indicating whether the user has too + many pending media uploads and the timestamp at which the first pending + media will expire, respectively. + """ + pending, first_expiration_ts = await self.store.count_pending_media( + user_id=auth_user + ) + return pending >= self.max_pending_media_uploads, first_expiration_ts + + @trace + async def verify_can_upload(self, media_id: str, auth_user: UserID) -> None: + """Verify that the media ID can be uploaded to by the given user. This + function checks that: + + * the media ID exists + * the media ID does not already have content + * the user uploading is the same as the one who created the media ID + * the media ID has not expired + + Args: + media_id: The media ID to verify + auth_user: The user_id of the uploader + """ + media = await self.store.get_local_media(media_id) + if media is None: + raise SynapseError(404, "Unknow media ID", errcode=Codes.NOT_FOUND) + + if media.user_id != auth_user.to_string(): + raise SynapseError( + 403, + "Only the creator of the media ID can upload to it", + errcode=Codes.FORBIDDEN, + ) + + if media.media_length is not None: + raise SynapseError( + 409, + "Media ID already has content", + errcode=Codes.CANNOT_OVERWRITE_MEDIA, + ) + + expired_time_ms = self.clock.time_msec() - self.unused_expiration_time + if media.created_ts < expired_time_ms: + raise NotFoundError("Media ID has expired") + + @trace + async def update_content( + self, + media_id: str, + media_type: str, + upload_name: Optional[str], + content: IO, + content_length: int, + auth_user: UserID, + ) -> None: + """Update the content of the given media ID. + + Args: + media_id: The media ID to replace. + media_type: The content type of the file. + upload_name: The name of the file, if provided. + content: A file like object that is the content to store + content_length: The length of the content + auth_user: The user_id of the uploader + """ + file_info = FileInfo(server_name=None, file_id=media_id) + fname = await self.media_storage.store_file(content, file_info) + logger.info("Stored local media in file %r", fname) + + await self.store.update_local_media( + media_id=media_id, + media_type=media_type, + upload_name=upload_name, + media_length=content_length, + user_id=auth_user, + ) + + try: + await self._generate_thumbnails(None, media_id, media_id, media_type) + except Exception as e: + logger.info("Failed to generate thumbnails: %s", e) + + @trace async def create_content( self, media_type: str, @@ -231,8 +347,74 @@ class MediaRepository: return MXCUri(self.server_name, media_id) + def respond_not_yet_uploaded(self, request: SynapseRequest) -> None: + respond_with_json( + request, + 504, + cs_error("Media has not been uploaded yet", code=Codes.NOT_YET_UPLOADED), + send_cors=True, + ) + + async def get_local_media_info( + self, request: SynapseRequest, media_id: str, max_timeout_ms: int + ) -> Optional[LocalMedia]: + """Gets the info dictionary for given local media ID. If the media has + not been uploaded yet, this function will wait up to ``max_timeout_ms`` + milliseconds for the media to be uploaded. + + Args: + request: The incoming request. + media_id: The media ID of the content. (This is the same as + the file_id for local content.) + max_timeout_ms: the maximum number of milliseconds to wait for the + media to be uploaded. + + Returns: + Either the info dictionary for the given local media ID or + ``None``. If ``None``, then no further processing is necessary as + this function will send the necessary JSON response. + """ + wait_until = self.clock.time_msec() + max_timeout_ms + while True: + # Get the info for the media + media_info = await self.store.get_local_media(media_id) + if not media_info: + logger.info("Media %s is unknown", media_id) + respond_404(request) + return None + + if media_info.quarantined_by: + logger.info("Media %s is quarantined", media_id) + respond_404(request) + return None + + # The file has been uploaded, so stop looping + if media_info.media_length is not None: + return media_info + + # Check if the media ID has expired and still hasn't been uploaded to. + now = self.clock.time_msec() + expired_time_ms = now - self.unused_expiration_time + if media_info.created_ts < expired_time_ms: + logger.info("Media %s has expired without being uploaded", media_id) + respond_404(request) + return None + + if now >= wait_until: + break + + await self.clock.sleep(0.5) + + logger.info("Media %s has not yet been uploaded", media_id) + self.respond_not_yet_uploaded(request) + return None + async def get_local_media( - self, request: SynapseRequest, media_id: str, name: Optional[str] + self, + request: SynapseRequest, + media_id: str, + name: Optional[str], + max_timeout_ms: int, ) -> None: """Responds to requests for local media, if exists, or returns 404. @@ -242,13 +424,14 @@ class MediaRepository: the file_id for local content.) name: Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. + max_timeout_ms: the maximum number of milliseconds to wait for the + media to be uploaded. Returns: Resolves once a response has successfully been written to request """ - media_info = await self.store.get_local_media(media_id) - if not media_info or media_info.quarantined_by: - respond_404(request) + media_info = await self.get_local_media_info(request, media_id, max_timeout_ms) + if not media_info: return self.mark_recently_accessed(None, media_id) @@ -273,6 +456,7 @@ class MediaRepository: server_name: str, media_id: str, name: Optional[str], + max_timeout_ms: int, ) -> None: """Respond to requests for remote media. @@ -282,6 +466,8 @@ class MediaRepository: media_id: The media ID of the content (as defined by the remote server). name: Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. + max_timeout_ms: the maximum number of milliseconds to wait for the + media to be uploaded. Returns: Resolves once a response has successfully been written to request @@ -307,11 +493,11 @@ class MediaRepository: key = (server_name, media_id) async with self.remote_media_linearizer.queue(key): responder, media_info = await self._get_remote_media_impl( - server_name, media_id + server_name, media_id, max_timeout_ms ) # We deliberately stream the file outside the lock - if responder: + if responder and media_info: upload_name = name if name else media_info.upload_name await respond_with_responder( request, @@ -324,7 +510,7 @@ class MediaRepository: respond_404(request) async def get_remote_media_info( - self, server_name: str, media_id: str + self, server_name: str, media_id: str, max_timeout_ms: int ) -> RemoteMedia: """Gets the media info associated with the remote file, downloading if necessary. @@ -332,6 +518,8 @@ class MediaRepository: Args: server_name: Remote server_name where the media originated. media_id: The media ID of the content (as defined by the remote server). + max_timeout_ms: the maximum number of milliseconds to wait for the + media to be uploaded. Returns: The media info of the file @@ -347,7 +535,7 @@ class MediaRepository: key = (server_name, media_id) async with self.remote_media_linearizer.queue(key): responder, media_info = await self._get_remote_media_impl( - server_name, media_id + server_name, media_id, max_timeout_ms ) # Ensure we actually use the responder so that it releases resources @@ -358,7 +546,7 @@ class MediaRepository: return media_info async def _get_remote_media_impl( - self, server_name: str, media_id: str + self, server_name: str, media_id: str, max_timeout_ms: int ) -> Tuple[Optional[Responder], RemoteMedia]: """Looks for media in local cache, if not there then attempt to download from remote server. @@ -367,6 +555,8 @@ class MediaRepository: server_name: Remote server_name where the media originated. media_id: The media ID of the content (as defined by the remote server). + max_timeout_ms: the maximum number of milliseconds to wait for the + media to be uploaded. Returns: A tuple of responder and the media info of the file. @@ -399,8 +589,7 @@ class MediaRepository: try: media_info = await self._download_remote_file( - server_name, - media_id, + server_name, media_id, max_timeout_ms ) except SynapseError: raise @@ -433,6 +622,7 @@ class MediaRepository: self, server_name: str, media_id: str, + max_timeout_ms: int, ) -> RemoteMedia: """Attempt to download the remote file from the given server name, using the given file_id as the local id. @@ -442,7 +632,8 @@ class MediaRepository: media_id: The media ID of the content (as defined by the remote server). This is different than the file_id, which is locally generated. - file_id: Local file ID + max_timeout_ms: the maximum number of milliseconds to wait for the + media to be uploaded. Returns: The media info of the file. @@ -466,7 +657,8 @@ class MediaRepository: # tell the remote server to 404 if it doesn't # recognise the server_name, to make sure we don't # end up with a routing loop. - "allow_remote": "false" + "allow_remote": "false", + "timeout_ms": str(max_timeout_ms), }, ) except RequestSendFailed as e: |