diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py
index 466c5a0b70..30b6d31d0a 100644
--- a/tests/rest/client/test_media.py
+++ b/tests/rest/client/test_media.py
@@ -43,6 +43,7 @@ from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
from twisted.web.http_headers import Headers
from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
+from twisted.web.resource import Resource
from synapse.api.errors import HttpResponseException
from synapse.api.ratelimiting import Ratelimiter
@@ -2466,3 +2467,211 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase):
server_name=None,
)
)
+
+
+configs = [
+ {"extra_config": {"dynamic_thumbnails": True}},
+ {"extra_config": {"dynamic_thumbnails": False}},
+]
+
+
+@parameterized_class(configs)
+class AuthenticatedMediaTestCase(unittest.HomeserverTestCase):
+ extra_config: Dict[str, Any]
+ servlets = [
+ media.register_servlets,
+ login.register_servlets,
+ admin.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ config = self.default_config()
+
+ self.clock = clock
+ self.storage_path = self.mktemp()
+ self.media_store_path = self.mktemp()
+ os.mkdir(self.storage_path)
+ os.mkdir(self.media_store_path)
+ config["media_store_path"] = self.media_store_path
+ config["enable_authenticated_media"] = True
+
+ provider_config = {
+ "module": "synapse.media.storage_provider.FileStorageProviderBackend",
+ "store_local": True,
+ "store_synchronous": False,
+ "store_remote": True,
+ "config": {"directory": self.storage_path},
+ }
+
+ config["media_storage_providers"] = [provider_config]
+ config.update(self.extra_config)
+
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.repo = hs.get_media_repository()
+ self.client = hs.get_federation_http_client()
+ self.store = hs.get_datastores().main
+ self.user = self.register_user("user", "pass")
+ self.tok = self.login("user", "pass")
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ resources = super().create_resource_dict()
+ resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+ return resources
+
+ def test_authenticated_media(self) -> None:
+ # upload some local media with authentication on
+ channel = self.make_request(
+ "POST",
+ "_matrix/media/v3/upload?filename=test_png_upload",
+ SMALL_PNG,
+ self.tok,
+ shorthand=False,
+ content_type=b"image/png",
+ custom_headers=[("Content-Length", str(67))],
+ )
+ self.assertEqual(channel.code, 200)
+ res = channel.json_body.get("content_uri")
+ assert res is not None
+ uri = res.split("mxc://")[1]
+
+ # request media over authenticated endpoint, should be found
+ channel2 = self.make_request(
+ "GET",
+ f"_matrix/client/v1/media/download/{uri}",
+ access_token=self.tok,
+ shorthand=False,
+ )
+ self.assertEqual(channel2.code, 200)
+
+ # request same media over unauthenticated media, should raise 404 not found
+ channel3 = self.make_request(
+ "GET", f"_matrix/media/v3/download/{uri}", shorthand=False
+ )
+ self.assertEqual(channel3.code, 404)
+
+ # check thumbnails as well
+ params = "?width=32&height=32&method=crop"
+ channel4 = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/media/thumbnail/{uri}{params}",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ self.assertEqual(channel4.code, 200)
+
+ params = "?width=32&height=32&method=crop"
+ channel5 = self.make_request(
+ "GET",
+ f"/_matrix/media/r0/thumbnail/{uri}{params}",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ self.assertEqual(channel5.code, 404)
+
+ # Inject a piece of remote media.
+ file_id = "abcdefg12345"
+ file_info = FileInfo(server_name="lonelyIsland", file_id=file_id)
+
+ media_storage = self.hs.get_media_repository().media_storage
+
+ ctx = media_storage.store_into_file(file_info)
+ (f, fname) = self.get_success(ctx.__aenter__())
+ f.write(SMALL_PNG)
+ self.get_success(ctx.__aexit__(None, None, None))
+
+ # we write the authenticated status when storing media, so this should pick up
+ # config and authenticate the media
+ self.get_success(
+ self.store.store_cached_remote_media(
+ origin="lonelyIsland",
+ media_id="52",
+ media_type="image/png",
+ media_length=1,
+ time_now_ms=self.clock.time_msec(),
+ upload_name="remote_test.png",
+ filesystem_id=file_id,
+ )
+ )
+
+ # ensure we have thumbnails for the non-dynamic code path
+ if self.extra_config == {"dynamic_thumbnails": False}:
+ self.get_success(
+ self.repo._generate_thumbnails(
+ "lonelyIsland", "52", file_id, "image/png"
+ )
+ )
+
+ channel6 = self.make_request(
+ "GET",
+ "_matrix/client/v1/media/download/lonelyIsland/52",
+ access_token=self.tok,
+ shorthand=False,
+ )
+ self.assertEqual(channel6.code, 200)
+
+ channel7 = self.make_request(
+ "GET", f"_matrix/media/v3/download/{uri}", shorthand=False
+ )
+ self.assertEqual(channel7.code, 404)
+
+ params = "?width=32&height=32&method=crop"
+ channel8 = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/media/thumbnail/lonelyIsland/52{params}",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ self.assertEqual(channel8.code, 200)
+
+ channel9 = self.make_request(
+ "GET",
+ f"/_matrix/media/r0/thumbnail/lonelyIsland/52{params}",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ self.assertEqual(channel9.code, 404)
+
+ # Inject a piece of local media that isn't authenticated
+ file_id = "abcdefg123456"
+ file_info = FileInfo(None, file_id=file_id)
+
+ ctx = media_storage.store_into_file(file_info)
+ (f, fname) = self.get_success(ctx.__aenter__())
+ f.write(SMALL_PNG)
+ self.get_success(ctx.__aexit__(None, None, None))
+
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "local_media_repository",
+ {
+ "media_id": "abcdefg123456",
+ "media_type": "image/png",
+ "created_ts": self.clock.time_msec(),
+ "upload_name": "test_local",
+ "media_length": 1,
+ "user_id": "someone",
+ "url_cache": None,
+ "authenticated": False,
+ },
+ desc="store_local_media",
+ )
+ )
+
+ # check that unauthenticated media is still available over both endpoints
+ channel9 = self.make_request(
+ "GET",
+ "/_matrix/client/v1/media/download/test/abcdefg123456",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ self.assertEqual(channel9.code, 200)
+
+ channel10 = self.make_request(
+ "GET",
+ "/_matrix/media/r0/download/test/abcdefg123456",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ self.assertEqual(channel10.code, 200)
|