diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 32bfd68ed1..64c780a341 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -22,7 +22,7 @@ from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.iweb import IAgent
from synapse.http.endpoint import parse_server_name
-from synapse.http.federation.srv_resolver import pick_server_from_list, resolve_service
+from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__)
@@ -37,13 +37,23 @@ class MatrixFederationAgent(object):
Args:
reactor (IReactor): twisted reactor to use for underlying requests
+
tls_client_options_factory (ClientTLSOptionsFactory|None):
factory to use for fetching client tls options, or none to disable TLS.
+
+ srv_resolver (SrvResolver|None):
+ SRVResolver impl to use for looking up SRV records. None to use a default
+ implementation.
"""
- def __init__(self, reactor, tls_client_options_factory):
+ def __init__(
+ self, reactor, tls_client_options_factory, _srv_resolver=None,
+ ):
self._reactor = reactor
self._tls_client_options_factory = tls_client_options_factory
+ if _srv_resolver is None:
+ _srv_resolver = SrvResolver()
+ self._srv_resolver = _srv_resolver
self._pool = HTTPConnectionPool(reactor)
self._pool.retryAutomatically = False
@@ -91,7 +101,7 @@ class MatrixFederationAgent(object):
if port is not None:
target = (host, port)
else:
- server_list = yield resolve_service(server_name_bytes)
+ server_list = yield self._srv_resolver.resolve_service(server_name_bytes)
if not server_list:
target = (host, 8448)
logger.debug("No SRV record for %s, using %s", host, target)
|