summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/6764.misc1
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py2
-rw-r--r--tests/rest/media/v1/test_media_storage.py48
3 files changed, 46 insertions, 5 deletions
diff --git a/changelog.d/6764.misc b/changelog.d/6764.misc
new file mode 100644
index 0000000000..8edd767405
--- /dev/null
+++ b/changelog.d/6764.misc
@@ -0,0 +1 @@
+Fixup `synapse.rest` to pass mypy.
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index eee93b4313..d57480f761 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -331,7 +331,7 @@ class ThumbnailResource(DirectServeResource):
                             )
                         )
             if crop_info_list:
-                return min(crop_info_list2)[-1]
+                return min(crop_info_list)[-1]
             else:
                 return min(crop_info_list2)[-1]
         else:
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index bc662b61db..1809ceb839 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -149,6 +149,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
 
         self.media_repo = hs.get_media_repository_resource()
         self.download_resource = self.media_repo.children[b"download"]
+        self.thumbnail_resource = self.media_repo.children[b"thumbnail"]
 
         # smol png
         self.end_content = unhexlify(
@@ -157,11 +158,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
             b"0a2db40000000049454e44ae426082"
         )
 
+        self.media_id = "example.com/12345"
+
     def _req(self, content_disposition):
 
-        request, channel = self.make_request(
-            "GET", "example.com/12345", shorthand=False
-        )
+        request, channel = self.make_request("GET", self.media_id, shorthand=False)
         request.render(self.download_resource)
         self.pump()
 
@@ -170,7 +171,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         self.assertEqual(len(self.fetches), 1)
         self.assertEqual(self.fetches[0][1], "example.com")
         self.assertEqual(
-            self.fetches[0][2], "/_matrix/media/v1/download/example.com/12345"
+            self.fetches[0][2], "/_matrix/media/v1/download/" + self.media_id
         )
         self.assertEqual(self.fetches[0][3], {"allow_remote": "false"})
 
@@ -229,3 +230,42 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         headers = channel.headers
         self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"])
         self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
+
+    def test_thumbnail_crop(self):
+        expected_body = unhexlify(
+            b"89504e470d0a1a0a0000000d4948445200000020000000200806"
+            b"000000737a7af40000001a49444154789cedc101010000008220"
+            b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
+            b"44ae426082"
+        )
+
+        self._test_thumbnail("crop", expected_body)
+
+    def test_thumbnail_scale(self):
+        expected_body = unhexlify(
+            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+            b"0000001f15c4890000000d49444154789c636060606000000005"
+            b"0001a5f645400000000049454e44ae426082"
+        )
+
+        self._test_thumbnail("scale", expected_body)
+
+    def _test_thumbnail(self, method, expected_body):
+        params = "?width=32&height=32&method=" + method
+        request, channel = self.make_request(
+            "GET", self.media_id + params, shorthand=False
+        )
+        request.render(self.thumbnail_resource)
+        self.pump()
+
+        headers = {
+            b"Content-Length": [b"%d" % (len(self.end_content))],
+            b"Content-Type": [b"image/png"],
+        }
+        self.fetches[0][0].callback(
+            (self.end_content, (len(self.end_content), headers))
+        )
+        self.pump()
+
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.result["body"], expected_body, channel.result["body"])