summary refs log tree commit diff
path: root/synapse/http/federation/matrix_federation_agent.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/federation/matrix_federation_agent.py')
-rw-r--r--synapse/http/federation/matrix_federation_agent.py16
1 files changed, 13 insertions, 3 deletions
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 32bfd68ed1..64c780a341 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -22,7 +22,7 @@ from twisted.web.client import URI, Agent, HTTPConnectionPool
 from twisted.web.iweb import IAgent
 
 from synapse.http.endpoint import parse_server_name
-from synapse.http.federation.srv_resolver import pick_server_from_list, resolve_service
+from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
 from synapse.util.logcontext import make_deferred_yieldable
 
 logger = logging.getLogger(__name__)
@@ -37,13 +37,23 @@ class MatrixFederationAgent(object):
 
     Args:
         reactor (IReactor): twisted reactor to use for underlying requests
+
         tls_client_options_factory (ClientTLSOptionsFactory|None):
             factory to use for fetching client tls options, or none to disable TLS.
+
+        srv_resolver (SrvResolver|None):
+            SRVResolver impl to use for looking up SRV records. None to use a default
+            implementation.
     """
 
-    def __init__(self, reactor, tls_client_options_factory):
+    def __init__(
+        self, reactor, tls_client_options_factory, _srv_resolver=None,
+    ):
         self._reactor = reactor
         self._tls_client_options_factory = tls_client_options_factory
+        if _srv_resolver is None:
+            _srv_resolver = SrvResolver()
+        self._srv_resolver = _srv_resolver
 
         self._pool = HTTPConnectionPool(reactor)
         self._pool.retryAutomatically = False
@@ -91,7 +101,7 @@ class MatrixFederationAgent(object):
         if port is not None:
             target = (host, port)
         else:
-            server_list = yield resolve_service(server_name_bytes)
+            server_list = yield self._srv_resolver.resolve_service(server_name_bytes)
             if not server_list:
                 target = (host, 8448)
                 logger.debug("No SRV record for %s, using %s", host, target)