diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index b230a6c361..1e9994cc0b 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -15,9 +15,7 @@ import logging
import os
from typing import Any, Optional, Tuple
-from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.protocol import Factory
-from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.http import HTTPChannel
from twisted.web.server import Request
@@ -27,7 +25,11 @@ from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.util import Clock
-from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
+from tests.http import (
+ TestServerTLSConnectionFactory,
+ get_test_ca_cert_file,
+ wrap_server_factory_for_tls,
+)
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, FakeTransport, make_request
from tests.test_utils import SMALL_PNG
@@ -94,7 +96,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
# build the test server
- server_tls_protocol = _build_test_server(get_connection_factory())
+ server_factory = Factory.forProtocol(HTTPChannel)
+ # Request.finish expects the factory to have a 'log' method.
+ server_factory.log = _log_request
+
+ server_tls_protocol = wrap_server_factory_for_tls(
+ server_factory, self.reactor, sanlist=[b"DNS:example.com"]
+ ).buildProtocol(None)
# now, tell the client protocol factory to build the client protocol (it will be a
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
@@ -114,7 +122,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
)
# fish the test server back out of the server-side TLS protocol.
- http_server: HTTPChannel = server_tls_protocol.wrappedProtocol # type: ignore[assignment]
+ http_server: HTTPChannel = server_tls_protocol.wrappedProtocol
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
@@ -240,40 +248,6 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return sum(len(files) for _, _, files in os.walk(path))
-def get_connection_factory() -> TestServerTLSConnectionFactory:
- # this needs to happen once, but not until we are ready to run the first test
- global test_server_connection_factory
- if test_server_connection_factory is None:
- test_server_connection_factory = TestServerTLSConnectionFactory(
- sanlist=[b"DNS:example.com"]
- )
- return test_server_connection_factory
-
-
-def _build_test_server(
- connection_creator: IOpenSSLServerConnectionCreator,
-) -> TLSMemoryBIOProtocol:
- """Construct a test server
-
- This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
-
- Args:
- connection_creator: thing to build SSL connections
-
- Returns:
- TLSMemoryBIOProtocol
- """
- server_factory = Factory.forProtocol(HTTPChannel)
- # Request.finish expects the factory to have a 'log' method.
- server_factory.log = _log_request
-
- server_tls_factory = TLSMemoryBIOFactory(
- connection_creator, isClient=False, wrappedFactory=server_factory
- )
-
- return server_tls_factory.buildProtocol(None)
-
-
def _log_request(request: Request) -> None:
"""Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request)
|