summary refs log tree commit diff
path: root/synapse/http/endpoint.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/endpoint.py')
-rw-r--r--synapse/http/endpoint.py47
1 files changed, 42 insertions, 5 deletions
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 5a9cbb3324..1b1123b292 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,6 +12,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import re
+
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.internet import defer
 from twisted.internet.error import ConnectError
@@ -41,8 +43,6 @@ _Server = collections.namedtuple(
 def parse_server_name(server_name):
     """Split a server name into host/port parts.
 
-    Does some basic sanity checking of the
-
     Args:
         server_name (str): server name to parse
 
@@ -55,9 +55,6 @@ def parse_server_name(server_name):
     try:
         if server_name[-1] == ']':
             # ipv6 literal, hopefully
-            if server_name[0] != '[':
-                raise Exception()
-
             return server_name, None
 
         domain_port = server_name.rsplit(":", 1)
@@ -68,6 +65,46 @@ def parse_server_name(server_name):
         raise ValueError("Invalid server name '%s'" % server_name)
 
 
+VALID_HOST_REGEX = re.compile(
+    "\\A[0-9a-zA-Z.-]+\\Z",
+)
+
+
+def parse_and_validate_server_name(server_name):
+    """Split a server name into host/port parts and do some basic validation.
+
+    Args:
+        server_name (str): server name to parse
+
+    Returns:
+        Tuple[str, int|None]: host/port parts.
+
+    Raises:
+        ValueError if the server name could not be parsed.
+    """
+    host, port = parse_server_name(server_name)
+
+    # these tests don't need to be bulletproof as we'll find out soon enough
+    # if somebody is giving us invalid data. What we *do* need is to be sure
+    # that nobody is sneaking IP literals in that look like hostnames, etc.
+
+    # look for ipv6 literals
+    if host[0] == '[':
+        if host[-1] != ']':
+            raise ValueError("Mismatched [...] in server name '%s'" % (
+                server_name,
+            ))
+        return host, port
+
+    # otherwise it should only be alphanumerics.
+    if not VALID_HOST_REGEX.match(host):
+        raise ValueError("Server name '%s' contains invalid characters" % (
+            server_name,
+        ))
+
+    return host, port
+
+
 def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
                                timeout=None):
     """Construct an endpoint for the given matrix destination.