diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 19d09f5422..1180d4b69d 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError, FederationDeniedError
+from synapse.http.endpoint import parse_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@@ -99,26 +100,6 @@ class Authenticator(object):
origin = None
- def parse_auth_header(header_str):
- try:
- params = auth.split(" ")[1].split(",")
- param_dict = dict(kv.split("=") for kv in params)
-
- def strip_quotes(value):
- if value.startswith("\""):
- return value[1:-1]
- else:
- return value
-
- origin = strip_quotes(param_dict["origin"])
- key = strip_quotes(param_dict["key"])
- sig = strip_quotes(param_dict["sig"])
- return (origin, key, sig)
- except Exception:
- raise AuthenticationError(
- 400, "Malformed Authorization header", Codes.UNAUTHORIZED
- )
-
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers:
@@ -127,8 +108,8 @@ class Authenticator(object):
)
for auth in auth_headers:
- if auth.startswith("X-Matrix"):
- (origin, key, sig) = parse_auth_header(auth)
+ if auth.startswith(b"X-Matrix"):
+ (origin, key, sig) = _parse_auth_header(auth)
json_request["origin"] = origin
json_request["signatures"].setdefault(origin, {})[key] = sig
@@ -165,6 +146,47 @@ class Authenticator(object):
logger.exception("Error resetting retry timings on %s", origin)
+def _parse_auth_header(header_bytes):
+ """Parse an X-Matrix auth header
+
+ Args:
+ header_bytes (bytes): header value
+
+ Returns:
+ Tuple[str, str, str]: origin, key id, signature.
+
+ Raises:
+ AuthenticationError if the header could not be parsed
+ """
+ try:
+ header_str = header_bytes.decode('utf-8')
+ params = header_str.split(" ")[1].split(",")
+ param_dict = dict(kv.split("=") for kv in params)
+
+ def strip_quotes(value):
+ if value.startswith(b"\""):
+ return value[1:-1]
+ else:
+ return value
+
+ origin = strip_quotes(param_dict["origin"])
+ # ensure that the origin is a valid server name
+ parse_server_name(origin)
+
+ key = strip_quotes(param_dict["key"])
+ sig = strip_quotes(param_dict["sig"])
+ return origin, key, sig
+ except Exception as e:
+ logger.warn(
+ "Error parsing auth header '%s': %s",
+ header_bytes.decode('ascii', 'replace'),
+ e,
+ )
+ raise AuthenticationError(
+ 400, "Malformed Authorization header", Codes.UNAUTHORIZED,
+ )
+
+
class BaseFederationServlet(object):
REQUIRE_AUTH = True
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 80da870584..5a9cbb3324 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -38,6 +38,36 @@ _Server = collections.namedtuple(
)
+def parse_server_name(server_name):
+ """Split a server name into host/port parts.
+
+ Does some basic sanity checking of the
+
+ Args:
+ server_name (str): server name to parse
+
+ Returns:
+ Tuple[str, int|None]: host/port parts.
+
+ Raises:
+ ValueError if the server name could not be parsed.
+ """
+ try:
+ if server_name[-1] == ']':
+ # ipv6 literal, hopefully
+ if server_name[0] != '[':
+ raise Exception()
+
+ return server_name, None
+
+ domain_port = server_name.rsplit(":", 1)
+ domain = domain_port[0]
+ port = int(domain_port[1]) if domain_port[1:] else None
+ return domain, port
+ except Exception:
+ raise ValueError("Invalid server name '%s'" % server_name)
+
+
def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout=None):
"""Construct an endpoint for the given matrix destination.
@@ -50,9 +80,7 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout (int): connection timeout in seconds
"""
- domain_port = destination.split(":")
- domain = domain_port[0]
- port = int(domain_port[1]) if domain_port[1:] else None
+ domain, port = parse_server_name(destination)
endpoint_kw_args = {}
|