summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2018-09-18 18:17:15 +0100
committerGitHub <noreply@github.com>2018-09-18 18:17:15 +0100
commit31c15dcb80c8f11fd03dbb9b0ccff4777dc8e457 (patch)
tree2ee45023e518888d996b2f0743bf16f2c5769f5f /tests
parentMerge pull request #3894 from matrix-org/hs/phone_home_py_version (diff)
downloadsynapse-31c15dcb80c8f11fd03dbb9b0ccff4777dc8e457.tar.xz
Refactor matrixfederationclient to fix logging (#3906)
We want to wait until we have read the response body before we log the request
as complete, otherwise a confusing thing happens where the request appears to
have completed, but we later fail it.

To do this, we factor the salient details of a request out to a separate
object, which can then keep track of the txn_id, so that it can be logged.
Diffstat (limited to 'tests')
-rw-r--r--tests/http/test_fedclient.py43
-rw-r--r--tests/replication/slave/storage/_base.py35
-rw-r--r--tests/server.py81
3 files changed, 122 insertions, 37 deletions
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 1c46c9cfeb..66c09f63b6 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -18,9 +18,14 @@ from mock import Mock
 from twisted.internet.defer import TimeoutError
 from twisted.internet.error import ConnectingCancelledError, DNSLookupError
 from twisted.web.client import ResponseNeverReceived
+from twisted.web.http import HTTPChannel
 
-from synapse.http.matrixfederationclient import MatrixFederationHttpClient
+from synapse.http.matrixfederationclient import (
+    MatrixFederationHttpClient,
+    MatrixFederationRequest,
+)
 
+from tests.server import FakeTransport
 from tests.unittest import HomeserverTestCase
 
 
@@ -40,7 +45,7 @@ class FederationClientTests(HomeserverTestCase):
         """
         If the DNS raising returns an error, it will bubble up.
         """
-        d = self.cl._request("testserv2:8008", "GET", "foo/bar", timeout=10000)
+        d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
         self.pump()
 
         f = self.failureResultOf(d)
@@ -51,7 +56,7 @@ class FederationClientTests(HomeserverTestCase):
         If the HTTP request is not connected and is timed out, it'll give a
         ConnectingCancelledError.
         """
-        d = self.cl._request("testserv:8008", "GET", "foo/bar", timeout=10000)
+        d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
 
         self.pump()
 
@@ -78,7 +83,7 @@ class FederationClientTests(HomeserverTestCase):
         If the HTTP request is connected, but gets no response before being
         timed out, it'll give a ResponseNeverReceived.
         """
-        d = self.cl._request("testserv:8008", "GET", "foo/bar", timeout=10000)
+        d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
 
         self.pump()
 
@@ -108,7 +113,12 @@ class FederationClientTests(HomeserverTestCase):
         """
         Once the client gets the headers, _request returns successfully.
         """
-        d = self.cl._request("testserv:8008", "GET", "foo/bar", timeout=10000)
+        request = MatrixFederationRequest(
+            method="GET",
+            destination="testserv:8008",
+            path="foo/bar",
+        )
+        d = self.cl._send_request(request, timeout=10000)
 
         self.pump()
 
@@ -155,3 +165,26 @@ class FederationClientTests(HomeserverTestCase):
         f = self.failureResultOf(d)
 
         self.assertIsInstance(f.value, TimeoutError)
+
+    def test_client_sends_body(self):
+        self.cl.post_json(
+            "testserv:8008", "foo/bar", timeout=10000,
+            data={"a": "b"}
+        )
+
+        self.pump()
+
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        client = clients[0][2].buildProtocol(None)
+        server = HTTPChannel()
+
+        client.makeConnection(FakeTransport(server, self.reactor))
+        server.makeConnection(FakeTransport(client, self.reactor))
+
+        self.pump(0.1)
+
+        self.assertEqual(len(server.requests), 1)
+        request = server.requests[0]
+        content = request.content.read()
+        self.assertEqual(content, b'{"a":"b"}')
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 089cecfbee..9e9fbbfe93 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -15,8 +15,6 @@
 
 from mock import Mock, NonCallableMock
 
-import attr
-
 from synapse.replication.tcp.client import (
     ReplicationClientFactory,
     ReplicationClientHandler,
@@ -24,6 +22,7 @@ from synapse.replication.tcp.client import (
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
 
 from tests import unittest
+from tests.server import FakeTransport
 
 
 class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
@@ -56,36 +55,8 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
         server = server_factory.buildProtocol(None)
         client = client_factory.buildProtocol(None)
 
-        @attr.s
-        class FakeTransport(object):
-
-            other = attr.ib()
-            disconnecting = False
-            buffer = attr.ib(default=b'')
-
-            def registerProducer(self, producer, streaming):
-
-                self.producer = producer
-
-                def _produce():
-                    self.producer.resumeProducing()
-                    reactor.callLater(0.1, _produce)
-
-                reactor.callLater(0.0, _produce)
-
-            def write(self, byt):
-                self.buffer = self.buffer + byt
-
-                if getattr(self.other, "transport") is not None:
-                    self.other.dataReceived(self.buffer)
-                    self.buffer = b""
-
-            def writeSequence(self, seq):
-                for x in seq:
-                    self.write(x)
-
-        client.makeConnection(FakeTransport(server))
-        server.makeConnection(FakeTransport(client))
+        client.makeConnection(FakeTransport(server, reactor))
+        server.makeConnection(FakeTransport(client, reactor))
 
     def replicate(self):
         """Tell the master side of replication that something has happened, and then
diff --git a/tests/server.py b/tests/server.py
index 420ec4e088..ccea3baa55 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -280,3 +280,84 @@ def get_clock():
     clock = ThreadedMemoryReactorClock()
     hs_clock = Clock(clock)
     return (clock, hs_clock)
+
+
+@attr.s
+class FakeTransport(object):
+    """
+    A twisted.internet.interfaces.ITransport implementation which sends all its data
+    straight into an IProtocol object: it exists to connect two IProtocols together.
+
+    To use it, instantiate it with the receiving IProtocol, and then pass it to the
+    sending IProtocol's makeConnection method:
+
+        server = HTTPChannel()
+        client.makeConnection(FakeTransport(server, self.reactor))
+
+    If you want bidirectional communication, you'll need two instances.
+    """
+
+    other = attr.ib()
+    """The Protocol object which will receive any data written to this transport.
+
+    :type: twisted.internet.interfaces.IProtocol
+    """
+
+    _reactor = attr.ib()
+    """Test reactor
+
+    :type: twisted.internet.interfaces.IReactorTime
+    """
+
+    disconnecting = False
+    buffer = attr.ib(default=b'')
+    producer = attr.ib(default=None)
+
+    def getPeer(self):
+        return None
+
+    def getHost(self):
+        return None
+
+    def loseConnection(self):
+        self.disconnecting = True
+
+    def abortConnection(self):
+        self.disconnecting = True
+
+    def pauseProducing(self):
+        self.producer.pauseProducing()
+
+    def unregisterProducer(self):
+        if not self.producer:
+            return
+
+        self.producer = None
+
+    def registerProducer(self, producer, streaming):
+        self.producer = producer
+        self.producerStreaming = streaming
+
+        def _produce():
+            d = self.producer.resumeProducing()
+            d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))
+
+        if not streaming:
+            self._reactor.callLater(0.0, _produce)
+
+    def write(self, byt):
+        self.buffer = self.buffer + byt
+
+        def _write():
+            if getattr(self.other, "transport") is not None:
+                self.other.dataReceived(self.buffer)
+                self.buffer = b""
+                return
+
+            self._reactor.callLater(0.0, _write)
+
+        _write()
+
+    def writeSequence(self, seq):
+        for x in seq:
+            self.write(x)