diff --git a/changelog.d/3480.feature b/changelog.d/3480.feature
new file mode 100644
index 0000000000..a21580943d
--- /dev/null
+++ b/changelog.d/3480.feature
@@ -0,0 +1 @@
+Reject invalid server names in federation requests
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 19d09f5422..1180d4b69d 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError, FederationDeniedError
+from synapse.http.endpoint import parse_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@@ -99,26 +100,6 @@ class Authenticator(object):
origin = None
- def parse_auth_header(header_str):
- try:
- params = auth.split(" ")[1].split(",")
- param_dict = dict(kv.split("=") for kv in params)
-
- def strip_quotes(value):
- if value.startswith("\""):
- return value[1:-1]
- else:
- return value
-
- origin = strip_quotes(param_dict["origin"])
- key = strip_quotes(param_dict["key"])
- sig = strip_quotes(param_dict["sig"])
- return (origin, key, sig)
- except Exception:
- raise AuthenticationError(
- 400, "Malformed Authorization header", Codes.UNAUTHORIZED
- )
-
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers:
@@ -127,8 +108,8 @@ class Authenticator(object):
)
for auth in auth_headers:
- if auth.startswith("X-Matrix"):
- (origin, key, sig) = parse_auth_header(auth)
+ if auth.startswith(b"X-Matrix"):
+ (origin, key, sig) = _parse_auth_header(auth)
json_request["origin"] = origin
json_request["signatures"].setdefault(origin, {})[key] = sig
@@ -165,6 +146,47 @@ class Authenticator(object):
logger.exception("Error resetting retry timings on %s", origin)
+def _parse_auth_header(header_bytes):
+ """Parse an X-Matrix auth header
+
+ Args:
+ header_bytes (bytes): header value
+
+ Returns:
+ Tuple[str, str, str]: origin, key id, signature.
+
+ Raises:
+ AuthenticationError if the header could not be parsed
+ """
+ try:
+ header_str = header_bytes.decode('utf-8')
+ params = header_str.split(" ")[1].split(",")
+ param_dict = dict(kv.split("=") for kv in params)
+
+ def strip_quotes(value):
+ if value.startswith(b"\""):
+ return value[1:-1]
+ else:
+ return value
+
+ origin = strip_quotes(param_dict["origin"])
+ # ensure that the origin is a valid server name
+ parse_server_name(origin)
+
+ key = strip_quotes(param_dict["key"])
+ sig = strip_quotes(param_dict["sig"])
+ return origin, key, sig
+ except Exception as e:
+ logger.warn(
+ "Error parsing auth header '%s': %s",
+ header_bytes.decode('ascii', 'replace'),
+ e,
+ )
+ raise AuthenticationError(
+ 400, "Malformed Authorization header", Codes.UNAUTHORIZED,
+ )
+
+
class BaseFederationServlet(object):
REQUIRE_AUTH = True
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 80da870584..5a9cbb3324 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -38,6 +38,36 @@ _Server = collections.namedtuple(
)
+def parse_server_name(server_name):
+ """Split a server name into host/port parts.
+
+ Does some basic sanity checking of the
+
+ Args:
+ server_name (str): server name to parse
+
+ Returns:
+ Tuple[str, int|None]: host/port parts.
+
+ Raises:
+ ValueError if the server name could not be parsed.
+ """
+ try:
+ if server_name[-1] == ']':
+ # ipv6 literal, hopefully
+ if server_name[0] != '[':
+ raise Exception()
+
+ return server_name, None
+
+ domain_port = server_name.rsplit(":", 1)
+ domain = domain_port[0]
+ port = int(domain_port[1]) if domain_port[1:] else None
+ return domain, port
+ except Exception:
+ raise ValueError("Invalid server name '%s'" % server_name)
+
+
def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout=None):
"""Construct an endpoint for the given matrix destination.
@@ -50,9 +80,7 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout (int): connection timeout in seconds
"""
- domain_port = destination.split(":")
- domain = domain_port[0]
- port = int(domain_port[1]) if domain_port[1:] else None
+ domain, port = parse_server_name(destination)
endpoint_kw_args = {}
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/http/__init__.py
diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py
new file mode 100644
index 0000000000..cd74825c85
--- /dev/null
+++ b/tests/http/test_endpoint.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.http.endpoint import parse_server_name
+from tests import unittest
+
+
+class ServerNameTestCase(unittest.TestCase):
+ def test_parse_server_name(self):
+ test_data = {
+ 'localhost': ('localhost', None),
+ 'my-example.com:1234': ('my-example.com', 1234),
+ '1.2.3.4': ('1.2.3.4', None),
+ '[0abc:1def::1234]': ('[0abc:1def::1234]', None),
+ '1.2.3.4:1': ('1.2.3.4', 1),
+ '[0abc:1def::1234]:8080': ('[0abc:1def::1234]', 8080),
+ }
+
+ for i, o in test_data.items():
+ self.assertEqual(parse_server_name(i), o)
+
+ def test_parse_bad_server_names(self):
+ test_data = [
+ "", # empty
+ "localhost:http", # non-numeric port
+ "1234]", # smells like ipv6 literal but isn't
+ ]
+ for i in test_data:
+ try:
+ parse_server_name(i)
+ self.fail(
+ "Expected parse_server_name(\"%s\") to throw" % i,
+ )
+ except ValueError:
+ pass
|