diff --git a/synapse/http/client.py b/synapse/http/client.py
index 46ffb41de1..5bdc484c15 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -42,7 +42,7 @@ from twisted.web._newclient import ResponseDone
from six import StringIO
from prometheus_client import Counter
-import simplejson as json
+from canonicaljson import json
import logging
import urllib
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 39432da452..6056df6226 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,8 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import re
+
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from twisted.internet.error import ConnectError
from twisted.names import client, dns
from twisted.names.error import DNSNameError, DomainError
@@ -37,6 +39,71 @@ _Server = collections.namedtuple(
)
+def parse_server_name(server_name):
+ """Split a server name into host/port parts.
+
+ Args:
+ server_name (str): server name to parse
+
+ Returns:
+ Tuple[str, int|None]: host/port parts.
+
+ Raises:
+ ValueError if the server name could not be parsed.
+ """
+ try:
+ if server_name[-1] == ']':
+ # ipv6 literal, hopefully
+ return server_name, None
+
+ domain_port = server_name.rsplit(":", 1)
+ domain = domain_port[0]
+ port = int(domain_port[1]) if domain_port[1:] else None
+ return domain, port
+ except Exception:
+ raise ValueError("Invalid server name '%s'" % server_name)
+
+
+VALID_HOST_REGEX = re.compile(
+ "\\A[0-9a-zA-Z.-]+\\Z",
+)
+
+
+def parse_and_validate_server_name(server_name):
+ """Split a server name into host/port parts and do some basic validation.
+
+ Args:
+ server_name (str): server name to parse
+
+ Returns:
+ Tuple[str, int|None]: host/port parts.
+
+ Raises:
+ ValueError if the server name could not be parsed.
+ """
+ host, port = parse_server_name(server_name)
+
+ # these tests don't need to be bulletproof as we'll find out soon enough
+ # if somebody is giving us invalid data. What we *do* need is to be sure
+ # that nobody is sneaking IP literals in that look like hostnames, etc.
+
+ # look for ipv6 literals
+ if host[0] == '[':
+ if host[-1] != ']':
+ raise ValueError("Mismatched [...] in server name '%s'" % (
+ server_name,
+ ))
+ return host, port
+
+ # otherwise it should only be alphanumerics.
+ if not VALID_HOST_REGEX.match(host):
+ raise ValueError("Server name '%s' contains invalid characters" % (
+ server_name,
+ ))
+
+ return host, port
+
+
def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None,
timeout=None):
"""Construct an endpoint for the given matrix destination.
@@ -50,9 +117,7 @@ def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=
timeout (int): connection timeout in seconds
"""
- domain_port = destination.split(":")
- domain = domain_port[0]
- port = int(domain_port[1]) if domain_port[1:] else None
+ domain, port = parse_server_name(destination)
endpoint_kw_args = {}
@@ -74,21 +139,22 @@ def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=
reactor, "matrix", domain, protocol="tcp",
default_port=default_port, endpoint=transport_endpoint,
endpoint_kw_args=endpoint_kw_args
- ))
+ ), reactor)
else:
return _WrappingEndpointFac(transport_endpoint(
reactor, domain, port, **endpoint_kw_args
- ))
+ ), reactor)
class _WrappingEndpointFac(object):
- def __init__(self, endpoint_fac):
+ def __init__(self, endpoint_fac, reactor):
self.endpoint_fac = endpoint_fac
+ self.reactor = reactor
@defer.inlineCallbacks
def connect(self, protocolFactory):
conn = yield self.endpoint_fac.connect(protocolFactory)
- conn = _WrappedConnection(conn)
+ conn = _WrappedConnection(conn, self.reactor)
defer.returnValue(conn)
@@ -98,9 +164,10 @@ class _WrappedConnection(object):
"""
__slots__ = ["conn", "last_request"]
- def __init__(self, conn):
+ def __init__(self, conn, reactor):
object.__setattr__(self, "conn", conn)
object.__setattr__(self, "last_request", time.time())
+ self._reactor = reactor
def __getattr__(self, name):
return getattr(self.conn, name)
@@ -131,14 +198,14 @@ class _WrappedConnection(object):
# Time this connection out if we haven't send a request in the last
# N minutes
# TODO: Cancel the previous callLater?
- reactor.callLater(3 * 60, self._time_things_out_maybe)
+ self._reactor.callLater(3 * 60, self._time_things_out_maybe)
d = self.conn.request(request)
def update_request_time(res):
self.last_request = time.time()
# TODO: Cancel the previous callLater?
- reactor.callLater(3 * 60, self._time_things_out_maybe)
+ self._reactor.callLater(3 * 60, self._time_things_out_maybe)
return res
d.addCallback(update_request_time)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index b48d05fcd2..6a398c9645 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -27,7 +27,7 @@ from synapse.util import logcontext
from synapse.util.logcontext import make_deferred_yieldable
import synapse.util.retryutils
-from canonicaljson import encode_canonical_json
+from canonicaljson import encode_canonical_json, json
from synapse.api.errors import (
SynapseError, Codes, HttpResponseException, FederationDeniedError,
@@ -36,7 +36,6 @@ from synapse.api.errors import (
from signedjson.sign import sign_json
import cgi
-import simplejson as json
import logging
import random
import sys
diff --git a/synapse/http/server.py b/synapse/http/server.py
index bc09b8b2be..517aaf7b5a 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -29,7 +29,7 @@ import synapse.metrics
import synapse.events
from canonicaljson import (
- encode_canonical_json, encode_pretty_printed_json
+ encode_canonical_json, encode_pretty_printed_json, json
)
from twisted.internet import defer
@@ -41,7 +41,6 @@ from twisted.web.util import redirectTo
import collections
import logging
import urllib
-import simplejson
logger = logging.getLogger(__name__)
@@ -410,7 +409,7 @@ def respond_with_json(request, code, json_object, send_cors=False,
if canonical_json or synapse.events.USE_FROZEN_DICTS:
json_bytes = encode_canonical_json(json_object)
else:
- json_bytes = simplejson.dumps(json_object)
+ json_bytes = json.dumps(json_object)
return respond_with_json_bytes(
request, code, json_bytes,
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index ef8e62901b..ef3a01ddc7 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -18,7 +18,9 @@
from synapse.api.errors import SynapseError, Codes
import logging
-import simplejson
+
+from canonicaljson import json
+
logger = logging.getLogger(__name__)
@@ -171,7 +173,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
return None
try:
- content = simplejson.loads(content_bytes)
+ content = json.loads(content_bytes)
except Exception as e:
logger.warn("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 74a752d6cf..fe93643b1e 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -107,13 +107,28 @@ class SynapseRequest(Request):
end_time = time.time()
+ # need to decode as it could be raw utf-8 bytes
+ # from a IDN servname in an auth header
+ authenticated_entity = self.authenticated_entity
+ if authenticated_entity is not None:
+ authenticated_entity = authenticated_entity.decode("utf-8", "replace")
+
+ # ...or could be raw utf-8 bytes in the User-Agent header.
+ # N.B. if you don't do this, the logger explodes cryptically
+ # with maximum recursion trying to log errors about
+ # the charset problem.
+ # c.f. https://github.com/matrix-org/synapse/issues/3471
+ user_agent = self.get_user_agent()
+ if user_agent is not None:
+ user_agent = user_agent.decode("utf-8", "replace")
+
self.site.access_logger.info(
"%s - %s - {%s}"
" Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
" %sB %s \"%s %s %s\" \"%s\" [%d dbevts]",
self.getClientIP(),
self.site.site_tag,
- self.authenticated_entity,
+ authenticated_entity,
end_time - self.start_time,
ru_utime,
ru_stime,
@@ -125,7 +140,7 @@ class SynapseRequest(Request):
self.method,
self.get_redacted_uri(),
self.clientproto,
- self.get_user_agent(),
+ user_agent,
evt_db_fetch_count,
)
|