summary refs log tree commit diff
path: root/synapse/http/endpoint.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/endpoint.py')
-rw-r--r--synapse/http/endpoint.py71
1 files changed, 68 insertions, 3 deletions
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py

index 80da870584..1b1123b292 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py
@@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re + from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet import defer from twisted.internet.error import ConnectError @@ -38,6 +40,71 @@ _Server = collections.namedtuple( ) +def parse_server_name(server_name): + """Split a server name into host/port parts. + + Args: + server_name (str): server name to parse + + Returns: + Tuple[str, int|None]: host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + try: + if server_name[-1] == ']': + # ipv6 literal, hopefully + return server_name, None + + domain_port = server_name.rsplit(":", 1) + domain = domain_port[0] + port = int(domain_port[1]) if domain_port[1:] else None + return domain, port + except Exception: + raise ValueError("Invalid server name '%s'" % server_name) + + +VALID_HOST_REGEX = re.compile( + "\\A[0-9a-zA-Z.-]+\\Z", +) + + +def parse_and_validate_server_name(server_name): + """Split a server name into host/port parts and do some basic validation. + + Args: + server_name (str): server name to parse + + Returns: + Tuple[str, int|None]: host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + host, port = parse_server_name(server_name) + + # these tests don't need to be bulletproof as we'll find out soon enough + # if somebody is giving us invalid data. What we *do* need is to be sure + # that nobody is sneaking IP literals in that look like hostnames, etc. + + # look for ipv6 literals + if host[0] == '[': + if host[-1] != ']': + raise ValueError("Mismatched [...] in server name '%s'" % ( + server_name, + )) + return host, port + + # otherwise it should only be alphanumerics. + if not VALID_HOST_REGEX.match(host): + raise ValueError("Server name '%s' contains invalid characters" % ( + server_name, + )) + + return host, port + + def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, timeout=None): """Construct an endpoint for the given matrix destination. @@ -50,9 +117,7 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, timeout (int): connection timeout in seconds """ - domain_port = destination.split(":") - domain = domain_port[0] - port = int(domain_port[1]) if domain_port[1:] else None + domain, port = parse_server_name(destination) endpoint_kw_args = {}