diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index ed47e701e7..4d74bd5d78 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -16,7 +16,7 @@
from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError
-from twisted.web.client import readBody, _AgentBase, _URI, HTTPConnectionPool
+from twisted.web.client import readBody, HTTPConnectionPool, Agent
from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone
@@ -55,41 +55,17 @@ incoming_responses_counter = metrics.register_counter(
)
-class MatrixFederationHttpAgent(_AgentBase):
-
- def __init__(self, reactor, pool=None):
- _AgentBase.__init__(self, reactor, pool)
-
- def request(self, destination, endpoint, method, path, params, query,
- headers, body_producer):
-
- outgoing_requests_counter.inc(method)
-
- host = b""
- port = 0
- fragment = b""
-
- parsed_URI = _URI(b"http", destination, host, port, path, params,
- query, fragment)
-
- # Set the connection pool key to be the destination.
- key = destination
-
- d = self._requestWithEndpoint(key, endpoint, method, parsed_URI,
- headers, body_producer,
- parsed_URI.originForm)
-
- def _cb(response):
- incoming_responses_counter.inc(method, response.code)
- return response
-
- def _eb(failure):
- incoming_responses_counter.inc(method, "ERR")
- return failure
+class MatrixFederationEndpointFactory(object):
+ def __init__(self, hs):
+ self.tls_context_factory = hs.tls_context_factory
- d.addCallbacks(_cb, _eb)
+ def endpointForURI(self, uri):
+ destination = uri.netloc
- return d
+ return matrix_federation_endpoint(
+ reactor, destination, timeout=10,
+ ssl_context_factory=self.tls_context_factory
+ )
class MatrixFederationHttpClient(object):
@@ -107,12 +83,18 @@ class MatrixFederationHttpClient(object):
self.server_name = hs.hostname
pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10
- self.agent = MatrixFederationHttpAgent(reactor, pool=pool)
+ self.agent = Agent.usingEndpointFactory(
+ reactor, MatrixFederationEndpointFactory(hs), pool=pool
+ )
self.clock = hs.get_clock()
self.version_string = hs.version_string
-
self._next_id = 1
+ def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
+ return urlparse.urlunparse(
+ ("matrix", destination, path_bytes, param_bytes, query_bytes, "")
+ )
+
@defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes,
body_callback, headers_dict={}, param_bytes=b"",
@@ -123,8 +105,8 @@ class MatrixFederationHttpClient(object):
headers_dict[b"User-Agent"] = [self.version_string]
headers_dict[b"Host"] = [destination]
- url_bytes = urlparse.urlunparse(
- ("", "", path_bytes, param_bytes, query_bytes, "",)
+ url_bytes = self._create_url(
+ destination, path_bytes, param_bytes, query_bytes
)
txn_id = "%s-O-%s" % (method, self._next_id)
@@ -139,8 +121,8 @@ class MatrixFederationHttpClient(object):
# (once we have reliable transactions in place)
retries_left = 5
- endpoint = preserve_context_over_fn(
- self._getEndpoint, reactor, destination
+ http_url_bytes = urlparse.urlunparse(
+ ("", "", path_bytes, param_bytes, query_bytes, "")
)
log_result = None
@@ -148,21 +130,19 @@ class MatrixFederationHttpClient(object):
while True:
producer = None
if body_callback:
- producer = body_callback(method, url_bytes, headers_dict)
+ producer = body_callback(method, http_url_bytes, headers_dict)
try:
def send_request():
- request_deferred = self.agent.request(
- destination,
- endpoint,
+ request_deferred = preserve_context_over_fn(
+ self.agent.request,
method,
- path_bytes,
- param_bytes,
- query_bytes,
+ url_bytes,
Headers(headers_dict),
producer
)
+
return self.clock.time_bound_deferred(
request_deferred,
time_out=timeout/1000. if timeout else 60,
@@ -452,12 +432,6 @@ class MatrixFederationHttpClient(object):
defer.returnValue((length, headers))
- def _getEndpoint(self, reactor, destination):
- return matrix_federation_endpoint(
- reactor, destination, timeout=10,
- ssl_context_factory=self.hs.tls_context_factory
- )
-
class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size):
|