summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_send_email.py69
-rw-r--r--tests/server.py54
-rw-r--r--tests/unittest.py2
3 files changed, 113 insertions, 12 deletions
diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py
index 8b6e4a40b6..a066745d70 100644
--- a/tests/handlers/test_send_email.py
+++ b/tests/handlers/test_send_email.py
@@ -13,19 +13,40 @@
 # limitations under the License.
 
 
-from typing import Callable, List, Tuple
+from typing import Callable, List, Tuple, Type, Union
+from unittest.mock import patch
 
 from zope.interface import implementer
 
 from twisted.internet import defer
-from twisted.internet.address import IPv4Address
+from twisted.internet._sslverify import ClientTLSOptions
+from twisted.internet.address import IPv4Address, IPv6Address
 from twisted.internet.defer import ensureDeferred
+from twisted.internet.interfaces import IProtocolFactory
+from twisted.internet.ssl import ContextFactory
 from twisted.mail import interfaces, smtp
 
 from tests.server import FakeTransport
 from tests.unittest import HomeserverTestCase, override_config
 
 
+def TestingESMTPTLSClientFactory(
+    contextFactory: ContextFactory,
+    _connectWrapped: bool,
+    wrappedProtocol: IProtocolFactory,
+) -> IProtocolFactory:
+    """We use this to pass through in testing without using TLS, but
+    saving the context information to check that it would have happened.
+
+    Note that this is what the MemoryReactor does on connectSSL.
+    It only saves the contextFactory, but starts the connection with the
+    underlying Factory.
+    See: L{twisted.internet.testing.MemoryReactor.connectSSL}"""
+
+    wrappedProtocol._testingContextFactory = contextFactory  # type: ignore[attr-defined]
+    return wrappedProtocol
+
+
 @implementer(interfaces.IMessageDelivery)
 class _DummyMessageDelivery:
     def __init__(self) -> None:
@@ -75,7 +96,13 @@ class _DummyMessage:
         pass
 
 
-class SendEmailHandlerTestCase(HomeserverTestCase):
+class SendEmailHandlerTestCaseIPv4(HomeserverTestCase):
+    ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address
+
+    def setUp(self) -> None:
+        super().setUp()
+        self.reactor.lookups["localhost"] = "127.0.0.1"
+
     def test_send_email(self) -> None:
         """Happy-path test that we can send email to a non-TLS server."""
         h = self.hs.get_send_email_handler()
@@ -89,7 +116,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
         (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[
             0
         ]
-        self.assertEqual(host, "localhost")
+        self.assertEqual(host, self.reactor.lookups["localhost"])
         self.assertEqual(port, 25)
 
         # wire it up to an SMTP server
@@ -105,7 +132,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
             FakeTransport(
                 client_protocol,
                 self.reactor,
-                peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
+                peer_address=self.ip_class(
+                    "TCP", self.reactor.lookups["localhost"], 1234
+                ),
             )
         )
 
@@ -118,6 +147,10 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
         self.assertEqual(str(user), "foo@bar.com")
         self.assertIn(b"Subject: test subject", msg)
 
+    @patch(
+        "synapse.handlers.send_email.TLSMemoryBIOFactory",
+        TestingESMTPTLSClientFactory,
+    )
     @override_config(
         {
             "email": {
@@ -135,17 +168,23 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
             )
         )
         # there should be an attempt to connect to localhost:465
-        self.assertEqual(len(self.reactor.sslClients), 1)
+        self.assertEqual(len(self.reactor.tcpClients), 1)
         (
             host,
             port,
             client_factory,
-            contextFactory,
             _timeout,
             _bindAddress,
-        ) = self.reactor.sslClients[0]
-        self.assertEqual(host, "localhost")
+        ) = self.reactor.tcpClients[0]
+        self.assertEqual(host, self.reactor.lookups["localhost"])
         self.assertEqual(port, 465)
+        # We need to make sure that TLS is happenning
+        self.assertIsInstance(
+            client_factory._wrappedFactory._testingContextFactory,
+            ClientTLSOptions,
+        )
+        # And since we use endpoints, they go through reactor.connectTCP
+        # which works differently to connectSSL on the testing reactor
 
         # wire it up to an SMTP server
         message_delivery = _DummyMessageDelivery()
@@ -160,7 +199,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
             FakeTransport(
                 client_protocol,
                 self.reactor,
-                peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
+                peer_address=self.ip_class(
+                    "TCP", self.reactor.lookups["localhost"], 1234
+                ),
             )
         )
 
@@ -172,3 +213,11 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
         user, msg = message_delivery.messages.pop()
         self.assertEqual(str(user), "foo@bar.com")
         self.assertIn(b"Subject: test subject", msg)
+
+
+class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4):
+    ip_class = IPv6Address
+
+    def setUp(self) -> None:
+        super().setUp()
+        self.reactor.lookups["localhost"] = "::1"
diff --git a/tests/server.py b/tests/server.py
index ff03d28864..659ccce838 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import hashlib
+import ipaddress
 import json
 import logging
 import os
@@ -45,7 +46,7 @@ import attr
 from typing_extensions import ParamSpec
 from zope.interface import implementer
 
-from twisted.internet import address, threads, udp
+from twisted.internet import address, tcp, threads, udp
 from twisted.internet._resolver import SimpleResolverComplexifier
 from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
 from twisted.internet.error import DNSLookupError
@@ -567,6 +568,8 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         conn = super().connectTCP(
             host, port, factory, timeout=timeout, bindAddress=None
         )
+        if self.lookups and host in self.lookups:
+            validate_connector(conn, self.lookups[host])
 
         callback = self._tcp_callbacks.get((host, port))
         if callback:
@@ -599,6 +602,55 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
             super().advance(0)
 
 
+def validate_connector(connector: tcp.Connector, expected_ip: str) -> None:
+    """Try to validate the obtained connector as it would happen when
+    synapse is running and the conection will be established.
+
+    This method will raise a useful exception when necessary, else it will
+    just do nothing.
+
+    This is in order to help catch quirks related to reactor.connectTCP,
+    since when called directly, the connector's destination will be of type
+    IPv4Address, with the hostname as the literal host that was given (which
+    could be an IPv6-only host or an IPv6 literal).
+
+    But when called from reactor.connectTCP *through* e.g. an Endpoint, the
+    connector's destination will contain the specific IP address with the
+    correct network stack class.
+
+    Note that testing code paths that use connectTCP directly should not be
+    affected by this check, unless they specifically add a test with a
+    matching reactor.lookups[HOSTNAME] = "IPv6Literal", where reactor is of
+    type ThreadedMemoryReactorClock.
+    For an example of implementing such tests, see test/handlers/send_email.py.
+    """
+    destination = connector.getDestination()
+
+    # We use address.IPv{4,6}Address to check what the reactor thinks it is
+    # is sending but check for validity with ipaddress.IPv{4,6}Address
+    # because they fail with IPs on the wrong network stack.
+    cls_mapping = {
+        address.IPv4Address: ipaddress.IPv4Address,
+        address.IPv6Address: ipaddress.IPv6Address,
+    }
+
+    cls = cls_mapping.get(destination.__class__)
+
+    if cls is not None:
+        try:
+            cls(expected_ip)
+        except Exception as exc:
+            raise ValueError(
+                "Invalid IP type and resolution for %s. Expected %s to be %s"
+                % (destination, expected_ip, cls.__name__)
+            ) from exc
+    else:
+        raise ValueError(
+            "Unknown address type %s for %s"
+            % (destination.__class__.__name__, destination)
+        )
+
+
 class ThreadPool:
     """
     Threadless thread pool.
diff --git a/tests/unittest.py b/tests/unittest.py
index b0721e060c..40672a4415 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -313,7 +313,7 @@ class HomeserverTestCase(TestCase):
         servlets: List of servlet registration function.
         user_id (str): The user ID to assume if auth is hijacked.
         hijack_auth: Whether to hijack auth to return the user specified
-        in user_id.
+           in user_id.
     """
 
     hijack_auth: ClassVar[bool] = True