summary refs log tree commit diff
path: root/synapse/http/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/client.py')
-rw-r--r--synapse/http/client.py64
1 files changed, 36 insertions, 28 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py
index ad454f4964..5c073fff07 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -17,7 +17,7 @@
 import logging
 from io import BytesIO
 
-from six import text_type
+from six import raise_from, text_type
 from six.moves import urllib
 
 import treq
@@ -90,45 +90,50 @@ class IPBlacklistingResolver(object):
     def resolveHostName(self, recv, hostname, portNumber=0):
 
         r = recv()
-        d = defer.Deferred()
         addresses = []
 
-        @provider(IResolutionReceiver)
-        class EndpointReceiver(object):
-            @staticmethod
-            def resolutionBegan(resolutionInProgress):
-                pass
+        def _callback():
+            r.resolutionBegan(None)
 
-            @staticmethod
-            def addressResolved(address):
-                ip_address = IPAddress(address.host)
+            has_bad_ip = False
+            for i in addresses:
+                ip_address = IPAddress(i.host)
 
                 if check_against_blacklist(
                     ip_address, self._ip_whitelist, self._ip_blacklist
                 ):
                     logger.info(
-                        "Dropped %s from DNS resolution to %s" % (ip_address, hostname)
+                        "Dropped %s from DNS resolution to %s due to blacklist" %
+                        (ip_address, hostname)
                     )
-                    raise SynapseError(403, "IP address blocked by IP blacklist entry")
+                    has_bad_ip = True
+
+            # if we have a blacklisted IP, we'd like to raise an error to block the
+            # request, but all we can really do from here is claim that there were no
+            # valid results.
+            if not has_bad_ip:
+                for i in addresses:
+                    r.addressResolved(i)
+            r.resolutionComplete()
+
+        @provider(IResolutionReceiver)
+        class EndpointReceiver(object):
+            @staticmethod
+            def resolutionBegan(resolutionInProgress):
+                pass
 
+            @staticmethod
+            def addressResolved(address):
                 addresses.append(address)
 
             @staticmethod
             def resolutionComplete():
-                d.callback(addresses)
+                _callback()
 
         self._reactor.nameResolver.resolveHostName(
             EndpointReceiver, hostname, portNumber=portNumber
         )
 
-        def _callback(addrs):
-            r.resolutionBegan(None)
-            for i in addrs:
-                r.addressResolved(i)
-            r.resolutionComplete()
-
-        d.addCallback(_callback)
-
         return r
 
 
@@ -160,7 +165,8 @@ class BlacklistingAgentWrapper(Agent):
                 ip_address, self._ip_whitelist, self._ip_blacklist
             ):
                 logger.info(
-                    "Blocking access to %s because of blacklist" % (ip_address,)
+                    "Blocking access to %s due to blacklist" %
+                    (ip_address,)
                 )
                 e = SynapseError(403, "IP address blocked by IP blacklist entry")
                 return defer.fail(Failure(e))
@@ -258,9 +264,6 @@ class SimpleHttpClient(object):
             uri (str): URI to query.
             data (bytes): Data to send in the request body, if applicable.
             headers (t.w.http_headers.Headers): Request headers.
-
-        Raises:
-            SynapseError: If the IP is blacklisted.
         """
         # A small wrapper around self.agent.request() so we can easily attach
         # counters to it
@@ -539,10 +542,15 @@ class SimpleHttpClient(object):
             length = yield make_deferred_yieldable(
                 _readBodyToFile(response, output_stream, max_size)
             )
+        except SynapseError:
+            # This can happen e.g. because the body is too large.
+            raise
         except Exception as e:
-            logger.exception("Failed to download body")
-            raise SynapseError(
-                502, ("Failed to download remote body: %s" % e), Codes.UNKNOWN
+            raise_from(
+                SynapseError(
+                    502, ("Failed to download remote body: %s" % e),
+                ),
+                e
             )
 
         defer.returnValue(