summary refs log tree commit diff
path: root/tests/rest/client/test_media.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/rest/client/test_media.py183
1 files changed, 177 insertions, 6 deletions
diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py

index 30b6d31d0a..6ee761e44b 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py
@@ -24,14 +24,13 @@ import json import os import re import shutil -from typing import Any, BinaryIO, Dict, List, Optional, Sequence, Tuple, Type +from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type from unittest.mock import MagicMock, Mock, patch from urllib import parse from urllib.parse import quote, urlencode from parameterized import parameterized, parameterized_class from PIL import Image as Image -from typing_extensions import ClassVar from twisted.internet import defer from twisted.internet._resolver import HostResolution @@ -66,6 +65,7 @@ from tests.media.test_media_storage import ( SVG, TestImage, empty_file, + small_cmyk_jpeg, small_lossless_webp, small_png, small_png_with_transparency, @@ -137,6 +137,7 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase): time_now_ms=clock.time_msec(), upload_name="test.png", filesystem_id=file_id, + sha256=file_id, ) ) self.register_user("user", "password") @@ -1005,7 +1006,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): data = base64.b64encode(SMALL_PNG) end_content = ( - b"<html><head>" b'<img src="data:image/png;base64,%s" />' b"</head></html>" + b'<html><head><img src="data:image/png;base64,%s" /></head></html>' ) % (data,) channel = self.make_request( @@ -1617,6 +1618,63 @@ class MediaConfigTest(unittest.HomeserverTestCase): ) +class MediaConfigModuleCallbackTestCase(unittest.HomeserverTestCase): + servlets = [ + media.register_servlets, + admin.register_servlets, + login.register_servlets, + ] + + def make_homeserver( + self, reactor: ThreadedMemoryReactorClock, clock: Clock + ) -> HomeServer: + config = self.default_config() + + self.storage_path = self.mktemp() + self.media_store_path = self.mktemp() + os.mkdir(self.storage_path) + os.mkdir(self.media_store_path) + config["media_store_path"] = self.media_store_path + + provider_config = { + "module": "synapse.media.storage_provider.FileStorageProviderBackend", + "store_local": True, + "store_synchronous": False, + "store_remote": True, + "config": {"directory": self.storage_path}, + } + + config["media_storage_providers"] = [provider_config] + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.user = self.register_user("user", "password") + self.tok = self.login("user", "password") + + hs.get_module_api().register_media_repository_callbacks( + get_media_config_for_user=self.get_media_config_for_user, + ) + + async def get_media_config_for_user( + self, + user_id: str, + ) -> Optional[JsonDict]: + # We echo back the user_id and set a custom upload size. + return {"m.upload.size": 1024, "user_id": user_id} + + def test_media_config(self) -> None: + channel = self.make_request( + "GET", + "/_matrix/client/v1/media/config", + shorthand=False, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["m.upload.size"], 1024) + self.assertEqual(channel.json_body["user_id"], self.user) + + class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase): servlets = [ media.register_servlets, @@ -1916,6 +1974,7 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase): test_images = [ small_png, small_png_with_transparency, + small_cmyk_jpeg, small_lossless_webp, empty_file, SVG, @@ -1957,7 +2016,7 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase): """A mock for MatrixFederationHttpClient.federation_get_file.""" def write_to( - r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]] + r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]], ) -> Tuple[int, Dict[bytes, List[bytes]], bytes]: data, response = r output_stream.write(data) @@ -1991,7 +2050,7 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase): """A mock for MatrixFederationHttpClient.get_file.""" def write_to( - r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]] + r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]], ) -> Tuple[int, Dict[bytes, List[bytes]]]: data, response = r output_stream.write(data) @@ -2400,7 +2459,7 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase): if expected_body is not None: self.assertEqual( - channel.result["body"], expected_body, channel.result["body"] + channel.result["body"], expected_body, channel.result["body"].hex() ) else: # ensure that the result is at least some valid image @@ -2592,6 +2651,7 @@ class AuthenticatedMediaTestCase(unittest.HomeserverTestCase): time_now_ms=self.clock.time_msec(), upload_name="remote_test.png", filesystem_id=file_id, + sha256=file_id, ) ) @@ -2675,3 +2735,114 @@ class AuthenticatedMediaTestCase(unittest.HomeserverTestCase): access_token=self.tok, ) self.assertEqual(channel10.code, 200) + + def test_authenticated_media_etag(self) -> None: + """Test that ETag works correctly with authenticated media over client + APIs""" + + # upload some local media with authentication on + channel = self.make_request( + "POST", + "_matrix/media/v3/upload?filename=test_png_upload", + SMALL_PNG, + self.tok, + shorthand=False, + content_type=b"image/png", + custom_headers=[("Content-Length", str(67))], + ) + self.assertEqual(channel.code, 200) + res = channel.json_body.get("content_uri") + assert res is not None + uri = res.split("mxc://")[1] + + # Check standard media endpoint + self._check_caching(f"/download/{uri}") + + # check thumbnails as well + params = "?width=32&height=32&method=crop" + self._check_caching(f"/thumbnail/{uri}{params}") + + # Inject a piece of remote media. + file_id = "abcdefg12345" + file_info = FileInfo(server_name="lonelyIsland", file_id=file_id) + + media_storage = self.hs.get_media_repository().media_storage + + ctx = media_storage.store_into_file(file_info) + (f, fname) = self.get_success(ctx.__aenter__()) + f.write(SMALL_PNG) + self.get_success(ctx.__aexit__(None, None, None)) + + # we write the authenticated status when storing media, so this should pick up + # config and authenticate the media + self.get_success( + self.store.store_cached_remote_media( + origin="lonelyIsland", + media_id="52", + media_type="image/png", + media_length=1, + time_now_ms=self.clock.time_msec(), + upload_name="remote_test.png", + filesystem_id=file_id, + sha256=file_id, + ) + ) + + # ensure we have thumbnails for the non-dynamic code path + if self.extra_config == {"dynamic_thumbnails": False}: + self.get_success( + self.repo._generate_thumbnails( + "lonelyIsland", "52", file_id, "image/png" + ) + ) + + self._check_caching("/download/lonelyIsland/52") + + params = "?width=32&height=32&method=crop" + self._check_caching(f"/thumbnail/lonelyIsland/52{params}") + + def _check_caching(self, path: str) -> None: + """ + Checks that: + 1. fetching the path returns an ETag header + 2. refetching with the ETag returns a 304 without a body + 3. refetching with the ETag but through unauthenticated endpoint + returns 404 + """ + + # Request media over authenticated endpoint, should be found + channel1 = self.make_request( + "GET", + f"/_matrix/client/v1/media{path}", + access_token=self.tok, + shorthand=False, + ) + self.assertEqual(channel1.code, 200) + + # Should have a single ETag field + etags = channel1.headers.getRawHeaders("ETag") + self.assertIsNotNone(etags) + assert etags is not None # For mypy + self.assertEqual(len(etags), 1) + etag = etags[0] + + # Refetching with the etag should result in 304 and empty body. + channel2 = self.make_request( + "GET", + f"/_matrix/client/v1/media{path}", + access_token=self.tok, + shorthand=False, + custom_headers=[("If-None-Match", etag)], + ) + self.assertEqual(channel2.code, 304) + self.assertEqual(channel2.is_finished(), True) + self.assertNotIn("body", channel2.result) + + # Refetching with the etag but no access token should result in 404. + channel3 = self.make_request( + "GET", + f"/_matrix/media/r0{path}", + shorthand=False, + custom_headers=[("If-None-Match", etag)], + ) + self.assertEqual(channel3.code, 404)