summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/http/test_fedclient.py54
-rw-r--r--tests/util/test_async_utils.py104
2 files changed, 151 insertions, 7 deletions
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index b2e38276d8..8426eee400 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -17,6 +17,7 @@ from mock import Mock
 
 from twisted.internet.defer import TimeoutError
 from twisted.internet.error import ConnectingCancelledError, DNSLookupError
+from twisted.test.proto_helpers import StringTransport
 from twisted.web.client import ResponseNeverReceived
 from twisted.web.http import HTTPChannel
 
@@ -44,7 +45,7 @@ class FederationClientTests(HomeserverTestCase):
 
     def test_dns_error(self):
         """
-        If the DNS raising returns an error, it will bubble up.
+        If the DNS lookup returns an error, it will bubble up.
         """
         d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
         self.pump()
@@ -63,7 +64,7 @@ class FederationClientTests(HomeserverTestCase):
         self.pump()
 
         # Nothing happened yet
-        self.assertFalse(d.called)
+        self.assertNoResult(d)
 
         # Make sure treq is trying to connect
         clients = self.reactor.tcpClients
@@ -72,7 +73,7 @@ class FederationClientTests(HomeserverTestCase):
         self.assertEqual(clients[0][1], 8008)
 
         # Deferred is still without a result
-        self.assertFalse(d.called)
+        self.assertNoResult(d)
 
         # Push by enough to time it out
         self.reactor.advance(10.5)
@@ -94,7 +95,7 @@ class FederationClientTests(HomeserverTestCase):
         self.pump()
 
         # Nothing happened yet
-        self.assertFalse(d.called)
+        self.assertNoResult(d)
 
         # Make sure treq is trying to connect
         clients = self.reactor.tcpClients
@@ -107,7 +108,7 @@ class FederationClientTests(HomeserverTestCase):
         client.makeConnection(conn)
 
         # Deferred is still without a result
-        self.assertFalse(d.called)
+        self.assertNoResult(d)
 
         # Push by enough to time it out
         self.reactor.advance(10.5)
@@ -135,7 +136,7 @@ class FederationClientTests(HomeserverTestCase):
         client.makeConnection(conn)
 
         # Deferred does not have a result
-        self.assertFalse(d.called)
+        self.assertNoResult(d)
 
         # Send it the HTTP response
         client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n")
@@ -159,7 +160,7 @@ class FederationClientTests(HomeserverTestCase):
         client.makeConnection(conn)
 
         # Deferred does not have a result
-        self.assertFalse(d.called)
+        self.assertNoResult(d)
 
         # Send it the HTTP response
         client.dataReceived(
@@ -195,3 +196,42 @@ class FederationClientTests(HomeserverTestCase):
         request = server.requests[0]
         content = request.content.read()
         self.assertEqual(content, b'{"a":"b"}')
+
+    def test_closes_connection(self):
+        """Check that the client closes unused HTTP connections"""
+        d = self.cl.get_json("testserv:8008", "foo/bar")
+
+        self.pump()
+
+        # there should have been a call to connectTCP
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (_host, _port, factory, _timeout, _bindAddress) = clients[0]
+
+        # complete the connection and wire it up to a fake transport
+        client = factory.buildProtocol(None)
+        conn = StringTransport()
+        client.makeConnection(conn)
+
+        # that should have made it send the request to the connection
+        self.assertRegex(conn.value(), b"^GET /foo/bar")
+
+        # Send the HTTP response
+        client.dataReceived(
+            b"HTTP/1.1 200 OK\r\n"
+            b"Content-Type: application/json\r\n"
+            b"Content-Length: 2\r\n"
+            b"\r\n"
+            b"{}"
+        )
+
+        # We should get a successful response
+        r = self.successResultOf(d)
+        self.assertEqual(r, {})
+
+        self.assertFalse(conn.disconnecting)
+
+        # wait for a while
+        self.pump(120)
+
+        self.assertTrue(conn.disconnecting)
diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py
new file mode 100644
index 0000000000..84dd71e47a
--- /dev/null
+++ b/tests/util/test_async_utils.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+from twisted.internet.defer import CancelledError, Deferred
+from twisted.internet.task import Clock
+
+from synapse.util import logcontext
+from synapse.util.async_helpers import timeout_deferred
+from synapse.util.logcontext import LoggingContext
+
+from tests.unittest import TestCase
+
+
+class TimeoutDeferredTest(TestCase):
+    def setUp(self):
+        self.clock = Clock()
+
+    def test_times_out(self):
+        """Basic test case that checks that the original deferred is cancelled and that
+        the timing-out deferred is errbacked
+        """
+        cancelled = [False]
+
+        def canceller(_d):
+            cancelled[0] = True
+
+        non_completing_d = Deferred(canceller)
+        timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
+
+        self.assertNoResult(timing_out_d)
+        self.assertFalse(cancelled[0], "deferred was cancelled prematurely")
+
+        self.clock.pump((1.0, ))
+
+        self.assertTrue(cancelled[0], "deferred was not cancelled by timeout")
+        self.failureResultOf(timing_out_d, defer.TimeoutError, )
+
+    def test_times_out_when_canceller_throws(self):
+        """Test that we have successfully worked around
+        https://twistedmatrix.com/trac/ticket/9534"""
+
+        def canceller(_d):
+            raise Exception("can't cancel this deferred")
+
+        non_completing_d = Deferred(canceller)
+        timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
+
+        self.assertNoResult(timing_out_d)
+
+        self.clock.pump((1.0, ))
+
+        self.failureResultOf(timing_out_d, defer.TimeoutError, )
+
+    def test_logcontext_is_preserved_on_cancellation(self):
+        blocking_was_cancelled = [False]
+
+        @defer.inlineCallbacks
+        def blocking():
+            non_completing_d = Deferred()
+            with logcontext.PreserveLoggingContext():
+                try:
+                    yield non_completing_d
+                except CancelledError:
+                    blocking_was_cancelled[0] = True
+                    raise
+
+        with logcontext.LoggingContext("one") as context_one:
+            # the errbacks should be run in the test logcontext
+            def errback(res, deferred_name):
+                self.assertIs(
+                    LoggingContext.current_context(), context_one,
+                    "errback %s run in unexpected logcontext %s" % (
+                        deferred_name, LoggingContext.current_context(),
+                    )
+                )
+                return res
+
+            original_deferred = blocking()
+            original_deferred.addErrback(errback, "orig")
+            timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock)
+            self.assertNoResult(timing_out_d)
+            self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+            timing_out_d.addErrback(errback, "timingout")
+
+            self.clock.pump((1.0, ))
+
+            self.assertTrue(
+                blocking_was_cancelled[0],
+                "non-completing deferred was not cancelled",
+            )
+            self.failureResultOf(timing_out_d, defer.TimeoutError, )
+            self.assertIs(LoggingContext.current_context(), context_one)