summary refs log tree commit diff
path: root/synapse/http/endpoint.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/endpoint.py')
-rw-r--r--synapse/http/endpoint.py158
1 files changed, 115 insertions, 43 deletions
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index d8923c9abb..d65daa72bb 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,30 +12,97 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet import defer, reactor
-from twisted.internet.error import ConnectError
-from twisted.names import client, dns
-from twisted.names.error import DNSNameError, DomainError
-
 import collections
 import logging
 import random
+import re
 import time
 
+from twisted.internet import defer
+from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.internet.error import ConnectError
+from twisted.names import client, dns
+from twisted.names.error import DNSNameError, DomainError
 
 logger = logging.getLogger(__name__)
 
 
 SERVER_CACHE = {}
 
-
+# our record of an individual server which can be tried to reach a destination.
+#
+# "host" is the hostname acquired from the SRV record. Except when there's
+# no SRV record, in which case it is the original hostname.
 _Server = collections.namedtuple(
     "_Server", "priority weight host port expires"
 )
 
 
+def parse_server_name(server_name):
+    """Split a server name into host/port parts.
+
+    Args:
+        server_name (str): server name to parse
+
+    Returns:
+        Tuple[str, int|None]: host/port parts.
+
+    Raises:
+        ValueError if the server name could not be parsed.
+    """
+    try:
+        if server_name[-1] == ']':
+            # ipv6 literal, hopefully
+            return server_name, None
+
+        domain_port = server_name.rsplit(":", 1)
+        domain = domain_port[0]
+        port = int(domain_port[1]) if domain_port[1:] else None
+        return domain, port
+    except Exception:
+        raise ValueError("Invalid server name '%s'" % server_name)
+
+
+VALID_HOST_REGEX = re.compile(
+    "\\A[0-9a-zA-Z.-]+\\Z",
+)
+
+
+def parse_and_validate_server_name(server_name):
+    """Split a server name into host/port parts and do some basic validation.
+
+    Args:
+        server_name (str): server name to parse
+
+    Returns:
+        Tuple[str, int|None]: host/port parts.
+
+    Raises:
+        ValueError if the server name could not be parsed.
+    """
+    host, port = parse_server_name(server_name)
+
+    # these tests don't need to be bulletproof as we'll find out soon enough
+    # if somebody is giving us invalid data. What we *do* need is to be sure
+    # that nobody is sneaking IP literals in that look like hostnames, etc.
+
+    # look for ipv6 literals
+    if host[0] == '[':
+        if host[-1] != ']':
+            raise ValueError("Mismatched [...] in server name '%s'" % (
+                server_name,
+            ))
+        return host, port
+
+    # otherwise it should only be alphanumerics.
+    if not VALID_HOST_REGEX.match(host):
+        raise ValueError("Server name '%s' contains invalid characters" % (
+            server_name,
+        ))
+
+    return host, port
+
+
 def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
                                timeout=None):
     """Construct an endpoint for the given matrix destination.
@@ -48,9 +115,7 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
         timeout (int): connection timeout in seconds
     """
 
-    domain_port = destination.split(":")
-    domain = domain_port[0]
-    port = int(domain_port[1]) if domain_port[1:] else None
+    domain, port = parse_server_name(destination)
 
     endpoint_kw_args = {}
 
@@ -72,21 +137,22 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
             reactor, "matrix", domain, protocol="tcp",
             default_port=default_port, endpoint=transport_endpoint,
             endpoint_kw_args=endpoint_kw_args
-        ))
+        ), reactor)
     else:
         return _WrappingEndpointFac(transport_endpoint(
             reactor, domain, port, **endpoint_kw_args
-        ))
+        ), reactor)
 
 
 class _WrappingEndpointFac(object):
-    def __init__(self, endpoint_fac):
+    def __init__(self, endpoint_fac, reactor):
         self.endpoint_fac = endpoint_fac
+        self.reactor = reactor
 
     @defer.inlineCallbacks
     def connect(self, protocolFactory):
         conn = yield self.endpoint_fac.connect(protocolFactory)
-        conn = _WrappedConnection(conn)
+        conn = _WrappedConnection(conn, self.reactor)
         defer.returnValue(conn)
 
 
@@ -96,9 +162,10 @@ class _WrappedConnection(object):
     """
     __slots__ = ["conn", "last_request"]
 
-    def __init__(self, conn):
+    def __init__(self, conn, reactor):
         object.__setattr__(self, "conn", conn)
         object.__setattr__(self, "last_request", time.time())
+        self._reactor = reactor
 
     def __getattr__(self, name):
         return getattr(self.conn, name)
@@ -113,10 +180,15 @@ class _WrappedConnection(object):
         if time.time() - self.last_request >= 2.5 * 60:
             self.abort()
             # Abort the underlying TLS connection. The abort() method calls
-            # loseConnection() on the underlying TLS connection which tries to
+            # loseConnection() on the TLS connection which tries to
             # shutdown the connection cleanly. We call abortConnection()
-            # since that will promptly close the underlying TCP connection.
-            self.transport.abortConnection()
+            # since that will promptly close the TLS connection.
+            #
+            # In Twisted >18.4; the TLS connection will be None if it has closed
+            # which will make abortConnection() throw. Check that the TLS connection
+            # is not None before trying to close it.
+            if self.transport.getHandle() is not None:
+                self.transport.abortConnection()
 
     def request(self, request):
         self.last_request = time.time()
@@ -124,14 +196,14 @@ class _WrappedConnection(object):
         # Time this connection out if we haven't send a request in the last
         # N minutes
         # TODO: Cancel the previous callLater?
-        reactor.callLater(3 * 60, self._time_things_out_maybe)
+        self._reactor.callLater(3 * 60, self._time_things_out_maybe)
 
         d = self.conn.request(request)
 
         def update_request_time(res):
             self.last_request = time.time()
             # TODO: Cancel the previous callLater?
-            reactor.callLater(3 * 60, self._time_things_out_maybe)
+            self._reactor.callLater(3 * 60, self._time_things_out_maybe)
             return res
 
         d.addCallback(update_request_time)
@@ -219,9 +291,10 @@ class SRVClientEndpoint(object):
                 return self.default_server
             else:
                 raise ConnectError(
-                    "Not server available for %s" % self.service_name
+                    "No server available for %s" % self.service_name
                 )
 
+        # look for all servers with the same priority
         min_priority = self.servers[0].priority
         weight_indexes = list(
             (index, server.weight + 1)
@@ -231,11 +304,22 @@ class SRVClientEndpoint(object):
 
         total_weight = sum(weight for index, weight in weight_indexes)
         target_weight = random.randint(0, total_weight)
-
         for index, weight in weight_indexes:
             target_weight -= weight
             if target_weight <= 0:
                 server = self.servers[index]
+                # XXX: this looks totally dubious:
+                #
+                # (a) we never reuse a server until we have been through
+                #     all of the servers at the same priority, so if the
+                #     weights are A: 100, B:1, we always do ABABAB instead of
+                #     AAAA...AAAB (approximately).
+                #
+                # (b) After using all the servers at the lowest priority,
+                #     we move onto the next priority. We should only use the
+                #     second priority if servers at the top priority are
+                #     unreachable.
+                #
                 del self.servers[index]
                 self.used_servers.append(server)
                 return server
@@ -272,7 +356,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
         if (len(answers) == 1
                 and answers[0].type == dns.SRV
                 and answers[0].payload
-                and answers[0].payload.target == dns.Name('.')):
+                and answers[0].payload.target == dns.Name(b'.')):
             raise ConnectError("Service %s unavailable" % service_name)
 
         for answer in answers:
@@ -280,26 +364,14 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
                 continue
 
             payload = answer.payload
-            host = str(payload.target)
-            srv_ttl = answer.ttl
-
-            try:
-                answers, _, _ = yield dns_client.lookupAddress(host)
-            except DNSNameError:
-                continue
 
-            for answer in answers:
-                if answer.type == dns.A and answer.payload:
-                    ip = answer.payload.dottedQuad()
-                    host_ttl = min(srv_ttl, answer.ttl)
-
-                    servers.append(_Server(
-                        host=ip,
-                        port=int(payload.port),
-                        priority=int(payload.priority),
-                        weight=int(payload.weight),
-                        expires=int(clock.time()) + host_ttl,
-                    ))
+            servers.append(_Server(
+                host=str(payload.target),
+                port=int(payload.port),
+                priority=int(payload.priority),
+                weight=int(payload.weight),
+                expires=int(clock.time()) + answer.ttl,
+            ))
 
         servers.sort()
         cache[service_name] = list(servers)