diff --git a/changelog.d/4409.misc b/changelog.d/4409.misc
new file mode 100644
index 0000000000..9cf2adfbb1
--- /dev/null
+++ b/changelog.d/4409.misc
@@ -0,0 +1 @@
+Remove redundant federation connection wrapping code
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index f86a0b624e..1c3b7ea28a 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -140,82 +140,15 @@ def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=
default_port = 8448
if port is None:
- return _WrappingEndpointFac(SRVClientEndpoint(
+ return SRVClientEndpoint(
reactor, "matrix", domain, protocol="tcp",
default_port=default_port, endpoint=transport_endpoint,
endpoint_kw_args=endpoint_kw_args
- ), reactor)
+ )
else:
- return _WrappingEndpointFac(transport_endpoint(
+ return transport_endpoint(
reactor, domain, port, **endpoint_kw_args
- ), reactor)
-
-
-class _WrappingEndpointFac(object):
- def __init__(self, endpoint_fac, reactor):
- self.endpoint_fac = endpoint_fac
- self.reactor = reactor
-
- @defer.inlineCallbacks
- def connect(self, protocolFactory):
- conn = yield self.endpoint_fac.connect(protocolFactory)
- conn = _WrappedConnection(conn, self.reactor)
- defer.returnValue(conn)
-
-
-class _WrappedConnection(object):
- """Wraps a connection and calls abort on it if it hasn't seen any action
- for 2.5-3 minutes.
- """
- __slots__ = ["conn", "last_request"]
-
- def __init__(self, conn, reactor):
- object.__setattr__(self, "conn", conn)
- object.__setattr__(self, "last_request", time.time())
- self._reactor = reactor
-
- def __getattr__(self, name):
- return getattr(self.conn, name)
-
- def __setattr__(self, name, value):
- setattr(self.conn, name, value)
-
- def _time_things_out_maybe(self):
- # We use a slightly shorter timeout here just in case the callLater is
- # triggered early. Paranoia ftw.
- # TODO: Cancel the previous callLater rather than comparing time.time()?
- if time.time() - self.last_request >= 2.5 * 60:
- self.abort()
- # Abort the underlying TLS connection. The abort() method calls
- # loseConnection() on the TLS connection which tries to
- # shutdown the connection cleanly. We call abortConnection()
- # since that will promptly close the TLS connection.
- #
- # In Twisted >18.4; the TLS connection will be None if it has closed
- # which will make abortConnection() throw. Check that the TLS connection
- # is not None before trying to close it.
- if self.transport.getHandle() is not None:
- self.transport.abortConnection()
-
- def request(self, request):
- self.last_request = time.time()
-
- # Time this connection out if we haven't send a request in the last
- # N minutes
- # TODO: Cancel the previous callLater?
- self._reactor.callLater(3 * 60, self._time_things_out_maybe)
-
- d = self.conn.request(request)
-
- def update_request_time(res):
- self.last_request = time.time()
- # TODO: Cancel the previous callLater?
- self._reactor.callLater(3 * 60, self._time_things_out_maybe)
- return res
-
- d.addCallback(update_request_time)
-
- return d
+ )
class SRVClientEndpoint(object):
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index ea2fc64b99..250bb1ef91 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -321,23 +321,23 @@ class MatrixFederationHttpClient(object):
url_str,
)
- # we don't want all the fancy cookie and redirect handling that
- # treq.request gives: just use the raw Agent.
- request_deferred = self.agent.request(
- method_bytes,
- url_bytes,
- headers=Headers(headers_dict),
- bodyProducer=producer,
- )
-
- request_deferred = timeout_deferred(
- request_deferred,
- timeout=_sec_timeout,
- reactor=self.hs.get_reactor(),
- )
-
try:
with Measure(self.clock, "outbound_request"):
+ # we don't want all the fancy cookie and redirect handling
+ # that treq.request gives: just use the raw Agent.
+ request_deferred = self.agent.request(
+ method_bytes,
+ url_bytes,
+ headers=Headers(headers_dict),
+ bodyProducer=producer,
+ )
+
+ request_deferred = timeout_deferred(
+ request_deferred,
+ timeout=_sec_timeout,
+ reactor=self.hs.get_reactor(),
+ )
+
response = yield make_deferred_yieldable(
request_deferred,
)
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index b2e38276d8..8426eee400 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -17,6 +17,7 @@ from mock import Mock
from twisted.internet.defer import TimeoutError
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
+from twisted.test.proto_helpers import StringTransport
from twisted.web.client import ResponseNeverReceived
from twisted.web.http import HTTPChannel
@@ -44,7 +45,7 @@ class FederationClientTests(HomeserverTestCase):
def test_dns_error(self):
"""
- If the DNS raising returns an error, it will bubble up.
+ If the DNS lookup returns an error, it will bubble up.
"""
d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
self.pump()
@@ -63,7 +64,7 @@ class FederationClientTests(HomeserverTestCase):
self.pump()
# Nothing happened yet
- self.assertFalse(d.called)
+ self.assertNoResult(d)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
@@ -72,7 +73,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertEqual(clients[0][1], 8008)
# Deferred is still without a result
- self.assertFalse(d.called)
+ self.assertNoResult(d)
# Push by enough to time it out
self.reactor.advance(10.5)
@@ -94,7 +95,7 @@ class FederationClientTests(HomeserverTestCase):
self.pump()
# Nothing happened yet
- self.assertFalse(d.called)
+ self.assertNoResult(d)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
@@ -107,7 +108,7 @@ class FederationClientTests(HomeserverTestCase):
client.makeConnection(conn)
# Deferred is still without a result
- self.assertFalse(d.called)
+ self.assertNoResult(d)
# Push by enough to time it out
self.reactor.advance(10.5)
@@ -135,7 +136,7 @@ class FederationClientTests(HomeserverTestCase):
client.makeConnection(conn)
# Deferred does not have a result
- self.assertFalse(d.called)
+ self.assertNoResult(d)
# Send it the HTTP response
client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n")
@@ -159,7 +160,7 @@ class FederationClientTests(HomeserverTestCase):
client.makeConnection(conn)
# Deferred does not have a result
- self.assertFalse(d.called)
+ self.assertNoResult(d)
# Send it the HTTP response
client.dataReceived(
@@ -195,3 +196,42 @@ class FederationClientTests(HomeserverTestCase):
request = server.requests[0]
content = request.content.read()
self.assertEqual(content, b'{"a":"b"}')
+
+ def test_closes_connection(self):
+ """Check that the client closes unused HTTP connections"""
+ d = self.cl.get_json("testserv:8008", "foo/bar")
+
+ self.pump()
+
+ # there should have been a call to connectTCP
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (_host, _port, factory, _timeout, _bindAddress) = clients[0]
+
+ # complete the connection and wire it up to a fake transport
+ client = factory.buildProtocol(None)
+ conn = StringTransport()
+ client.makeConnection(conn)
+
+ # that should have made it send the request to the connection
+ self.assertRegex(conn.value(), b"^GET /foo/bar")
+
+ # Send the HTTP response
+ client.dataReceived(
+ b"HTTP/1.1 200 OK\r\n"
+ b"Content-Type: application/json\r\n"
+ b"Content-Length: 2\r\n"
+ b"\r\n"
+ b"{}"
+ )
+
+ # We should get a successful response
+ r = self.successResultOf(d)
+ self.assertEqual(r, {})
+
+ self.assertFalse(conn.disconnecting)
+
+ # wait for a while
+ self.pump(120)
+
+ self.assertTrue(conn.disconnecting)
|