summary refs log tree commit diff
path: root/tests/http/test_fedclient.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/http/test_fedclient.py')
-rw-r--r--tests/http/test_fedclient.py169
1 files changed, 155 insertions, 14 deletions
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index f3cb1423f0..b03b37affe 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -15,41 +15,138 @@
 
 from mock import Mock
 
+from twisted.internet import defer
 from twisted.internet.defer import TimeoutError
 from twisted.internet.error import ConnectingCancelledError, DNSLookupError
+from twisted.test.proto_helpers import StringTransport
 from twisted.web.client import ResponseNeverReceived
 from twisted.web.http import HTTPChannel
 
+from synapse.api.errors import RequestSendFailed
 from synapse.http.matrixfederationclient import (
     MatrixFederationHttpClient,
     MatrixFederationRequest,
 )
+from synapse.util.logcontext import LoggingContext
 
 from tests.server import FakeTransport
 from tests.unittest import HomeserverTestCase
 
 
+def check_logcontext(context):
+    current = LoggingContext.current_context()
+    if current is not context:
+        raise AssertionError(
+            "Expected logcontext %s but was %s" % (context, current),
+        )
+
+
 class FederationClientTests(HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
-
         hs = self.setup_test_homeserver(reactor=reactor, clock=clock)
-        hs.tls_client_options_factory = None
         return hs
 
     def prepare(self, reactor, clock, homeserver):
-
-        self.cl = MatrixFederationHttpClient(self.hs)
+        self.cl = MatrixFederationHttpClient(self.hs, None)
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
+    def test_client_get(self):
+        """
+        happy-path test of a GET request
+        """
+        @defer.inlineCallbacks
+        def do_request():
+            with LoggingContext("one") as context:
+                fetch_d = self.cl.get_json("testserv:8008", "foo/bar")
+
+                # Nothing happened yet
+                self.assertNoResult(fetch_d)
+
+                # should have reset logcontext to the sentinel
+                check_logcontext(LoggingContext.sentinel)
+
+                try:
+                    fetch_res = yield fetch_d
+                    defer.returnValue(fetch_res)
+                finally:
+                    check_logcontext(context)
+
+        test_d = do_request()
+
+        self.pump()
+
+        # Nothing happened yet
+        self.assertNoResult(test_d)
+
+        # Make sure treq is trying to connect
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(port, 8008)
+
+        # complete the connection and wire it up to a fake transport
+        protocol = factory.buildProtocol(None)
+        transport = StringTransport()
+        protocol.makeConnection(transport)
+
+        # that should have made it send the request to the transport
+        self.assertRegex(transport.value(), b"^GET /foo/bar")
+        self.assertRegex(transport.value(), b"Host: testserv:8008")
+
+        # Deferred is still without a result
+        self.assertNoResult(test_d)
+
+        # Send it the HTTP response
+        res_json = '{ "a": 1 }'.encode('ascii')
+        protocol.dataReceived(
+            b"HTTP/1.1 200 OK\r\n"
+            b"Server: Fake\r\n"
+            b"Content-Type: application/json\r\n"
+            b"Content-Length: %i\r\n"
+            b"\r\n"
+            b"%s" % (len(res_json), res_json)
+        )
+
+        self.pump()
+
+        res = self.successResultOf(test_d)
+
+        # check the response is as expected
+        self.assertEqual(res, {"a": 1})
+
     def test_dns_error(self):
         """
-        If the DNS raising returns an error, it will bubble up.
+        If the DNS lookup returns an error, it will bubble up.
         """
         d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
         self.pump()
 
         f = self.failureResultOf(d)
-        self.assertIsInstance(f.value, DNSLookupError)
+        self.assertIsInstance(f.value, RequestSendFailed)
+        self.assertIsInstance(f.value.inner_exception, DNSLookupError)
+
+    def test_client_connection_refused(self):
+        d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+
+        self.pump()
+
+        # Nothing happened yet
+        self.assertNoResult(d)
+
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(port, 8008)
+        e = Exception("go away")
+        factory.clientConnectionFailed(None, e)
+        self.pump(0.5)
+
+        f = self.failureResultOf(d)
+
+        self.assertIsInstance(f.value, RequestSendFailed)
+        self.assertIs(f.value.inner_exception, e)
 
     def test_client_never_connect(self):
         """
@@ -61,7 +158,7 @@ class FederationClientTests(HomeserverTestCase):
         self.pump()
 
         # Nothing happened yet
-        self.assertFalse(d.called)
+        self.assertNoResult(d)
 
         # Make sure treq is trying to connect
         clients = self.reactor.tcpClients
@@ -70,13 +167,17 @@ class FederationClientTests(HomeserverTestCase):
         self.assertEqual(clients[0][1], 8008)
 
         # Deferred is still without a result
-        self.assertFalse(d.called)
+        self.assertNoResult(d)
 
         # Push by enough to time it out
         self.reactor.advance(10.5)
         f = self.failureResultOf(d)
 
-        self.assertIsInstance(f.value, (ConnectingCancelledError, TimeoutError))
+        self.assertIsInstance(f.value, RequestSendFailed)
+        self.assertIsInstance(
+            f.value.inner_exception,
+            (ConnectingCancelledError, TimeoutError),
+        )
 
     def test_client_connect_no_response(self):
         """
@@ -88,7 +189,7 @@ class FederationClientTests(HomeserverTestCase):
         self.pump()
 
         # Nothing happened yet
-        self.assertFalse(d.called)
+        self.assertNoResult(d)
 
         # Make sure treq is trying to connect
         clients = self.reactor.tcpClients
@@ -101,13 +202,14 @@ class FederationClientTests(HomeserverTestCase):
         client.makeConnection(conn)
 
         # Deferred is still without a result
-        self.assertFalse(d.called)
+        self.assertNoResult(d)
 
         # Push by enough to time it out
         self.reactor.advance(10.5)
         f = self.failureResultOf(d)
 
-        self.assertIsInstance(f.value, ResponseNeverReceived)
+        self.assertIsInstance(f.value, RequestSendFailed)
+        self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived)
 
     def test_client_gets_headers(self):
         """
@@ -128,7 +230,7 @@ class FederationClientTests(HomeserverTestCase):
         client.makeConnection(conn)
 
         # Deferred does not have a result
-        self.assertFalse(d.called)
+        self.assertNoResult(d)
 
         # Send it the HTTP response
         client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n")
@@ -152,7 +254,7 @@ class FederationClientTests(HomeserverTestCase):
         client.makeConnection(conn)
 
         # Deferred does not have a result
-        self.assertFalse(d.called)
+        self.assertNoResult(d)
 
         # Send it the HTTP response
         client.dataReceived(
@@ -188,3 +290,42 @@ class FederationClientTests(HomeserverTestCase):
         request = server.requests[0]
         content = request.content.read()
         self.assertEqual(content, b'{"a":"b"}')
+
+    def test_closes_connection(self):
+        """Check that the client closes unused HTTP connections"""
+        d = self.cl.get_json("testserv:8008", "foo/bar")
+
+        self.pump()
+
+        # there should have been a call to connectTCP
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (_host, _port, factory, _timeout, _bindAddress) = clients[0]
+
+        # complete the connection and wire it up to a fake transport
+        client = factory.buildProtocol(None)
+        conn = StringTransport()
+        client.makeConnection(conn)
+
+        # that should have made it send the request to the connection
+        self.assertRegex(conn.value(), b"^GET /foo/bar")
+
+        # Send the HTTP response
+        client.dataReceived(
+            b"HTTP/1.1 200 OK\r\n"
+            b"Content-Type: application/json\r\n"
+            b"Content-Length: 2\r\n"
+            b"\r\n"
+            b"{}"
+        )
+
+        # We should get a successful response
+        r = self.successResultOf(d)
+        self.assertEqual(r, {})
+
+        self.assertFalse(conn.disconnecting)
+
+        # wait for a while
+        self.pump(120)
+
+        self.assertTrue(conn.disconnecting)