diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
new file mode 100644
index 0000000000..0ec28c6696
--- /dev/null
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -0,0 +1,125 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from zope.interface import implementer
+
+from twisted.internet import defer
+from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.web.client import URI, Agent, HTTPConnectionPool
+from twisted.web.iweb import IAgent
+
+from synapse.http.endpoint import parse_server_name
+from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
+from synapse.util.logcontext import make_deferred_yieldable
+
+logger = logging.getLogger(__name__)
+
+
+@implementer(IAgent)
+class MatrixFederationAgent(object):
+ """An Agent-like thing which provides a `request` method which will look up a matrix
+ server and send an HTTP request to it.
+
+ Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
+
+ Args:
+ reactor (IReactor): twisted reactor to use for underlying requests
+
+ tls_client_options_factory (ClientTLSOptionsFactory|None):
+ factory to use for fetching client tls options, or none to disable TLS.
+
+ srv_resolver (SrvResolver|None):
+ SRVResolver impl to use for looking up SRV records. None to use a default
+ implementation.
+ """
+
+ def __init__(
+ self, reactor, tls_client_options_factory, _srv_resolver=None,
+ ):
+ self._reactor = reactor
+ self._tls_client_options_factory = tls_client_options_factory
+ if _srv_resolver is None:
+ _srv_resolver = SrvResolver()
+ self._srv_resolver = _srv_resolver
+
+ self._pool = HTTPConnectionPool(reactor)
+ self._pool.retryAutomatically = False
+ self._pool.maxPersistentPerHost = 5
+ self._pool.cachedConnectionTimeout = 2 * 60
+
+ @defer.inlineCallbacks
+ def request(self, method, uri, headers=None, bodyProducer=None):
+ """
+ Args:
+ method (bytes): HTTP method: GET/POST/etc
+
+ uri (bytes): Absolute URI to be retrieved
+
+ headers (twisted.web.http_headers.Headers|None):
+ HTTP headers to send with the request, or None to
+ send no extra headers.
+
+ bodyProducer (twisted.web.iweb.IBodyProducer|None):
+ An object which can generate bytes to make up the
+ body of this request (for example, the properly encoded contents of
+ a file for a file upload). Or None if the request is to have
+ no body.
+
+ Returns:
+ Deferred[twisted.web.iweb.IResponse]:
+ fires when the header of the response has been received (regardless of the
+ response status code). Fails if there is any problem which prevents that
+ response from being received (including problems that prevent the request
+ from being sent).
+ """
+
+ parsed_uri = URI.fromBytes(uri)
+ server_name_bytes = parsed_uri.netloc
+ host, port = parse_server_name(server_name_bytes.decode("ascii"))
+
+ # XXX disabling TLS is really only supported here for the benefit of the
+ # unit tests. We should make the UTs cope with TLS rather than having to make
+ # the code support the unit tests.
+ if self._tls_client_options_factory is None:
+ tls_options = None
+ else:
+ tls_options = self._tls_client_options_factory.get_options(host)
+
+ if port is not None:
+ target = (host, port)
+ else:
+ service_name = b"_matrix._tcp.%s" % (server_name_bytes, )
+ server_list = yield self._srv_resolver.resolve_service(service_name)
+ if not server_list:
+ target = (host, 8448)
+ logger.debug("No SRV record for %s, using %s", host, target)
+ else:
+ target = pick_server_from_list(server_list)
+
+ class EndpointFactory(object):
+ @staticmethod
+ def endpointForURI(_uri):
+ logger.info("Connecting to %s:%s", target[0], target[1])
+ ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1])
+ if tls_options is not None:
+ ep = wrapClientTLS(tls_options, ep)
+ return ep
+
+ agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
+ res = yield make_deferred_yieldable(
+ agent.request(method, uri, headers, bodyProducer)
+ )
+ defer.returnValue(res)
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index c49b82c394..71830c549d 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+import random
import time
import attr
@@ -51,74 +52,118 @@ class Server(object):
expires = attr.ib(default=0)
-@defer.inlineCallbacks
-def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
- """Look up a SRV record, with caching
+def pick_server_from_list(server_list):
+ """Randomly choose a server from the server list
+
+ Args:
+ server_list (list[Server]): list of candidate servers
+
+ Returns:
+ Tuple[bytes, int]: (host, port) pair for the chosen server
+ """
+ if not server_list:
+ raise RuntimeError("pick_server_from_list called with empty list")
+
+ # TODO: currently we only use the lowest-priority servers. We should maintain a
+ # cache of servers known to be "down" and filter them out
+
+ min_priority = min(s.priority for s in server_list)
+ eligible_servers = list(s for s in server_list if s.priority == min_priority)
+ total_weight = sum(s.weight for s in eligible_servers)
+ target_weight = random.randint(0, total_weight)
+
+ for s in eligible_servers:
+ target_weight -= s.weight
+
+ if target_weight <= 0:
+ return s.host, s.port
+
+ # this should be impossible.
+ raise RuntimeError(
+ "pick_server_from_list got to end of eligible server list.",
+ )
+
+
+class SrvResolver(object):
+ """Interface to the dns client to do SRV lookups, with result caching.
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
but the cache never gets populated), so we add our own caching layer here.
Args:
- service_name (unicode|bytes): record to look up
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
cache (dict): cache object
- clock (object): clock implementation. must provide a time() method.
-
- Returns:
- Deferred[list[Server]]: a list of the SRV records, or an empty list if none found
+ get_time (callable): clock implementation. Should return seconds since the epoch
"""
- # TODO: the dns client handles both unicode names (encoding via idna) and pre-encoded
- # byteses; however they will obviously end up as separate entries in the cache. We
- # should pick one form and stick with it.
- cache_entry = cache.get(service_name, None)
- if cache_entry:
- if all(s.expires > int(clock.time()) for s in cache_entry):
- servers = list(cache_entry)
- defer.returnValue(servers)
-
- try:
- answers, _, _ = yield make_deferred_yieldable(
- dns_client.lookupService(service_name),
- )
- except DNSNameError:
- # TODO: cache this. We can get the SOA out of the exception, and use
- # the negative-TTL value.
- defer.returnValue([])
- except DomainError as e:
- # We failed to resolve the name (other than a NameError)
- # Try something in the cache, else rereaise
- cache_entry = cache.get(service_name, None)
+ def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
+ self._dns_client = dns_client
+ self._cache = cache
+ self._get_time = get_time
+
+ @defer.inlineCallbacks
+ def resolve_service(self, service_name):
+ """Look up a SRV record
+
+ Args:
+ service_name (bytes): record to look up
+
+ Returns:
+ Deferred[list[Server]]:
+ a list of the SRV records, or an empty list if none found
+ """
+ now = int(self._get_time())
+
+ if not isinstance(service_name, bytes):
+ raise TypeError("%r is not a byte string" % (service_name,))
+
+ cache_entry = self._cache.get(service_name, None)
if cache_entry:
- logger.warn(
- "Failed to resolve %r, falling back to cache. %r",
- service_name, e
+ if all(s.expires > now for s in cache_entry):
+ servers = list(cache_entry)
+ defer.returnValue(servers)
+
+ try:
+ answers, _, _ = yield make_deferred_yieldable(
+ self._dns_client.lookupService(service_name),
)
- defer.returnValue(list(cache_entry))
- else:
- raise e
-
- if (len(answers) == 1
- and answers[0].type == dns.SRV
- and answers[0].payload
- and answers[0].payload.target == dns.Name(b'.')):
- raise ConnectError("Service %s unavailable" % service_name)
-
- servers = []
-
- for answer in answers:
- if answer.type != dns.SRV or not answer.payload:
- continue
-
- payload = answer.payload
-
- servers.append(Server(
- host=payload.target.name,
- port=payload.port,
- priority=payload.priority,
- weight=payload.weight,
- expires=int(clock.time()) + answer.ttl,
- ))
-
- servers.sort() # FIXME: get rid of this (it's broken by the attrs change)
- cache[service_name] = list(servers)
- defer.returnValue(servers)
+ except DNSNameError:
+ # TODO: cache this. We can get the SOA out of the exception, and use
+ # the negative-TTL value.
+ defer.returnValue([])
+ except DomainError as e:
+ # We failed to resolve the name (other than a NameError)
+ # Try something in the cache, else rereaise
+ cache_entry = self._cache.get(service_name, None)
+ if cache_entry:
+ logger.warn(
+ "Failed to resolve %r, falling back to cache. %r",
+ service_name, e
+ )
+ defer.returnValue(list(cache_entry))
+ else:
+ raise e
+
+ if (len(answers) == 1
+ and answers[0].type == dns.SRV
+ and answers[0].payload
+ and answers[0].payload.target == dns.Name(b'.')):
+ raise ConnectError("Service %s unavailable" % service_name)
+
+ servers = []
+
+ for answer in answers:
+ if answer.type != dns.SRV or not answer.payload:
+ continue
+
+ payload = answer.payload
+
+ servers.append(Server(
+ host=payload.target.name,
+ port=payload.port,
+ priority=payload.priority,
+ weight=payload.weight,
+ expires=now + answer.ttl,
+ ))
+
+ self._cache[service_name] = list(servers)
+ defer.returnValue(servers)
|