summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9563.misc1
-rw-r--r--synapse/http/client.py26
-rw-r--r--tests/http/test_client.py126
3 files changed, 139 insertions, 14 deletions
diff --git a/changelog.d/9563.misc b/changelog.d/9563.misc
new file mode 100644
index 0000000000..7a3493e4a1
--- /dev/null
+++ b/changelog.d/9563.misc
@@ -0,0 +1 @@
+Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper.
diff --git a/synapse/http/client.py b/synapse/http/client.py
index af34d583ad..8f3da486b3 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -39,6 +39,7 @@ from zope.interface import implementer, provider
 from OpenSSL import SSL
 from OpenSSL.SSL import VERIFY_NONE
 from twisted.internet import defer, error as twisted_error, protocol, ssl
+from twisted.internet.address import IPv4Address, IPv6Address
 from twisted.internet.interfaces import (
     IAddress,
     IHostResolution,
@@ -151,16 +152,17 @@ class _IPBlacklistingResolver:
     def resolveHostName(
         self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
     ) -> IResolutionReceiver:
-
-        r = recv()
         addresses = []  # type: List[IAddress]
 
         def _callback() -> None:
-            r.resolutionBegan(None)
-
             has_bad_ip = False
-            for i in addresses:
-                ip_address = IPAddress(i.host)
+            for address in addresses:
+                # We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
+                # should go through this path.
+                if not isinstance(address, (IPv4Address, IPv6Address)):
+                    continue
+
+                ip_address = IPAddress(address.host)
 
                 if check_against_blacklist(
                     ip_address, self._ip_whitelist, self._ip_blacklist
@@ -175,15 +177,15 @@ class _IPBlacklistingResolver:
             # request, but all we can really do from here is claim that there were no
             # valid results.
             if not has_bad_ip:
-                for i in addresses:
-                    r.addressResolved(i)
-            r.resolutionComplete()
+                for address in addresses:
+                    recv.addressResolved(address)
+            recv.resolutionComplete()
 
         @provider(IResolutionReceiver)
         class EndpointReceiver:
             @staticmethod
             def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
-                pass
+                recv.resolutionBegan(resolutionInProgress)
 
             @staticmethod
             def addressResolved(address: IAddress) -> None:
@@ -197,7 +199,7 @@ class _IPBlacklistingResolver:
             EndpointReceiver, hostname, portNumber=portNumber
         )
 
-        return r
+        return recv
 
 
 @implementer(ISynapseReactor)
@@ -346,7 +348,7 @@ class SimpleHttpClient:
             contextFactory=self.hs.get_http_client_context_factory(),
             pool=pool,
             use_proxy=use_proxy,
-        )
+        )  # type: IAgent
 
         if self._ip_blacklist:
             # If we have an IP blacklist, we then install the blacklisting Agent
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 21ecb81c99..0ce181a51e 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -16,12 +16,23 @@ from io import BytesIO
 
 from mock import Mock
 
+from netaddr import IPSet
+
+from twisted.internet.error import DNSLookupError
 from twisted.python.failure import Failure
-from twisted.web.client import ResponseDone
+from twisted.test.proto_helpers import AccumulatingProtocol
+from twisted.web.client import Agent, ResponseDone
 from twisted.web.iweb import UNKNOWN_LENGTH
 
-from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
+from synapse.api.errors import SynapseError
+from synapse.http.client import (
+    BlacklistingAgentWrapper,
+    BlacklistingReactorWrapper,
+    BodyExceededMaxSize,
+    read_body_with_max_size,
+)
 
+from tests.server import FakeTransport, get_clock
 from tests.unittest import TestCase
 
 
@@ -119,3 +130,114 @@ class ReadBodyWithMaxSizeTests(TestCase):
 
         # The data is never consumed.
         self.assertEqual(result.getvalue(), b"")
+
+
+class BlacklistingAgentTest(TestCase):
+    def setUp(self):
+        self.reactor, self.clock = get_clock()
+
+        self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
+        self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
+        self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
+
+        # Configure the reactor's DNS resolver.
+        for (domain, ip) in (
+            (self.safe_domain, self.safe_ip),
+            (self.unsafe_domain, self.unsafe_ip),
+            (self.allowed_domain, self.allowed_ip),
+        ):
+            self.reactor.lookups[domain.decode()] = ip.decode()
+            self.reactor.lookups[ip.decode()] = ip.decode()
+
+        self.ip_whitelist = IPSet([self.allowed_ip.decode()])
+        self.ip_blacklist = IPSet(["5.0.0.0/8"])
+
+    def test_reactor(self):
+        """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
+        agent = Agent(
+            BlacklistingReactorWrapper(
+                self.reactor,
+                ip_whitelist=self.ip_whitelist,
+                ip_blacklist=self.ip_blacklist,
+            ),
+        )
+
+        # The unsafe domains and IPs should be rejected.
+        for domain in (self.unsafe_domain, self.unsafe_ip):
+            self.failureResultOf(
+                agent.request(b"GET", b"http://" + domain), DNSLookupError
+            )
+
+        # The safe domains IPs should be accepted.
+        for domain in (
+            self.safe_domain,
+            self.allowed_domain,
+            self.safe_ip,
+            self.allowed_ip,
+        ):
+            d = agent.request(b"GET", b"http://" + domain)
+
+            # Grab the latest TCP connection.
+            (
+                host,
+                port,
+                client_factory,
+                _timeout,
+                _bindAddress,
+            ) = self.reactor.tcpClients[-1]
+
+            # Make the connection and pump data through it.
+            client = client_factory.buildProtocol(None)
+            server = AccumulatingProtocol()
+            server.makeConnection(FakeTransport(client, self.reactor))
+            client.makeConnection(FakeTransport(server, self.reactor))
+            client.dataReceived(
+                b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
+            )
+
+            response = self.successResultOf(d)
+            self.assertEqual(response.code, 200)
+
+    def test_agent(self):
+        """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
+        agent = BlacklistingAgentWrapper(
+            Agent(self.reactor),
+            ip_whitelist=self.ip_whitelist,
+            ip_blacklist=self.ip_blacklist,
+        )
+
+        # The unsafe IPs should be rejected.
+        self.failureResultOf(
+            agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
+        )
+
+        # The safe and unsafe domains and safe IPs should be accepted.
+        for domain in (
+            self.safe_domain,
+            self.unsafe_domain,
+            self.allowed_domain,
+            self.safe_ip,
+            self.allowed_ip,
+        ):
+            d = agent.request(b"GET", b"http://" + domain)
+
+            # Grab the latest TCP connection.
+            (
+                host,
+                port,
+                client_factory,
+                _timeout,
+                _bindAddress,
+            ) = self.reactor.tcpClients[-1]
+
+            # Make the connection and pump data through it.
+            client = client_factory.buildProtocol(None)
+            server = AccumulatingProtocol()
+            server.makeConnection(FakeTransport(client, self.reactor))
+            client.makeConnection(FakeTransport(server, self.reactor))
+            client.dataReceived(
+                b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
+            )
+
+            response = self.successResultOf(d)
+            self.assertEqual(response.code, 200)