diff options
Diffstat (limited to 'synapse/http/federation')
-rw-r--r-- | synapse/http/federation/matrix_federation_agent.py | 125 | ||||
-rw-r--r-- | synapse/http/federation/srv_resolver.py | 167 |
2 files changed, 231 insertions, 61 deletions
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py new file mode 100644 index 0000000000..0ec28c6696 --- /dev/null +++ b/synapse/http/federation/matrix_federation_agent.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from zope.interface import implementer + +from twisted.internet import defer +from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS +from twisted.web.client import URI, Agent, HTTPConnectionPool +from twisted.web.iweb import IAgent + +from synapse.http.endpoint import parse_server_name +from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list +from synapse.util.logcontext import make_deferred_yieldable + +logger = logging.getLogger(__name__) + + +@implementer(IAgent) +class MatrixFederationAgent(object): + """An Agent-like thing which provides a `request` method which will look up a matrix + server and send an HTTP request to it. + + Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.) + + Args: + reactor (IReactor): twisted reactor to use for underlying requests + + tls_client_options_factory (ClientTLSOptionsFactory|None): + factory to use for fetching client tls options, or none to disable TLS. + + srv_resolver (SrvResolver|None): + SRVResolver impl to use for looking up SRV records. None to use a default + implementation. + """ + + def __init__( + self, reactor, tls_client_options_factory, _srv_resolver=None, + ): + self._reactor = reactor + self._tls_client_options_factory = tls_client_options_factory + if _srv_resolver is None: + _srv_resolver = SrvResolver() + self._srv_resolver = _srv_resolver + + self._pool = HTTPConnectionPool(reactor) + self._pool.retryAutomatically = False + self._pool.maxPersistentPerHost = 5 + self._pool.cachedConnectionTimeout = 2 * 60 + + @defer.inlineCallbacks + def request(self, method, uri, headers=None, bodyProducer=None): + """ + Args: + method (bytes): HTTP method: GET/POST/etc + + uri (bytes): Absolute URI to be retrieved + + headers (twisted.web.http_headers.Headers|None): + HTTP headers to send with the request, or None to + send no extra headers. + + bodyProducer (twisted.web.iweb.IBodyProducer|None): + An object which can generate bytes to make up the + body of this request (for example, the properly encoded contents of + a file for a file upload). Or None if the request is to have + no body. + + Returns: + Deferred[twisted.web.iweb.IResponse]: + fires when the header of the response has been received (regardless of the + response status code). Fails if there is any problem which prevents that + response from being received (including problems that prevent the request + from being sent). + """ + + parsed_uri = URI.fromBytes(uri) + server_name_bytes = parsed_uri.netloc + host, port = parse_server_name(server_name_bytes.decode("ascii")) + + # XXX disabling TLS is really only supported here for the benefit of the + # unit tests. We should make the UTs cope with TLS rather than having to make + # the code support the unit tests. + if self._tls_client_options_factory is None: + tls_options = None + else: + tls_options = self._tls_client_options_factory.get_options(host) + + if port is not None: + target = (host, port) + else: + service_name = b"_matrix._tcp.%s" % (server_name_bytes, ) + server_list = yield self._srv_resolver.resolve_service(service_name) + if not server_list: + target = (host, 8448) + logger.debug("No SRV record for %s, using %s", host, target) + else: + target = pick_server_from_list(server_list) + + class EndpointFactory(object): + @staticmethod + def endpointForURI(_uri): + logger.info("Connecting to %s:%s", target[0], target[1]) + ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1]) + if tls_options is not None: + ep = wrapClientTLS(tls_options, ep) + return ep + + agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool) + res = yield make_deferred_yieldable( + agent.request(method, uri, headers, bodyProducer) + ) + defer.returnValue(res) diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index c49b82c394..71830c549d 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +import random import time import attr @@ -51,74 +52,118 @@ class Server(object): expires = attr.ib(default=0) -@defer.inlineCallbacks -def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time): - """Look up a SRV record, with caching +def pick_server_from_list(server_list): + """Randomly choose a server from the server list + + Args: + server_list (list[Server]): list of candidate servers + + Returns: + Tuple[bytes, int]: (host, port) pair for the chosen server + """ + if not server_list: + raise RuntimeError("pick_server_from_list called with empty list") + + # TODO: currently we only use the lowest-priority servers. We should maintain a + # cache of servers known to be "down" and filter them out + + min_priority = min(s.priority for s in server_list) + eligible_servers = list(s for s in server_list if s.priority == min_priority) + total_weight = sum(s.weight for s in eligible_servers) + target_weight = random.randint(0, total_weight) + + for s in eligible_servers: + target_weight -= s.weight + + if target_weight <= 0: + return s.host, s.port + + # this should be impossible. + raise RuntimeError( + "pick_server_from_list got to end of eligible server list.", + ) + + +class SrvResolver(object): + """Interface to the dns client to do SRV lookups, with result caching. The default resolver in twisted.names doesn't do any caching (it has a CacheResolver, but the cache never gets populated), so we add our own caching layer here. Args: - service_name (unicode|bytes): record to look up dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl cache (dict): cache object - clock (object): clock implementation. must provide a time() method. - - Returns: - Deferred[list[Server]]: a list of the SRV records, or an empty list if none found + get_time (callable): clock implementation. Should return seconds since the epoch """ - # TODO: the dns client handles both unicode names (encoding via idna) and pre-encoded - # byteses; however they will obviously end up as separate entries in the cache. We - # should pick one form and stick with it. - cache_entry = cache.get(service_name, None) - if cache_entry: - if all(s.expires > int(clock.time()) for s in cache_entry): - servers = list(cache_entry) - defer.returnValue(servers) - - try: - answers, _, _ = yield make_deferred_yieldable( - dns_client.lookupService(service_name), - ) - except DNSNameError: - # TODO: cache this. We can get the SOA out of the exception, and use - # the negative-TTL value. - defer.returnValue([]) - except DomainError as e: - # We failed to resolve the name (other than a NameError) - # Try something in the cache, else rereaise - cache_entry = cache.get(service_name, None) + def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time): + self._dns_client = dns_client + self._cache = cache + self._get_time = get_time + + @defer.inlineCallbacks + def resolve_service(self, service_name): + """Look up a SRV record + + Args: + service_name (bytes): record to look up + + Returns: + Deferred[list[Server]]: + a list of the SRV records, or an empty list if none found + """ + now = int(self._get_time()) + + if not isinstance(service_name, bytes): + raise TypeError("%r is not a byte string" % (service_name,)) + + cache_entry = self._cache.get(service_name, None) if cache_entry: - logger.warn( - "Failed to resolve %r, falling back to cache. %r", - service_name, e + if all(s.expires > now for s in cache_entry): + servers = list(cache_entry) + defer.returnValue(servers) + + try: + answers, _, _ = yield make_deferred_yieldable( + self._dns_client.lookupService(service_name), ) - defer.returnValue(list(cache_entry)) - else: - raise e - - if (len(answers) == 1 - and answers[0].type == dns.SRV - and answers[0].payload - and answers[0].payload.target == dns.Name(b'.')): - raise ConnectError("Service %s unavailable" % service_name) - - servers = [] - - for answer in answers: - if answer.type != dns.SRV or not answer.payload: - continue - - payload = answer.payload - - servers.append(Server( - host=payload.target.name, - port=payload.port, - priority=payload.priority, - weight=payload.weight, - expires=int(clock.time()) + answer.ttl, - )) - - servers.sort() # FIXME: get rid of this (it's broken by the attrs change) - cache[service_name] = list(servers) - defer.returnValue(servers) + except DNSNameError: + # TODO: cache this. We can get the SOA out of the exception, and use + # the negative-TTL value. + defer.returnValue([]) + except DomainError as e: + # We failed to resolve the name (other than a NameError) + # Try something in the cache, else rereaise + cache_entry = self._cache.get(service_name, None) + if cache_entry: + logger.warn( + "Failed to resolve %r, falling back to cache. %r", + service_name, e + ) + defer.returnValue(list(cache_entry)) + else: + raise e + + if (len(answers) == 1 + and answers[0].type == dns.SRV + and answers[0].payload + and answers[0].payload.target == dns.Name(b'.')): + raise ConnectError("Service %s unavailable" % service_name) + + servers = [] + + for answer in answers: + if answer.type != dns.SRV or not answer.payload: + continue + + payload = answer.payload + + servers.append(Server( + host=payload.target.name, + port=payload.port, + priority=payload.priority, + weight=payload.weight, + expires=now + answer.ttl, + )) + + self._cache[service_name] = list(servers) + defer.returnValue(servers) |