diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index f4f3e56777..5f897d49cf 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -120,12 +120,13 @@ class _TestImage:
extension = attr.ib(type=bytes)
expected_cropped = attr.ib(type=Optional[bytes])
expected_scaled = attr.ib(type=Optional[bytes])
+ expected_found = attr.ib(default=True, type=bool)
@parameterized_class(
("test_image",),
[
- # smol png
+ # smoll png
(
_TestImage(
unhexlify(
@@ -161,6 +162,8 @@ class _TestImage:
None,
),
),
+ # an empty file
+ (_TestImage(b"", b"image/gif", b".gif", None, None, False,),),
],
)
class MediaRepoTests(unittest.HomeserverTestCase):
@@ -303,12 +306,16 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
def test_thumbnail_crop(self):
- self._test_thumbnail("crop", self.test_image.expected_cropped)
+ self._test_thumbnail(
+ "crop", self.test_image.expected_cropped, self.test_image.expected_found
+ )
def test_thumbnail_scale(self):
- self._test_thumbnail("scale", self.test_image.expected_scaled)
+ self._test_thumbnail(
+ "scale", self.test_image.expected_scaled, self.test_image.expected_found
+ )
- def _test_thumbnail(self, method, expected_body):
+ def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method
request, channel = self.make_request(
"GET", self.media_id + params, shorthand=False
@@ -325,11 +332,23 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
self.pump()
- self.assertEqual(channel.code, 200)
- if expected_body is not None:
+ if expected_found:
+ self.assertEqual(channel.code, 200)
+ if expected_body is not None:
+ self.assertEqual(
+ channel.result["body"], expected_body, channel.result["body"]
+ )
+ else:
+ # ensure that the result is at least some valid image
+ Image.open(BytesIO(channel.result["body"]))
+ else:
+ # A 404 with a JSON body.
+ self.assertEqual(channel.code, 404)
self.assertEqual(
- channel.result["body"], expected_body, channel.result["body"]
+ channel.json_body,
+ {
+ "errcode": "M_NOT_FOUND",
+ "error": "Not found [b'example.com', b'12345?width=32&height=32&method=%s']"
+ % method,
+ },
)
- else:
- # ensure that the result is at least some valid image
- Image.open(BytesIO(channel.result["body"]))
|