diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py
index 8b6e4a40b6..a066745d70 100644
--- a/tests/handlers/test_send_email.py
+++ b/tests/handlers/test_send_email.py
@@ -13,19 +13,40 @@
# limitations under the License.
-from typing import Callable, List, Tuple
+from typing import Callable, List, Tuple, Type, Union
+from unittest.mock import patch
from zope.interface import implementer
from twisted.internet import defer
-from twisted.internet.address import IPv4Address
+from twisted.internet._sslverify import ClientTLSOptions
+from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.defer import ensureDeferred
+from twisted.internet.interfaces import IProtocolFactory
+from twisted.internet.ssl import ContextFactory
from twisted.mail import interfaces, smtp
from tests.server import FakeTransport
from tests.unittest import HomeserverTestCase, override_config
+def TestingESMTPTLSClientFactory(
+ contextFactory: ContextFactory,
+ _connectWrapped: bool,
+ wrappedProtocol: IProtocolFactory,
+) -> IProtocolFactory:
+ """We use this to pass through in testing without using TLS, but
+ saving the context information to check that it would have happened.
+
+ Note that this is what the MemoryReactor does on connectSSL.
+ It only saves the contextFactory, but starts the connection with the
+ underlying Factory.
+ See: L{twisted.internet.testing.MemoryReactor.connectSSL}"""
+
+ wrappedProtocol._testingContextFactory = contextFactory # type: ignore[attr-defined]
+ return wrappedProtocol
+
+
@implementer(interfaces.IMessageDelivery)
class _DummyMessageDelivery:
def __init__(self) -> None:
@@ -75,7 +96,13 @@ class _DummyMessage:
pass
-class SendEmailHandlerTestCase(HomeserverTestCase):
+class SendEmailHandlerTestCaseIPv4(HomeserverTestCase):
+ ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address
+
+ def setUp(self) -> None:
+ super().setUp()
+ self.reactor.lookups["localhost"] = "127.0.0.1"
+
def test_send_email(self) -> None:
"""Happy-path test that we can send email to a non-TLS server."""
h = self.hs.get_send_email_handler()
@@ -89,7 +116,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
(host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[
0
]
- self.assertEqual(host, "localhost")
+ self.assertEqual(host, self.reactor.lookups["localhost"])
self.assertEqual(port, 25)
# wire it up to an SMTP server
@@ -105,7 +132,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
FakeTransport(
client_protocol,
self.reactor,
- peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
+ peer_address=self.ip_class(
+ "TCP", self.reactor.lookups["localhost"], 1234
+ ),
)
)
@@ -118,6 +147,10 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
self.assertEqual(str(user), "foo@bar.com")
self.assertIn(b"Subject: test subject", msg)
+ @patch(
+ "synapse.handlers.send_email.TLSMemoryBIOFactory",
+ TestingESMTPTLSClientFactory,
+ )
@override_config(
{
"email": {
@@ -135,17 +168,23 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
)
)
# there should be an attempt to connect to localhost:465
- self.assertEqual(len(self.reactor.sslClients), 1)
+ self.assertEqual(len(self.reactor.tcpClients), 1)
(
host,
port,
client_factory,
- contextFactory,
_timeout,
_bindAddress,
- ) = self.reactor.sslClients[0]
- self.assertEqual(host, "localhost")
+ ) = self.reactor.tcpClients[0]
+ self.assertEqual(host, self.reactor.lookups["localhost"])
self.assertEqual(port, 465)
+ # We need to make sure that TLS is happenning
+ self.assertIsInstance(
+ client_factory._wrappedFactory._testingContextFactory,
+ ClientTLSOptions,
+ )
+ # And since we use endpoints, they go through reactor.connectTCP
+ # which works differently to connectSSL on the testing reactor
# wire it up to an SMTP server
message_delivery = _DummyMessageDelivery()
@@ -160,7 +199,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
FakeTransport(
client_protocol,
self.reactor,
- peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
+ peer_address=self.ip_class(
+ "TCP", self.reactor.lookups["localhost"], 1234
+ ),
)
)
@@ -172,3 +213,11 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
user, msg = message_delivery.messages.pop()
self.assertEqual(str(user), "foo@bar.com")
self.assertIn(b"Subject: test subject", msg)
+
+
+class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4):
+ ip_class = IPv6Address
+
+ def setUp(self) -> None:
+ super().setUp()
+ self.reactor.lookups["localhost"] = "::1"
|