summary refs log tree commit diff
path: root/synapse/http/endpoint.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/endpoint.py')
-rw-r--r--synapse/http/endpoint.py97
1 files changed, 81 insertions, 16 deletions
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 87a482650d..d65daa72bb 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,17 +12,17 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet import defer, reactor
-from twisted.internet.error import ConnectError
-from twisted.names import client, dns
-from twisted.names.error import DNSNameError, DomainError
-
 import collections
 import logging
 import random
+import re
 import time
 
+from twisted.internet import defer
+from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.internet.error import ConnectError
+from twisted.names import client, dns
+from twisted.names.error import DNSNameError, DomainError
 
 logger = logging.getLogger(__name__)
 
@@ -38,6 +38,71 @@ _Server = collections.namedtuple(
 )
 
 
+def parse_server_name(server_name):
+    """Split a server name into host/port parts.
+
+    Args:
+        server_name (str): server name to parse
+
+    Returns:
+        Tuple[str, int|None]: host/port parts.
+
+    Raises:
+        ValueError if the server name could not be parsed.
+    """
+    try:
+        if server_name[-1] == ']':
+            # ipv6 literal, hopefully
+            return server_name, None
+
+        domain_port = server_name.rsplit(":", 1)
+        domain = domain_port[0]
+        port = int(domain_port[1]) if domain_port[1:] else None
+        return domain, port
+    except Exception:
+        raise ValueError("Invalid server name '%s'" % server_name)
+
+
+VALID_HOST_REGEX = re.compile(
+    "\\A[0-9a-zA-Z.-]+\\Z",
+)
+
+
+def parse_and_validate_server_name(server_name):
+    """Split a server name into host/port parts and do some basic validation.
+
+    Args:
+        server_name (str): server name to parse
+
+    Returns:
+        Tuple[str, int|None]: host/port parts.
+
+    Raises:
+        ValueError if the server name could not be parsed.
+    """
+    host, port = parse_server_name(server_name)
+
+    # these tests don't need to be bulletproof as we'll find out soon enough
+    # if somebody is giving us invalid data. What we *do* need is to be sure
+    # that nobody is sneaking IP literals in that look like hostnames, etc.
+
+    # look for ipv6 literals
+    if host[0] == '[':
+        if host[-1] != ']':
+            raise ValueError("Mismatched [...] in server name '%s'" % (
+                server_name,
+            ))
+        return host, port
+
+    # otherwise it should only be alphanumerics.
+    if not VALID_HOST_REGEX.match(host):
+        raise ValueError("Server name '%s' contains invalid characters" % (
+            server_name,
+        ))
+
+    return host, port
+
+
 def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
                                timeout=None):
     """Construct an endpoint for the given matrix destination.
@@ -50,9 +115,7 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
         timeout (int): connection timeout in seconds
     """
 
-    domain_port = destination.split(":")
-    domain = domain_port[0]
-    port = int(domain_port[1]) if domain_port[1:] else None
+    domain, port = parse_server_name(destination)
 
     endpoint_kw_args = {}
 
@@ -74,21 +137,22 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
             reactor, "matrix", domain, protocol="tcp",
             default_port=default_port, endpoint=transport_endpoint,
             endpoint_kw_args=endpoint_kw_args
-        ))
+        ), reactor)
     else:
         return _WrappingEndpointFac(transport_endpoint(
             reactor, domain, port, **endpoint_kw_args
-        ))
+        ), reactor)
 
 
 class _WrappingEndpointFac(object):
-    def __init__(self, endpoint_fac):
+    def __init__(self, endpoint_fac, reactor):
         self.endpoint_fac = endpoint_fac
+        self.reactor = reactor
 
     @defer.inlineCallbacks
     def connect(self, protocolFactory):
         conn = yield self.endpoint_fac.connect(protocolFactory)
-        conn = _WrappedConnection(conn)
+        conn = _WrappedConnection(conn, self.reactor)
         defer.returnValue(conn)
 
 
@@ -98,9 +162,10 @@ class _WrappedConnection(object):
     """
     __slots__ = ["conn", "last_request"]
 
-    def __init__(self, conn):
+    def __init__(self, conn, reactor):
         object.__setattr__(self, "conn", conn)
         object.__setattr__(self, "last_request", time.time())
+        self._reactor = reactor
 
     def __getattr__(self, name):
         return getattr(self.conn, name)
@@ -131,14 +196,14 @@ class _WrappedConnection(object):
         # Time this connection out if we haven't send a request in the last
         # N minutes
         # TODO: Cancel the previous callLater?
-        reactor.callLater(3 * 60, self._time_things_out_maybe)
+        self._reactor.callLater(3 * 60, self._time_things_out_maybe)
 
         d = self.conn.request(request)
 
         def update_request_time(res):
             self.last_request = time.time()
             # TODO: Cancel the previous callLater?
-            reactor.callLater(3 * 60, self._time_things_out_maybe)
+            self._reactor.callLater(3 * 60, self._time_things_out_maybe)
             return res
 
         d.addCallback(update_request_time)