summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
authorShay <hillerys@element.io>2024-07-08 02:11:20 -0700
committerGitHub <noreply@github.com>2024-07-08 10:11:20 +0100
commitcf69f8d59b0a1fad2b0f313281647e3ea527cf5e (patch)
tree6542c9ad652b88d6653cf720cbbf9e3711942bdb /tests/rest
parentBump ruff from 0.3.7 to 0.5.0 (#17381) (diff)
downloadsynapse-cf69f8d59b0a1fad2b0f313281647e3ea527cf5e.tar.xz
Support MSC3916 by adding a federation /thumbnail endpoint and authenticated `_matrix/client/v1/media/thumbnail` endpoint (#17388)
[MSC3916](https://github.com/matrix-org/matrix-spec-proposals/pull/3916)
added the endpoints `_matrix/federation/v1/media/thumbnail` and the
authenticated `_matrix/client/v1/media/thumbnail`.

This PR implements those endpoints, along with stabilizing
`_matrix/client/v1/media/config` and
`_matrix/client/v1/media/preview_url`.

Complement tests are at
https://github.com/matrix-org/complement/pull/728
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/client/test_media.py358
1 files changed, 287 insertions, 71 deletions
diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py
index 6b5af2dbb6..7f2caed7d5 100644
--- a/tests/rest/client/test_media.py
+++ b/tests/rest/client/test_media.py
@@ -23,12 +23,15 @@ import io
 import json
 import os
 import re
-from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type
+import shutil
+from typing import Any, BinaryIO, Dict, List, Optional, Sequence, Tuple, Type
 from unittest.mock import MagicMock, Mock, patch
 from urllib import parse
 from urllib.parse import quote, urlencode
 
-from parameterized import parameterized_class
+from parameterized import parameterized, parameterized_class
+from PIL import Image as Image
+from typing_extensions import ClassVar
 
 from twisted.internet import defer
 from twisted.internet._resolver import HostResolution
@@ -40,7 +43,6 @@ from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
-from twisted.web.resource import Resource
 
 from synapse.api.errors import HttpResponseException
 from synapse.api.ratelimiting import Ratelimiter
@@ -48,7 +50,8 @@ from synapse.config.oembed import OEmbedEndpointConfig
 from synapse.http.client import MultipartResponse
 from synapse.http.types import QueryParams
 from synapse.logging.context import make_deferred_yieldable
-from synapse.media._base import FileInfo
+from synapse.media._base import FileInfo, ThumbnailInfo
+from synapse.media.thumbnailer import ThumbnailProvider
 from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS
 from synapse.rest import admin
 from synapse.rest.client import login, media
@@ -76,7 +79,7 @@ except ImportError:
     lxml = None  # type: ignore[assignment]
 
 
-class UnstableMediaDomainBlockingTests(unittest.HomeserverTestCase):
+class MediaDomainBlockingTests(unittest.HomeserverTestCase):
     remote_media_id = "doesnotmatter"
     remote_server_name = "evil.com"
     servlets = [
@@ -144,7 +147,6 @@ class UnstableMediaDomainBlockingTests(unittest.HomeserverTestCase):
             # Should result in a 404.
             "prevent_media_downloads_from": ["evil.com"],
             "dynamic_thumbnails": True,
-            "experimental_features": {"msc3916_authenticated_media_enabled": True},
         }
     )
     def test_cannot_download_blocked_media_thumbnail(self) -> None:
@@ -153,7 +155,7 @@ class UnstableMediaDomainBlockingTests(unittest.HomeserverTestCase):
         """
         response = self.make_request(
             "GET",
-            f"/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/evil.com/{self.remote_media_id}?width=100&height=100",
+            f"/_matrix/client/v1/media/thumbnail/evil.com/{self.remote_media_id}?width=100&height=100",
             shorthand=False,
             content={"width": 100, "height": 100},
             access_token=self.tok,
@@ -166,7 +168,6 @@ class UnstableMediaDomainBlockingTests(unittest.HomeserverTestCase):
             # This proves we haven't broken anything.
             "prevent_media_downloads_from": ["not-listed.com"],
             "dynamic_thumbnails": True,
-            "experimental_features": {"msc3916_authenticated_media_enabled": True},
         }
     )
     def test_remote_media_thumbnail_normally_unblocked(self) -> None:
@@ -175,14 +176,14 @@ class UnstableMediaDomainBlockingTests(unittest.HomeserverTestCase):
         """
         response = self.make_request(
             "GET",
-            f"/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/evil.com/{self.remote_media_id}?width=100&height=100",
+            f"/_matrix/client/v1/media/thumbnail/evil.com/{self.remote_media_id}?width=100&height=100",
             shorthand=False,
             access_token=self.tok,
         )
         self.assertEqual(response.code, 200)
 
 
-class UnstableURLPreviewTests(unittest.HomeserverTestCase):
+class URLPreviewTests(unittest.HomeserverTestCase):
     if not lxml:
         skip = "url preview feature requires lxml"
 
@@ -198,7 +199,6 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
-        config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
         config["url_preview_enabled"] = True
         config["max_spider_size"] = 9999999
         config["url_preview_ip_range_blacklist"] = (
@@ -284,18 +284,6 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         self.reactor.nameResolver = Resolver()  # type: ignore[assignment]
 
-    def create_resource_dict(self) -> Dict[str, Resource]:
-        """Create a resource tree for the test server
-
-        A resource tree is a mapping from path to twisted.web.resource.
-
-        The default implementation creates a JsonResource and calls each function in
-        `servlets` to register servlets against it.
-        """
-        resources = super().create_resource_dict()
-        resources["/_matrix/media"] = self.hs.get_media_repository_resource()
-        return resources
-
     def _assert_small_png(self, json_body: JsonDict) -> None:
         """Assert properties from the SMALL_PNG test image."""
         self.assertTrue(json_body["og:image"].startswith("mxc://"))
@@ -309,7 +297,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+            "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -334,7 +322,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
         # Check the cache returns the correct response
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+            "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
             shorthand=False,
         )
 
@@ -352,7 +340,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
         # Check the database cache returns the correct response
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+            "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
             shorthand=False,
         )
 
@@ -375,7 +363,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+            "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -405,7 +393,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+            "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -441,7 +429,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+            "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -482,7 +470,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+            "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -517,7 +505,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+            "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -550,7 +538,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com",
+            "/_matrix/client/v1/media/preview_url?url=http://example.com",
             shorthand=False,
             await_result=False,
         )
@@ -580,7 +568,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com",
+            "/_matrix/client/v1/media/preview_url?url=http://example.com",
             shorthand=False,
         )
 
@@ -603,7 +591,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com",
+            "/_matrix/client/v1/media/preview_url?url=http://example.com",
             shorthand=False,
         )
 
@@ -622,7 +610,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
         """
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://192.168.1.1",
+            "/_matrix/client/v1/media/preview_url?url=http://192.168.1.1",
             shorthand=False,
         )
 
@@ -640,7 +628,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
         """
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://1.1.1.2",
+            "/_matrix/client/v1/media/preview_url?url=http://1.1.1.2",
             shorthand=False,
         )
 
@@ -659,7 +647,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com",
+            "/_matrix/client/v1/media/preview_url?url=http://example.com",
             shorthand=False,
             await_result=False,
         )
@@ -696,7 +684,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com",
+            "/_matrix/client/v1/media/preview_url?url=http://example.com",
             shorthand=False,
         )
         self.assertEqual(channel.code, 502)
@@ -718,7 +706,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com",
+            "/_matrix/client/v1/media/preview_url?url=http://example.com",
             shorthand=False,
         )
 
@@ -741,7 +729,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com",
+            "/_matrix/client/v1/media/preview_url?url=http://example.com",
             shorthand=False,
         )
 
@@ -760,7 +748,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
         """
         channel = self.make_request(
             "OPTIONS",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com",
+            "/_matrix/client/v1/media/preview_url?url=http://example.com",
             shorthand=False,
         )
         self.assertEqual(channel.code, 204)
@@ -774,7 +762,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
         # Build and make a request to the server
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com",
+            "/_matrix/client/v1/media/preview_url?url=http://example.com",
             shorthand=False,
             await_result=False,
         )
@@ -827,7 +815,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+            "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -877,7 +865,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+            "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -919,7 +907,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+            "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -959,7 +947,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+            "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -1000,7 +988,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            f"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?{query_params}",
+            f"/_matrix/client/v1/media/preview_url?{query_params}",
             shorthand=False,
         )
         self.pump()
@@ -1021,7 +1009,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org",
+            "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
             shorthand=False,
             await_result=False,
         )
@@ -1058,7 +1046,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+            "/_matrix/client/v1/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
             shorthand=False,
             await_result=False,
         )
@@ -1118,7 +1106,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+            "/_matrix/client/v1/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
             shorthand=False,
             await_result=False,
         )
@@ -1167,7 +1155,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://www.hulu.com/watch/12345",
+            "/_matrix/client/v1/media/preview_url?url=http://www.hulu.com/watch/12345",
             shorthand=False,
             await_result=False,
         )
@@ -1212,7 +1200,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+            "/_matrix/client/v1/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
             shorthand=False,
             await_result=False,
         )
@@ -1241,7 +1229,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
+            "/_matrix/client/v1/media/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
             shorthand=False,
             await_result=False,
         )
@@ -1333,7 +1321,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
+            "/_matrix/client/v1/media/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
             shorthand=False,
             await_result=False,
         )
@@ -1374,7 +1362,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://cdn.twitter.com/matrixdotorg",
+            "/_matrix/client/v1/media/preview_url?url=http://cdn.twitter.com/matrixdotorg",
             shorthand=False,
             await_result=False,
         )
@@ -1416,7 +1404,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
         # Check fetching
         channel = self.make_request(
             "GET",
-            f"/_matrix/media/v3/download/{host}/{media_id}",
+            f"/_matrix/client/v1/media/download/{host}/{media_id}",
             shorthand=False,
             await_result=False,
         )
@@ -1429,7 +1417,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            f"/_matrix/media/v3/download/{host}/{media_id}",
+            f"/_matrix/client/v1/download/{host}/{media_id}",
             shorthand=False,
             await_result=False,
         )
@@ -1464,7 +1452,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
         # Check fetching
         channel = self.make_request(
             "GET",
-            f"/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
+            f"/_matrix/client/v1/media/thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
             shorthand=False,
             await_result=False,
         )
@@ -1482,7 +1470,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            f"/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
+            f"/_matrix/client/v1/media/thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
             shorthand=False,
             await_result=False,
         )
@@ -1532,8 +1520,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url="
-            + bad_url,
+            "/_matrix/client/v1/media/preview_url?url=" + bad_url,
             shorthand=False,
             await_result=False,
         )
@@ -1542,8 +1529,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url="
-            + good_url,
+            "/_matrix/client/v1/media/preview_url?url=" + good_url,
             shorthand=False,
             await_result=False,
         )
@@ -1575,8 +1561,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url="
-            + bad_url,
+            "/_matrix/client/v1/media/preview_url?url=" + bad_url,
             shorthand=False,
             await_result=False,
         )
@@ -1584,7 +1569,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 403, channel.result)
 
 
-class UnstableMediaConfigTest(unittest.HomeserverTestCase):
+class MediaConfigTest(unittest.HomeserverTestCase):
     servlets = [
         media.register_servlets,
         admin.register_servlets,
@@ -1595,7 +1580,6 @@ class UnstableMediaConfigTest(unittest.HomeserverTestCase):
         self, reactor: ThreadedMemoryReactorClock, clock: Clock
     ) -> HomeServer:
         config = self.default_config()
-        config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
 
         self.storage_path = self.mktemp()
         self.media_store_path = self.mktemp()
@@ -1622,7 +1606,7 @@ class UnstableMediaConfigTest(unittest.HomeserverTestCase):
     def test_media_config(self) -> None:
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/org.matrix.msc3916/media/config",
+            "/_matrix/client/v1/media/config",
             shorthand=False,
             access_token=self.tok,
         )
@@ -1899,7 +1883,7 @@ input_values = [(x,) for x in test_images]
 
 
 @parameterized_class(("test_image",), input_values)
-class DownloadTestCase(unittest.HomeserverTestCase):
+class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase):
     test_image: ClassVar[TestImage]
     servlets = [
         media.register_servlets,
@@ -2005,7 +1989,6 @@ class DownloadTestCase(unittest.HomeserverTestCase):
             "config": {"directory": self.storage_path},
         }
         config["media_storage_providers"] = [provider_config]
-        config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
 
         hs = self.setup_test_homeserver(config=config, federation_http_client=client)
 
@@ -2164,7 +2147,7 @@ class DownloadTestCase(unittest.HomeserverTestCase):
 
     def test_unknown_federation_endpoint(self) -> None:
         """
-        Test that if the downloadd request to remote federation endpoint returns a 404
+        Test that if the download request to remote federation endpoint returns a 404
         we fall back to the _matrix/media endpoint
         """
         channel = self.make_request(
@@ -2210,3 +2193,236 @@ class DownloadTestCase(unittest.HomeserverTestCase):
 
         self.pump()
         self.assertEqual(channel.code, 200)
+
+    def test_thumbnail_crop(self) -> None:
+        """Test that a cropped remote thumbnail is available."""
+        self._test_thumbnail(
+            "crop",
+            self.test_image.expected_cropped,
+            expected_found=self.test_image.expected_found,
+            unable_to_thumbnail=self.test_image.unable_to_thumbnail,
+        )
+
+    def test_thumbnail_scale(self) -> None:
+        """Test that a scaled remote thumbnail is available."""
+        self._test_thumbnail(
+            "scale",
+            self.test_image.expected_scaled,
+            expected_found=self.test_image.expected_found,
+            unable_to_thumbnail=self.test_image.unable_to_thumbnail,
+        )
+
+    def test_invalid_type(self) -> None:
+        """An invalid thumbnail type is never available."""
+        self._test_thumbnail(
+            "invalid",
+            None,
+            expected_found=False,
+            unable_to_thumbnail=self.test_image.unable_to_thumbnail,
+        )
+
+    @unittest.override_config(
+        {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
+    )
+    def test_no_thumbnail_crop(self) -> None:
+        """
+        Override the config to generate only scaled thumbnails, but request a cropped one.
+        """
+        self._test_thumbnail(
+            "crop",
+            None,
+            expected_found=False,
+            unable_to_thumbnail=self.test_image.unable_to_thumbnail,
+        )
+
+    @unittest.override_config(
+        {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
+    )
+    def test_no_thumbnail_scale(self) -> None:
+        """
+        Override the config to generate only cropped thumbnails, but request a scaled one.
+        """
+        self._test_thumbnail(
+            "scale",
+            None,
+            expected_found=False,
+            unable_to_thumbnail=self.test_image.unable_to_thumbnail,
+        )
+
+    def test_thumbnail_repeated_thumbnail(self) -> None:
+        """Test that fetching the same thumbnail works, and deleting the on disk
+        thumbnail regenerates it.
+        """
+        self._test_thumbnail(
+            "scale",
+            self.test_image.expected_scaled,
+            expected_found=self.test_image.expected_found,
+            unable_to_thumbnail=self.test_image.unable_to_thumbnail,
+        )
+
+        if not self.test_image.expected_found:
+            return
+
+        # Fetching again should work, without re-requesting the image from the
+        # remote.
+        params = "?width=32&height=32&method=scale"
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/v1/media/thumbnail/{self.remote}/{self.media_id}{params}",
+            shorthand=False,
+            await_result=False,
+            access_token=self.tok,
+        )
+        self.pump()
+
+        self.assertEqual(channel.code, 200)
+        if self.test_image.expected_scaled:
+            self.assertEqual(
+                channel.result["body"],
+                self.test_image.expected_scaled,
+                channel.result["body"],
+            )
+
+        # Deleting the thumbnail on disk then re-requesting it should work as
+        # Synapse should regenerate missing thumbnails.
+        info = self.get_success(
+            self.store.get_cached_remote_media(self.remote, self.media_id)
+        )
+        assert info is not None
+        file_id = info.filesystem_id
+
+        thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
+            self.remote, file_id
+        )
+        shutil.rmtree(thumbnail_dir, ignore_errors=True)
+
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/v1/media/thumbnail/{self.remote}/{self.media_id}{params}",
+            shorthand=False,
+            await_result=False,
+            access_token=self.tok,
+        )
+        self.pump()
+
+        self.assertEqual(channel.code, 200)
+        if self.test_image.expected_scaled:
+            self.assertEqual(
+                channel.result["body"],
+                self.test_image.expected_scaled,
+                channel.result["body"],
+            )
+
+    def _test_thumbnail(
+        self,
+        method: str,
+        expected_body: Optional[bytes],
+        expected_found: bool,
+        unable_to_thumbnail: bool = False,
+    ) -> None:
+        """Test the given thumbnailing method works as expected.
+
+        Args:
+            method: The thumbnailing method to use (crop, scale).
+            expected_body: The expected bytes from thumbnailing, or None if
+                test should just check for a valid image.
+            expected_found: True if the file should exist on the server, or False if
+                a 404/400 is expected.
+            unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or
+                False if the thumbnailing should succeed or a normal 404 is expected.
+        """
+
+        params = "?width=32&height=32&method=" + method
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/v1/media/thumbnail/{self.remote}/{self.media_id}{params}",
+            shorthand=False,
+            await_result=False,
+            access_token=self.tok,
+        )
+        self.pump()
+        headers = {
+            b"Content-Length": [b"%d" % (len(self.test_image.data))],
+            b"Content-Type": [self.test_image.content_type],
+        }
+        self.fetches[0][0].callback(
+            (self.test_image.data, (len(self.test_image.data), headers))
+        )
+        self.pump()
+        if expected_found:
+            self.assertEqual(channel.code, 200)
+
+            self.assertEqual(
+                channel.headers.getRawHeaders(b"Cross-Origin-Resource-Policy"),
+                [b"cross-origin"],
+            )
+
+            if expected_body is not None:
+                self.assertEqual(
+                    channel.result["body"], expected_body, channel.result["body"]
+                )
+            else:
+                # ensure that the result is at least some valid image
+                Image.open(io.BytesIO(channel.result["body"]))
+        elif unable_to_thumbnail:
+            # A 400 with a JSON body.
+            self.assertEqual(channel.code, 400)
+            self.assertEqual(
+                channel.json_body,
+                {
+                    "errcode": "M_UNKNOWN",
+                    "error": "Cannot find any thumbnails for the requested media ('/_matrix/client/v1/media/thumbnail/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
+                },
+            )
+        else:
+            # A 404 with a JSON body.
+            self.assertEqual(channel.code, 404)
+            self.assertEqual(
+                channel.json_body,
+                {
+                    "errcode": "M_NOT_FOUND",
+                    "error": "Not found '/_matrix/client/v1/media/thumbnail/example.com/12345'",
+                },
+            )
+
+    @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)])
+    def test_same_quality(self, method: str, desired_size: int) -> None:
+        """Test that choosing between thumbnails with the same quality rating succeeds.
+
+        We are not particular about which thumbnail is chosen."""
+
+        content_type = self.test_image.content_type.decode()
+        media_repo = self.hs.get_media_repository()
+        thumbnail_provider = ThumbnailProvider(
+            self.hs, media_repo, media_repo.media_storage
+        )
+
+        self.assertIsNotNone(
+            thumbnail_provider._select_thumbnail(
+                desired_width=desired_size,
+                desired_height=desired_size,
+                desired_method=method,
+                desired_type=content_type,
+                # Provide two identical thumbnails which are guaranteed to have the same
+                # quality rating.
+                thumbnail_infos=[
+                    ThumbnailInfo(
+                        width=32,
+                        height=32,
+                        method=method,
+                        type=content_type,
+                        length=256,
+                    ),
+                    ThumbnailInfo(
+                        width=32,
+                        height=32,
+                        method=method,
+                        type=content_type,
+                        length=256,
+                    ),
+                ],
+                file_id=f"image{self.test_image.extension.decode()}",
+                url_cache=False,
+                server_name=None,
+            )
+        )