summary refs log tree commit diff
path: root/tests/handlers/test_send_email.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_send_email.py')
-rw-r--r--tests/handlers/test_send_email.py69
1 files changed, 59 insertions, 10 deletions
diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py
index 8b6e4a40b6..a066745d70 100644
--- a/tests/handlers/test_send_email.py
+++ b/tests/handlers/test_send_email.py
@@ -13,19 +13,40 @@
 # limitations under the License.
 
 
-from typing import Callable, List, Tuple
+from typing import Callable, List, Tuple, Type, Union
+from unittest.mock import patch
 
 from zope.interface import implementer
 
 from twisted.internet import defer
-from twisted.internet.address import IPv4Address
+from twisted.internet._sslverify import ClientTLSOptions
+from twisted.internet.address import IPv4Address, IPv6Address
 from twisted.internet.defer import ensureDeferred
+from twisted.internet.interfaces import IProtocolFactory
+from twisted.internet.ssl import ContextFactory
 from twisted.mail import interfaces, smtp
 
 from tests.server import FakeTransport
 from tests.unittest import HomeserverTestCase, override_config
 
 
+def TestingESMTPTLSClientFactory(
+    contextFactory: ContextFactory,
+    _connectWrapped: bool,
+    wrappedProtocol: IProtocolFactory,
+) -> IProtocolFactory:
+    """We use this to pass through in testing without using TLS, but
+    saving the context information to check that it would have happened.
+
+    Note that this is what the MemoryReactor does on connectSSL.
+    It only saves the contextFactory, but starts the connection with the
+    underlying Factory.
+    See: L{twisted.internet.testing.MemoryReactor.connectSSL}"""
+
+    wrappedProtocol._testingContextFactory = contextFactory  # type: ignore[attr-defined]
+    return wrappedProtocol
+
+
 @implementer(interfaces.IMessageDelivery)
 class _DummyMessageDelivery:
     def __init__(self) -> None:
@@ -75,7 +96,13 @@ class _DummyMessage:
         pass
 
 
-class SendEmailHandlerTestCase(HomeserverTestCase):
+class SendEmailHandlerTestCaseIPv4(HomeserverTestCase):
+    ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address
+
+    def setUp(self) -> None:
+        super().setUp()
+        self.reactor.lookups["localhost"] = "127.0.0.1"
+
     def test_send_email(self) -> None:
         """Happy-path test that we can send email to a non-TLS server."""
         h = self.hs.get_send_email_handler()
@@ -89,7 +116,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
         (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[
             0
         ]
-        self.assertEqual(host, "localhost")
+        self.assertEqual(host, self.reactor.lookups["localhost"])
         self.assertEqual(port, 25)
 
         # wire it up to an SMTP server
@@ -105,7 +132,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
             FakeTransport(
                 client_protocol,
                 self.reactor,
-                peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
+                peer_address=self.ip_class(
+                    "TCP", self.reactor.lookups["localhost"], 1234
+                ),
             )
         )
 
@@ -118,6 +147,10 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
         self.assertEqual(str(user), "foo@bar.com")
         self.assertIn(b"Subject: test subject", msg)
 
+    @patch(
+        "synapse.handlers.send_email.TLSMemoryBIOFactory",
+        TestingESMTPTLSClientFactory,
+    )
     @override_config(
         {
             "email": {
@@ -135,17 +168,23 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
             )
         )
         # there should be an attempt to connect to localhost:465
-        self.assertEqual(len(self.reactor.sslClients), 1)
+        self.assertEqual(len(self.reactor.tcpClients), 1)
         (
             host,
             port,
             client_factory,
-            contextFactory,
             _timeout,
             _bindAddress,
-        ) = self.reactor.sslClients[0]
-        self.assertEqual(host, "localhost")
+        ) = self.reactor.tcpClients[0]
+        self.assertEqual(host, self.reactor.lookups["localhost"])
         self.assertEqual(port, 465)
+        # We need to make sure that TLS is happenning
+        self.assertIsInstance(
+            client_factory._wrappedFactory._testingContextFactory,
+            ClientTLSOptions,
+        )
+        # And since we use endpoints, they go through reactor.connectTCP
+        # which works differently to connectSSL on the testing reactor
 
         # wire it up to an SMTP server
         message_delivery = _DummyMessageDelivery()
@@ -160,7 +199,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
             FakeTransport(
                 client_protocol,
                 self.reactor,
-                peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
+                peer_address=self.ip_class(
+                    "TCP", self.reactor.lookups["localhost"], 1234
+                ),
             )
         )
 
@@ -172,3 +213,11 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
         user, msg = message_delivery.messages.pop()
         self.assertEqual(str(user), "foo@bar.com")
         self.assertIn(b"Subject: test subject", msg)
+
+
+class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4):
+    ip_class = IPv6Address
+
+    def setUp(self) -> None:
+        super().setUp()
+        self.reactor.lookups["localhost"] = "::1"