summary refs log tree commit diff
path: root/tests/media/test_media_storage.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/media/test_media_storage.py')
-rw-r--r--tests/media/test_media_storage.py122
1 files changed, 57 insertions, 65 deletions
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index 04fc7bdcef..15f5d644e4 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -28,12 +28,13 @@ from typing_extensions import Literal
 from twisted.internet import defer
 from twisted.internet.defer import Deferred
 from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import Resource
 
 from synapse.api.errors import Codes
 from synapse.events import EventBase
 from synapse.http.types import QueryParams
 from synapse.logging.context import make_deferred_yieldable
-from synapse.media._base import FileInfo
+from synapse.media._base import FileInfo, ThumbnailInfo
 from synapse.media.filepath import MediaFilePaths
 from synapse.media.media_storage import MediaStorage, ReadableFileWrapper
 from synapse.media.storage_provider import FileStorageProviderBackend
@@ -41,12 +42,13 @@ from synapse.module_api import ModuleApi
 from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
 from synapse.rest import admin
 from synapse.rest.client import login
+from synapse.rest.media.thumbnail_resource import ThumbnailResource
 from synapse.server import HomeServer
 from synapse.types import JsonDict, RoomAlias
 from synapse.util import Clock
 
 from tests import unittest
-from tests.server import FakeChannel, FakeSite, make_request
+from tests.server import FakeChannel
 from tests.test_utils import SMALL_PNG
 from tests.utils import default_config
 
@@ -288,22 +290,22 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         return hs
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        media_resource = hs.get_media_repository_resource()
-        self.download_resource = media_resource.children[b"download"]
-        self.thumbnail_resource = media_resource.children[b"thumbnail"]
         self.store = hs.get_datastores().main
         self.media_repo = hs.get_media_repository()
 
         self.media_id = "example.com/12345"
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        resources = super().create_resource_dict()
+        resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+        return resources
+
     def _req(
         self, content_disposition: Optional[bytes], include_content_type: bool = True
     ) -> FakeChannel:
-        channel = make_request(
-            self.reactor,
-            FakeSite(self.download_resource, self.reactor),
+        channel = self.make_request(
             "GET",
-            self.media_id,
+            f"/_matrix/media/v3/download/{self.media_id}",
             shorthand=False,
             await_result=False,
         )
@@ -481,11 +483,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         # Fetching again should work, without re-requesting the image from the
         # remote.
         params = "?width=32&height=32&method=scale"
-        channel = make_request(
-            self.reactor,
-            FakeSite(self.thumbnail_resource, self.reactor),
+        channel = self.make_request(
             "GET",
-            self.media_id + params,
+            f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
             shorthand=False,
             await_result=False,
         )
@@ -511,11 +511,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         )
         shutil.rmtree(thumbnail_dir, ignore_errors=True)
 
-        channel = make_request(
-            self.reactor,
-            FakeSite(self.thumbnail_resource, self.reactor),
+        channel = self.make_request(
             "GET",
-            self.media_id + params,
+            f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
             shorthand=False,
             await_result=False,
         )
@@ -549,11 +547,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         """
 
         params = "?width=32&height=32&method=" + method
-        channel = make_request(
-            self.reactor,
-            FakeSite(self.thumbnail_resource, self.reactor),
+        channel = self.make_request(
             "GET",
-            self.media_id + params,
+            f"/_matrix/media/r0/thumbnail/{self.media_id}{params}",
             shorthand=False,
             await_result=False,
         )
@@ -590,7 +586,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
                 channel.json_body,
                 {
                     "errcode": "M_UNKNOWN",
-                    "error": "Cannot find any thumbnails for the requested media ([b'example.com', b'12345']). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
+                    "error": "Cannot find any thumbnails for the requested media ('/_matrix/media/r0/thumbnail/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
                 },
             )
         else:
@@ -600,7 +596,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
                 channel.json_body,
                 {
                     "errcode": "M_NOT_FOUND",
-                    "error": "Not found [b'example.com', b'12345']",
+                    "error": "Not found '/_matrix/media/r0/thumbnail/example.com/12345'",
                 },
             )
 
@@ -609,34 +605,39 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         """Test that choosing between thumbnails with the same quality rating succeeds.
 
         We are not particular about which thumbnail is chosen."""
+
+        content_type = self.test_image.content_type.decode()
+        media_repo = self.hs.get_media_repository()
+        thumbnail_resouce = ThumbnailResource(
+            self.hs, media_repo, media_repo.media_storage
+        )
+
         self.assertIsNotNone(
-            self.thumbnail_resource._select_thumbnail(
+            thumbnail_resouce._select_thumbnail(
                 desired_width=desired_size,
                 desired_height=desired_size,
                 desired_method=method,
-                desired_type=self.test_image.content_type,
+                desired_type=content_type,
                 # Provide two identical thumbnails which are guaranteed to have the same
                 # quality rating.
                 thumbnail_infos=[
-                    {
-                        "thumbnail_width": 32,
-                        "thumbnail_height": 32,
-                        "thumbnail_method": method,
-                        "thumbnail_type": self.test_image.content_type,
-                        "thumbnail_length": 256,
-                        "filesystem_id": f"thumbnail1{self.test_image.extension.decode()}",
-                    },
-                    {
-                        "thumbnail_width": 32,
-                        "thumbnail_height": 32,
-                        "thumbnail_method": method,
-                        "thumbnail_type": self.test_image.content_type,
-                        "thumbnail_length": 256,
-                        "filesystem_id": f"thumbnail2{self.test_image.extension.decode()}",
-                    },
+                    ThumbnailInfo(
+                        width=32,
+                        height=32,
+                        method=method,
+                        type=content_type,
+                        length=256,
+                    ),
+                    ThumbnailInfo(
+                        width=32,
+                        height=32,
+                        method=method,
+                        type=content_type,
+                        length=256,
+                    ),
                 ],
                 file_id=f"image{self.test_image.extension.decode()}",
-                url_cache=None,
+                url_cache=False,
                 server_name=None,
             )
         )
@@ -725,13 +726,13 @@ class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase):
         self.user = self.register_user("user", "pass")
         self.tok = self.login("user", "pass")
 
-        # Allow for uploading and downloading to/from the media repo
-        self.media_repo = hs.get_media_repository_resource()
-        self.download_resource = self.media_repo.children[b"download"]
-        self.upload_resource = self.media_repo.children[b"upload"]
-
         load_legacy_spam_checkers(hs)
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        resources = super().create_resource_dict()
+        resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+        return resources
+
     def default_config(self) -> Dict[str, Any]:
         config = default_config("test")
 
@@ -751,9 +752,7 @@ class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase):
 
     def test_upload_innocent(self) -> None:
         """Attempt to upload some innocent data that should be allowed."""
-        self.helper.upload_media(
-            self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
-        )
+        self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
 
     def test_upload_ban(self) -> None:
         """Attempt to upload some data that includes bytes "evil", which should
@@ -762,9 +761,7 @@ class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase):
 
         data = b"Some evil data"
 
-        self.helper.upload_media(
-            self.upload_resource, data, tok=self.tok, expect_code=400
-        )
+        self.helper.upload_media(data, tok=self.tok, expect_code=400)
 
 
 EVIL_DATA = b"Some evil data"
@@ -781,15 +778,15 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
         self.user = self.register_user("user", "pass")
         self.tok = self.login("user", "pass")
 
-        # Allow for uploading and downloading to/from the media repo
-        self.media_repo = hs.get_media_repository_resource()
-        self.download_resource = self.media_repo.children[b"download"]
-        self.upload_resource = self.media_repo.children[b"upload"]
-
         hs.get_module_api().register_spam_checker_callbacks(
             check_media_file_for_spam=self.check_media_file_for_spam
         )
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        resources = super().create_resource_dict()
+        resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+        return resources
+
     async def check_media_file_for_spam(
         self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
     ) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]:
@@ -805,21 +802,16 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
 
     def test_upload_innocent(self) -> None:
         """Attempt to upload some innocent data that should be allowed."""
-        self.helper.upload_media(
-            self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
-        )
+        self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
 
     def test_upload_ban(self) -> None:
         """Attempt to upload some data that includes bytes "evil", which should
         get rejected by the spam checker.
         """
 
-        self.helper.upload_media(
-            self.upload_resource, EVIL_DATA, tok=self.tok, expect_code=400
-        )
+        self.helper.upload_media(EVIL_DATA, tok=self.tok, expect_code=400)
 
         self.helper.upload_media(
-            self.upload_resource,
             EVIL_DATA_EXPERIMENT,
             tok=self.tok,
             expect_code=400,