summary refs log tree commit diff
path: root/tests/replication/test_multi_media_repo.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/test_multi_media_repo.py')
-rw-r--r--tests/replication/test_multi_media_repo.py52
1 files changed, 13 insertions, 39 deletions
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index b230a6c361..1e9994cc0b 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -15,9 +15,7 @@ import logging
 import os
 from typing import Any, Optional, Tuple
 
-from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
 from twisted.internet.protocol import Factory
-from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
 from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.http import HTTPChannel
 from twisted.web.server import Request
@@ -27,7 +25,11 @@ from synapse.rest.client import login
 from synapse.server import HomeServer
 from synapse.util import Clock
 
-from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
+from tests.http import (
+    TestServerTLSConnectionFactory,
+    get_test_ca_cert_file,
+    wrap_server_factory_for_tls,
+)
 from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.server import FakeChannel, FakeTransport, make_request
 from tests.test_utils import SMALL_PNG
@@ -94,7 +96,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
 
         # build the test server
-        server_tls_protocol = _build_test_server(get_connection_factory())
+        server_factory = Factory.forProtocol(HTTPChannel)
+        # Request.finish expects the factory to have a 'log' method.
+        server_factory.log = _log_request
+
+        server_tls_protocol = wrap_server_factory_for_tls(
+            server_factory, self.reactor, sanlist=[b"DNS:example.com"]
+        ).buildProtocol(None)
 
         # now, tell the client protocol factory to build the client protocol (it will be a
         # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
@@ -114,7 +122,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         # fish the test server back out of the server-side TLS protocol.
-        http_server: HTTPChannel = server_tls_protocol.wrappedProtocol  # type: ignore[assignment]
+        http_server: HTTPChannel = server_tls_protocol.wrappedProtocol
 
         # give the reactor a pump to get the TLS juices flowing.
         self.reactor.pump((0.1,))
@@ -240,40 +248,6 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         return sum(len(files) for _, _, files in os.walk(path))
 
 
-def get_connection_factory() -> TestServerTLSConnectionFactory:
-    # this needs to happen once, but not until we are ready to run the first test
-    global test_server_connection_factory
-    if test_server_connection_factory is None:
-        test_server_connection_factory = TestServerTLSConnectionFactory(
-            sanlist=[b"DNS:example.com"]
-        )
-    return test_server_connection_factory
-
-
-def _build_test_server(
-    connection_creator: IOpenSSLServerConnectionCreator,
-) -> TLSMemoryBIOProtocol:
-    """Construct a test server
-
-    This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
-
-    Args:
-        connection_creator: thing to build SSL connections
-
-    Returns:
-        TLSMemoryBIOProtocol
-    """
-    server_factory = Factory.forProtocol(HTTPChannel)
-    # Request.finish expects the factory to have a 'log' method.
-    server_factory.log = _log_request
-
-    server_tls_factory = TLSMemoryBIOFactory(
-        connection_creator, isClient=False, wrappedFactory=server_factory
-    )
-
-    return server_tls_factory.buildProtocol(None)
-
-
 def _log_request(request: Request) -> None:
     """Implements Factory.log, which is expected by Request.finish"""
     logger.info("Completed request %s", request)