diff options
Diffstat (limited to '')
-rw-r--r-- | changelog.d/10713.bugfix | 1 | ||||
-rw-r--r-- | mypy.ini | 1 | ||||
-rw-r--r-- | synapse/handlers/send_email.py | 65 | ||||
-rw-r--r-- | tests/handlers/test_send_email.py | 112 | ||||
-rw-r--r-- | tests/server.py | 15 |
5 files changed, 173 insertions, 21 deletions
diff --git a/changelog.d/10713.bugfix b/changelog.d/10713.bugfix new file mode 100644 index 0000000000..e8caf3d23a --- /dev/null +++ b/changelog.d/10713.bugfix @@ -0,0 +1 @@ +Fix a regression introduced in Synapse 1.41 which broke email transmission on Systems using older versions of the Twisted library. diff --git a/mypy.ini b/mypy.ini index 745e6b78eb..f6de668edd 100644 --- a/mypy.ini +++ b/mypy.ini @@ -91,6 +91,7 @@ files = tests/test_utils, tests/handlers/test_password_providers.py, tests/handlers/test_room_summary.py, + tests/handlers/test_send_email.py, tests/handlers/test_sync.py, tests/rest/client/test_login.py, tests/rest/client/test_auth.py, diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index dda9659c11..a31fe3e3c7 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -19,9 +19,12 @@ from email.mime.text import MIMEText from io import BytesIO from typing import TYPE_CHECKING, Optional +from pkg_resources import parse_version + +import twisted from twisted.internet.defer import Deferred -from twisted.internet.interfaces import IReactorTCP -from twisted.mail.smtp import ESMTPSenderFactory +from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorTCP +from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory from synapse.logging.context import make_deferred_yieldable @@ -30,6 +33,19 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +_is_old_twisted = parse_version(twisted.__version__) < parse_version("21") + + +class _NoTLSESMTPSender(ESMTPSender): + """Extend ESMTPSender to disable TLS + + Unfortunately, before Twisted 21.2, ESMTPSender doesn't give an easy way to disable + TLS, so we override its internal method which it uses to generate a context factory. + """ + + def _getContextFactory(self) -> Optional[IOpenSSLContextFactory]: + return None + async def _sendmail( reactor: IReactorTCP, @@ -42,7 +58,7 @@ async def _sendmail( password: Optional[bytes] = None, require_auth: bool = False, require_tls: bool = False, - tls_hostname: Optional[str] = None, + enable_tls: bool = True, ) -> None: """A simple wrapper around ESMTPSenderFactory, to allow substitution in tests @@ -57,24 +73,37 @@ async def _sendmail( password: password to give when authenticating require_auth: if auth is not offered, fail the request require_tls: if TLS is not offered, fail the reqest - tls_hostname: TLS hostname to check for. None to disable TLS. + enable_tls: True to enable TLS. If this is False and require_tls is True, + the request will fail. """ msg = BytesIO(msg_bytes) - d: "Deferred[object]" = Deferred() - factory = ESMTPSenderFactory( - username, - password, - from_addr, - to_addr, - msg, - d, - heloFallback=True, - requireAuthentication=require_auth, - requireTransportSecurity=require_tls, - hostname=tls_hostname, - ) + def build_sender_factory(**kwargs) -> ESMTPSenderFactory: + return ESMTPSenderFactory( + username, + password, + from_addr, + to_addr, + msg, + d, + heloFallback=True, + requireAuthentication=require_auth, + requireTransportSecurity=require_tls, + **kwargs, + ) + + if _is_old_twisted: + # before twisted 21.2, we have to override the ESMTPSender protocol to disable + # TLS + factory = build_sender_factory() + + if not enable_tls: + factory.protocol = _NoTLSESMTPSender + else: + # for twisted 21.2 and later, there is a 'hostname' parameter which we should + # set to enable TLS. + factory = build_sender_factory(hostname=smtphost if enable_tls else None) # the IReactorTCP interface claims host has to be a bytes, which seems to be wrong reactor.connectTCP(smtphost, smtpport, factory, timeout=30, bindAddress=None) # type: ignore[arg-type] @@ -154,5 +183,5 @@ class SendEmailHandler: password=self._smtp_pass, require_auth=self._smtp_user is not None, require_tls=self._require_transport_security, - tls_hostname=self._smtp_host if self._enable_tls else None, + enable_tls=self._enable_tls, ) diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py new file mode 100644 index 0000000000..6f77b1237c --- /dev/null +++ b/tests/handlers/test_send_email.py @@ -0,0 +1,112 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Tuple + +from zope.interface import implementer + +from twisted.internet import defer +from twisted.internet.address import IPv4Address +from twisted.internet.defer import ensureDeferred +from twisted.mail import interfaces, smtp + +from tests.server import FakeTransport +from tests.unittest import HomeserverTestCase + + +@implementer(interfaces.IMessageDelivery) +class _DummyMessageDelivery: + def __init__(self): + # (recipient, message) tuples + self.messages: List[Tuple[smtp.Address, bytes]] = [] + + def receivedHeader(self, helo, origin, recipients): + return None + + def validateFrom(self, helo, origin): + return origin + + def record_message(self, recipient: smtp.Address, message: bytes): + self.messages.append((recipient, message)) + + def validateTo(self, user: smtp.User): + return lambda: _DummyMessage(self, user) + + +@implementer(interfaces.IMessageSMTP) +class _DummyMessage: + """IMessageSMTP implementation which saves the message delivered to it + to the _DummyMessageDelivery object. + """ + + def __init__(self, delivery: _DummyMessageDelivery, user: smtp.User): + self._delivery = delivery + self._user = user + self._buffer: List[bytes] = [] + + def lineReceived(self, line): + self._buffer.append(line) + + def eomReceived(self): + message = b"\n".join(self._buffer) + b"\n" + self._delivery.record_message(self._user.dest, message) + return defer.succeed(b"saved") + + def connectionLost(self): + pass + + +class SendEmailHandlerTestCase(HomeserverTestCase): + def test_send_email(self): + """Happy-path test that we can send email to a non-TLS server.""" + h = self.hs.get_send_email_handler() + d = ensureDeferred( + h.send_email( + "foo@bar.com", "test subject", "Tests", "HTML content", "Text content" + ) + ) + # there should be an attempt to connect to localhost:25 + self.assertEqual(len(self.reactor.tcpClients), 1) + (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[ + 0 + ] + self.assertEqual(host, "localhost") + self.assertEqual(port, 25) + + # wire it up to an SMTP server + message_delivery = _DummyMessageDelivery() + server_protocol = smtp.ESMTP() + server_protocol.delivery = message_delivery + # make sure that the server uses the test reactor to set timeouts + server_protocol.callLater = self.reactor.callLater # type: ignore[assignment] + + client_protocol = client_factory.buildProtocol(None) + client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor)) + server_protocol.makeConnection( + FakeTransport( + client_protocol, + self.reactor, + peer_address=IPv4Address("TCP", "127.0.0.1", 1234), + ) + ) + + # the message should now get delivered + self.get_success(d, by=0.1) + + # check it arrived + self.assertEqual(len(message_delivery.messages), 1) + user, msg = message_delivery.messages.pop() + self.assertEqual(str(user), "foo@bar.com") + self.assertIn(b"Subject: test subject", msg) diff --git a/tests/server.py b/tests/server.py index 6fddd3b305..b861c7b866 100644 --- a/tests/server.py +++ b/tests/server.py @@ -10,9 +10,10 @@ from zope.interface import implementer from twisted.internet import address, threads, udp from twisted.internet._resolver import SimpleResolverComplexifier -from twisted.internet.defer import Deferred, fail, succeed +from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import ( + IAddress, IHostnameResolver, IProtocol, IPullProducer, @@ -511,6 +512,9 @@ class FakeTransport: will get called back for connectionLost() notifications etc. """ + _peer_address: Optional[IAddress] = attr.ib(default=None) + """The value to be returend by getPeer""" + disconnecting = False disconnected = False connected = True @@ -519,7 +523,7 @@ class FakeTransport: autoflush = attr.ib(default=True) def getPeer(self): - return None + return self._peer_address def getHost(self): return None @@ -572,7 +576,12 @@ class FakeTransport: self.producerStreaming = streaming def _produce(): - d = self.producer.resumeProducing() + if not self.producer: + # we've been unregistered + return + # some implementations of IProducer (for example, FileSender) + # don't return a deferred. + d = maybeDeferred(self.producer.resumeProducing) d.addCallback(lambda x: self._reactor.callLater(0.1, _produce)) if not streaming: |