summary refs log tree commit diff
path: root/tests/rest/client/test_media.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_media.py')
-rw-r--r--tests/rest/client/test_media.py259
1 files changed, 254 insertions, 5 deletions
diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py
index 7f2caed7d5..30b6d31d0a 100644
--- a/tests/rest/client/test_media.py
+++ b/tests/rest/client/test_media.py
@@ -43,6 +43,7 @@ from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
+from twisted.web.resource import Resource
 
 from synapse.api.errors import HttpResponseException
 from synapse.api.ratelimiting import Ratelimiter
@@ -1809,13 +1810,19 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
         )
         assert channel.code == 200
 
+    @override_config(
+        {
+            "remote_media_download_burst_count": "87M",
+        }
+    )
     @patch(
         "synapse.http.matrixfederationclient.read_multipart_response",
         read_multipart_response_30MiB,
     )
-    def test_download_ratelimit_max_size_sub(self) -> None:
+    def test_download_ratelimit_unknown_length(self) -> None:
         """
-        Test that if no content-length is provided, the default max size is applied instead
+        Test that if no content-length is provided, ratelimiting is still applied after
+        media is downloaded and length is known
         """
 
         # mock out actually sending the request
@@ -1831,8 +1838,9 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
 
         self.client._send_request = _send_request  # type: ignore
 
-        # ten requests should go through using the max size (500MB/50MB)
-        for i in range(10):
+        # first 3 will go through (note that 3rd request technically violates rate limit but
+        # that since the ratelimiting is applied *after* download it goes through, but next one fails)
+        for i in range(3):
             channel2 = self.make_request(
                 "GET",
                 f"/_matrix/client/v1/media/download/remote.org/abc{i}",
@@ -1841,7 +1849,7 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
             )
             assert channel2.code == 200
 
-        # eleventh will hit ratelimit
+        # 4th will hit ratelimit
         channel3 = self.make_request(
             "GET",
             "/_matrix/client/v1/media/download/remote.org/abcd",
@@ -1850,6 +1858,39 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
         )
         assert channel3.code == 429
 
+    @override_config({"max_upload_size": "29M"})
+    @patch(
+        "synapse.http.matrixfederationclient.read_multipart_response",
+        read_multipart_response_30MiB,
+    )
+    def test_max_download_respected(self) -> None:
+        """
+        Test that the max download size is enforced - note that max download size is determined
+        by the max_upload_size
+        """
+
+        # mock out actually sending the request, returns a 30MiB response
+        async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+            resp = MagicMock(spec=IResponse)
+            resp.code = 200
+            resp.length = 31457280
+            resp.headers = Headers(
+                {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
+            )
+            resp.phrase = b"OK"
+            return resp
+
+        self.client._send_request = _send_request  # type: ignore
+
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/remote.org/abcd",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        assert channel.code == 502
+        assert channel.json_body["errcode"] == "M_TOO_LARGE"
+
     def test_file_download(self) -> None:
         content = io.BytesIO(b"file_to_stream")
         content_uri = self.get_success(
@@ -2426,3 +2467,211 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase):
                 server_name=None,
             )
         )
+
+
+configs = [
+    {"extra_config": {"dynamic_thumbnails": True}},
+    {"extra_config": {"dynamic_thumbnails": False}},
+]
+
+
+@parameterized_class(configs)
+class AuthenticatedMediaTestCase(unittest.HomeserverTestCase):
+    extra_config: Dict[str, Any]
+    servlets = [
+        media.register_servlets,
+        login.register_servlets,
+        admin.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        config = self.default_config()
+
+        self.clock = clock
+        self.storage_path = self.mktemp()
+        self.media_store_path = self.mktemp()
+        os.mkdir(self.storage_path)
+        os.mkdir(self.media_store_path)
+        config["media_store_path"] = self.media_store_path
+        config["enable_authenticated_media"] = True
+
+        provider_config = {
+            "module": "synapse.media.storage_provider.FileStorageProviderBackend",
+            "store_local": True,
+            "store_synchronous": False,
+            "store_remote": True,
+            "config": {"directory": self.storage_path},
+        }
+
+        config["media_storage_providers"] = [provider_config]
+        config.update(self.extra_config)
+
+        return self.setup_test_homeserver(config=config)
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.repo = hs.get_media_repository()
+        self.client = hs.get_federation_http_client()
+        self.store = hs.get_datastores().main
+        self.user = self.register_user("user", "pass")
+        self.tok = self.login("user", "pass")
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        resources = super().create_resource_dict()
+        resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+        return resources
+
+    def test_authenticated_media(self) -> None:
+        # upload some local media with authentication on
+        channel = self.make_request(
+            "POST",
+            "_matrix/media/v3/upload?filename=test_png_upload",
+            SMALL_PNG,
+            self.tok,
+            shorthand=False,
+            content_type=b"image/png",
+            custom_headers=[("Content-Length", str(67))],
+        )
+        self.assertEqual(channel.code, 200)
+        res = channel.json_body.get("content_uri")
+        assert res is not None
+        uri = res.split("mxc://")[1]
+
+        # request media over authenticated endpoint, should be found
+        channel2 = self.make_request(
+            "GET",
+            f"_matrix/client/v1/media/download/{uri}",
+            access_token=self.tok,
+            shorthand=False,
+        )
+        self.assertEqual(channel2.code, 200)
+
+        # request same media over unauthenticated media, should raise 404 not found
+        channel3 = self.make_request(
+            "GET", f"_matrix/media/v3/download/{uri}", shorthand=False
+        )
+        self.assertEqual(channel3.code, 404)
+
+        # check thumbnails as well
+        params = "?width=32&height=32&method=crop"
+        channel4 = self.make_request(
+            "GET",
+            f"/_matrix/client/v1/media/thumbnail/{uri}{params}",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        self.assertEqual(channel4.code, 200)
+
+        params = "?width=32&height=32&method=crop"
+        channel5 = self.make_request(
+            "GET",
+            f"/_matrix/media/r0/thumbnail/{uri}{params}",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        self.assertEqual(channel5.code, 404)
+
+        # Inject a piece of remote media.
+        file_id = "abcdefg12345"
+        file_info = FileInfo(server_name="lonelyIsland", file_id=file_id)
+
+        media_storage = self.hs.get_media_repository().media_storage
+
+        ctx = media_storage.store_into_file(file_info)
+        (f, fname) = self.get_success(ctx.__aenter__())
+        f.write(SMALL_PNG)
+        self.get_success(ctx.__aexit__(None, None, None))
+
+        # we write the authenticated status when storing media, so this should pick up
+        # config and authenticate the media
+        self.get_success(
+            self.store.store_cached_remote_media(
+                origin="lonelyIsland",
+                media_id="52",
+                media_type="image/png",
+                media_length=1,
+                time_now_ms=self.clock.time_msec(),
+                upload_name="remote_test.png",
+                filesystem_id=file_id,
+            )
+        )
+
+        # ensure we have thumbnails for the non-dynamic code path
+        if self.extra_config == {"dynamic_thumbnails": False}:
+            self.get_success(
+                self.repo._generate_thumbnails(
+                    "lonelyIsland", "52", file_id, "image/png"
+                )
+            )
+
+        channel6 = self.make_request(
+            "GET",
+            "_matrix/client/v1/media/download/lonelyIsland/52",
+            access_token=self.tok,
+            shorthand=False,
+        )
+        self.assertEqual(channel6.code, 200)
+
+        channel7 = self.make_request(
+            "GET", f"_matrix/media/v3/download/{uri}", shorthand=False
+        )
+        self.assertEqual(channel7.code, 404)
+
+        params = "?width=32&height=32&method=crop"
+        channel8 = self.make_request(
+            "GET",
+            f"/_matrix/client/v1/media/thumbnail/lonelyIsland/52{params}",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        self.assertEqual(channel8.code, 200)
+
+        channel9 = self.make_request(
+            "GET",
+            f"/_matrix/media/r0/thumbnail/lonelyIsland/52{params}",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        self.assertEqual(channel9.code, 404)
+
+        # Inject a piece of local media that isn't authenticated
+        file_id = "abcdefg123456"
+        file_info = FileInfo(None, file_id=file_id)
+
+        ctx = media_storage.store_into_file(file_info)
+        (f, fname) = self.get_success(ctx.__aenter__())
+        f.write(SMALL_PNG)
+        self.get_success(ctx.__aexit__(None, None, None))
+
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "local_media_repository",
+                {
+                    "media_id": "abcdefg123456",
+                    "media_type": "image/png",
+                    "created_ts": self.clock.time_msec(),
+                    "upload_name": "test_local",
+                    "media_length": 1,
+                    "user_id": "someone",
+                    "url_cache": None,
+                    "authenticated": False,
+                },
+                desc="store_local_media",
+            )
+        )
+
+        # check that unauthenticated media is still available over both endpoints
+        channel9 = self.make_request(
+            "GET",
+            "/_matrix/client/v1/media/download/test/abcdefg123456",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        self.assertEqual(channel9.code, 200)
+
+        channel10 = self.make_request(
+            "GET",
+            "/_matrix/media/r0/download/test/abcdefg123456",
+            shorthand=False,
+            access_token=self.tok,
+        )
+        self.assertEqual(channel10.code, 200)