diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/media/test_media_storage.py | 62 | ||||
-rw-r--r-- | tests/replication/test_multi_media_repo.py | 2 |
2 files changed, 59 insertions, 5 deletions
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index f262304c3d..f981d1c0d8 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -27,10 +27,11 @@ from typing_extensions import Literal from twisted.internet import defer from twisted.internet.defer import Deferred +from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource -from synapse.api.errors import Codes +from synapse.api.errors import Codes, HttpResponseException from synapse.events import EventBase from synapse.http.types import QueryParams from synapse.logging.context import make_deferred_yieldable @@ -247,6 +248,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): retry_on_dns_fail: bool = True, max_size: Optional[int] = None, ignore_backoff: bool = False, + follow_redirects: bool = False, ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]": """A mock for MatrixFederationHttpClient.get_file.""" @@ -257,10 +259,15 @@ class MediaRepoTests(unittest.HomeserverTestCase): output_stream.write(data) return response + def write_err(f: Failure) -> Failure: + f.trap(HttpResponseException) + output_stream.write(f.value.response) + return f + d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred() self.fetches.append((d, destination, path, args)) # Note that this callback changes the value held by d. - d_after_callback = d.addCallback(write_to) + d_after_callback = d.addCallbacks(write_to, write_err) return make_deferred_yieldable(d_after_callback) # Mock out the homeserver's MatrixFederationHttpClient @@ -316,10 +323,11 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.assertEqual(len(self.fetches), 1) self.assertEqual(self.fetches[0][1], "example.com") self.assertEqual( - self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id + self.fetches[0][2], "/_matrix/media/v3/download/" + self.media_id ) self.assertEqual( - self.fetches[0][3], {"allow_remote": "false", "timeout_ms": "20000"} + self.fetches[0][3], + {"allow_remote": "false", "timeout_ms": "20000", "allow_redirect": "true"}, ) headers = { @@ -671,6 +679,52 @@ class MediaRepoTests(unittest.HomeserverTestCase): [b"cross-origin"], ) + def test_unknown_v3_endpoint(self) -> None: + """ + If the v3 endpoint fails, try the r0 one. + """ + channel = self.make_request( + "GET", + f"/_matrix/media/v3/download/{self.media_id}", + shorthand=False, + await_result=False, + ) + self.pump() + + # We've made one fetch, to example.com, using the media URL, and asking + # the other server not to do a remote fetch + self.assertEqual(len(self.fetches), 1) + self.assertEqual(self.fetches[0][1], "example.com") + self.assertEqual( + self.fetches[0][2], "/_matrix/media/v3/download/" + self.media_id + ) + + # The result which says the endpoint is unknown. + unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}' + self.fetches[0][0].errback( + HttpResponseException(404, "NOT FOUND", unknown_endpoint) + ) + + self.pump() + + # There should now be another request to the r0 URL. + self.assertEqual(len(self.fetches), 2) + self.assertEqual(self.fetches[1][1], "example.com") + self.assertEqual( + self.fetches[1][2], f"/_matrix/media/r0/download/{self.media_id}" + ) + + headers = { + b"Content-Length": [b"%d" % (len(self.test_image.data))], + } + + self.fetches[1][0].callback( + (self.test_image.data, (len(self.test_image.data), headers)) + ) + + self.pump() + self.assertEqual(channel.code, 200) + class TestSpamCheckerLegacy: """A spam checker module that rejects all media that includes the bytes diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 1e9994cc0b..9a7b675f54 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -133,7 +133,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(request.method, b"GET") self.assertEqual( request.path, - f"/_matrix/media/r0/download/{target}/{media_id}".encode(), + f"/_matrix/media/v3/download/{target}/{media_id}".encode(), ) self.assertEqual( request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] |