diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index d8923c9abb..d65daa72bb 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,30 +12,97 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet import defer, reactor
-from twisted.internet.error import ConnectError
-from twisted.names import client, dns
-from twisted.names.error import DNSNameError, DomainError
-
import collections
import logging
import random
+import re
import time
+from twisted.internet import defer
+from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.internet.error import ConnectError
+from twisted.names import client, dns
+from twisted.names.error import DNSNameError, DomainError
logger = logging.getLogger(__name__)
SERVER_CACHE = {}
-
+# our record of an individual server which can be tried to reach a destination.
+#
+# "host" is the hostname acquired from the SRV record. Except when there's
+# no SRV record, in which case it is the original hostname.
_Server = collections.namedtuple(
"_Server", "priority weight host port expires"
)
+def parse_server_name(server_name):
+ """Split a server name into host/port parts.
+
+ Args:
+ server_name (str): server name to parse
+
+ Returns:
+ Tuple[str, int|None]: host/port parts.
+
+ Raises:
+ ValueError if the server name could not be parsed.
+ """
+ try:
+ if server_name[-1] == ']':
+ # ipv6 literal, hopefully
+ return server_name, None
+
+ domain_port = server_name.rsplit(":", 1)
+ domain = domain_port[0]
+ port = int(domain_port[1]) if domain_port[1:] else None
+ return domain, port
+ except Exception:
+ raise ValueError("Invalid server name '%s'" % server_name)
+
+
+VALID_HOST_REGEX = re.compile(
+ "\\A[0-9a-zA-Z.-]+\\Z",
+)
+
+
+def parse_and_validate_server_name(server_name):
+ """Split a server name into host/port parts and do some basic validation.
+
+ Args:
+ server_name (str): server name to parse
+
+ Returns:
+ Tuple[str, int|None]: host/port parts.
+
+ Raises:
+ ValueError if the server name could not be parsed.
+ """
+ host, port = parse_server_name(server_name)
+
+ # these tests don't need to be bulletproof as we'll find out soon enough
+ # if somebody is giving us invalid data. What we *do* need is to be sure
+ # that nobody is sneaking IP literals in that look like hostnames, etc.
+
+ # look for ipv6 literals
+ if host[0] == '[':
+ if host[-1] != ']':
+ raise ValueError("Mismatched [...] in server name '%s'" % (
+ server_name,
+ ))
+ return host, port
+
+ # otherwise it should only be alphanumerics.
+ if not VALID_HOST_REGEX.match(host):
+ raise ValueError("Server name '%s' contains invalid characters" % (
+ server_name,
+ ))
+
+ return host, port
+
+
def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout=None):
"""Construct an endpoint for the given matrix destination.
@@ -48,9 +115,7 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout (int): connection timeout in seconds
"""
- domain_port = destination.split(":")
- domain = domain_port[0]
- port = int(domain_port[1]) if domain_port[1:] else None
+ domain, port = parse_server_name(destination)
endpoint_kw_args = {}
@@ -72,21 +137,22 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
reactor, "matrix", domain, protocol="tcp",
default_port=default_port, endpoint=transport_endpoint,
endpoint_kw_args=endpoint_kw_args
- ))
+ ), reactor)
else:
return _WrappingEndpointFac(transport_endpoint(
reactor, domain, port, **endpoint_kw_args
- ))
+ ), reactor)
class _WrappingEndpointFac(object):
- def __init__(self, endpoint_fac):
+ def __init__(self, endpoint_fac, reactor):
self.endpoint_fac = endpoint_fac
+ self.reactor = reactor
@defer.inlineCallbacks
def connect(self, protocolFactory):
conn = yield self.endpoint_fac.connect(protocolFactory)
- conn = _WrappedConnection(conn)
+ conn = _WrappedConnection(conn, self.reactor)
defer.returnValue(conn)
@@ -96,9 +162,10 @@ class _WrappedConnection(object):
"""
__slots__ = ["conn", "last_request"]
- def __init__(self, conn):
+ def __init__(self, conn, reactor):
object.__setattr__(self, "conn", conn)
object.__setattr__(self, "last_request", time.time())
+ self._reactor = reactor
def __getattr__(self, name):
return getattr(self.conn, name)
@@ -113,10 +180,15 @@ class _WrappedConnection(object):
if time.time() - self.last_request >= 2.5 * 60:
self.abort()
# Abort the underlying TLS connection. The abort() method calls
- # loseConnection() on the underlying TLS connection which tries to
+ # loseConnection() on the TLS connection which tries to
# shutdown the connection cleanly. We call abortConnection()
- # since that will promptly close the underlying TCP connection.
- self.transport.abortConnection()
+ # since that will promptly close the TLS connection.
+ #
+ # In Twisted >18.4; the TLS connection will be None if it has closed
+ # which will make abortConnection() throw. Check that the TLS connection
+ # is not None before trying to close it.
+ if self.transport.getHandle() is not None:
+ self.transport.abortConnection()
def request(self, request):
self.last_request = time.time()
@@ -124,14 +196,14 @@ class _WrappedConnection(object):
# Time this connection out if we haven't send a request in the last
# N minutes
# TODO: Cancel the previous callLater?
- reactor.callLater(3 * 60, self._time_things_out_maybe)
+ self._reactor.callLater(3 * 60, self._time_things_out_maybe)
d = self.conn.request(request)
def update_request_time(res):
self.last_request = time.time()
# TODO: Cancel the previous callLater?
- reactor.callLater(3 * 60, self._time_things_out_maybe)
+ self._reactor.callLater(3 * 60, self._time_things_out_maybe)
return res
d.addCallback(update_request_time)
@@ -219,9 +291,10 @@ class SRVClientEndpoint(object):
return self.default_server
else:
raise ConnectError(
- "Not server available for %s" % self.service_name
+ "No server available for %s" % self.service_name
)
+ # look for all servers with the same priority
min_priority = self.servers[0].priority
weight_indexes = list(
(index, server.weight + 1)
@@ -231,11 +304,22 @@ class SRVClientEndpoint(object):
total_weight = sum(weight for index, weight in weight_indexes)
target_weight = random.randint(0, total_weight)
-
for index, weight in weight_indexes:
target_weight -= weight
if target_weight <= 0:
server = self.servers[index]
+ # XXX: this looks totally dubious:
+ #
+ # (a) we never reuse a server until we have been through
+ # all of the servers at the same priority, so if the
+ # weights are A: 100, B:1, we always do ABABAB instead of
+ # AAAA...AAAB (approximately).
+ #
+ # (b) After using all the servers at the lowest priority,
+ # we move onto the next priority. We should only use the
+ # second priority if servers at the top priority are
+ # unreachable.
+ #
del self.servers[index]
self.used_servers.append(server)
return server
@@ -272,7 +356,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
- and answers[0].payload.target == dns.Name('.')):
+ and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name)
for answer in answers:
@@ -280,26 +364,14 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
continue
payload = answer.payload
- host = str(payload.target)
- srv_ttl = answer.ttl
-
- try:
- answers, _, _ = yield dns_client.lookupAddress(host)
- except DNSNameError:
- continue
- for answer in answers:
- if answer.type == dns.A and answer.payload:
- ip = answer.payload.dottedQuad()
- host_ttl = min(srv_ttl, answer.ttl)
-
- servers.append(_Server(
- host=ip,
- port=int(payload.port),
- priority=int(payload.priority),
- weight=int(payload.weight),
- expires=int(clock.time()) + host_ttl,
- ))
+ servers.append(_Server(
+ host=str(payload.target),
+ port=int(payload.port),
+ priority=int(payload.priority),
+ weight=int(payload.weight),
+ expires=int(clock.time()) + answer.ttl,
+ ))
servers.sort()
cache[service_name] = list(servers)
|