summary refs log tree commit diff
path: root/synapse/media
diff options
context:
space:
mode:
authorShay <hillerys@element.io>2024-07-02 06:07:04 -0700
committerGitHub <noreply@github.com>2024-07-02 14:07:04 +0100
commit8f890447b0f8b6cbe369b162670185e8c746b2f2 (patch)
treec8c290661a59b06257ce7e2fda19e799d83825eb /synapse/media
parentFix sync waiting for an invalid token from the "future" (#17386) (diff)
downloadsynapse-8f890447b0f8b6cbe369b162670185e8c746b2f2.tar.xz
Support MSC3916 by adding `_matrix/client/v1/media/download` endpoint (#17365)
Diffstat (limited to 'synapse/media')
-rw-r--r--synapse/media/_base.py28
-rw-r--r--synapse/media/media_repository.py151
-rw-r--r--synapse/media/media_storage.py27
3 files changed, 191 insertions, 15 deletions
diff --git a/synapse/media/_base.py b/synapse/media/_base.py
index 7ad0b7c3cf..1b268ce4d4 100644
--- a/synapse/media/_base.py
+++ b/synapse/media/_base.py
@@ -221,6 +221,7 @@ def add_file_headers(
     # select private. don't bother setting Expires as all our
     # clients are smart enough to be happy with Cache-Control
     request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
+
     if file_size is not None:
         request.setHeader(b"Content-Length", b"%d" % (file_size,))
 
@@ -302,12 +303,37 @@ async def respond_with_multipart_responder(
             )
             return
 
+        if media_info.media_type.lower().split(";", 1)[0] in INLINE_CONTENT_TYPES:
+            disposition = "inline"
+        else:
+            disposition = "attachment"
+
+        def _quote(x: str) -> str:
+            return urllib.parse.quote(x.encode("utf-8"))
+
+        if media_info.upload_name:
+            if _can_encode_filename_as_token(media_info.upload_name):
+                disposition = "%s; filename=%s" % (
+                    disposition,
+                    media_info.upload_name,
+                )
+            else:
+                disposition = "%s; filename*=utf-8''%s" % (
+                    disposition,
+                    _quote(media_info.upload_name),
+                )
+
         from synapse.media.media_storage import MultipartFileConsumer
 
         # note that currently the json_object is just {}, this will change when linked media
         # is implemented
         multipart_consumer = MultipartFileConsumer(
-            clock, request, media_info.media_type, {}, media_info.media_length
+            clock,
+            request,
+            media_info.media_type,
+            {},
+            disposition,
+            media_info.media_length,
         )
 
         logger.debug("Responding to media request with responder %s", responder)
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 1436329fad..542642b900 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -480,6 +480,7 @@ class MediaRepository:
         name: Optional[str],
         max_timeout_ms: int,
         ip_address: str,
+        use_federation_endpoint: bool,
     ) -> None:
         """Respond to requests for remote media.
 
@@ -492,6 +493,8 @@ class MediaRepository:
             max_timeout_ms: the maximum number of milliseconds to wait for the
                 media to be uploaded.
             ip_address: the IP address of the requester
+            use_federation_endpoint: whether to request the remote media over the new
+                federation `/download` endpoint
 
         Returns:
             Resolves once a response has successfully been written to request
@@ -522,6 +525,7 @@ class MediaRepository:
                 max_timeout_ms,
                 self.download_ratelimiter,
                 ip_address,
+                use_federation_endpoint,
             )
 
         # We deliberately stream the file outside the lock
@@ -569,6 +573,7 @@ class MediaRepository:
                 max_timeout_ms,
                 self.download_ratelimiter,
                 ip_address,
+                False,
             )
 
         # Ensure we actually use the responder so that it releases resources
@@ -585,6 +590,7 @@ class MediaRepository:
         max_timeout_ms: int,
         download_ratelimiter: Ratelimiter,
         ip_address: str,
+        use_federation_endpoint: bool,
     ) -> Tuple[Optional[Responder], RemoteMedia]:
         """Looks for media in local cache, if not there then attempt to
         download from remote server.
@@ -598,6 +604,8 @@ class MediaRepository:
             download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
                 requester IP.
             ip_address: the IP address of the requester
+            use_federation_endpoint: whether to request the remote media over the new federation
+            /download endpoint
 
         Returns:
             A tuple of responder and the media info of the file.
@@ -629,9 +637,23 @@ class MediaRepository:
         # Failed to find the file anywhere, lets download it.
 
         try:
-            media_info = await self._download_remote_file(
-                server_name, media_id, max_timeout_ms, download_ratelimiter, ip_address
-            )
+            if not use_federation_endpoint:
+                media_info = await self._download_remote_file(
+                    server_name,
+                    media_id,
+                    max_timeout_ms,
+                    download_ratelimiter,
+                    ip_address,
+                )
+            else:
+                media_info = await self._federation_download_remote_file(
+                    server_name,
+                    media_id,
+                    max_timeout_ms,
+                    download_ratelimiter,
+                    ip_address,
+                )
+
         except SynapseError:
             raise
         except Exception as e:
@@ -775,6 +797,129 @@ class MediaRepository:
             quarantined_by=None,
         )
 
+    async def _federation_download_remote_file(
+        self,
+        server_name: str,
+        media_id: str,
+        max_timeout_ms: int,
+        download_ratelimiter: Ratelimiter,
+        ip_address: str,
+    ) -> RemoteMedia:
+        """Attempt to download the remote file from the given server name.
+        Uses the given file_id as the local id and downloads the file over the federation
+        v1 download endpoint
+
+        Args:
+            server_name: Originating server
+            media_id: The media ID of the content (as defined by the
+                remote server). This is different than the file_id, which is
+                locally generated.
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
+            download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
+                requester IP
+            ip_address: the IP address of the requester
+
+        Returns:
+            The media info of the file.
+        """
+
+        file_id = random_string(24)
+
+        file_info = FileInfo(server_name=server_name, file_id=file_id)
+
+        async with self.media_storage.store_into_file(file_info) as (f, fname):
+            try:
+                res = await self.client.federation_download_media(
+                    server_name,
+                    media_id,
+                    output_stream=f,
+                    max_size=self.max_upload_size,
+                    max_timeout_ms=max_timeout_ms,
+                    download_ratelimiter=download_ratelimiter,
+                    ip_address=ip_address,
+                )
+                # if we had to fall back to the _matrix/media endpoint it will only return
+                # the headers and length, check the length of the tuple before unpacking
+                if len(res) == 3:
+                    length, headers, json = res
+                else:
+                    length, headers = res
+            except RequestSendFailed as e:
+                logger.warning(
+                    "Request failed fetching remote media %s/%s: %r",
+                    server_name,
+                    media_id,
+                    e,
+                )
+                raise SynapseError(502, "Failed to fetch remote media")
+
+            except HttpResponseException as e:
+                logger.warning(
+                    "HTTP error fetching remote media %s/%s: %s",
+                    server_name,
+                    media_id,
+                    e.response,
+                )
+                if e.code == twisted.web.http.NOT_FOUND:
+                    raise e.to_synapse_error()
+                raise SynapseError(502, "Failed to fetch remote media")
+
+            except SynapseError:
+                logger.warning(
+                    "Failed to fetch remote media %s/%s", server_name, media_id
+                )
+                raise
+            except NotRetryingDestination:
+                logger.warning("Not retrying destination %r", server_name)
+                raise SynapseError(502, "Failed to fetch remote media")
+            except Exception:
+                logger.exception(
+                    "Failed to fetch remote media %s/%s", server_name, media_id
+                )
+                raise SynapseError(502, "Failed to fetch remote media")
+
+            if b"Content-Type" in headers:
+                media_type = headers[b"Content-Type"][0].decode("ascii")
+            else:
+                media_type = "application/octet-stream"
+            upload_name = get_filename_from_headers(headers)
+            time_now_ms = self.clock.time_msec()
+
+            # Multiple remote media download requests can race (when using
+            # multiple media repos), so this may throw a violation constraint
+            # exception. If it does we'll delete the newly downloaded file from
+            # disk (as we're in the ctx manager).
+            #
+            # However: we've already called `finish()` so we may have also
+            # written to the storage providers. This is preferable to the
+            # alternative where we call `finish()` *after* this, where we could
+            # end up having an entry in the DB but fail to write the files to
+            # the storage providers.
+            await self.store.store_cached_remote_media(
+                origin=server_name,
+                media_id=media_id,
+                media_type=media_type,
+                time_now_ms=time_now_ms,
+                upload_name=upload_name,
+                media_length=length,
+                filesystem_id=file_id,
+            )
+
+        logger.debug("Stored remote media in file %r", fname)
+
+        return RemoteMedia(
+            media_origin=server_name,
+            media_id=media_id,
+            media_type=media_type,
+            media_length=length,
+            upload_name=upload_name,
+            created_ts=time_now_ms,
+            filesystem_id=file_id,
+            last_access_ts=time_now_ms,
+            quarantined_by=None,
+        )
+
     def _get_thumbnail_requirements(
         self, media_type: str
     ) -> Tuple[ThumbnailRequirement, ...]:
diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py
index 1be2c9b5f5..2a106bb0eb 100644
--- a/synapse/media/media_storage.py
+++ b/synapse/media/media_storage.py
@@ -401,13 +401,14 @@ class MultipartFileConsumer:
         wrapped_consumer: interfaces.IConsumer,
         file_content_type: str,
         json_object: JsonDict,
-        content_length: Optional[int] = None,
+        disposition: str,
+        content_length: Optional[int],
     ) -> None:
         self.clock = clock
         self.wrapped_consumer = wrapped_consumer
         self.json_field = json_object
         self.json_field_written = False
-        self.content_type_written = False
+        self.file_headers_written = False
         self.file_content_type = file_content_type
         self.boundary = uuid4().hex.encode("ascii")
 
@@ -420,6 +421,7 @@ class MultipartFileConsumer:
         self.paused = False
 
         self.length = content_length
+        self.disposition = disposition
 
     ### IConsumer APIs ###
 
@@ -488,11 +490,13 @@ class MultipartFileConsumer:
             self.json_field_written = True
 
         # if we haven't written the content type yet, do so
-        if not self.content_type_written:
+        if not self.file_headers_written:
             type = self.file_content_type.encode("utf-8")
             content_type = Header(b"Content-Type", type)
-            self.wrapped_consumer.write(bytes(content_type) + CRLF + CRLF)
-            self.content_type_written = True
+            self.wrapped_consumer.write(bytes(content_type) + CRLF)
+            disp_header = Header(b"Content-Disposition", self.disposition)
+            self.wrapped_consumer.write(bytes(disp_header) + CRLF + CRLF)
+            self.file_headers_written = True
 
         self.wrapped_consumer.write(data)
 
@@ -506,7 +510,6 @@ class MultipartFileConsumer:
         producing data for good.
         """
         assert self.producer is not None
-
         self.paused = True
         self.producer.stopProducing()
 
@@ -518,7 +521,6 @@ class MultipartFileConsumer:
         the time being, and to stop until C{resumeProducing()} is called.
         """
         assert self.producer is not None
-
         self.paused = True
 
         if self.streaming:
@@ -549,7 +551,7 @@ class MultipartFileConsumer:
         """
         if not self.length:
             return None
-        # calculate length of json field and content-type header
+        # calculate length of json field and content-type, disposition headers
         json_field = json.dumps(self.json_field)
         json_bytes = json_field.encode("utf-8")
         json_length = len(json_bytes)
@@ -558,9 +560,13 @@ class MultipartFileConsumer:
         content_type = Header(b"Content-Type", type)
         type_length = len(bytes(content_type))
 
-        # 154 is the length of the elements that aren't variable, ie
+        disp = self.disposition.encode("utf-8")
+        disp_header = Header(b"Content-Disposition", disp)
+        disp_length = len(bytes(disp_header))
+
+        # 156 is the length of the elements that aren't variable, ie
         # CRLFs and boundary strings, etc
-        self.length += json_length + type_length + 154
+        self.length += json_length + type_length + disp_length + 156
 
         return self.length
 
@@ -569,7 +575,6 @@ class MultipartFileConsumer:
     async def _resumeProducingRepeatedly(self) -> None:
         assert self.producer is not None
         assert not self.streaming
-
         producer = cast("interfaces.IPullProducer", self.producer)
 
         self.paused = False