diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index eb963d80fb..b32d7566a5 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -26,6 +26,7 @@ from twisted.web.http import HTTPChannel
from synapse.crypto.context_factory import ClientTLSOptionsFactory
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
+from synapse.http.federation.srv_resolver import Server
from synapse.util.logcontext import LoggingContext
from tests.server import FakeTransport, ThreadedMemoryReactorClock
@@ -46,7 +47,7 @@ class MatrixFederationAgentTests(TestCase):
_srv_resolver=self.mock_resolver,
)
- def _make_connection(self, client_factory):
+ def _make_connection(self, client_factory, expected_sni):
"""Builds a test server, and completes the outgoing client connection
Returns:
@@ -69,9 +70,17 @@ class MatrixFederationAgentTests(TestCase):
# tell the server tls protocol to send its stuff back to the client, too
server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor))
- # finally, give the reactor a pump to get the TLS juices flowing.
+ # give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
+ # check the SNI
+ server_name = server_tls_protocol._tlsConnection.get_servername()
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
# fish the test server back out of the server-side TLS protocol.
return server_tls_protocol.wrappedProtocol
@@ -97,7 +106,7 @@ class MatrixFederationAgentTests(TestCase):
def test_get(self):
"""
- happy-path test of a GET request
+ happy-path test of a GET request with an explicit port
"""
self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar")
@@ -113,16 +122,15 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(port, 8448)
# make a test server, and wire up the client
- http_server = self._make_connection(client_factory)
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ )
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
- self.assertEqual(
- request.requestHeaders.getRawHeaders(b'host'),
- [b'testserv:8448']
- )
content = request.content.read()
self.assertEqual(content, b'')
@@ -150,6 +158,130 @@ class MatrixFederationAgentTests(TestCase):
json = self.successResultOf(treq.json_content(response))
self.assertEqual(json, {"a": 1})
+ def test_get_ip_address(self):
+ """
+ Test the behaviour when the server name contains an explicit IP (with no port)
+ """
+
+ # the SRV lookup will return an empty list (XXX: why do we even do an SRV lookup?)
+ self.mock_resolver.resolve_service.side_effect = lambda _: []
+
+ # then there will be a getaddrinfo on the IP
+ self.reactor.lookups["1.2.3.4"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix._tcp.1.2.3.4",
+ )
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 8448)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=None,
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/foo/bar')
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ def test_get_hostname_no_srv(self):
+ """
+ Test the behaviour when the server name has no port, and no SRV record
+ """
+
+ self.mock_resolver.resolve_service.side_effect = lambda _: []
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix._tcp.testserv",
+ )
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 8448)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=b'testserv',
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/foo/bar')
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+ def test_get_hostname_srv(self):
+ """
+ Test the behaviour when there is a single SRV record
+ """
+ self.mock_resolver.resolve_service.side_effect = lambda _: [
+ Server(host="srvtarget", port=8443)
+ ]
+ self.reactor.lookups["srvtarget"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix._tcp.testserv",
+ )
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 8443)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=b'testserv',
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/foo/bar')
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
def _check_logcontext(context):
current = LoggingContext.current_context()
|