diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 6e78daa830..b230a6c361 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
import os
-from typing import Optional, Tuple
+from typing import Any, Optional, Tuple
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.protocol import Factory
@@ -29,7 +29,7 @@ from synapse.util import Clock
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
+from tests.server import FakeChannel, FakeTransport, make_request
from tests.test_utils import SMALL_PNG
logger = logging.getLogger(__name__)
@@ -56,6 +56,16 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
return conf
+ def make_worker_hs(
+ self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
+ ) -> HomeServer:
+ worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs)
+ # Force the media paths onto the replication resource.
+ worker_hs.get_media_repository_resource().register_servlets(
+ self._hs_to_site[worker_hs].resource, worker_hs
+ )
+ return worker_hs
+
def _get_media_req(
self, hs: HomeServer, target: str, media_id: str
) -> Tuple[FakeChannel, Request]:
@@ -68,12 +78,11 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
The channel for the *client* request and the *outbound* request for
the media which the caller should respond to.
"""
- resource = hs.get_media_repository_resource().children[b"download"]
channel = make_request(
self.reactor,
- FakeSite(resource, self.reactor),
+ self._hs_to_site[hs],
"GET",
- f"/{target}/{media_id}",
+ f"/_matrix/media/r0/download/{target}/{media_id}",
shorthand=False,
access_token=self.access_token,
await_result=False,
|