diff options
Diffstat (limited to 'synapse/http/endpoint.py')
-rw-r--r-- | synapse/http/endpoint.py | 158 |
1 files changed, 115 insertions, 43 deletions
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index d8923c9abb..d65daa72bb 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -12,30 +12,97 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS -from twisted.internet import defer, reactor -from twisted.internet.error import ConnectError -from twisted.names import client, dns -from twisted.names.error import DNSNameError, DomainError - import collections import logging import random +import re import time +from twisted.internet import defer +from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS +from twisted.internet.error import ConnectError +from twisted.names import client, dns +from twisted.names.error import DNSNameError, DomainError logger = logging.getLogger(__name__) SERVER_CACHE = {} - +# our record of an individual server which can be tried to reach a destination. +# +# "host" is the hostname acquired from the SRV record. Except when there's +# no SRV record, in which case it is the original hostname. _Server = collections.namedtuple( "_Server", "priority weight host port expires" ) +def parse_server_name(server_name): + """Split a server name into host/port parts. + + Args: + server_name (str): server name to parse + + Returns: + Tuple[str, int|None]: host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + try: + if server_name[-1] == ']': + # ipv6 literal, hopefully + return server_name, None + + domain_port = server_name.rsplit(":", 1) + domain = domain_port[0] + port = int(domain_port[1]) if domain_port[1:] else None + return domain, port + except Exception: + raise ValueError("Invalid server name '%s'" % server_name) + + +VALID_HOST_REGEX = re.compile( + "\\A[0-9a-zA-Z.-]+\\Z", +) + + +def parse_and_validate_server_name(server_name): + """Split a server name into host/port parts and do some basic validation. + + Args: + server_name (str): server name to parse + + Returns: + Tuple[str, int|None]: host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + host, port = parse_server_name(server_name) + + # these tests don't need to be bulletproof as we'll find out soon enough + # if somebody is giving us invalid data. What we *do* need is to be sure + # that nobody is sneaking IP literals in that look like hostnames, etc. + + # look for ipv6 literals + if host[0] == '[': + if host[-1] != ']': + raise ValueError("Mismatched [...] in server name '%s'" % ( + server_name, + )) + return host, port + + # otherwise it should only be alphanumerics. + if not VALID_HOST_REGEX.match(host): + raise ValueError("Server name '%s' contains invalid characters" % ( + server_name, + )) + + return host, port + + def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, timeout=None): """Construct an endpoint for the given matrix destination. @@ -48,9 +115,7 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, timeout (int): connection timeout in seconds """ - domain_port = destination.split(":") - domain = domain_port[0] - port = int(domain_port[1]) if domain_port[1:] else None + domain, port = parse_server_name(destination) endpoint_kw_args = {} @@ -72,21 +137,22 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, reactor, "matrix", domain, protocol="tcp", default_port=default_port, endpoint=transport_endpoint, endpoint_kw_args=endpoint_kw_args - )) + ), reactor) else: return _WrappingEndpointFac(transport_endpoint( reactor, domain, port, **endpoint_kw_args - )) + ), reactor) class _WrappingEndpointFac(object): - def __init__(self, endpoint_fac): + def __init__(self, endpoint_fac, reactor): self.endpoint_fac = endpoint_fac + self.reactor = reactor @defer.inlineCallbacks def connect(self, protocolFactory): conn = yield self.endpoint_fac.connect(protocolFactory) - conn = _WrappedConnection(conn) + conn = _WrappedConnection(conn, self.reactor) defer.returnValue(conn) @@ -96,9 +162,10 @@ class _WrappedConnection(object): """ __slots__ = ["conn", "last_request"] - def __init__(self, conn): + def __init__(self, conn, reactor): object.__setattr__(self, "conn", conn) object.__setattr__(self, "last_request", time.time()) + self._reactor = reactor def __getattr__(self, name): return getattr(self.conn, name) @@ -113,10 +180,15 @@ class _WrappedConnection(object): if time.time() - self.last_request >= 2.5 * 60: self.abort() # Abort the underlying TLS connection. The abort() method calls - # loseConnection() on the underlying TLS connection which tries to + # loseConnection() on the TLS connection which tries to # shutdown the connection cleanly. We call abortConnection() - # since that will promptly close the underlying TCP connection. - self.transport.abortConnection() + # since that will promptly close the TLS connection. + # + # In Twisted >18.4; the TLS connection will be None if it has closed + # which will make abortConnection() throw. Check that the TLS connection + # is not None before trying to close it. + if self.transport.getHandle() is not None: + self.transport.abortConnection() def request(self, request): self.last_request = time.time() @@ -124,14 +196,14 @@ class _WrappedConnection(object): # Time this connection out if we haven't send a request in the last # N minutes # TODO: Cancel the previous callLater? - reactor.callLater(3 * 60, self._time_things_out_maybe) + self._reactor.callLater(3 * 60, self._time_things_out_maybe) d = self.conn.request(request) def update_request_time(res): self.last_request = time.time() # TODO: Cancel the previous callLater? - reactor.callLater(3 * 60, self._time_things_out_maybe) + self._reactor.callLater(3 * 60, self._time_things_out_maybe) return res d.addCallback(update_request_time) @@ -219,9 +291,10 @@ class SRVClientEndpoint(object): return self.default_server else: raise ConnectError( - "Not server available for %s" % self.service_name + "No server available for %s" % self.service_name ) + # look for all servers with the same priority min_priority = self.servers[0].priority weight_indexes = list( (index, server.weight + 1) @@ -231,11 +304,22 @@ class SRVClientEndpoint(object): total_weight = sum(weight for index, weight in weight_indexes) target_weight = random.randint(0, total_weight) - for index, weight in weight_indexes: target_weight -= weight if target_weight <= 0: server = self.servers[index] + # XXX: this looks totally dubious: + # + # (a) we never reuse a server until we have been through + # all of the servers at the same priority, so if the + # weights are A: 100, B:1, we always do ABABAB instead of + # AAAA...AAAB (approximately). + # + # (b) After using all the servers at the lowest priority, + # we move onto the next priority. We should only use the + # second priority if servers at the top priority are + # unreachable. + # del self.servers[index] self.used_servers.append(server) return server @@ -272,7 +356,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t if (len(answers) == 1 and answers[0].type == dns.SRV and answers[0].payload - and answers[0].payload.target == dns.Name('.')): + and answers[0].payload.target == dns.Name(b'.')): raise ConnectError("Service %s unavailable" % service_name) for answer in answers: @@ -280,26 +364,14 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t continue payload = answer.payload - host = str(payload.target) - srv_ttl = answer.ttl - - try: - answers, _, _ = yield dns_client.lookupAddress(host) - except DNSNameError: - continue - for answer in answers: - if answer.type == dns.A and answer.payload: - ip = answer.payload.dottedQuad() - host_ttl = min(srv_ttl, answer.ttl) - - servers.append(_Server( - host=ip, - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight), - expires=int(clock.time()) + host_ttl, - )) + servers.append(_Server( + host=str(payload.target), + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight), + expires=int(clock.time()) + answer.ttl, + )) servers.sort() cache[service_name] = list(servers) |