summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/3894.feature1
-rw-r--r--changelog.d/3899.bugfix1
-rw-r--r--changelog.d/3906.misc1
-rwxr-xr-xsynapse/app/homeserver.py4
-rw-r--r--synapse/handlers/room_member.py5
-rw-r--r--synapse/http/matrixfederationclient.py385
-rw-r--r--tests/http/test_fedclient.py43
-rw-r--r--tests/replication/slave/storage/_base.py35
-rw-r--r--tests/server.py81
9 files changed, 370 insertions, 186 deletions
diff --git a/changelog.d/3894.feature b/changelog.d/3894.feature
new file mode 100644
index 0000000000..1ed0cccdb2
--- /dev/null
+++ b/changelog.d/3894.feature
@@ -0,0 +1 @@
+Report "python_version" in the phone home stats
diff --git a/changelog.d/3899.bugfix b/changelog.d/3899.bugfix
new file mode 100644
index 0000000000..5120e3a823
--- /dev/null
+++ b/changelog.d/3899.bugfix
@@ -0,0 +1 @@
+When we join a room, always try the server we used for the alias lookup first, to avoid unresponsive and out-of-date servers.
diff --git a/changelog.d/3906.misc b/changelog.d/3906.misc
new file mode 100644
index 0000000000..11709186d3
--- /dev/null
+++ b/changelog.d/3906.misc
@@ -0,0 +1 @@
+Improve logging of outbound federation requests
\ No newline at end of file
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index ac97e19649..3241ded188 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -457,6 +457,10 @@ def run(hs):
         stats["homeserver"] = hs.config.server_name
         stats["timestamp"] = now
         stats["uptime_seconds"] = uptime
+        version = sys.version_info
+        stats["python_version"] = "{}.{}.{}".format(
+            version.major, version.minor, version.micro
+        )
         stats["total_users"] = yield hs.get_datastore().count_all_users()
 
         total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users()
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index f643619047..07fd3e82fc 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -583,6 +583,11 @@ class RoomMemberHandler(object):
         room_id = mapping["room_id"]
         servers = mapping["servers"]
 
+        # put the server which owns the alias at the front of the server list.
+        if room_alias.domain in servers:
+            servers.remove(room_alias.domain)
+        servers.insert(0, room_alias.domain)
+
         defer.returnValue((RoomID.from_string(room_id), servers))
 
     @defer.inlineCallbacks
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 083484a687..6a2d447289 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -17,10 +17,12 @@ import cgi
 import logging
 import random
 import sys
+from io import BytesIO
 
 from six import PY3, string_types
 from six.moves import urllib
 
+import attr
 import treq
 from canonicaljson import encode_canonical_json
 from prometheus_client import Counter
@@ -28,8 +30,9 @@ from signedjson.sign import sign_json
 
 from twisted.internet import defer, protocol
 from twisted.internet.error import DNSLookupError
+from twisted.internet.task import _EPSILON, Cooperator
 from twisted.web._newclient import ResponseDone
-from twisted.web.client import Agent, HTTPConnectionPool
+from twisted.web.client import Agent, FileBodyProducer, HTTPConnectionPool
 from twisted.web.http_headers import Headers
 
 import synapse.metrics
@@ -41,13 +44,11 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util import logcontext
 from synapse.util.async_helpers import timeout_no_seriously
 from synapse.util.logcontext import make_deferred_yieldable
 from synapse.util.metrics import Measure
 
 logger = logging.getLogger(__name__)
-outbound_logger = logging.getLogger("synapse.http.outbound")
 
 outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests",
                                     "", ["method"])
@@ -78,6 +79,93 @@ class MatrixFederationEndpointFactory(object):
         )
 
 
+_next_id = 1
+
+
+@attr.s
+class MatrixFederationRequest(object):
+    method = attr.ib()
+    """HTTP method
+    :type: str
+    """
+
+    path = attr.ib()
+    """HTTP path
+    :type: str
+    """
+
+    destination = attr.ib()
+    """The remote server to send the HTTP request to.
+    :type: str"""
+
+    json = attr.ib(default=None)
+    """JSON to send in the body.
+    :type: dict|None
+    """
+
+    json_callback = attr.ib(default=None)
+    """A callback to generate the JSON.
+    :type: func|None
+    """
+
+    query = attr.ib(default=None)
+    """Query arguments.
+    :type: dict|None
+    """
+
+    txn_id = attr.ib(default=None)
+    """Unique ID for this request (for logging)
+    :type: str|None
+    """
+
+    def __attrs_post_init__(self):
+        global _next_id
+        self.txn_id = "%s-O-%s" % (self.method, _next_id)
+        _next_id = (_next_id + 1) % (MAXINT - 1)
+
+    def get_json(self):
+        if self.json_callback:
+            return self.json_callback()
+        return self.json
+
+
+@defer.inlineCallbacks
+def _handle_json_response(reactor, timeout_sec, request, response):
+    """
+    Reads the JSON body of a response, with a timeout
+
+    Args:
+        reactor (IReactor): twisted reactor, for the timeout
+        timeout_sec (float): number of seconds to wait for response to complete
+        request (MatrixFederationRequest): the request that triggered the response
+        response (IResponse): response to the request
+
+    Returns:
+        dict: parsed JSON response
+    """
+    try:
+        check_content_type_is_json(response.headers)
+        d = treq.json_content(response)
+        d.addTimeout(timeout_sec, reactor)
+        body = yield make_deferred_yieldable(d)
+    except Exception as e:
+        logger.warn(
+            "{%s} [%d] Error reading response: %s",
+            request.txn_id,
+            request.destination,
+            e,
+        )
+        raise
+    logger.info(
+        "{%s} [%d] Completed: %d %s",
+        request.txn_id,
+        request.destination,
+        response.code,
+        response.phrase.decode('ascii', errors='replace'),
+    )
+    defer.returnValue(body)
+
+
 class MatrixFederationHttpClient(object):
     """HTTP client used to talk to other homeservers over the federation
     protocol. Send client certificates and signs requests.
@@ -102,34 +190,35 @@ class MatrixFederationHttpClient(object):
         self.clock = hs.get_clock()
         self._store = hs.get_datastore()
         self.version_string = hs.version_string.encode('ascii')
-        self._next_id = 1
         self.default_timeout = 60
 
-    def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
-        return urllib.parse.urlunparse(
-            (b"matrix", destination, path_bytes, param_bytes, query_bytes, b"")
-        )
+        def schedule(x):
+            reactor.callLater(_EPSILON, x)
+
+        self._cooperator = Cooperator(scheduler=schedule)
 
     @defer.inlineCallbacks
-    def _request(self, destination, method, path,
-                 json=None, json_callback=None,
-                 param_bytes=b"",
-                 query=None, retry_on_dns_fail=True,
-                 timeout=None, long_retries=False,
-                 ignore_backoff=False,
-                 backoff_on_404=False):
+    def _send_request(
+        self,
+        request,
+        retry_on_dns_fail=True,
+        timeout=None,
+        long_retries=False,
+        ignore_backoff=False,
+        backoff_on_404=False
+    ):
         """
-        Creates and sends a request to the given server.
+        Sends a request to the given server.
 
         Args:
-            destination (str): The remote server to send the HTTP request to.
-            method (str): HTTP method
-            path (str): The HTTP path
-            json (dict or None): JSON to send in the body.
-            json_callback (func or None): A callback to generate the JSON.
-            query (dict or None): Query arguments.
+            request (MatrixFederationRequest): details of request to be sent
+
+            timeout (int|None): number of milliseconds to wait for the response headers
+                (including connecting to the server). 60s by default.
+
             ignore_backoff (bool): true to ignore the historical backoff data
                 and try the request anyway.
+
             backoff_on_404 (bool): Back off if we get a 404
 
         Returns:
@@ -154,38 +243,32 @@ class MatrixFederationHttpClient(object):
 
         if (
             self.hs.config.federation_domain_whitelist is not None and
-            destination not in self.hs.config.federation_domain_whitelist
+            request.destination not in self.hs.config.federation_domain_whitelist
         ):
-            raise FederationDeniedError(destination)
+            raise FederationDeniedError(request.destination)
 
         limiter = yield synapse.util.retryutils.get_retry_limiter(
-            destination,
+            request.destination,
             self.clock,
             self._store,
             backoff_on_404=backoff_on_404,
             ignore_backoff=ignore_backoff,
         )
 
-        headers_dict = {}
-        path_bytes = path.encode("ascii")
-        if query:
-            query_bytes = encode_query_args(query)
+        method = request.method
+        destination = request.destination
+        path_bytes = request.path.encode("ascii")
+        if request.query:
+            query_bytes = encode_query_args(request.query)
         else:
             query_bytes = b""
 
         headers_dict = {
             "User-Agent": [self.version_string],
-            "Host": [destination],
+            "Host": [request.destination],
         }
 
         with limiter:
-            url = self._create_url(
-                destination.encode("ascii"), path_bytes, param_bytes, query_bytes
-            ).decode('ascii')
-
-            txn_id = "%s-O-%s" % (method, self._next_id)
-            self._next_id = (self._next_id + 1) % (MAXINT - 1)
-
             # XXX: Would be much nicer to retry only at the transaction-layer
             # (once we have reliable transactions in place)
             if long_retries:
@@ -193,16 +276,19 @@ class MatrixFederationHttpClient(object):
             else:
                 retries_left = MAX_SHORT_RETRIES
 
-            http_url = urllib.parse.urlunparse(
-                (b"", b"", path_bytes, param_bytes, query_bytes, b"")
-            ).decode('ascii')
+            url = urllib.parse.urlunparse((
+                b"matrix", destination.encode("ascii"),
+                path_bytes, None, query_bytes, b"",
+            )).decode('ascii')
+
+            http_url = urllib.parse.urlunparse((
+                b"", b"",
+                path_bytes, None, query_bytes, b"",
+            )).decode('ascii')
 
-            log_result = None
             while True:
                 try:
-                    if json_callback:
-                        json = json_callback()
-
+                    json = request.get_json()
                     if json:
                         data = encode_canonical_json(json)
                         headers_dict["Content-Type"] = ["application/json"]
@@ -213,16 +299,24 @@ class MatrixFederationHttpClient(object):
                         data = None
                         self.sign_request(destination, method, http_url, headers_dict)
 
-                    outbound_logger.info(
+                    logger.info(
                         "{%s} [%s] Sending request: %s %s",
-                        txn_id, destination, method, url
+                        request.txn_id, destination, method, url
                     )
 
+                    if data:
+                        producer = FileBodyProducer(
+                            BytesIO(data),
+                            cooperator=self._cooperator
+                        )
+                    else:
+                        producer = None
+
                     request_deferred = treq.request(
                         method,
                         url,
                         headers=Headers(headers_dict),
-                        data=data,
+                        data=producer,
                         agent=self.agent,
                         reactor=self.hs.get_reactor(),
                         unbuffered=True
@@ -244,33 +338,19 @@ class MatrixFederationHttpClient(object):
                             request_deferred,
                         )
 
-                    log_result = "%d %s" % (
-                        response.code,
-                        response.phrase.decode('ascii', errors='replace'),
-                    )
                     break
                 except Exception as e:
-                    if not retry_on_dns_fail and isinstance(e, DNSLookupError):
-                        logger.warn(
-                            "DNS Lookup failed to %s with %s",
-                            destination,
-                            e
-                        )
-                        log_result = "DNS Lookup failed to %s with %s" % (
-                            destination, e
-                        )
-                        raise
-
                     logger.warn(
-                        "{%s} Sending request failed to %s: %s %s: %s",
-                        txn_id,
+                        "{%s} [%s] Request failed: %s %s: %s",
+                        request.txn_id,
                         destination,
                         method,
                         url,
                         _flatten_response_never_received(e),
                     )
 
-                    log_result = _flatten_response_never_received(e)
+                    if not retry_on_dns_fail and isinstance(e, DNSLookupError):
+                        raise
 
                     if retries_left and not timeout:
                         if long_retries:
@@ -283,33 +363,33 @@ class MatrixFederationHttpClient(object):
                             delay *= random.uniform(0.8, 1.4)
 
                         logger.debug(
-                            "{%s} Waiting %s before sending to %s...",
-                            txn_id,
+                            "{%s} [%s] Waiting %ss before re-sending...",
+                            request.txn_id,
+                            destination,
                             delay,
-                            destination
                         )
 
                         yield self.clock.sleep(delay)
                         retries_left -= 1
                     else:
                         raise
-                finally:
-                    outbound_logger.info(
-                        "{%s} [%s] Result: %s",
-                        txn_id,
-                        destination,
-                        log_result,
-                    )
+
+            logger.info(
+                "{%s} [%s] Got response headers: %d %s",
+                request.txn_id,
+                destination,
+                response.code,
+                response.phrase.decode('ascii', errors='replace'),
+            )
 
             if 200 <= response.code < 300:
                 pass
             else:
                 # :'(
                 # Update transactions table?
-                with logcontext.PreserveLoggingContext():
-                    d = treq.content(response)
-                    d.addTimeout(_sec_timeout, self.hs.get_reactor())
-                    body = yield make_deferred_yieldable(d)
+                d = treq.content(response)
+                d.addTimeout(_sec_timeout, self.hs.get_reactor())
+                body = yield make_deferred_yieldable(d)
                 raise HttpResponseException(
                     response.code, response.phrase, body
                 )
@@ -403,29 +483,26 @@ class MatrixFederationHttpClient(object):
             is not on our federation whitelist
         """
 
-        if not json_data_callback:
-            json_data_callback = lambda: data
-
-        response = yield self._request(
-            destination,
-            "PUT",
-            path,
-            json_callback=json_data_callback,
+        request = MatrixFederationRequest(
+            method="PUT",
+            destination=destination,
+            path=path,
             query=args,
+            json_callback=json_data_callback,
+            json=data,
+        )
+
+        response = yield self._send_request(
+            request,
             long_retries=long_retries,
             timeout=timeout,
             ignore_backoff=ignore_backoff,
             backoff_on_404=backoff_on_404,
         )
 
-        if 200 <= response.code < 300:
-            # We need to update the transactions table to say it was sent?
-            check_content_type_is_json(response.headers)
-
-        with logcontext.PreserveLoggingContext():
-            d = treq.json_content(response)
-            d.addTimeout(self.default_timeout, self.hs.get_reactor())
-            body = yield make_deferred_yieldable(d)
+        body = yield _handle_json_response(
+            self.hs.get_reactor(), self.default_timeout, request, response,
+        )
         defer.returnValue(body)
 
     @defer.inlineCallbacks
@@ -459,31 +536,30 @@ class MatrixFederationHttpClient(object):
             Fails with ``FederationDeniedError`` if this destination
             is not on our federation whitelist
         """
-        response = yield self._request(
-            destination,
-            "POST",
-            path,
+
+        request = MatrixFederationRequest(
+            method="POST",
+            destination=destination,
+            path=path,
             query=args,
             json=data,
+        )
+
+        response = yield self._send_request(
+            request,
             long_retries=long_retries,
             timeout=timeout,
             ignore_backoff=ignore_backoff,
         )
 
-        if 200 <= response.code < 300:
-            # We need to update the transactions table to say it was sent?
-            check_content_type_is_json(response.headers)
-
-        with logcontext.PreserveLoggingContext():
-            d = treq.json_content(response)
-            if timeout:
-                _sec_timeout = timeout / 1000
-            else:
-                _sec_timeout = self.default_timeout
-
-            d.addTimeout(_sec_timeout, self.hs.get_reactor())
-            body = yield make_deferred_yieldable(d)
+        if timeout:
+            _sec_timeout = timeout / 1000
+        else:
+            _sec_timeout = self.default_timeout
 
+        body = yield _handle_json_response(
+            self.hs.get_reactor(), _sec_timeout, request, response,
+        )
         defer.returnValue(body)
 
     @defer.inlineCallbacks
@@ -519,25 +595,23 @@ class MatrixFederationHttpClient(object):
 
         logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
 
-        response = yield self._request(
-            destination,
-            "GET",
-            path,
+        request = MatrixFederationRequest(
+            method="GET",
+            destination=destination,
+            path=path,
             query=args,
+        )
+
+        response = yield self._send_request(
+            request,
             retry_on_dns_fail=retry_on_dns_fail,
             timeout=timeout,
             ignore_backoff=ignore_backoff,
         )
 
-        if 200 <= response.code < 300:
-            # We need to update the transactions table to say it was sent?
-            check_content_type_is_json(response.headers)
-
-        with logcontext.PreserveLoggingContext():
-            d = treq.json_content(response)
-            d.addTimeout(self.default_timeout, self.hs.get_reactor())
-            body = yield make_deferred_yieldable(d)
-
+        body = yield _handle_json_response(
+            self.hs.get_reactor(), self.default_timeout, request, response,
+        )
         defer.returnValue(body)
 
     @defer.inlineCallbacks
@@ -568,25 +642,23 @@ class MatrixFederationHttpClient(object):
             Fails with ``FederationDeniedError`` if this destination
             is not on our federation whitelist
         """
-        response = yield self._request(
-            destination,
-            "DELETE",
-            path,
+        request = MatrixFederationRequest(
+            method="DELETE",
+            destination=destination,
+            path=path,
             query=args,
+        )
+
+        response = yield self._send_request(
+            request,
             long_retries=long_retries,
             timeout=timeout,
             ignore_backoff=ignore_backoff,
         )
 
-        if 200 <= response.code < 300:
-            # We need to update the transactions table to say it was sent?
-            check_content_type_is_json(response.headers)
-
-        with logcontext.PreserveLoggingContext():
-            d = treq.json_content(response)
-            d.addTimeout(self.default_timeout, self.hs.get_reactor())
-            body = yield make_deferred_yieldable(d)
-
+        body = yield _handle_json_response(
+            self.hs.get_reactor(), self.default_timeout, request, response,
+        )
         defer.returnValue(body)
 
     @defer.inlineCallbacks
@@ -614,11 +686,15 @@ class MatrixFederationHttpClient(object):
             Fails with ``FederationDeniedError`` if this destination
             is not on our federation whitelist
         """
-        response = yield self._request(
-            destination,
-            "GET",
-            path,
+        request = MatrixFederationRequest(
+            method="GET",
+            destination=destination,
+            path=path,
             query=args,
+        )
+
+        response = yield self._send_request(
+            request,
             retry_on_dns_fail=retry_on_dns_fail,
             ignore_backoff=ignore_backoff,
         )
@@ -626,14 +702,25 @@ class MatrixFederationHttpClient(object):
         headers = dict(response.headers.getAllRawHeaders())
 
         try:
-            with logcontext.PreserveLoggingContext():
-                d = _readBodyToFile(response, output_stream, max_size)
-                d.addTimeout(self.default_timeout, self.hs.get_reactor())
-                length = yield make_deferred_yieldable(d)
-        except Exception:
-            logger.exception("Failed to download body")
+            d = _readBodyToFile(response, output_stream, max_size)
+            d.addTimeout(self.default_timeout, self.hs.get_reactor())
+            length = yield make_deferred_yieldable(d)
+        except Exception as e:
+            logger.warn(
+                "{%s} [%d] Error reading response: %s",
+                request.txn_id,
+                request.destination,
+                e,
+            )
             raise
-
+        logger.info(
+            "{%s} [%d] Completed: %d %s [%d bytes]",
+            request.txn_id,
+            request.destination,
+            response.code,
+            response.phrase.decode('ascii', errors='replace'),
+            length,
+        )
         defer.returnValue((length, headers))
 
 
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 1c46c9cfeb..66c09f63b6 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -18,9 +18,14 @@ from mock import Mock
 from twisted.internet.defer import TimeoutError
 from twisted.internet.error import ConnectingCancelledError, DNSLookupError
 from twisted.web.client import ResponseNeverReceived
+from twisted.web.http import HTTPChannel
 
-from synapse.http.matrixfederationclient import MatrixFederationHttpClient
+from synapse.http.matrixfederationclient import (
+    MatrixFederationHttpClient,
+    MatrixFederationRequest,
+)
 
+from tests.server import FakeTransport
 from tests.unittest import HomeserverTestCase
 
 
@@ -40,7 +45,7 @@ class FederationClientTests(HomeserverTestCase):
         """
         If the DNS raising returns an error, it will bubble up.
         """
-        d = self.cl._request("testserv2:8008", "GET", "foo/bar", timeout=10000)
+        d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
         self.pump()
 
         f = self.failureResultOf(d)
@@ -51,7 +56,7 @@ class FederationClientTests(HomeserverTestCase):
         If the HTTP request is not connected and is timed out, it'll give a
         ConnectingCancelledError.
         """
-        d = self.cl._request("testserv:8008", "GET", "foo/bar", timeout=10000)
+        d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
 
         self.pump()
 
@@ -78,7 +83,7 @@ class FederationClientTests(HomeserverTestCase):
         If the HTTP request is connected, but gets no response before being
         timed out, it'll give a ResponseNeverReceived.
         """
-        d = self.cl._request("testserv:8008", "GET", "foo/bar", timeout=10000)
+        d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
 
         self.pump()
 
@@ -108,7 +113,12 @@ class FederationClientTests(HomeserverTestCase):
         """
         Once the client gets the headers, _request returns successfully.
         """
-        d = self.cl._request("testserv:8008", "GET", "foo/bar", timeout=10000)
+        request = MatrixFederationRequest(
+            method="GET",
+            destination="testserv:8008",
+            path="foo/bar",
+        )
+        d = self.cl._send_request(request, timeout=10000)
 
         self.pump()
 
@@ -155,3 +165,26 @@ class FederationClientTests(HomeserverTestCase):
         f = self.failureResultOf(d)
 
         self.assertIsInstance(f.value, TimeoutError)
+
+    def test_client_sends_body(self):
+        self.cl.post_json(
+            "testserv:8008", "foo/bar", timeout=10000,
+            data={"a": "b"}
+        )
+
+        self.pump()
+
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        client = clients[0][2].buildProtocol(None)
+        server = HTTPChannel()
+
+        client.makeConnection(FakeTransport(server, self.reactor))
+        server.makeConnection(FakeTransport(client, self.reactor))
+
+        self.pump(0.1)
+
+        self.assertEqual(len(server.requests), 1)
+        request = server.requests[0]
+        content = request.content.read()
+        self.assertEqual(content, b'{"a":"b"}')
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 089cecfbee..9e9fbbfe93 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -15,8 +15,6 @@
 
 from mock import Mock, NonCallableMock
 
-import attr
-
 from synapse.replication.tcp.client import (
     ReplicationClientFactory,
     ReplicationClientHandler,
@@ -24,6 +22,7 @@ from synapse.replication.tcp.client import (
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
 
 from tests import unittest
+from tests.server import FakeTransport
 
 
 class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
@@ -56,36 +55,8 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
         server = server_factory.buildProtocol(None)
         client = client_factory.buildProtocol(None)
 
-        @attr.s
-        class FakeTransport(object):
-
-            other = attr.ib()
-            disconnecting = False
-            buffer = attr.ib(default=b'')
-
-            def registerProducer(self, producer, streaming):
-
-                self.producer = producer
-
-                def _produce():
-                    self.producer.resumeProducing()
-                    reactor.callLater(0.1, _produce)
-
-                reactor.callLater(0.0, _produce)
-
-            def write(self, byt):
-                self.buffer = self.buffer + byt
-
-                if getattr(self.other, "transport") is not None:
-                    self.other.dataReceived(self.buffer)
-                    self.buffer = b""
-
-            def writeSequence(self, seq):
-                for x in seq:
-                    self.write(x)
-
-        client.makeConnection(FakeTransport(server))
-        server.makeConnection(FakeTransport(client))
+        client.makeConnection(FakeTransport(server, reactor))
+        server.makeConnection(FakeTransport(client, reactor))
 
     def replicate(self):
         """Tell the master side of replication that something has happened, and then
diff --git a/tests/server.py b/tests/server.py
index 420ec4e088..ccea3baa55 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -280,3 +280,84 @@ def get_clock():
     clock = ThreadedMemoryReactorClock()
     hs_clock = Clock(clock)
     return (clock, hs_clock)
+
+
+@attr.s
+class FakeTransport(object):
+    """
+    A twisted.internet.interfaces.ITransport implementation which sends all its data
+    straight into an IProtocol object: it exists to connect two IProtocols together.
+
+    To use it, instantiate it with the receiving IProtocol, and then pass it to the
+    sending IProtocol's makeConnection method:
+
+        server = HTTPChannel()
+        client.makeConnection(FakeTransport(server, self.reactor))
+
+    If you want bidirectional communication, you'll need two instances.
+    """
+
+    other = attr.ib()
+    """The Protocol object which will receive any data written to this transport.
+
+    :type: twisted.internet.interfaces.IProtocol
+    """
+
+    _reactor = attr.ib()
+    """Test reactor
+
+    :type: twisted.internet.interfaces.IReactorTime
+    """
+
+    disconnecting = False
+    buffer = attr.ib(default=b'')
+    producer = attr.ib(default=None)
+
+    def getPeer(self):
+        return None
+
+    def getHost(self):
+        return None
+
+    def loseConnection(self):
+        self.disconnecting = True
+
+    def abortConnection(self):
+        self.disconnecting = True
+
+    def pauseProducing(self):
+        self.producer.pauseProducing()
+
+    def unregisterProducer(self):
+        if not self.producer:
+            return
+
+        self.producer = None
+
+    def registerProducer(self, producer, streaming):
+        self.producer = producer
+        self.producerStreaming = streaming
+
+        def _produce():
+            d = self.producer.resumeProducing()
+            d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))
+
+        if not streaming:
+            self._reactor.callLater(0.0, _produce)
+
+    def write(self, byt):
+        self.buffer = self.buffer + byt
+
+        def _write():
+            if getattr(self.other, "transport") is not None:
+                self.other.dataReceived(self.buffer)
+                self.buffer = b""
+                return
+
+            self._reactor.callLater(0.0, _write)
+
+        _write()
+
+    def writeSequence(self, seq):
+        for x in seq:
+            self.write(x)