diff options
Diffstat (limited to '')
-rw-r--r-- | changelog.d/4468.misc | 1 | ||||
-rw-r--r-- | synapse/http/federation/matrix_federation_agent.py | 10 | ||||
-rw-r--r-- | synapse/http/matrixfederationclient.py | 1 | ||||
-rw-r--r-- | tests/http/federation/test_matrix_federation_agent.py | 16 | ||||
-rw-r--r-- | tests/http/test_fedclient.py | 2 |
5 files changed, 28 insertions, 2 deletions
diff --git a/changelog.d/4468.misc b/changelog.d/4468.misc new file mode 100644 index 0000000000..9a51434755 --- /dev/null +++ b/changelog.d/4468.misc @@ -0,0 +1 @@ +Move SRV logic into the Agent layer diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 0ec28c6696..1788e9a34a 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -19,6 +19,7 @@ from zope.interface import implementer from twisted.internet import defer from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.web.client import URI, Agent, HTTPConnectionPool +from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent from synapse.http.endpoint import parse_server_name @@ -109,6 +110,15 @@ class MatrixFederationAgent(object): else: target = pick_server_from_list(server_list) + # make sure that the Host header is set correctly + if headers is None: + headers = Headers() + else: + headers = headers.copy() + + if not headers.hasHeader(b'host'): + headers.addRawHeader(b'host', server_name_bytes) + class EndpointFactory(object): @staticmethod def endpointForURI(_uri): diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 980e912348..bb2e64ed80 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -255,7 +255,6 @@ class MatrixFederationHttpClient(object): headers_dict = { b"User-Agent": [self.version_string_bytes], - b"Host": [destination_bytes], } with limiter: diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index b32d7566a5..261afb5f41 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -131,6 +131,10 @@ class MatrixFederationAgentTests(TestCase): request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') + self.assertEqual( + request.requestHeaders.getRawHeaders(b'host'), + [b'testserv:8448'] + ) content = request.content.read() self.assertEqual(content, b'') @@ -195,6 +199,10 @@ class MatrixFederationAgentTests(TestCase): request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') + self.assertEqual( + request.requestHeaders.getRawHeaders(b'host'), + [b'1.2.3.4'], + ) # finish the request request.finish() @@ -235,6 +243,10 @@ class MatrixFederationAgentTests(TestCase): request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') + self.assertEqual( + request.requestHeaders.getRawHeaders(b'host'), + [b'testserv'], + ) # finish the request request.finish() @@ -276,6 +288,10 @@ class MatrixFederationAgentTests(TestCase): request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') + self.assertEqual( + request.requestHeaders.getRawHeaders(b'host'), + [b'testserv'], + ) # finish the request request.finish() diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index d37f8f9981..018c77ebcd 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -49,7 +49,6 @@ class FederationClientTests(HomeserverTestCase): return hs def prepare(self, reactor, clock, homeserver): - self.cl = MatrixFederationHttpClient(self.hs) self.reactor.lookups["testserv"] = "1.2.3.4" @@ -95,6 +94,7 @@ class FederationClientTests(HomeserverTestCase): # that should have made it send the request to the transport self.assertRegex(transport.value(), b"^GET /foo/bar") + self.assertRegex(transport.value(), b"Host: testserv:8008") # Deferred is still without a result self.assertNoResult(test_d) |