diff options
Diffstat (limited to 'tests/handlers/test_send_email.py')
-rw-r--r-- | tests/handlers/test_send_email.py | 69 |
1 files changed, 59 insertions, 10 deletions
diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py index 8b6e4a40b6..a066745d70 100644 --- a/tests/handlers/test_send_email.py +++ b/tests/handlers/test_send_email.py @@ -13,19 +13,40 @@ # limitations under the License. -from typing import Callable, List, Tuple +from typing import Callable, List, Tuple, Type, Union +from unittest.mock import patch from zope.interface import implementer from twisted.internet import defer -from twisted.internet.address import IPv4Address +from twisted.internet._sslverify import ClientTLSOptions +from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.defer import ensureDeferred +from twisted.internet.interfaces import IProtocolFactory +from twisted.internet.ssl import ContextFactory from twisted.mail import interfaces, smtp from tests.server import FakeTransport from tests.unittest import HomeserverTestCase, override_config +def TestingESMTPTLSClientFactory( + contextFactory: ContextFactory, + _connectWrapped: bool, + wrappedProtocol: IProtocolFactory, +) -> IProtocolFactory: + """We use this to pass through in testing without using TLS, but + saving the context information to check that it would have happened. + + Note that this is what the MemoryReactor does on connectSSL. + It only saves the contextFactory, but starts the connection with the + underlying Factory. + See: L{twisted.internet.testing.MemoryReactor.connectSSL}""" + + wrappedProtocol._testingContextFactory = contextFactory # type: ignore[attr-defined] + return wrappedProtocol + + @implementer(interfaces.IMessageDelivery) class _DummyMessageDelivery: def __init__(self) -> None: @@ -75,7 +96,13 @@ class _DummyMessage: pass -class SendEmailHandlerTestCase(HomeserverTestCase): +class SendEmailHandlerTestCaseIPv4(HomeserverTestCase): + ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address + + def setUp(self) -> None: + super().setUp() + self.reactor.lookups["localhost"] = "127.0.0.1" + def test_send_email(self) -> None: """Happy-path test that we can send email to a non-TLS server.""" h = self.hs.get_send_email_handler() @@ -89,7 +116,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase): (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[ 0 ] - self.assertEqual(host, "localhost") + self.assertEqual(host, self.reactor.lookups["localhost"]) self.assertEqual(port, 25) # wire it up to an SMTP server @@ -105,7 +132,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase): FakeTransport( client_protocol, self.reactor, - peer_address=IPv4Address("TCP", "127.0.0.1", 1234), + peer_address=self.ip_class( + "TCP", self.reactor.lookups["localhost"], 1234 + ), ) ) @@ -118,6 +147,10 @@ class SendEmailHandlerTestCase(HomeserverTestCase): self.assertEqual(str(user), "foo@bar.com") self.assertIn(b"Subject: test subject", msg) + @patch( + "synapse.handlers.send_email.TLSMemoryBIOFactory", + TestingESMTPTLSClientFactory, + ) @override_config( { "email": { @@ -135,17 +168,23 @@ class SendEmailHandlerTestCase(HomeserverTestCase): ) ) # there should be an attempt to connect to localhost:465 - self.assertEqual(len(self.reactor.sslClients), 1) + self.assertEqual(len(self.reactor.tcpClients), 1) ( host, port, client_factory, - contextFactory, _timeout, _bindAddress, - ) = self.reactor.sslClients[0] - self.assertEqual(host, "localhost") + ) = self.reactor.tcpClients[0] + self.assertEqual(host, self.reactor.lookups["localhost"]) self.assertEqual(port, 465) + # We need to make sure that TLS is happenning + self.assertIsInstance( + client_factory._wrappedFactory._testingContextFactory, + ClientTLSOptions, + ) + # And since we use endpoints, they go through reactor.connectTCP + # which works differently to connectSSL on the testing reactor # wire it up to an SMTP server message_delivery = _DummyMessageDelivery() @@ -160,7 +199,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase): FakeTransport( client_protocol, self.reactor, - peer_address=IPv4Address("TCP", "127.0.0.1", 1234), + peer_address=self.ip_class( + "TCP", self.reactor.lookups["localhost"], 1234 + ), ) ) @@ -172,3 +213,11 @@ class SendEmailHandlerTestCase(HomeserverTestCase): user, msg = message_delivery.messages.pop() self.assertEqual(str(user), "foo@bar.com") self.assertIn(b"Subject: test subject", msg) + + +class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4): + ip_class = IPv6Address + + def setUp(self) -> None: + super().setUp() + self.reactor.lookups["localhost"] = "::1" |