summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2019-01-24 09:34:44 +0000
committerGitHub <noreply@github.com>2019-01-24 09:34:44 +0000
commit97fd29c019ae92cd3dc0635de249acfc9c892340 (patch)
tree2a85a700325501b61ebc41fce8514ade812da8bc /tests
parentMerge pull request #4445 from matrix-org/anoa/user_dir_develop_backport (diff)
downloadsynapse-97fd29c019ae92cd3dc0635de249acfc9c892340.tar.xz
Don't send IP addresses as SNI (#4452)
The problem here is that we have cut-and-pasted an impl from Twisted, and then
failed to maintain it. It was fixed in Twisted in
https://github.com/twisted/twisted/pull/1047/files; let's do the same here.
Diffstat (limited to 'tests')
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py63
1 files changed, 60 insertions, 3 deletions
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index eb963d80fb..7a3881f558 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -46,7 +46,7 @@ class MatrixFederationAgentTests(TestCase):
             _srv_resolver=self.mock_resolver,
         )
 
-    def _make_connection(self, client_factory):
+    def _make_connection(self, client_factory, expected_sni):
         """Builds a test server, and completes the outgoing client connection
 
         Returns:
@@ -69,9 +69,17 @@ class MatrixFederationAgentTests(TestCase):
         # tell the server tls protocol to send its stuff back to the client, too
         server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor))
 
-        # finally, give the reactor a pump to get the TLS juices flowing.
+        # give the reactor a pump to get the TLS juices flowing.
         self.reactor.pump((0.1,))
 
+        # check the SNI
+        server_name = server_tls_protocol._tlsConnection.get_servername()
+        self.assertEqual(
+            server_name,
+            expected_sni,
+            "Expected SNI %s but got %s" % (expected_sni, server_name),
+        )
+
         # fish the test server back out of the server-side TLS protocol.
         return server_tls_protocol.wrappedProtocol
 
@@ -113,7 +121,10 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(port, 8448)
 
         # make a test server, and wire up the client
-        http_server = self._make_connection(client_factory)
+        http_server = self._make_connection(
+            client_factory,
+            expected_sni=b"testserv",
+        )
 
         self.assertEqual(len(http_server.requests), 1)
         request = http_server.requests[0]
@@ -150,6 +161,52 @@ class MatrixFederationAgentTests(TestCase):
         json = self.successResultOf(treq.json_content(response))
         self.assertEqual(json, {"a": 1})
 
+    def test_get_ip_address(self):
+        """
+        Test the behaviour when the server name contains an explicit IP (with no port)
+        """
+
+        # the SRV lookup will return an empty list (XXX: why do we even do an SRV lookup?)
+        self.mock_resolver.resolve_service.side_effect = lambda _: []
+
+        # then there will be a getaddrinfo on the IP
+        self.reactor.lookups["1.2.3.4"] = "1.2.3.4"
+
+        test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar")
+
+        # Nothing happened yet
+        self.assertNoResult(test_d)
+
+        self.mock_resolver.resolve_service.assert_called_once()
+
+        # Make sure treq is trying to connect
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(port, 8448)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(
+            client_factory,
+            expected_sni=None,
+        )
+
+        self.assertEqual(len(http_server.requests), 1)
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b'GET')
+        self.assertEqual(request.path, b'/foo/bar')
+        # XXX currently broken
+        # self.assertEqual(
+        #     request.requestHeaders.getRawHeaders(b'host'),
+        #     [b'1.2.3.4:8448']
+        # )
+
+        # finish the request
+        request.finish()
+        self.reactor.pump((0.1,))
+        self.successResultOf(test_d)
+
 
 def _check_logcontext(context):
     current = LoggingContext.current_context()