summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/http/__init__.py37
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py60
-rw-r--r--tests/http/test_proxyagent.py44
-rw-r--r--tests/replication/test_multi_media_repo.py52
-rw-r--r--tests/server.py12
5 files changed, 94 insertions, 111 deletions
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 528cdee34b..d5306e7ee0 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -15,14 +15,20 @@ import os.path
 import subprocess
 from typing import List
 
+from incremental import Version
 from zope.interface import implementer
 
+import twisted
 from OpenSSL import SSL
 from OpenSSL.SSL import Connection
 from twisted.internet.address import IPv4Address
-from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
+from twisted.internet.interfaces import (
+    IOpenSSLServerConnectionCreator,
+    IProtocolFactory,
+    IReactorTime,
+)
 from twisted.internet.ssl import Certificate, trustRootFromCertificates
-from twisted.protocols.tls import TLSMemoryBIOProtocol
+from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
 from twisted.web.client import BrowserLikePolicyForHTTPS  # noqa: F401
 from twisted.web.iweb import IPolicyForHTTPS  # noqa: F401
 
@@ -153,6 +159,33 @@ class TestServerTLSConnectionFactory:
         return Connection(ctx, None)
 
 
+def wrap_server_factory_for_tls(
+    factory: IProtocolFactory, clock: IReactorTime, sanlist: List[bytes]
+) -> TLSMemoryBIOFactory:
+    """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
+
+    The resultant factory will create a TLS server which presents a certificate
+    signed by our test CA, valid for the domains in `sanlist`
+
+    Args:
+        factory: protocol factory to wrap
+        sanlist: list of domains the cert should be valid for
+
+    Returns:
+        interfaces.IProtocolFactory
+    """
+    connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
+    # Twisted > 23.8.0 has a different API that accepts a clock.
+    if twisted.version <= Version("Twisted", 23, 8, 0):
+        return TLSMemoryBIOFactory(
+            connection_creator, isClient=False, wrappedFactory=factory
+        )
+    else:
+        return TLSMemoryBIOFactory(
+            connection_creator, isClient=False, wrappedFactory=factory, clock=clock  # type: ignore[call-arg]
+        )
+
+
 # A dummy address, useful for tests that use FakeTransport and don't care about where
 # packets are going to/coming from.
 dummy_address = IPv4Address("TCP", "127.0.0.1", 80)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 9f63fa6fa8..0f623ae50b 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -31,7 +31,7 @@ from twisted.internet.interfaces import (
     IProtocolFactory,
 )
 from twisted.internet.protocol import Factory, Protocol
-from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
+from twisted.protocols.tls import TLSMemoryBIOProtocol
 from twisted.web._newclient import ResponseNeverReceived
 from twisted.web.client import Agent
 from twisted.web.http import HTTPChannel, Request
@@ -57,11 +57,7 @@ from synapse.types import ISynapseReactor
 from synapse.util.caches.ttlcache import TTLCache
 
 from tests import unittest
-from tests.http import (
-    TestServerTLSConnectionFactory,
-    dummy_address,
-    get_test_ca_cert_file,
-)
+from tests.http import dummy_address, get_test_ca_cert_file, wrap_server_factory_for_tls
 from tests.server import FakeTransport, ThreadedMemoryReactorClock
 from tests.utils import checked_cast, default_config
 
@@ -125,7 +121,18 @@ class MatrixFederationAgentTests(unittest.TestCase):
         # build the test server
         server_factory = _get_test_protocol_factory()
         if ssl:
-            server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
+            server_factory = wrap_server_factory_for_tls(
+                server_factory,
+                self.reactor,
+                tls_sanlist
+                or [
+                    b"DNS:testserv",
+                    b"DNS:target-server",
+                    b"DNS:xn--bcher-kva.com",
+                    b"IP:1.2.3.4",
+                    b"IP:::1",
+                ],
+            )
 
         server_protocol = server_factory.buildProtocol(dummy_address)
         assert server_protocol is not None
@@ -435,8 +442,16 @@ class MatrixFederationAgentTests(unittest.TestCase):
         request.finish()
 
         # now we make another test server to act as the upstream HTTP server.
-        server_ssl_protocol = _wrap_server_factory_for_tls(
-            _get_test_protocol_factory()
+        server_ssl_protocol = wrap_server_factory_for_tls(
+            _get_test_protocol_factory(),
+            self.reactor,
+            sanlist=[
+                b"DNS:testserv",
+                b"DNS:target-server",
+                b"DNS:xn--bcher-kva.com",
+                b"IP:1.2.3.4",
+                b"IP:::1",
+            ],
         ).buildProtocol(dummy_address)
 
         # Tell the HTTP server to send outgoing traffic back via the proxy's transport.
@@ -1786,33 +1801,6 @@ def _check_logcontext(context: LoggingContextOrSentinel) -> None:
         raise AssertionError("Expected logcontext %s but was %s" % (context, current))
 
 
-def _wrap_server_factory_for_tls(
-    factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
-) -> TLSMemoryBIOFactory:
-    """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
-    The resultant factory will create a TLS server which presents a certificate
-    signed by our test CA, valid for the domains in `sanlist`
-    Args:
-        factory: protocol factory to wrap
-        sanlist: list of domains the cert should be valid for
-    Returns:
-        interfaces.IProtocolFactory
-    """
-    if sanlist is None:
-        sanlist = [
-            b"DNS:testserv",
-            b"DNS:target-server",
-            b"DNS:xn--bcher-kva.com",
-            b"IP:1.2.3.4",
-            b"IP:::1",
-        ]
-
-    connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
-    return TLSMemoryBIOFactory(
-        connection_creator, isClient=False, wrappedFactory=factory
-    )
-
-
 def _get_test_protocol_factory() -> IProtocolFactory:
     """Get a protocol Factory which will build an HTTPChannel
     Returns:
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 41dfd5dc17..1f117276cf 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -29,18 +29,14 @@ from twisted.internet.endpoints import (
 )
 from twisted.internet.interfaces import IProtocol, IProtocolFactory
 from twisted.internet.protocol import Factory, Protocol
-from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
+from twisted.protocols.tls import TLSMemoryBIOProtocol
 from twisted.web.http import HTTPChannel
 
 from synapse.http.client import BlocklistingReactorWrapper
 from synapse.http.connectproxyclient import BasicProxyCredentials
 from synapse.http.proxyagent import ProxyAgent, parse_proxy
 
-from tests.http import (
-    TestServerTLSConnectionFactory,
-    dummy_address,
-    get_test_https_policy,
-)
+from tests.http import dummy_address, get_test_https_policy, wrap_server_factory_for_tls
 from tests.server import FakeTransport, ThreadedMemoryReactorClock
 from tests.unittest import TestCase
 from tests.utils import checked_cast
@@ -272,7 +268,9 @@ class MatrixFederationAgentTests(TestCase):
             the server Protocol returned by server_factory
         """
         if ssl:
-            server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
+            server_factory = wrap_server_factory_for_tls(
+                server_factory, self.reactor, tls_sanlist or [b"DNS:test.com"]
+            )
 
         server_protocol = server_factory.buildProtocol(dummy_address)
         assert server_protocol is not None
@@ -639,8 +637,8 @@ class MatrixFederationAgentTests(TestCase):
         request.finish()
 
         # now we make another test server to act as the upstream HTTP server.
-        server_ssl_protocol = _wrap_server_factory_for_tls(
-            _get_test_protocol_factory()
+        server_ssl_protocol = wrap_server_factory_for_tls(
+            _get_test_protocol_factory(), self.reactor, sanlist=[b"DNS:test.com"]
         ).buildProtocol(dummy_address)
 
         # Tell the HTTP server to send outgoing traffic back via the proxy's transport.
@@ -806,7 +804,9 @@ class MatrixFederationAgentTests(TestCase):
         request.finish()
 
         # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
-        ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
+        ssl_factory = wrap_server_factory_for_tls(
+            _get_test_protocol_factory(), self.reactor, sanlist=[b"DNS:test.com"]
+        )
         ssl_protocol = ssl_factory.buildProtocol(dummy_address)
         assert isinstance(ssl_protocol, TLSMemoryBIOProtocol)
         http_server = ssl_protocol.wrappedProtocol
@@ -870,30 +870,6 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888)
 
 
-def _wrap_server_factory_for_tls(
-    factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
-) -> TLSMemoryBIOFactory:
-    """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
-
-    The resultant factory will create a TLS server which presents a certificate
-    signed by our test CA, valid for the domains in `sanlist`
-
-    Args:
-        factory: protocol factory to wrap
-        sanlist: list of domains the cert should be valid for
-
-    Returns:
-        interfaces.IProtocolFactory
-    """
-    if sanlist is None:
-        sanlist = [b"DNS:test.com"]
-
-    connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
-    return TLSMemoryBIOFactory(
-        connection_creator, isClient=False, wrappedFactory=factory
-    )
-
-
 def _get_test_protocol_factory() -> IProtocolFactory:
     """Get a protocol Factory which will build an HTTPChannel
 
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index b230a6c361..1e9994cc0b 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -15,9 +15,7 @@ import logging
 import os
 from typing import Any, Optional, Tuple
 
-from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
 from twisted.internet.protocol import Factory
-from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
 from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.http import HTTPChannel
 from twisted.web.server import Request
@@ -27,7 +25,11 @@ from synapse.rest.client import login
 from synapse.server import HomeServer
 from synapse.util import Clock
 
-from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
+from tests.http import (
+    TestServerTLSConnectionFactory,
+    get_test_ca_cert_file,
+    wrap_server_factory_for_tls,
+)
 from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.server import FakeChannel, FakeTransport, make_request
 from tests.test_utils import SMALL_PNG
@@ -94,7 +96,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
 
         # build the test server
-        server_tls_protocol = _build_test_server(get_connection_factory())
+        server_factory = Factory.forProtocol(HTTPChannel)
+        # Request.finish expects the factory to have a 'log' method.
+        server_factory.log = _log_request
+
+        server_tls_protocol = wrap_server_factory_for_tls(
+            server_factory, self.reactor, sanlist=[b"DNS:example.com"]
+        ).buildProtocol(None)
 
         # now, tell the client protocol factory to build the client protocol (it will be a
         # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
@@ -114,7 +122,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         # fish the test server back out of the server-side TLS protocol.
-        http_server: HTTPChannel = server_tls_protocol.wrappedProtocol  # type: ignore[assignment]
+        http_server: HTTPChannel = server_tls_protocol.wrappedProtocol
 
         # give the reactor a pump to get the TLS juices flowing.
         self.reactor.pump((0.1,))
@@ -240,40 +248,6 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         return sum(len(files) for _, _, files in os.walk(path))
 
 
-def get_connection_factory() -> TestServerTLSConnectionFactory:
-    # this needs to happen once, but not until we are ready to run the first test
-    global test_server_connection_factory
-    if test_server_connection_factory is None:
-        test_server_connection_factory = TestServerTLSConnectionFactory(
-            sanlist=[b"DNS:example.com"]
-        )
-    return test_server_connection_factory
-
-
-def _build_test_server(
-    connection_creator: IOpenSSLServerConnectionCreator,
-) -> TLSMemoryBIOProtocol:
-    """Construct a test server
-
-    This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
-
-    Args:
-        connection_creator: thing to build SSL connections
-
-    Returns:
-        TLSMemoryBIOProtocol
-    """
-    server_factory = Factory.forProtocol(HTTPChannel)
-    # Request.finish expects the factory to have a 'log' method.
-    server_factory.log = _log_request
-
-    server_tls_factory = TLSMemoryBIOFactory(
-        connection_creator, isClient=False, wrappedFactory=server_factory
-    )
-
-    return server_tls_factory.buildProtocol(None)
-
-
 def _log_request(request: Request) -> None:
     """Implements Factory.log, which is expected by Request.finish"""
     logger.info("Completed request %s", request)
diff --git a/tests/server.py b/tests/server.py
index 08633fe640..cfb0fb823b 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -43,9 +43,11 @@ from typing import (
 from unittest.mock import Mock
 
 import attr
+from incremental import Version
 from typing_extensions import ParamSpec
 from zope.interface import implementer
 
+import twisted
 from twisted.internet import address, tcp, threads, udp
 from twisted.internet._resolver import SimpleResolverComplexifier
 from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
@@ -474,6 +476,16 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
                     return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
                 return succeed(lookups[name])
 
+        # In order for the TLS protocol tests to work, modify _get_default_clock
+        # on newer Twisted versions to use the test reactor's clock.
+        #
+        # This is *super* dirty since it is never undone and relies on the next
+        # test to overwrite it.
+        if twisted.version > Version("Twisted", 23, 8, 0):
+            from twisted.protocols import tls
+
+            tls._get_default_clock = lambda: self  # type: ignore[attr-defined]
+
         self.nameResolver = SimpleResolverComplexifier(FakeResolver())
         super().__init__()