summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-x.circleci/merge_base_branch.sh4
-rw-r--r--changelog.d/3857.misc1
-rw-r--r--changelog.d/3858.misc1
-rw-r--r--changelog.d/3859.bugfix1
-rw-r--r--synapse/events/__init__.py5
-rw-r--r--synapse/federation/federation_base.py2
-rw-r--r--synapse/http/matrixfederationclient.py98
-rw-r--r--tests/http/test_fedclient.py157
-rw-r--r--tests/server.py42
9 files changed, 253 insertions, 58 deletions
diff --git a/.circleci/merge_base_branch.sh b/.circleci/merge_base_branch.sh
index 4e297d77da..9614eb91b6 100755
--- a/.circleci/merge_base_branch.sh
+++ b/.circleci/merge_base_branch.sh
@@ -19,6 +19,10 @@ GITBASE=`curl -q https://api.github.com/repos/matrix-org/synapse/pulls/${CIRCLE_
 # Show what we are before
 git show -s
 
+# Set up username so it can do a merge
+git config --global user.email bot@matrix.org
+git config --global user.name "A robot"
+
 # Fetch and merge. If it doesn't work, it will raise due to set -e.
 git fetch -u origin $GITBASE
 git merge --no-edit origin/$GITBASE
diff --git a/changelog.d/3857.misc b/changelog.d/3857.misc
new file mode 100644
index 0000000000..e128d193d9
--- /dev/null
+++ b/changelog.d/3857.misc
@@ -0,0 +1 @@
+Refactor some HTTP timeout code.
\ No newline at end of file
diff --git a/changelog.d/3858.misc b/changelog.d/3858.misc
new file mode 100644
index 0000000000..4644db5330
--- /dev/null
+++ b/changelog.d/3858.misc
@@ -0,0 +1 @@
+Fix running merged builds on CircleCI
\ No newline at end of file
diff --git a/changelog.d/3859.bugfix b/changelog.d/3859.bugfix
new file mode 100644
index 0000000000..ec5b172464
--- /dev/null
+++ b/changelog.d/3859.bugfix
@@ -0,0 +1 @@
+Fix handling of redacted events from federation
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 51f9084b90..b782af6308 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -13,6 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import six
+
 from synapse.util.caches import intern_dict
 from synapse.util.frozenutils import freeze
 
@@ -147,6 +149,9 @@ class EventBase(object):
     def items(self):
         return list(self._event_dict.items())
 
+    def keys(self):
+        return six.iterkeys(self._event_dict)
+
 
 class FrozenEvent(EventBase):
     def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 61782ae1c0..b7ad729c63 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -153,7 +153,7 @@ class FederationBase(object):
                     # *actual* redacted copy to be on the safe side.)
                     redacted_event = prune_event(pdu)
                     if (
-                        set(six.iterkeys(redacted_event)) == set(six.iterkeys(pdu)) and
+                        set(redacted_event.keys()) == set(pdu.keys()) and
                         set(six.iterkeys(redacted_event.content))
                             == set(six.iterkeys(pdu.content))
                     ):
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index cf920bc041..c49dbacd93 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -26,7 +26,7 @@ from canonicaljson import encode_canonical_json
 from prometheus_client import Counter
 from signedjson.sign import sign_json
 
-from twisted.internet import defer, protocol, reactor
+from twisted.internet import defer, protocol
 from twisted.internet.error import DNSLookupError
 from twisted.web._newclient import ResponseDone
 from twisted.web.client import Agent, HTTPConnectionPool
@@ -40,10 +40,8 @@ from synapse.api.errors import (
     HttpResponseException,
     SynapseError,
 )
-from synapse.http import cancelled_to_request_timed_out_error
 from synapse.http.endpoint import matrix_federation_endpoint
 from synapse.util import logcontext
-from synapse.util.async_helpers import add_timeout_to_deferred
 from synapse.util.logcontext import make_deferred_yieldable
 
 logger = logging.getLogger(__name__)
@@ -66,13 +64,14 @@ else:
 
 class MatrixFederationEndpointFactory(object):
     def __init__(self, hs):
+        self.reactor = hs.get_reactor()
         self.tls_client_options_factory = hs.tls_client_options_factory
 
     def endpointForURI(self, uri):
         destination = uri.netloc.decode('ascii')
 
         return matrix_federation_endpoint(
-            reactor, destination, timeout=10,
+            self.reactor, destination, timeout=10,
             tls_client_options_factory=self.tls_client_options_factory
         )
 
@@ -90,6 +89,7 @@ class MatrixFederationHttpClient(object):
         self.hs = hs
         self.signing_key = hs.config.signing_key[0]
         self.server_name = hs.hostname
+        reactor = hs.get_reactor()
         pool = HTTPConnectionPool(reactor)
         pool.maxPersistentPerHost = 5
         pool.cachedConnectionTimeout = 2 * 60
@@ -100,6 +100,7 @@ class MatrixFederationHttpClient(object):
         self._store = hs.get_datastore()
         self.version_string = hs.version_string.encode('ascii')
         self._next_id = 1
+        self.default_timeout = 60
 
     def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
         return urllib.parse.urlunparse(
@@ -143,6 +144,11 @@ class MatrixFederationHttpClient(object):
             (May also fail with plenty of other Exceptions for things like DNS
                 failures, connection failures, SSL failures.)
         """
+        if timeout:
+            _sec_timeout = timeout / 1000
+        else:
+            _sec_timeout = self.default_timeout
+
         if (
             self.hs.config.federation_domain_whitelist is not None and
             destination not in self.hs.config.federation_domain_whitelist
@@ -215,13 +221,9 @@ class MatrixFederationHttpClient(object):
                         headers=Headers(headers_dict),
                         data=data,
                         agent=self.agent,
+                        reactor=self.hs.get_reactor()
                     )
-                    add_timeout_to_deferred(
-                        request_deferred,
-                        timeout / 1000. if timeout else 60,
-                        self.hs.get_reactor(),
-                        cancelled_to_request_timed_out_error,
-                    )
+                    request_deferred.addTimeout(_sec_timeout, self.hs.get_reactor())
                     response = yield make_deferred_yieldable(
                         request_deferred,
                     )
@@ -261,6 +263,13 @@ class MatrixFederationHttpClient(object):
                             delay = min(delay, 2)
                             delay *= random.uniform(0.8, 1.4)
 
+                        logger.debug(
+                            "{%s} Waiting %s before sending to %s...",
+                            txn_id,
+                            delay,
+                            destination
+                        )
+
                         yield self.clock.sleep(delay)
                         retries_left -= 1
                     else:
@@ -279,10 +288,9 @@ class MatrixFederationHttpClient(object):
                 # :'(
                 # Update transactions table?
                 with logcontext.PreserveLoggingContext():
-                    body = yield self._timeout_deferred(
-                        treq.content(response),
-                        timeout,
-                    )
+                    d = treq.content(response)
+                    d.addTimeout(_sec_timeout, self.hs.get_reactor())
+                    body = yield make_deferred_yieldable(d)
                 raise HttpResponseException(
                     response.code, response.phrase, body
                 )
@@ -396,10 +404,9 @@ class MatrixFederationHttpClient(object):
             check_content_type_is_json(response.headers)
 
         with logcontext.PreserveLoggingContext():
-            body = yield self._timeout_deferred(
-                treq.json_content(response),
-                timeout,
-            )
+            d = treq.json_content(response)
+            d.addTimeout(self.default_timeout, self.hs.get_reactor())
+            body = yield make_deferred_yieldable(d)
         defer.returnValue(body)
 
     @defer.inlineCallbacks
@@ -449,10 +456,14 @@ class MatrixFederationHttpClient(object):
             check_content_type_is_json(response.headers)
 
         with logcontext.PreserveLoggingContext():
-            body = yield self._timeout_deferred(
-                treq.json_content(response),
-                timeout,
-            )
+            d = treq.json_content(response)
+            if timeout:
+                _sec_timeout = timeout / 1000
+            else:
+                _sec_timeout = self.default_timeout
+
+            d.addTimeout(_sec_timeout, self.hs.get_reactor())
+            body = yield make_deferred_yieldable(d)
 
         defer.returnValue(body)
 
@@ -504,10 +515,9 @@ class MatrixFederationHttpClient(object):
             check_content_type_is_json(response.headers)
 
         with logcontext.PreserveLoggingContext():
-            body = yield self._timeout_deferred(
-                treq.json_content(response),
-                timeout,
-            )
+            d = treq.json_content(response)
+            d.addTimeout(self.default_timeout, self.hs.get_reactor())
+            body = yield make_deferred_yieldable(d)
 
         defer.returnValue(body)
 
@@ -554,10 +564,9 @@ class MatrixFederationHttpClient(object):
             check_content_type_is_json(response.headers)
 
         with logcontext.PreserveLoggingContext():
-            body = yield self._timeout_deferred(
-                treq.json_content(response),
-                timeout,
-            )
+            d = treq.json_content(response)
+            d.addTimeout(self.default_timeout, self.hs.get_reactor())
+            body = yield make_deferred_yieldable(d)
 
         defer.returnValue(body)
 
@@ -599,38 +608,15 @@ class MatrixFederationHttpClient(object):
 
         try:
             with logcontext.PreserveLoggingContext():
-                length = yield self._timeout_deferred(
-                    _readBodyToFile(
-                        response, output_stream, max_size
-                    ),
-                )
+                d = _readBodyToFile(response, output_stream, max_size)
+                d.addTimeout(self.default_timeout, self.hs.get_reactor())
+                length = yield make_deferred_yieldable(d)
         except Exception:
             logger.exception("Failed to download body")
             raise
 
         defer.returnValue((length, headers))
 
-    def _timeout_deferred(self, deferred, timeout_ms=None):
-        """Times the deferred out after `timeout_ms` ms
-
-        Args:
-            deferred (Deferred)
-            timeout_ms (int|None): Timeout in milliseconds. If None defaults
-                to 60 seconds.
-
-        Returns:
-            Deferred
-        """
-
-        add_timeout_to_deferred(
-            deferred,
-            timeout_ms / 1000. if timeout_ms else 60,
-            self.hs.get_reactor(),
-            cancelled_to_request_timed_out_error,
-        )
-
-        return deferred
-
 
 class _ReadBodyToFileProtocol(protocol.Protocol):
     def __init__(self, stream, deferred, max_size):
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
new file mode 100644
index 0000000000..1c46c9cfeb
--- /dev/null
+++ b/tests/http/test_fedclient.py
@@ -0,0 +1,157 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from twisted.internet.defer import TimeoutError
+from twisted.internet.error import ConnectingCancelledError, DNSLookupError
+from twisted.web.client import ResponseNeverReceived
+
+from synapse.http.matrixfederationclient import MatrixFederationHttpClient
+
+from tests.unittest import HomeserverTestCase
+
+
+class FederationClientTests(HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+
+        hs = self.setup_test_homeserver(reactor=reactor, clock=clock)
+        hs.tls_client_options_factory = None
+        return hs
+
+    def prepare(self, reactor, clock, homeserver):
+
+        self.cl = MatrixFederationHttpClient(self.hs)
+        self.reactor.lookups["testserv"] = "1.2.3.4"
+
+    def test_dns_error(self):
+        """
+        If the DNS raising returns an error, it will bubble up.
+        """
+        d = self.cl._request("testserv2:8008", "GET", "foo/bar", timeout=10000)
+        self.pump()
+
+        f = self.failureResultOf(d)
+        self.assertIsInstance(f.value, DNSLookupError)
+
+    def test_client_never_connect(self):
+        """
+        If the HTTP request is not connected and is timed out, it'll give a
+        ConnectingCancelledError.
+        """
+        d = self.cl._request("testserv:8008", "GET", "foo/bar", timeout=10000)
+
+        self.pump()
+
+        # Nothing happened yet
+        self.assertFalse(d.called)
+
+        # Make sure treq is trying to connect
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        self.assertEqual(clients[0][0], '1.2.3.4')
+        self.assertEqual(clients[0][1], 8008)
+
+        # Deferred is still without a result
+        self.assertFalse(d.called)
+
+        # Push by enough to time it out
+        self.reactor.advance(10.5)
+        f = self.failureResultOf(d)
+
+        self.assertIsInstance(f.value, ConnectingCancelledError)
+
+    def test_client_connect_no_response(self):
+        """
+        If the HTTP request is connected, but gets no response before being
+        timed out, it'll give a ResponseNeverReceived.
+        """
+        d = self.cl._request("testserv:8008", "GET", "foo/bar", timeout=10000)
+
+        self.pump()
+
+        # Nothing happened yet
+        self.assertFalse(d.called)
+
+        # Make sure treq is trying to connect
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        self.assertEqual(clients[0][0], '1.2.3.4')
+        self.assertEqual(clients[0][1], 8008)
+
+        conn = Mock()
+        client = clients[0][2].buildProtocol(None)
+        client.makeConnection(conn)
+
+        # Deferred is still without a result
+        self.assertFalse(d.called)
+
+        # Push by enough to time it out
+        self.reactor.advance(10.5)
+        f = self.failureResultOf(d)
+
+        self.assertIsInstance(f.value, ResponseNeverReceived)
+
+    def test_client_gets_headers(self):
+        """
+        Once the client gets the headers, _request returns successfully.
+        """
+        d = self.cl._request("testserv:8008", "GET", "foo/bar", timeout=10000)
+
+        self.pump()
+
+        conn = Mock()
+        clients = self.reactor.tcpClients
+        client = clients[0][2].buildProtocol(None)
+        client.makeConnection(conn)
+
+        # Deferred does not have a result
+        self.assertFalse(d.called)
+
+        # Send it the HTTP response
+        client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n")
+
+        # We should get a successful response
+        r = self.successResultOf(d)
+        self.assertEqual(r.code, 200)
+
+    def test_client_headers_no_body(self):
+        """
+        If the HTTP request is connected, but gets no response before being
+        timed out, it'll give a ResponseNeverReceived.
+        """
+        d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
+
+        self.pump()
+
+        conn = Mock()
+        clients = self.reactor.tcpClients
+        client = clients[0][2].buildProtocol(None)
+        client.makeConnection(conn)
+
+        # Deferred does not have a result
+        self.assertFalse(d.called)
+
+        # Send it the HTTP response
+        client.dataReceived(
+            (b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n"
+             b"Server: Fake\r\n\r\n")
+        )
+
+        # Push by enough to time it out
+        self.reactor.advance(10.5)
+        f = self.failureResultOf(d)
+
+        self.assertIsInstance(f.value, TimeoutError)
diff --git a/tests/server.py b/tests/server.py
index a2c3ca61f6..420ec4e088 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -4,9 +4,14 @@ from io import BytesIO
 from six import text_type
 
 import attr
+from zope.interface import implementer
 
-from twisted.internet import address, threads
+from twisted.internet import address, threads, udp
+from twisted.internet._resolver import HostResolution
+from twisted.internet.address import IPv4Address
 from twisted.internet.defer import Deferred
+from twisted.internet.error import DNSLookupError
+from twisted.internet.interfaces import IReactorPluggableNameResolver
 from twisted.python.failure import Failure
 from twisted.test.proto_helpers import MemoryReactorClock
 
@@ -154,11 +159,46 @@ def render(request, resource, clock):
     wait_until_result(clock, request)
 
 
+@implementer(IReactorPluggableNameResolver)
 class ThreadedMemoryReactorClock(MemoryReactorClock):
     """
     A MemoryReactorClock that supports callFromThread.
     """
 
+    def __init__(self):
+        self._udp = []
+        self.lookups = {}
+
+        class Resolver(object):
+            def resolveHostName(
+                _self,
+                resolutionReceiver,
+                hostName,
+                portNumber=0,
+                addressTypes=None,
+                transportSemantics='TCP',
+            ):
+
+                resolution = HostResolution(hostName)
+                resolutionReceiver.resolutionBegan(resolution)
+                if hostName not in self.lookups:
+                    raise DNSLookupError("OH NO")
+
+                resolutionReceiver.addressResolved(
+                    IPv4Address('TCP', self.lookups[hostName], portNumber)
+                )
+                resolutionReceiver.resolutionComplete()
+                return resolution
+
+        self.nameResolver = Resolver()
+        super(ThreadedMemoryReactorClock, self).__init__()
+
+    def listenUDP(self, port, protocol, interface='', maxPacketSize=8196):
+        p = udp.Port(port, protocol, interface, maxPacketSize, self)
+        p.startListening()
+        self._udp.append(p)
+        return p
+
     def callFromThread(self, callback, *args, **kwargs):
         """
         Make the callback fire in the next reactor iteration.