diff options
Diffstat (limited to 'synapse/http')
-rw-r--r-- | synapse/http/matrixfederationclient.py | 79 |
1 files changed, 26 insertions, 53 deletions
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index ed47e701e7..854e17a473 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -16,7 +16,7 @@ from twisted.internet import defer, reactor, protocol from twisted.internet.error import DNSLookupError -from twisted.web.client import readBody, _AgentBase, _URI, HTTPConnectionPool +from twisted.web.client import readBody, HTTPConnectionPool, Agent from twisted.web.http_headers import Headers from twisted.web._newclient import ResponseDone @@ -55,41 +55,17 @@ incoming_responses_counter = metrics.register_counter( ) -class MatrixFederationHttpAgent(_AgentBase): - - def __init__(self, reactor, pool=None): - _AgentBase.__init__(self, reactor, pool) - - def request(self, destination, endpoint, method, path, params, query, - headers, body_producer): - - outgoing_requests_counter.inc(method) - - host = b"" - port = 0 - fragment = b"" - - parsed_URI = _URI(b"http", destination, host, port, path, params, - query, fragment) - - # Set the connection pool key to be the destination. - key = destination - - d = self._requestWithEndpoint(key, endpoint, method, parsed_URI, - headers, body_producer, - parsed_URI.originForm) - - def _cb(response): - incoming_responses_counter.inc(method, response.code) - return response - - def _eb(failure): - incoming_responses_counter.inc(method, "ERR") - return failure +class MatrixFederationEndpointFactory(object): + def __init__(self, hs): + self.tls_context_factory = hs.tls_context_factory - d.addCallbacks(_cb, _eb) + def endpointForURI(self, uri): + destination = uri.netloc - return d + return matrix_federation_endpoint( + reactor, destination, timeout=10, + ssl_context_factory=self.tls_context_factory + ) class MatrixFederationHttpClient(object): @@ -107,12 +83,18 @@ class MatrixFederationHttpClient(object): self.server_name = hs.hostname pool = HTTPConnectionPool(reactor) pool.maxPersistentPerHost = 10 - self.agent = MatrixFederationHttpAgent(reactor, pool=pool) + self.agent = Agent.usingEndpointFactory( + reactor, MatrixFederationEndpointFactory(hs), pool=pool + ) self.clock = hs.get_clock() self.version_string = hs.version_string - self._next_id = 1 + def _create_url(self, destination, path_bytes, param_bytes, query_bytes): + return urlparse.urlunparse( + ("matrix", destination, path_bytes, param_bytes, query_bytes, "") + ) + @defer.inlineCallbacks def _create_request(self, destination, method, path_bytes, body_callback, headers_dict={}, param_bytes=b"", @@ -123,8 +105,8 @@ class MatrixFederationHttpClient(object): headers_dict[b"User-Agent"] = [self.version_string] headers_dict[b"Host"] = [destination] - url_bytes = urlparse.urlunparse( - ("", "", path_bytes, param_bytes, query_bytes, "",) + url_bytes = self._create_url( + destination, path_bytes, param_bytes, query_bytes ) txn_id = "%s-O-%s" % (method, self._next_id) @@ -139,8 +121,8 @@ class MatrixFederationHttpClient(object): # (once we have reliable transactions in place) retries_left = 5 - endpoint = preserve_context_over_fn( - self._getEndpoint, reactor, destination + http_url_bytes = urlparse.urlunparse( + ("", "", path_bytes, param_bytes, query_bytes, "") ) log_result = None @@ -148,17 +130,14 @@ class MatrixFederationHttpClient(object): while True: producer = None if body_callback: - producer = body_callback(method, url_bytes, headers_dict) + producer = body_callback(method, http_url_bytes, headers_dict) try: def send_request(): - request_deferred = self.agent.request( - destination, - endpoint, + request_deferred = preserve_context_over_fn( + self.agent.request, method, - path_bytes, - param_bytes, - query_bytes, + url_bytes, Headers(headers_dict), producer ) @@ -452,12 +431,6 @@ class MatrixFederationHttpClient(object): defer.returnValue((length, headers)) - def _getEndpoint(self, reactor, destination): - return matrix_federation_endpoint( - reactor, destination, timeout=10, - ssl_context_factory=self.hs.tls_context_factory - ) - class _ReadBodyToFileProtocol(protocol.Protocol): def __init__(self, stream, deferred, max_size): |