summary refs log tree commit diff
path: root/tests/rest/client/test_media.py
diff options
context:
space:
mode:
authorShay <hillerys@element.io>2024-07-02 06:07:04 -0700
committerGitHub <noreply@github.com>2024-07-02 14:07:04 +0100
commit8f890447b0f8b6cbe369b162670185e8c746b2f2 (patch)
treec8c290661a59b06257ce7e2fda19e799d83825eb /tests/rest/client/test_media.py
parentFix sync waiting for an invalid token from the "future" (#17386) (diff)
downloadsynapse-8f890447b0f8b6cbe369b162670185e8c746b2f2.tar.xz
Support MSC3916 by adding `_matrix/client/v1/media/download` endpoint (#17365)
Diffstat (limited to '')
-rw-r--r--tests/rest/client/test_media.py609
1 files changed, 606 insertions, 3 deletions
diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py
index be4a289ec1..6b5af2dbb6 100644
--- a/tests/rest/client/test_media.py
+++ b/tests/rest/client/test_media.py
@@ -19,31 +19,54 @@
 #
 #
 import base64
+import io
 import json
 import os
 import re
-from typing import Any, Dict, Optional, Sequence, Tuple, Type
+from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type
+from unittest.mock import MagicMock, Mock, patch
+from urllib import parse
 from urllib.parse import quote, urlencode
 
+from parameterized import parameterized_class
+
+from twisted.internet import defer
 from twisted.internet._resolver import HostResolution
 from twisted.internet.address import IPv4Address, IPv6Address
+from twisted.internet.defer import Deferred
 from twisted.internet.error import DNSLookupError
 from twisted.internet.interfaces import IAddress, IResolutionReceiver
+from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
 from twisted.web.resource import Resource
 
+from synapse.api.errors import HttpResponseException
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.config.oembed import OEmbedEndpointConfig
+from synapse.http.client import MultipartResponse
+from synapse.http.types import QueryParams
+from synapse.logging.context import make_deferred_yieldable
 from synapse.media._base import FileInfo
 from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS
 from synapse.rest import admin
 from synapse.rest.client import login, media
 from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
 from synapse.util import Clock
 from synapse.util.stringutils import parse_and_validate_mxc_uri
 
 from tests import unittest
-from tests.server import FakeTransport, ThreadedMemoryReactorClock
+from tests.media.test_media_storage import (
+    SVG,
+    TestImage,
+    empty_file,
+    small_lossless_webp,
+    small_png,
+    small_png_with_transparency,
+)
+from tests.server import FakeChannel, FakeTransport, ThreadedMemoryReactorClock
 from tests.test_utils import SMALL_PNG
 from tests.unittest import override_config
 
@@ -1607,3 +1630,583 @@ class UnstableMediaConfigTest(unittest.HomeserverTestCase):
         self.assertEqual(
             channel.json_body["m.upload.size"], self.hs.config.media.max_upload_size
         )
+
+
+class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        media.register_servlets,
+        login.register_servlets,
+        admin.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        config = self.default_config()
+
+        self.storage_path = self.mktemp()
+        self.media_store_path = self.mktemp()
+        os.mkdir(self.storage_path)
+        os.mkdir(self.media_store_path)
+        config["media_store_path"] = self.media_store_path
+
+        provider_config = {
+            "module": "synapse.media.storage_provider.FileStorageProviderBackend",
+            "store_local": True,
+            "store_synchronous": False,
+            "store_remote": True,
+            "config": {"directory": self.storage_path},
+        }
+
+        config["media_storage_providers"] = [provider_config]
+
+        return self.setup_test_homeserver(config=config)
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.repo = hs.get_media_repository()
+        self.client = hs.get_federation_http_client()
+        self.store = hs.get_datastores().main
+        self.user = self.register_user("user", "pass")
+        self.tok = self.login("user", "pass")
+
+    # mock actually reading file body
+    def read_multipart_response_30MiB(*args: Any, **kwargs: Any) -> Deferred:
+        d: Deferred = defer.Deferred()
+        d.callback(MultipartResponse(b"{}", 31457280, b"img/png", None))
+        return d
+
+    def read_multipart_response_50MiB(*args: Any, **kwargs: Any) -> Deferred:
+        d: Deferred = defer.Deferred()
+        d.callback(MultipartResponse(b"{}", 31457280, b"img/png", None))
+        return d
+
+    @patch(
+        "synapse.http.matrixfederationclient.read_multipart_response",
+        read_multipart_response_30MiB,
+    )
+    def test_download_ratelimit_default(self) -> None:
+        """
+        Test remote media download ratelimiting against default configuration - 500MB bucket
+        and 87kb/second drain rate
+        """
+
+        # mock out actually sending the request, returns a 30MiB response
+        async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+            resp = MagicMock(spec=IResponse)
+            resp.code = 200
+            resp.length = 31457280
+            resp.headers = Headers(
+                {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
+            )
+            resp.phrase = b"OK"
+            return resp
+
+        self.client._send_request = _send_request  # type: ignore
+
+        # first request should go through
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abc",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        assert channel.code == 200
+
+        # next 15 should go through
+        for i in range(15):
+            channel2 = self.make_request(
+                "GET",
+                f"/_matrix/client/v1/media/download/remote.org/abc{i}",
+                shorthand=False,
+                access_token=self.tok,
+            )
+            assert channel2.code == 200
+
+        # 17th will hit ratelimit
+        channel3 = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abcd",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        assert channel3.code == 429
+
+        # however, a request from a different IP will go through
+        channel4 = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abcde",
+            shorthand=False,
+            client_ip="187.233.230.159",
+            access_token=self.tok,
+        )
+        assert channel4.code == 200
+
+        # at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another
+        # 30MiB download is authorized - The last download was blocked at 503,316,480.
+        # The next download will be authorized when bucket hits 492,830,720
+        # (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760
+        # needs to drain before another download will be authorized, that will take ~=
+        # 2 minutes (10,485,760/89,088/60)
+        self.reactor.pump([2.0 * 60.0])
+
+        # enough has drained and next request goes through
+        channel5 = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abcdef",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        assert channel5.code == 200
+
+    @override_config(
+        {
+            "remote_media_download_per_second": "50M",
+            "remote_media_download_burst_count": "50M",
+        }
+    )
+    @patch(
+        "synapse.http.matrixfederationclient.read_multipart_response",
+        read_multipart_response_50MiB,
+    )
+    def test_download_rate_limit_config(self) -> None:
+        """
+        Test that download rate limit config options are correctly picked up and applied
+        """
+
+        async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+            resp = MagicMock(spec=IResponse)
+            resp.code = 200
+            resp.length = 52428800
+            resp.headers = Headers(
+                {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
+            )
+            resp.phrase = b"OK"
+            return resp
+
+        self.client._send_request = _send_request  # type: ignore
+
+        # first request should go through
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abc",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        assert channel.code == 200
+
+        # immediate second request should fail
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abcd",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        assert channel.code == 429
+
+        # advance half a second
+        self.reactor.pump([0.5])
+
+        # request still fails
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abcde",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        assert channel.code == 429
+
+        # advance another half second
+        self.reactor.pump([0.5])
+
+        # enough has drained from bucket and request is successful
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abcdef",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        assert channel.code == 200
+
+    @patch(
+        "synapse.http.matrixfederationclient.read_multipart_response",
+        read_multipart_response_30MiB,
+    )
+    def test_download_ratelimit_max_size_sub(self) -> None:
+        """
+        Test that if no content-length is provided, the default max size is applied instead
+        """
+
+        # mock out actually sending the request
+        async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+            resp = MagicMock(spec=IResponse)
+            resp.code = 200
+            resp.length = UNKNOWN_LENGTH
+            resp.headers = Headers(
+                {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
+            )
+            resp.phrase = b"OK"
+            return resp
+
+        self.client._send_request = _send_request  # type: ignore
+
+        # ten requests should go through using the max size (500MB/50MB)
+        for i in range(10):
+            channel2 = self.make_request(
+                "GET",
+                f"/_matrix/client/v1/media/download/remote.org/abc{i}",
+                shorthand=False,
+                access_token=self.tok,
+            )
+            assert channel2.code == 200
+
+        # eleventh will hit ratelimit
+        channel3 = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abcd",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        assert channel3.code == 429
+
+    def test_file_download(self) -> None:
+        content = io.BytesIO(b"file_to_stream")
+        content_uri = self.get_success(
+            self.repo.create_content(
+                "text/plain",
+                "test_upload",
+                content,
+                46,
+                UserID.from_string("@user_id:whatever.org"),
+            )
+        )
+        # test with a text file
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/v1/media/download/test/{content_uri.media_id}",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        self.pump()
+        self.assertEqual(200, channel.code)
+
+
+test_images = [
+    small_png,
+    small_png_with_transparency,
+    small_lossless_webp,
+    empty_file,
+    SVG,
+]
+input_values = [(x,) for x in test_images]
+
+
+@parameterized_class(("test_image",), input_values)
+class DownloadTestCase(unittest.HomeserverTestCase):
+    test_image: ClassVar[TestImage]
+    servlets = [
+        media.register_servlets,
+        login.register_servlets,
+        admin.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        self.fetches: List[
+            Tuple[
+                "Deferred[Any]",
+                str,
+                str,
+                Optional[QueryParams],
+            ]
+        ] = []
+
+        def federation_get_file(
+            destination: str,
+            path: str,
+            output_stream: BinaryIO,
+            download_ratelimiter: Ratelimiter,
+            ip_address: Any,
+            max_size: int,
+            args: Optional[QueryParams] = None,
+            retry_on_dns_fail: bool = True,
+            ignore_backoff: bool = False,
+            follow_redirects: bool = False,
+        ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]], bytes]]":
+            """A mock for MatrixFederationHttpClient.federation_get_file."""
+
+            def write_to(
+                r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]
+            ) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
+                data, response = r
+                output_stream.write(data)
+                return response
+
+            def write_err(f: Failure) -> Failure:
+                f.trap(HttpResponseException)
+                output_stream.write(f.value.response)
+                return f
+
+            d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]] = (
+                Deferred()
+            )
+            self.fetches.append((d, destination, path, args))
+            # Note that this callback changes the value held by d.
+            d_after_callback = d.addCallbacks(write_to, write_err)
+            return make_deferred_yieldable(d_after_callback)
+
+        def get_file(
+            destination: str,
+            path: str,
+            output_stream: BinaryIO,
+            download_ratelimiter: Ratelimiter,
+            ip_address: Any,
+            max_size: int,
+            args: Optional[QueryParams] = None,
+            retry_on_dns_fail: bool = True,
+            ignore_backoff: bool = False,
+            follow_redirects: bool = False,
+        ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
+            """A mock for MatrixFederationHttpClient.get_file."""
+
+            def write_to(
+                r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
+            ) -> Tuple[int, Dict[bytes, List[bytes]]]:
+                data, response = r
+                output_stream.write(data)
+                return response
+
+            def write_err(f: Failure) -> Failure:
+                f.trap(HttpResponseException)
+                output_stream.write(f.value.response)
+                return f
+
+            d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
+            self.fetches.append((d, destination, path, args))
+            # Note that this callback changes the value held by d.
+            d_after_callback = d.addCallbacks(write_to, write_err)
+            return make_deferred_yieldable(d_after_callback)
+
+        # Mock out the homeserver's MatrixFederationHttpClient
+        client = Mock()
+        client.federation_get_file = federation_get_file
+        client.get_file = get_file
+
+        self.storage_path = self.mktemp()
+        self.media_store_path = self.mktemp()
+        os.mkdir(self.storage_path)
+        os.mkdir(self.media_store_path)
+
+        config = self.default_config()
+        config["media_store_path"] = self.media_store_path
+        config["max_image_pixels"] = 2000000
+
+        provider_config = {
+            "module": "synapse.media.storage_provider.FileStorageProviderBackend",
+            "store_local": True,
+            "store_synchronous": False,
+            "store_remote": True,
+            "config": {"directory": self.storage_path},
+        }
+        config["media_storage_providers"] = [provider_config]
+        config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
+
+        hs = self.setup_test_homeserver(config=config, federation_http_client=client)
+
+        return hs
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.media_repo = hs.get_media_repository()
+
+        self.remote = "example.com"
+        self.media_id = "12345"
+
+        self.user = self.register_user("user", "pass")
+        self.tok = self.login("user", "pass")
+
+    def _req(
+        self, content_disposition: Optional[bytes], include_content_type: bool = True
+    ) -> FakeChannel:
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}",
+            shorthand=False,
+            await_result=False,
+            access_token=self.tok,
+        )
+        self.pump()
+
+        # We've made one fetch, to example.com, using the federation media URL
+        self.assertEqual(len(self.fetches), 1)
+        self.assertEqual(self.fetches[0][1], "example.com")
+        self.assertEqual(
+            self.fetches[0][2], "/_matrix/federation/v1/media/download/" + self.media_id
+        )
+        self.assertEqual(
+            self.fetches[0][3],
+            {"timeout_ms": "20000"},
+        )
+
+        headers = {
+            b"Content-Length": [b"%d" % (len(self.test_image.data))],
+        }
+
+        if include_content_type:
+            headers[b"Content-Type"] = [self.test_image.content_type]
+
+        if content_disposition:
+            headers[b"Content-Disposition"] = [content_disposition]
+
+        self.fetches[0][0].callback(
+            (self.test_image.data, (len(self.test_image.data), headers, b"{}"))
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
+
+        return channel
+
+    def test_handle_missing_content_type(self) -> None:
+        channel = self._req(
+            b"attachment; filename=out" + self.test_image.extension,
+            include_content_type=False,
+        )
+        headers = channel.headers
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
+        )
+
+    def test_disposition_filename_ascii(self) -> None:
+        """
+        If the filename is filename=<ascii> then Synapse will decode it as an
+        ASCII string, and use filename= in the response.
+        """
+        channel = self._req(b"attachment; filename=out" + self.test_image.extension)
+
+        headers = channel.headers
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
+        )
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Disposition"),
+            [
+                (b"inline" if self.test_image.is_inline else b"attachment")
+                + b"; filename=out"
+                + self.test_image.extension
+            ],
+        )
+
+    def test_disposition_filenamestar_utf8escaped(self) -> None:
+        """
+        If the filename is filename=*utf8''<utf8 escaped> then Synapse will
+        correctly decode it as the UTF-8 string, and use filename* in the
+        response.
+        """
+        filename = parse.quote("\u2603".encode()).encode("ascii")
+        channel = self._req(
+            b"attachment; filename*=utf-8''" + filename + self.test_image.extension
+        )
+
+        headers = channel.headers
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
+        )
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Disposition"),
+            [
+                (b"inline" if self.test_image.is_inline else b"attachment")
+                + b"; filename*=utf-8''"
+                + filename
+                + self.test_image.extension
+            ],
+        )
+
+    def test_disposition_none(self) -> None:
+        """
+        If there is no filename, Content-Disposition should only
+        be a disposition type.
+        """
+        channel = self._req(None)
+
+        headers = channel.headers
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
+        )
+        self.assertEqual(
+            headers.getRawHeaders(b"Content-Disposition"),
+            [b"inline" if self.test_image.is_inline else b"attachment"],
+        )
+
+    def test_x_robots_tag_header(self) -> None:
+        """
+        Tests that the `X-Robots-Tag` header is present, which informs web crawlers
+        to not index, archive, or follow links in media.
+        """
+        channel = self._req(b"attachment; filename=out" + self.test_image.extension)
+
+        headers = channel.headers
+        self.assertEqual(
+            headers.getRawHeaders(b"X-Robots-Tag"),
+            [b"noindex, nofollow, noarchive, noimageindex"],
+        )
+
+    def test_cross_origin_resource_policy_header(self) -> None:
+        """
+        Test that the Cross-Origin-Resource-Policy header is set to "cross-origin"
+        allowing web clients to embed media from the downloads API.
+        """
+        channel = self._req(b"attachment; filename=out" + self.test_image.extension)
+
+        headers = channel.headers
+
+        self.assertEqual(
+            headers.getRawHeaders(b"Cross-Origin-Resource-Policy"),
+            [b"cross-origin"],
+        )
+
+    def test_unknown_federation_endpoint(self) -> None:
+        """
+        Test that if the downloadd request to remote federation endpoint returns a 404
+        we fall back to the _matrix/media endpoint
+        """
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}",
+            shorthand=False,
+            await_result=False,
+            access_token=self.tok,
+        )
+        self.pump()
+
+        # We've made one fetch, to example.com, using the media URL, and asking
+        # the other server not to do a remote fetch
+        self.assertEqual(len(self.fetches), 1)
+        self.assertEqual(self.fetches[0][1], "example.com")
+        self.assertEqual(
+            self.fetches[0][2], f"/_matrix/federation/v1/media/download/{self.media_id}"
+        )
+
+        # The result which says the endpoint is unknown.
+        unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}'
+        self.fetches[0][0].errback(
+            HttpResponseException(404, "NOT FOUND", unknown_endpoint)
+        )
+
+        self.pump()
+
+        # There should now be another request to the _matrix/media/v3/download URL.
+        self.assertEqual(len(self.fetches), 2)
+        self.assertEqual(self.fetches[1][1], "example.com")
+        self.assertEqual(
+            self.fetches[1][2],
+            f"/_matrix/media/v3/download/example.com/{self.media_id}",
+        )
+
+        headers = {
+            b"Content-Length": [b"%d" % (len(self.test_image.data))],
+        }
+
+        self.fetches[1][0].callback(
+            (self.test_image.data, (len(self.test_image.data), headers))
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)