diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 32bfd68ed1..64c780a341 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -22,7 +22,7 @@ from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.iweb import IAgent
from synapse.http.endpoint import parse_server_name
-from synapse.http.federation.srv_resolver import pick_server_from_list, resolve_service
+from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__)
@@ -37,13 +37,23 @@ class MatrixFederationAgent(object):
Args:
reactor (IReactor): twisted reactor to use for underlying requests
+
tls_client_options_factory (ClientTLSOptionsFactory|None):
factory to use for fetching client tls options, or none to disable TLS.
+
+ srv_resolver (SrvResolver|None):
+ SRVResolver impl to use for looking up SRV records. None to use a default
+ implementation.
"""
- def __init__(self, reactor, tls_client_options_factory):
+ def __init__(
+ self, reactor, tls_client_options_factory, _srv_resolver=None,
+ ):
self._reactor = reactor
self._tls_client_options_factory = tls_client_options_factory
+ if _srv_resolver is None:
+ _srv_resolver = SrvResolver()
+ self._srv_resolver = _srv_resolver
self._pool = HTTPConnectionPool(reactor)
self._pool.retryAutomatically = False
@@ -91,7 +101,7 @@ class MatrixFederationAgent(object):
if port is not None:
target = (host, port)
else:
- server_list = yield resolve_service(server_name_bytes)
+ server_list = yield self._srv_resolver.resolve_service(server_name_bytes)
if not server_list:
target = (host, 8448)
logger.debug("No SRV record for %s, using %s", host, target)
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index e05f934d0b..71830c549d 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -84,73 +84,86 @@ def pick_server_from_list(server_list):
)
-@defer.inlineCallbacks
-def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
- """Look up a SRV record, with caching
+class SrvResolver(object):
+ """Interface to the dns client to do SRV lookups, with result caching.
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
but the cache never gets populated), so we add our own caching layer here.
Args:
- service_name (bytes): record to look up
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
cache (dict): cache object
- clock (object): clock implementation. must provide a time() method.
-
- Returns:
- Deferred[list[Server]]: a list of the SRV records, or an empty list if none found
+ get_time (callable): clock implementation. Should return seconds since the epoch
"""
- if not isinstance(service_name, bytes):
- raise TypeError("%r is not a byte string" % (service_name,))
-
- cache_entry = cache.get(service_name, None)
- if cache_entry:
- if all(s.expires > int(clock.time()) for s in cache_entry):
- servers = list(cache_entry)
- defer.returnValue(servers)
-
- try:
- answers, _, _ = yield make_deferred_yieldable(
- dns_client.lookupService(service_name),
- )
- except DNSNameError:
- # TODO: cache this. We can get the SOA out of the exception, and use
- # the negative-TTL value.
- defer.returnValue([])
- except DomainError as e:
- # We failed to resolve the name (other than a NameError)
- # Try something in the cache, else rereaise
- cache_entry = cache.get(service_name, None)
+ def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
+ self._dns_client = dns_client
+ self._cache = cache
+ self._get_time = get_time
+
+ @defer.inlineCallbacks
+ def resolve_service(self, service_name):
+ """Look up a SRV record
+
+ Args:
+ service_name (bytes): record to look up
+
+ Returns:
+ Deferred[list[Server]]:
+ a list of the SRV records, or an empty list if none found
+ """
+ now = int(self._get_time())
+
+ if not isinstance(service_name, bytes):
+ raise TypeError("%r is not a byte string" % (service_name,))
+
+ cache_entry = self._cache.get(service_name, None)
if cache_entry:
- logger.warn(
- "Failed to resolve %r, falling back to cache. %r",
- service_name, e
+ if all(s.expires > now for s in cache_entry):
+ servers = list(cache_entry)
+ defer.returnValue(servers)
+
+ try:
+ answers, _, _ = yield make_deferred_yieldable(
+ self._dns_client.lookupService(service_name),
)
- defer.returnValue(list(cache_entry))
- else:
- raise e
-
- if (len(answers) == 1
- and answers[0].type == dns.SRV
- and answers[0].payload
- and answers[0].payload.target == dns.Name(b'.')):
- raise ConnectError("Service %s unavailable" % service_name)
-
- servers = []
-
- for answer in answers:
- if answer.type != dns.SRV or not answer.payload:
- continue
-
- payload = answer.payload
-
- servers.append(Server(
- host=payload.target.name,
- port=payload.port,
- priority=payload.priority,
- weight=payload.weight,
- expires=int(clock.time()) + answer.ttl,
- ))
-
- cache[service_name] = list(servers)
- defer.returnValue(servers)
+ except DNSNameError:
+ # TODO: cache this. We can get the SOA out of the exception, and use
+ # the negative-TTL value.
+ defer.returnValue([])
+ except DomainError as e:
+ # We failed to resolve the name (other than a NameError)
+ # Try something in the cache, else rereaise
+ cache_entry = self._cache.get(service_name, None)
+ if cache_entry:
+ logger.warn(
+ "Failed to resolve %r, falling back to cache. %r",
+ service_name, e
+ )
+ defer.returnValue(list(cache_entry))
+ else:
+ raise e
+
+ if (len(answers) == 1
+ and answers[0].type == dns.SRV
+ and answers[0].payload
+ and answers[0].payload.target == dns.Name(b'.')):
+ raise ConnectError("Service %s unavailable" % service_name)
+
+ servers = []
+
+ for answer in answers:
+ if answer.type != dns.SRV or not answer.payload:
+ continue
+
+ payload = answer.payload
+
+ servers.append(Server(
+ host=payload.target.name,
+ port=payload.port,
+ priority=payload.priority,
+ weight=payload.weight,
+ expires=now + answer.ttl,
+ ))
+
+ self._cache[service_name] = list(servers)
+ defer.returnValue(servers)
|