summary refs log tree commit diff
path: root/synapse/media/media_repository.py
diff options
context:
space:
mode:
authorSumner Evans <me@sumnerevans.com>2023-11-15 07:19:24 -0700
committerGitHub <noreply@github.com>2023-11-15 09:19:24 -0500
commit999bd77d3abb7b0a4430f31f5912956c3bc100ee (patch)
tree28c85ef5ced509995ed15a2eee3467313a3db69e /synapse/media/media_repository.py
parentAdd links to pre-1.0 changelog issue/PR references. (#16638) (diff)
downloadsynapse-999bd77d3abb7b0a4430f31f5912956c3bc100ee.tar.xz
Asynchronous Uploads (#15503)
Support asynchronous uploads as defined in MSC2246.
Diffstat (limited to 'synapse/media/media_repository.py')
-rw-r--r--synapse/media/media_repository.py220
1 files changed, 206 insertions, 14 deletions
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 1957426c6a..bf976b9e7c 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -27,13 +27,16 @@ import twisted.web.http
 from twisted.internet.defer import Deferred
 
 from synapse.api.errors import (
+    Codes,
     FederationDeniedError,
     HttpResponseException,
     NotFoundError,
     RequestSendFailed,
     SynapseError,
+    cs_error,
 )
 from synapse.config.repository import ThumbnailRequirement
+from synapse.http.server import respond_with_json
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
 from synapse.logging.opentracing import trace
@@ -51,7 +54,7 @@ from synapse.media.storage_provider import StorageProviderWrapper
 from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
 from synapse.media.url_previewer import UrlPreviewer
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.databases.main.media_repository import RemoteMedia
+from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
 from synapse.types import UserID
 from synapse.util.async_helpers import Linearizer
 from synapse.util.retryutils import NotRetryingDestination
@@ -80,6 +83,8 @@ class MediaRepository:
         self.store = hs.get_datastores().main
         self.max_upload_size = hs.config.media.max_upload_size
         self.max_image_pixels = hs.config.media.max_image_pixels
+        self.unused_expiration_time = hs.config.media.unused_expiration_time
+        self.max_pending_media_uploads = hs.config.media.max_pending_media_uploads
 
         Thumbnailer.set_limits(self.max_image_pixels)
 
@@ -186,6 +191,117 @@ class MediaRepository:
             self.recently_accessed_locals.add(media_id)
 
     @trace
+    async def create_media_id(self, auth_user: UserID) -> Tuple[str, int]:
+        """Create and store a media ID for a local user and return the MXC URI and its
+        expiration.
+
+        Args:
+            auth_user: The user_id of the uploader
+
+        Returns:
+            A tuple containing the MXC URI of the stored content and the timestamp at
+            which the MXC URI expires.
+        """
+        media_id = random_string(24)
+        now = self.clock.time_msec()
+        await self.store.store_local_media_id(
+            media_id=media_id,
+            time_now_ms=now,
+            user_id=auth_user,
+        )
+        return f"mxc://{self.server_name}/{media_id}", now + self.unused_expiration_time
+
+    @trace
+    async def reached_pending_media_limit(self, auth_user: UserID) -> Tuple[bool, int]:
+        """Check if the user is over the limit for pending media uploads.
+
+        Args:
+            auth_user: The user_id of the uploader
+
+        Returns:
+            A tuple with a boolean and an integer indicating whether the user has too
+            many pending media uploads and the timestamp at which the first pending
+            media will expire, respectively.
+        """
+        pending, first_expiration_ts = await self.store.count_pending_media(
+            user_id=auth_user
+        )
+        return pending >= self.max_pending_media_uploads, first_expiration_ts
+
+    @trace
+    async def verify_can_upload(self, media_id: str, auth_user: UserID) -> None:
+        """Verify that the media ID can be uploaded to by the given user. This
+        function checks that:
+
+        * the media ID exists
+        * the media ID does not already have content
+        * the user uploading is the same as the one who created the media ID
+        * the media ID has not expired
+
+        Args:
+            media_id: The media ID to verify
+            auth_user: The user_id of the uploader
+        """
+        media = await self.store.get_local_media(media_id)
+        if media is None:
+            raise SynapseError(404, "Unknow media ID", errcode=Codes.NOT_FOUND)
+
+        if media.user_id != auth_user.to_string():
+            raise SynapseError(
+                403,
+                "Only the creator of the media ID can upload to it",
+                errcode=Codes.FORBIDDEN,
+            )
+
+        if media.media_length is not None:
+            raise SynapseError(
+                409,
+                "Media ID already has content",
+                errcode=Codes.CANNOT_OVERWRITE_MEDIA,
+            )
+
+        expired_time_ms = self.clock.time_msec() - self.unused_expiration_time
+        if media.created_ts < expired_time_ms:
+            raise NotFoundError("Media ID has expired")
+
+    @trace
+    async def update_content(
+        self,
+        media_id: str,
+        media_type: str,
+        upload_name: Optional[str],
+        content: IO,
+        content_length: int,
+        auth_user: UserID,
+    ) -> None:
+        """Update the content of the given media ID.
+
+        Args:
+            media_id: The media ID to replace.
+            media_type: The content type of the file.
+            upload_name: The name of the file, if provided.
+            content: A file like object that is the content to store
+            content_length: The length of the content
+            auth_user: The user_id of the uploader
+        """
+        file_info = FileInfo(server_name=None, file_id=media_id)
+        fname = await self.media_storage.store_file(content, file_info)
+        logger.info("Stored local media in file %r", fname)
+
+        await self.store.update_local_media(
+            media_id=media_id,
+            media_type=media_type,
+            upload_name=upload_name,
+            media_length=content_length,
+            user_id=auth_user,
+        )
+
+        try:
+            await self._generate_thumbnails(None, media_id, media_id, media_type)
+        except Exception as e:
+            logger.info("Failed to generate thumbnails: %s", e)
+
+    @trace
     async def create_content(
         self,
         media_type: str,
@@ -231,8 +347,74 @@ class MediaRepository:
 
         return MXCUri(self.server_name, media_id)
 
+    def respond_not_yet_uploaded(self, request: SynapseRequest) -> None:
+        respond_with_json(
+            request,
+            504,
+            cs_error("Media has not been uploaded yet", code=Codes.NOT_YET_UPLOADED),
+            send_cors=True,
+        )
+
+    async def get_local_media_info(
+        self, request: SynapseRequest, media_id: str, max_timeout_ms: int
+    ) -> Optional[LocalMedia]:
+        """Gets the info dictionary for given local media ID. If the media has
+        not been uploaded yet, this function will wait up to ``max_timeout_ms``
+        milliseconds for the media to be uploaded.
+
+        Args:
+            request: The incoming request.
+            media_id: The media ID of the content. (This is the same as
+                the file_id for local content.)
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
+
+        Returns:
+            Either the info dictionary for the given local media ID or
+            ``None``. If ``None``, then no further processing is necessary as
+            this function will send the necessary JSON response.
+        """
+        wait_until = self.clock.time_msec() + max_timeout_ms
+        while True:
+            # Get the info for the media
+            media_info = await self.store.get_local_media(media_id)
+            if not media_info:
+                logger.info("Media %s is unknown", media_id)
+                respond_404(request)
+                return None
+
+            if media_info.quarantined_by:
+                logger.info("Media %s is quarantined", media_id)
+                respond_404(request)
+                return None
+
+            # The file has been uploaded, so stop looping
+            if media_info.media_length is not None:
+                return media_info
+
+            # Check if the media ID has expired and still hasn't been uploaded to.
+            now = self.clock.time_msec()
+            expired_time_ms = now - self.unused_expiration_time
+            if media_info.created_ts < expired_time_ms:
+                logger.info("Media %s has expired without being uploaded", media_id)
+                respond_404(request)
+                return None
+
+            if now >= wait_until:
+                break
+
+            await self.clock.sleep(0.5)
+
+        logger.info("Media %s has not yet been uploaded", media_id)
+        self.respond_not_yet_uploaded(request)
+        return None
+
     async def get_local_media(
-        self, request: SynapseRequest, media_id: str, name: Optional[str]
+        self,
+        request: SynapseRequest,
+        media_id: str,
+        name: Optional[str],
+        max_timeout_ms: int,
     ) -> None:
         """Responds to requests for local media, if exists, or returns 404.
 
@@ -242,13 +424,14 @@ class MediaRepository:
                 the file_id for local content.)
             name: Optional name that, if specified, will be used as
                 the filename in the Content-Disposition header of the response.
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
 
         Returns:
             Resolves once a response has successfully been written to request
         """
-        media_info = await self.store.get_local_media(media_id)
-        if not media_info or media_info.quarantined_by:
-            respond_404(request)
+        media_info = await self.get_local_media_info(request, media_id, max_timeout_ms)
+        if not media_info:
             return
 
         self.mark_recently_accessed(None, media_id)
@@ -273,6 +456,7 @@ class MediaRepository:
         server_name: str,
         media_id: str,
         name: Optional[str],
+        max_timeout_ms: int,
     ) -> None:
         """Respond to requests for remote media.
 
@@ -282,6 +466,8 @@ class MediaRepository:
             media_id: The media ID of the content (as defined by the remote server).
             name: Optional name that, if specified, will be used as
                 the filename in the Content-Disposition header of the response.
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
 
         Returns:
             Resolves once a response has successfully been written to request
@@ -307,11 +493,11 @@ class MediaRepository:
         key = (server_name, media_id)
         async with self.remote_media_linearizer.queue(key):
             responder, media_info = await self._get_remote_media_impl(
-                server_name, media_id
+                server_name, media_id, max_timeout_ms
             )
 
         # We deliberately stream the file outside the lock
-        if responder:
+        if responder and media_info:
             upload_name = name if name else media_info.upload_name
             await respond_with_responder(
                 request,
@@ -324,7 +510,7 @@ class MediaRepository:
             respond_404(request)
 
     async def get_remote_media_info(
-        self, server_name: str, media_id: str
+        self, server_name: str, media_id: str, max_timeout_ms: int
     ) -> RemoteMedia:
         """Gets the media info associated with the remote file, downloading
         if necessary.
@@ -332,6 +518,8 @@ class MediaRepository:
         Args:
             server_name: Remote server_name where the media originated.
             media_id: The media ID of the content (as defined by the remote server).
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
 
         Returns:
             The media info of the file
@@ -347,7 +535,7 @@ class MediaRepository:
         key = (server_name, media_id)
         async with self.remote_media_linearizer.queue(key):
             responder, media_info = await self._get_remote_media_impl(
-                server_name, media_id
+                server_name, media_id, max_timeout_ms
             )
 
         # Ensure we actually use the responder so that it releases resources
@@ -358,7 +546,7 @@ class MediaRepository:
         return media_info
 
     async def _get_remote_media_impl(
-        self, server_name: str, media_id: str
+        self, server_name: str, media_id: str, max_timeout_ms: int
     ) -> Tuple[Optional[Responder], RemoteMedia]:
         """Looks for media in local cache, if not there then attempt to
         download from remote server.
@@ -367,6 +555,8 @@ class MediaRepository:
             server_name: Remote server_name where the media originated.
             media_id: The media ID of the content (as defined by the
                 remote server).
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
 
         Returns:
             A tuple of responder and the media info of the file.
@@ -399,8 +589,7 @@ class MediaRepository:
 
         try:
             media_info = await self._download_remote_file(
-                server_name,
-                media_id,
+                server_name, media_id, max_timeout_ms
             )
         except SynapseError:
             raise
@@ -433,6 +622,7 @@ class MediaRepository:
         self,
         server_name: str,
         media_id: str,
+        max_timeout_ms: int,
     ) -> RemoteMedia:
         """Attempt to download the remote file from the given server name,
         using the given file_id as the local id.
@@ -442,7 +632,8 @@ class MediaRepository:
             media_id: The media ID of the content (as defined by the
                 remote server). This is different than the file_id, which is
                 locally generated.
-            file_id: Local file ID
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
 
         Returns:
             The media info of the file.
@@ -466,7 +657,8 @@ class MediaRepository:
                         # tell the remote server to 404 if it doesn't
                         # recognise the server_name, to make sure we don't
                         # end up with a routing loop.
-                        "allow_remote": "false"
+                        "allow_remote": "false",
+                        "timeout_ms": str(max_timeout_ms),
                     },
                 )
             except RequestSendFailed as e: