diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index d36bcd6336..3acf772cd1 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError
class RequestTimedOutError(SynapseError):
"""Exception representing timeout of an outbound request"""
+
def __init__(self):
super(RequestTimedOutError, self).__init__(504, "Timed out")
@@ -40,15 +41,12 @@ def cancelled_to_request_timed_out_error(value, timeout):
return value
-ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
+ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$")
def redact_uri(uri):
"""Strips access tokens from the uri replaces with <redacted>"""
- return ACCESS_TOKEN_RE.sub(
- r'\1<redacted>\3',
- uri
- )
+ return ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri)
class QuieterFileBodyProducer(FileBodyProducer):
@@ -57,6 +55,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
Workaround for https://github.com/matrix-org/synapse/issues/4003 /
https://twistedmatrix.com/trac/ticket/6528
"""
+
def stopProducing(self):
try:
FileBodyProducer.stopProducing(self)
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
index 0e10e3f8f7..096619a8c2 100644
--- a/synapse/http/additional_resource.py
+++ b/synapse/http/additional_resource.py
@@ -28,6 +28,7 @@ class AdditionalResource(Resource):
This class is also where we wrap the request handler with logging, metrics,
and exception handling.
"""
+
def __init__(self, hs, handler):
"""Initialise AdditionalResource
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 5c073fff07..9bc7035c8d 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -103,8 +103,8 @@ class IPBlacklistingResolver(object):
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info(
- "Dropped %s from DNS resolution to %s due to blacklist" %
- (ip_address, hostname)
+ "Dropped %s from DNS resolution to %s due to blacklist"
+ % (ip_address, hostname)
)
has_bad_ip = True
@@ -156,7 +156,7 @@ class BlacklistingAgentWrapper(Agent):
self._ip_blacklist = ip_blacklist
def request(self, method, uri, headers=None, bodyProducer=None):
- h = urllib.parse.urlparse(uri.decode('ascii'))
+ h = urllib.parse.urlparse(uri.decode("ascii"))
try:
ip_address = IPAddress(h.hostname)
@@ -164,10 +164,7 @@ class BlacklistingAgentWrapper(Agent):
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
):
- logger.info(
- "Blocking access to %s due to blacklist" %
- (ip_address,)
- )
+ logger.info("Blocking access to %s due to blacklist" % (ip_address,))
e = SynapseError(403, "IP address blocked by IP blacklist entry")
return defer.fail(Failure(e))
except Exception:
@@ -206,7 +203,7 @@ class SimpleHttpClient(object):
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
- self.user_agent = self.user_agent.encode('ascii')
+ self.user_agent = self.user_agent.encode("ascii")
if self._ip_blacklist:
real_reactor = hs.get_reactor()
@@ -520,8 +517,8 @@ class SimpleHttpClient(object):
resp_headers = dict(response.headers.getAllRawHeaders())
if (
- b'Content-Length' in resp_headers
- and int(resp_headers[b'Content-Length'][0]) > max_size
+ b"Content-Length" in resp_headers
+ and int(resp_headers[b"Content-Length"][0]) > max_size
):
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
@@ -546,18 +543,13 @@ class SimpleHttpClient(object):
# This can happen e.g. because the body is too large.
raise
except Exception as e:
- raise_from(
- SynapseError(
- 502, ("Failed to download remote body: %s" % e),
- ),
- e
- )
+ raise_from(SynapseError(502, ("Failed to download remote body: %s" % e)), e)
defer.returnValue(
(
length,
resp_headers,
- response.request.absoluteURI.decode('ascii'),
+ response.request.absoluteURI.decode("ascii"),
response.code,
)
)
@@ -647,7 +639,7 @@ def encode_urlencode_args(args):
def encode_urlencode_arg(arg):
if isinstance(arg, text_type):
- return arg.encode('utf-8')
+ return arg.encode("utf-8")
elif isinstance(arg, list):
return [encode_urlencode_arg(i) for i in arg]
else:
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index cd79ebab62..92a5b606c8 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -31,7 +31,7 @@ def parse_server_name(server_name):
ValueError if the server name could not be parsed.
"""
try:
- if server_name[-1] == ']':
+ if server_name[-1] == "]":
# ipv6 literal, hopefully
return server_name, None
@@ -43,9 +43,7 @@ def parse_server_name(server_name):
raise ValueError("Invalid server name '%s'" % server_name)
-VALID_HOST_REGEX = re.compile(
- "\\A[0-9a-zA-Z.-]+\\Z",
-)
+VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
def parse_and_validate_server_name(server_name):
@@ -67,17 +65,15 @@ def parse_and_validate_server_name(server_name):
# that nobody is sneaking IP literals in that look like hostnames, etc.
# look for ipv6 literals
- if host[0] == '[':
- if host[-1] != ']':
- raise ValueError("Mismatched [...] in server name '%s'" % (
- server_name,
- ))
+ if host[0] == "[":
+ if host[-1] != "]":
+ raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
return host, port
# otherwise it should only be alphanumerics.
if not VALID_HOST_REGEX.match(host):
- raise ValueError("Server name '%s' contains invalid characters" % (
- server_name,
- ))
+ raise ValueError(
+ "Server name '%s' contains invalid characters" % (server_name,)
+ )
return host, port
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index b4cbe97b41..414cde0777 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -48,7 +48,7 @@ WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
logger = logging.getLogger(__name__)
-well_known_cache = TTLCache('well-known')
+well_known_cache = TTLCache("well-known")
@implementer(IAgent)
@@ -78,7 +78,9 @@ class MatrixFederationAgent(object):
"""
def __init__(
- self, reactor, tls_client_options_factory,
+ self,
+ reactor,
+ tls_client_options_factory,
_well_known_tls_policy=None,
_srv_resolver=None,
_well_known_cache=well_known_cache,
@@ -100,9 +102,9 @@ class MatrixFederationAgent(object):
if _well_known_tls_policy is not None:
# the param is called 'contextFactory', but actually passing a
# contextfactory is deprecated, and it expects an IPolicyForHTTPS.
- agent_args['contextFactory'] = _well_known_tls_policy
+ agent_args["contextFactory"] = _well_known_tls_policy
_well_known_agent = RedirectAgent(
- Agent(self._reactor, pool=self._pool, **agent_args),
+ Agent(self._reactor, pool=self._pool, **agent_args)
)
self._well_known_agent = _well_known_agent
@@ -149,7 +151,7 @@ class MatrixFederationAgent(object):
tls_options = None
else:
tls_options = self._tls_client_options_factory.get_options(
- res.tls_server_name.decode("ascii"),
+ res.tls_server_name.decode("ascii")
)
# make sure that the Host header is set correctly
@@ -158,14 +160,14 @@ class MatrixFederationAgent(object):
else:
headers = headers.copy()
- if not headers.hasHeader(b'host'):
- headers.addRawHeader(b'host', res.host_header)
+ if not headers.hasHeader(b"host"):
+ headers.addRawHeader(b"host", res.host_header)
class EndpointFactory(object):
@staticmethod
def endpointForURI(_uri):
ep = LoggingHostnameEndpoint(
- self._reactor, res.target_host, res.target_port,
+ self._reactor, res.target_host, res.target_port
)
if tls_options is not None:
ep = wrapClientTLS(tls_options, ep)
@@ -203,21 +205,25 @@ class MatrixFederationAgent(object):
port = parsed_uri.port
if port == -1:
port = 8448
- defer.returnValue(_RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=parsed_uri.host,
- target_port=port,
- ))
+ defer.returnValue(
+ _RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=parsed_uri.host,
+ target_port=port,
+ )
+ )
if parsed_uri.port != -1:
# there is an explicit port
- defer.returnValue(_RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=parsed_uri.host,
- target_port=parsed_uri.port,
- ))
+ defer.returnValue(
+ _RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=parsed_uri.host,
+ target_port=parsed_uri.port,
+ )
+ )
if lookup_well_known:
# try a .well-known lookup
@@ -229,8 +235,8 @@ class MatrixFederationAgent(object):
# parse the server name in the .well-known response into host/port.
# (This code is lifted from twisted.web.client.URI.fromBytes).
- if b':' in well_known_server:
- well_known_host, well_known_port = well_known_server.rsplit(b':', 1)
+ if b":" in well_known_server:
+ well_known_host, well_known_port = well_known_server.rsplit(b":", 1)
try:
well_known_port = int(well_known_port)
except ValueError:
@@ -264,21 +270,27 @@ class MatrixFederationAgent(object):
port = 8448
logger.debug(
"No SRV record for %s, using %s:%i",
- parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port,
+ parsed_uri.host.decode("ascii"),
+ target_host.decode("ascii"),
+ port,
)
else:
target_host, port = pick_server_from_list(server_list)
logger.debug(
"Picked %s:%i from SRV records for %s",
- target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"),
+ target_host.decode("ascii"),
+ port,
+ parsed_uri.host.decode("ascii"),
)
- defer.returnValue(_RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=target_host,
- target_port=port,
- ))
+ defer.returnValue(
+ _RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=target_host,
+ target_port=port,
+ )
+ )
@defer.inlineCallbacks
def _get_well_known(self, server_name):
@@ -318,18 +330,18 @@ class MatrixFederationAgent(object):
- None if there was no .well-known file.
- INVALID_WELL_KNOWN if the .well-known was invalid
"""
- uri = b"https://%s/.well-known/matrix/server" % (server_name, )
+ uri = b"https://%s/.well-known/matrix/server" % (server_name,)
uri_str = uri.decode("ascii")
logger.info("Fetching %s", uri_str)
try:
response = yield make_deferred_yieldable(
- self._well_known_agent.request(b"GET", uri),
+ self._well_known_agent.request(b"GET", uri)
)
body = yield make_deferred_yieldable(readBody(response))
if response.code != 200:
- raise Exception("Non-200 response %s" % (response.code, ))
+ raise Exception("Non-200 response %s" % (response.code,))
- parsed_body = json.loads(body.decode('utf-8'))
+ parsed_body = json.loads(body.decode("utf-8"))
logger.info("Response from .well-known: %s", parsed_body)
if not isinstance(parsed_body, dict):
raise Exception("not a dict")
@@ -347,8 +359,7 @@ class MatrixFederationAgent(object):
result = parsed_body["m.server"].encode("ascii")
cache_period = _cache_period_from_headers(
- response.headers,
- time_now=self._reactor.seconds,
+ response.headers, time_now=self._reactor.seconds
)
if cache_period is None:
cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD
@@ -364,6 +375,7 @@ class MatrixFederationAgent(object):
@implementer(IStreamClientEndpoint)
class LoggingHostnameEndpoint(object):
"""A wrapper for HostnameEndpint which logs when it connects"""
+
def __init__(self, reactor, host, port, *args, **kwargs):
self.host = host
self.port = port
@@ -377,17 +389,17 @@ class LoggingHostnameEndpoint(object):
def _cache_period_from_headers(headers, time_now=time.time):
cache_controls = _parse_cache_control(headers)
- if b'no-store' in cache_controls:
+ if b"no-store" in cache_controls:
return 0
- if b'max-age' in cache_controls:
+ if b"max-age" in cache_controls:
try:
- max_age = int(cache_controls[b'max-age'])
+ max_age = int(cache_controls[b"max-age"])
return max_age
except ValueError:
pass
- expires = headers.getRawHeaders(b'expires')
+ expires = headers.getRawHeaders(b"expires")
if expires is not None:
try:
expires_date = stringToDatetime(expires[-1])
@@ -403,9 +415,9 @@ def _cache_period_from_headers(headers, time_now=time.time):
def _parse_cache_control(headers):
cache_controls = {}
- for hdr in headers.getRawHeaders(b'cache-control', []):
- for directive in hdr.split(b','):
- splits = [x.strip() for x in directive.split(b'=', 1)]
+ for hdr in headers.getRawHeaders(b"cache-control", []):
+ for directive in hdr.split(b","):
+ splits = [x.strip() for x in directive.split(b"=", 1)]
k = splits[0].lower()
v = splits[1] if len(splits) > 1 else None
cache_controls[k] = v
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index 71830c549d..1f22f78a75 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -45,6 +45,7 @@ class Server(object):
expires (int): when the cache should expire this record - in *seconds* since
the epoch
"""
+
host = attr.ib()
port = attr.ib()
priority = attr.ib(default=0)
@@ -79,9 +80,7 @@ def pick_server_from_list(server_list):
return s.host, s.port
# this should be impossible.
- raise RuntimeError(
- "pick_server_from_list got to end of eligible server list.",
- )
+ raise RuntimeError("pick_server_from_list got to end of eligible server list.")
class SrvResolver(object):
@@ -95,6 +94,7 @@ class SrvResolver(object):
cache (dict): cache object
get_time (callable): clock implementation. Should return seconds since the epoch
"""
+
def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
self._dns_client = dns_client
self._cache = cache
@@ -124,7 +124,7 @@ class SrvResolver(object):
try:
answers, _, _ = yield make_deferred_yieldable(
- self._dns_client.lookupService(service_name),
+ self._dns_client.lookupService(service_name)
)
except DNSNameError:
# TODO: cache this. We can get the SOA out of the exception, and use
@@ -136,17 +136,18 @@ class SrvResolver(object):
cache_entry = self._cache.get(service_name, None)
if cache_entry:
logger.warn(
- "Failed to resolve %r, falling back to cache. %r",
- service_name, e
+ "Failed to resolve %r, falling back to cache. %r", service_name, e
)
defer.returnValue(list(cache_entry))
else:
raise e
- if (len(answers) == 1
- and answers[0].type == dns.SRV
- and answers[0].payload
- and answers[0].payload.target == dns.Name(b'.')):
+ if (
+ len(answers) == 1
+ and answers[0].type == dns.SRV
+ and answers[0].payload
+ and answers[0].payload.target == dns.Name(b".")
+ ):
raise ConnectError("Service %s unavailable" % service_name)
servers = []
@@ -157,13 +158,15 @@ class SrvResolver(object):
payload = answer.payload
- servers.append(Server(
- host=payload.target.name,
- port=payload.port,
- priority=payload.priority,
- weight=payload.weight,
- expires=now + answer.ttl,
- ))
+ servers.append(
+ Server(
+ host=payload.target.name,
+ port=payload.port,
+ priority=payload.priority,
+ weight=payload.weight,
+ expires=now + answer.ttl,
+ )
+ )
self._cache[service_name] = list(servers)
defer.returnValue(servers)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 663ea72a7a..5ef8bb60a3 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -54,10 +54,12 @@ from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
-outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests",
- "", ["method"])
-incoming_responses_counter = Counter("synapse_http_matrixfederationclient_responses",
- "", ["method", "code"])
+outgoing_requests_counter = Counter(
+ "synapse_http_matrixfederationclient_requests", "", ["method"]
+)
+incoming_responses_counter = Counter(
+ "synapse_http_matrixfederationclient_responses", "", ["method", "code"]
+)
MAX_LONG_RETRIES = 10
@@ -137,11 +139,7 @@ def _handle_json_response(reactor, timeout_sec, request, response):
check_content_type_is_json(response.headers)
d = treq.json_content(response)
- d = timeout_deferred(
- d,
- timeout=timeout_sec,
- reactor=reactor,
- )
+ d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
body = yield make_deferred_yieldable(d)
except Exception as e:
@@ -157,7 +155,7 @@ def _handle_json_response(reactor, timeout_sec, request, response):
request.txn_id,
request.destination,
response.code,
- response.phrase.decode('ascii', errors='replace'),
+ response.phrase.decode("ascii", errors="replace"),
)
defer.returnValue(body)
@@ -181,7 +179,7 @@ class MatrixFederationHttpClient(object):
# We need to use a DNS resolver which filters out blacklisted IP
# addresses, to prevent DNS rebinding.
nameResolver = IPBlacklistingResolver(
- real_reactor, None, hs.config.federation_ip_range_blacklist,
+ real_reactor, None, hs.config.federation_ip_range_blacklist
)
@implementer(IReactorPluggableNameResolver)
@@ -194,21 +192,19 @@ class MatrixFederationHttpClient(object):
self.reactor = Reactor()
- self.agent = MatrixFederationAgent(
- self.reactor,
- tls_client_options_factory,
- )
+ self.agent = MatrixFederationAgent(self.reactor, tls_client_options_factory)
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
- self.agent, self.reactor,
+ self.agent,
+ self.reactor,
ip_blacklist=hs.config.federation_ip_range_blacklist,
)
self.clock = hs.get_clock()
self._store = hs.get_datastore()
- self.version_string_bytes = hs.version_string.encode('ascii')
+ self.version_string_bytes = hs.version_string.encode("ascii")
self.default_timeout = 60
def schedule(x):
@@ -218,10 +214,7 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks
def _send_request_with_optional_trailing_slash(
- self,
- request,
- try_trailing_slash_on_400=False,
- **send_request_args
+ self, request, try_trailing_slash_on_400=False, **send_request_args
):
"""Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a
@@ -244,9 +237,7 @@ class MatrixFederationHttpClient(object):
Deferred[Dict]: Parsed JSON response body.
"""
try:
- response = yield self._send_request(
- request, **send_request_args
- )
+ response = yield self._send_request(request, **send_request_args)
except HttpResponseException as e:
# Received an HTTP error > 300. Check if it meets the requirements
# to retry with a trailing slash
@@ -262,9 +253,7 @@ class MatrixFederationHttpClient(object):
logger.info("Retrying request with trailing slash")
request.path += "/"
- response = yield self._send_request(
- request, **send_request_args
- )
+ response = yield self._send_request(request, **send_request_args)
defer.returnValue(response)
@@ -329,8 +318,8 @@ class MatrixFederationHttpClient(object):
_sec_timeout = self.default_timeout
if (
- self.hs.config.federation_domain_whitelist is not None and
- request.destination not in self.hs.config.federation_domain_whitelist
+ self.hs.config.federation_domain_whitelist is not None
+ and request.destination not in self.hs.config.federation_domain_whitelist
):
raise FederationDeniedError(request.destination)
@@ -350,9 +339,7 @@ class MatrixFederationHttpClient(object):
else:
query_bytes = b""
- headers_dict = {
- b"User-Agent": [self.version_string_bytes],
- }
+ headers_dict = {b"User-Agent": [self.version_string_bytes]}
with limiter:
# XXX: Would be much nicer to retry only at the transaction-layer
@@ -362,16 +349,14 @@ class MatrixFederationHttpClient(object):
else:
retries_left = MAX_SHORT_RETRIES
- url_bytes = urllib.parse.urlunparse((
- b"matrix", destination_bytes,
- path_bytes, None, query_bytes, b"",
- ))
- url_str = url_bytes.decode('ascii')
+ url_bytes = urllib.parse.urlunparse(
+ (b"matrix", destination_bytes, path_bytes, None, query_bytes, b"")
+ )
+ url_str = url_bytes.decode("ascii")
- url_to_sign_bytes = urllib.parse.urlunparse((
- b"", b"",
- path_bytes, None, query_bytes, b"",
- ))
+ url_to_sign_bytes = urllib.parse.urlunparse(
+ (b"", b"", path_bytes, None, query_bytes, b"")
+ )
while True:
try:
@@ -379,26 +364,27 @@ class MatrixFederationHttpClient(object):
if json:
headers_dict[b"Content-Type"] = [b"application/json"]
auth_headers = self.build_auth_headers(
- destination_bytes, method_bytes, url_to_sign_bytes,
- json,
+ destination_bytes, method_bytes, url_to_sign_bytes, json
)
data = encode_canonical_json(json)
producer = QuieterFileBodyProducer(
- BytesIO(data),
- cooperator=self._cooperator,
+ BytesIO(data), cooperator=self._cooperator
)
else:
producer = None
auth_headers = self.build_auth_headers(
- destination_bytes, method_bytes, url_to_sign_bytes,
+ destination_bytes, method_bytes, url_to_sign_bytes
)
headers_dict[b"Authorization"] = auth_headers
logger.info(
"{%s} [%s] Sending request: %s %s; timeout %fs",
- request.txn_id, request.destination, request.method,
- url_str, _sec_timeout,
+ request.txn_id,
+ request.destination,
+ request.method,
+ url_str,
+ _sec_timeout,
)
try:
@@ -430,7 +416,7 @@ class MatrixFederationHttpClient(object):
request.txn_id,
request.destination,
response.code,
- response.phrase.decode('ascii', errors='replace'),
+ response.phrase.decode("ascii", errors="replace"),
)
if 200 <= response.code < 300:
@@ -440,9 +426,7 @@ class MatrixFederationHttpClient(object):
# Update transactions table?
d = treq.content(response)
d = timeout_deferred(
- d,
- timeout=_sec_timeout,
- reactor=self.reactor,
+ d, timeout=_sec_timeout, reactor=self.reactor
)
try:
@@ -460,9 +444,7 @@ class MatrixFederationHttpClient(object):
)
body = None
- e = HttpResponseException(
- response.code, response.phrase, body
- )
+ e = HttpResponseException(response.code, response.phrase, body)
# Retry if the error is a 429 (Too Many Requests),
# otherwise just raise a standard HttpResponseException
@@ -521,7 +503,7 @@ class MatrixFederationHttpClient(object):
defer.returnValue(response)
def build_auth_headers(
- self, destination, method, url_bytes, content=None, destination_is=None,
+ self, destination, method, url_bytes, content=None, destination_is=None
):
"""
Builds the Authorization headers for a federation request
@@ -538,11 +520,7 @@ class MatrixFederationHttpClient(object):
Returns:
list[bytes]: a list of headers to be added as "Authorization:" headers
"""
- request = {
- "method": method,
- "uri": url_bytes,
- "origin": self.server_name,
- }
+ request = {"method": method, "uri": url_bytes, "origin": self.server_name}
if destination is not None:
request["destination"] = destination
@@ -558,20 +536,28 @@ class MatrixFederationHttpClient(object):
auth_headers = []
for key, sig in request["signatures"][self.server_name].items():
- auth_headers.append((
- "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
- self.server_name, key, sig,
- )).encode('ascii')
+ auth_headers.append(
+ (
+ 'X-Matrix origin=%s,key="%s",sig="%s"'
+ % (self.server_name, key, sig)
+ ).encode("ascii")
)
return auth_headers
@defer.inlineCallbacks
- def put_json(self, destination, path, args={}, data={},
- json_data_callback=None,
- long_retries=False, timeout=None,
- ignore_backoff=False,
- backoff_on_404=False,
- try_trailing_slash_on_400=False):
+ def put_json(
+ self,
+ destination,
+ path,
+ args={},
+ data={},
+ json_data_callback=None,
+ long_retries=False,
+ timeout=None,
+ ignore_backoff=False,
+ backoff_on_404=False,
+ try_trailing_slash_on_400=False,
+ ):
""" Sends the specifed json data using PUT
Args:
@@ -635,14 +621,22 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
- self.reactor, self.default_timeout, request, response,
+ self.reactor, self.default_timeout, request, response
)
defer.returnValue(body)
@defer.inlineCallbacks
- def post_json(self, destination, path, data={}, long_retries=False,
- timeout=None, ignore_backoff=False, args={}):
+ def post_json(
+ self,
+ destination,
+ path,
+ data={},
+ long_retries=False,
+ timeout=None,
+ ignore_backoff=False,
+ args={},
+ ):
""" Sends the specifed json data using POST
Args:
@@ -681,11 +675,7 @@ class MatrixFederationHttpClient(object):
"""
request = MatrixFederationRequest(
- method="POST",
- destination=destination,
- path=path,
- query=args,
- json=data,
+ method="POST", destination=destination, path=path, query=args, json=data
)
response = yield self._send_request(
@@ -701,14 +691,21 @@ class MatrixFederationHttpClient(object):
_sec_timeout = self.default_timeout
body = yield _handle_json_response(
- self.reactor, _sec_timeout, request, response,
+ self.reactor, _sec_timeout, request, response
)
defer.returnValue(body)
@defer.inlineCallbacks
- def get_json(self, destination, path, args=None, retry_on_dns_fail=True,
- timeout=None, ignore_backoff=False,
- try_trailing_slash_on_400=False):
+ def get_json(
+ self,
+ destination,
+ path,
+ args=None,
+ retry_on_dns_fail=True,
+ timeout=None,
+ ignore_backoff=False,
+ try_trailing_slash_on_400=False,
+ ):
""" GETs some json from the given host homeserver and path
Args:
@@ -745,10 +742,7 @@ class MatrixFederationHttpClient(object):
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
- method="GET",
- destination=destination,
- path=path,
- query=args,
+ method="GET", destination=destination, path=path, query=args
)
response = yield self._send_request_with_optional_trailing_slash(
@@ -761,14 +755,21 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
- self.reactor, self.default_timeout, request, response,
+ self.reactor, self.default_timeout, request, response
)
defer.returnValue(body)
@defer.inlineCallbacks
- def delete_json(self, destination, path, long_retries=False,
- timeout=None, ignore_backoff=False, args={}):
+ def delete_json(
+ self,
+ destination,
+ path,
+ long_retries=False,
+ timeout=None,
+ ignore_backoff=False,
+ args={},
+ ):
"""Send a DELETE request to the remote expecting some json response
Args:
@@ -802,10 +803,7 @@ class MatrixFederationHttpClient(object):
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
- method="DELETE",
- destination=destination,
- path=path,
- query=args,
+ method="DELETE", destination=destination, path=path, query=args
)
response = yield self._send_request(
@@ -816,14 +814,21 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
- self.reactor, self.default_timeout, request, response,
+ self.reactor, self.default_timeout, request, response
)
defer.returnValue(body)
@defer.inlineCallbacks
- def get_file(self, destination, path, output_stream, args={},
- retry_on_dns_fail=True, max_size=None,
- ignore_backoff=False):
+ def get_file(
+ self,
+ destination,
+ path,
+ output_stream,
+ args={},
+ retry_on_dns_fail=True,
+ max_size=None,
+ ignore_backoff=False,
+ ):
"""GETs a file from a given homeserver
Args:
destination (str): The remote server to send the HTTP request to.
@@ -848,16 +853,11 @@ class MatrixFederationHttpClient(object):
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
- method="GET",
- destination=destination,
- path=path,
- query=args,
+ method="GET", destination=destination, path=path, query=args
)
response = yield self._send_request(
- request,
- retry_on_dns_fail=retry_on_dns_fail,
- ignore_backoff=ignore_backoff,
+ request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff
)
headers = dict(response.headers.getAllRawHeaders())
@@ -879,7 +879,7 @@ class MatrixFederationHttpClient(object):
request.txn_id,
request.destination,
response.code,
- response.phrase.decode('ascii', errors='replace'),
+ response.phrase.decode("ascii", errors="replace"),
length,
)
defer.returnValue((length, headers))
@@ -896,11 +896,13 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
- self.deferred.errback(SynapseError(
- 502,
- "Requested file is too large > %r bytes" % (self.max_size,),
- Codes.TOO_LARGE,
- ))
+ self.deferred.errback(
+ SynapseError(
+ 502,
+ "Requested file is too large > %r bytes" % (self.max_size,),
+ Codes.TOO_LARGE,
+ )
+ )
self.deferred = defer.Deferred()
self.transport.loseConnection()
@@ -920,8 +922,7 @@ def _readBodyToFile(response, stream, max_size):
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
reasons = ", ".join(
- _flatten_response_never_received(f.value)
- for f in e.reasons
+ _flatten_response_never_received(f.value) for f in e.reasons
)
return "%s:[%s]" % (type(e).__name__, reasons)
@@ -943,16 +944,15 @@ def check_content_type_is_json(headers):
"""
c_type = headers.getRawHeaders(b"Content-Type")
if c_type is None:
- raise RequestSendFailed(RuntimeError(
- "No Content-Type header"
- ), can_retry=False)
+ raise RequestSendFailed(RuntimeError("No Content-Type header"), can_retry=False)
- c_type = c_type[0].decode('ascii') # only the first header
+ c_type = c_type[0].decode("ascii") # only the first header
val, options = cgi.parse_header(c_type)
if val != "application/json":
- raise RequestSendFailed(RuntimeError(
- "Content-Type not application/json: was '%s'" % c_type
- ), can_retry=False)
+ raise RequestSendFailed(
+ RuntimeError("Content-Type not application/json: was '%s'" % c_type),
+ can_retry=False,
+ )
def encode_query_args(args):
@@ -967,4 +967,4 @@ def encode_query_args(args):
query_bytes = urllib.parse.urlencode(encoded_args, True)
- return query_bytes.encode('utf8')
+ return query_bytes.encode("utf8")
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 16fb7935da..6fd13e87d1 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -81,9 +81,7 @@ def wrap_json_request_handler(h):
yield h(self, request)
except SynapseError as e:
code = e.code
- logger.info(
- "%s SynapseError: %s - %s", request, code, e.msg
- )
+ logger.info("%s SynapseError: %s - %s", request, code, e.msg)
# Only respond with an error response if we haven't already started
# writing, otherwise lets just kill the connection
@@ -96,7 +94,10 @@ def wrap_json_request_handler(h):
pass
else:
respond_with_json(
- request, code, e.error_dict(), send_cors=True,
+ request,
+ code,
+ e.error_dict(),
+ send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
@@ -124,10 +125,7 @@ def wrap_json_request_handler(h):
respond_with_json(
request,
500,
- {
- "error": "Internal server error",
- "errcode": Codes.UNKNOWN,
- },
+ {"error": "Internal server error", "errcode": Codes.UNKNOWN},
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
@@ -143,6 +141,7 @@ def wrap_html_request_handler(h):
The handler method must have a signature of "handle_foo(self, request)",
where "request" must be a SynapseRequest.
"""
+
def wrapped_request_handler(self, request):
d = defer.maybeDeferred(h, self, request)
d.addErrback(_return_html_error, request)
@@ -164,9 +163,7 @@ def _return_html_error(f, request):
msg = cme.msg
if isinstance(cme, SynapseError):
- logger.info(
- "%s SynapseError: %s - %s", request, code, msg
- )
+ logger.info("%s SynapseError: %s - %s", request, code, msg)
else:
logger.error(
"Failed handle request %r",
@@ -183,9 +180,7 @@ def _return_html_error(f, request):
exc_info=(f.type, f.value, f.getTracebackObject()),
)
- body = HTML_ERROR_TEMPLATE.format(
- code=code, msg=cgi.escape(msg),
- ).encode("utf-8")
+ body = HTML_ERROR_TEMPLATE.format(code=code, msg=cgi.escape(msg)).encode("utf-8")
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%i" % (len(body),))
@@ -205,6 +200,7 @@ def wrap_async_request_handler(h):
The handler may return a deferred, in which case the completion of the request isn't
logged until the deferred completes.
"""
+
@defer.inlineCallbacks
def wrapped_async_request_handler(self, request):
with request.processing():
@@ -306,12 +302,14 @@ class JsonResource(HttpServer, resource.Resource):
# URL again (as it was decoded by _get_handler_for request), as
# ASCII because it's a URL, and then decode it to get the UTF-8
# characters that were quoted.
- return urllib.parse.unquote(s.encode('ascii')).decode('utf8')
+ return urllib.parse.unquote(s.encode("ascii")).decode("utf8")
- kwargs = intern_dict({
- name: _unquote(value) if value else value
- for name, value in group_dict.items()
- })
+ kwargs = intern_dict(
+ {
+ name: _unquote(value) if value else value
+ for name, value in group_dict.items()
+ }
+ )
callback_return = yield callback(request, **kwargs)
if callback_return is not None:
@@ -339,7 +337,7 @@ class JsonResource(HttpServer, resource.Resource):
# Loop through all the registered callbacks to check if the method
# and path regex match
for path_entry in self.path_regexs.get(request.method, []):
- m = path_entry.pattern.match(request.path.decode('ascii'))
+ m = path_entry.pattern.match(request.path.decode("ascii"))
if m:
# We found a match!
return path_entry.callback, m.groupdict()
@@ -347,11 +345,14 @@ class JsonResource(HttpServer, resource.Resource):
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
return _unrecognised_request_handler, {}
- def _send_response(self, request, code, response_json_object,
- response_code_message=None):
+ def _send_response(
+ self, request, code, response_json_object, response_code_message=None
+ ):
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
- request, code, response_json_object,
+ request,
+ code,
+ response_json_object,
send_cors=True,
response_code_message=response_code_message,
pretty_print=_request_user_agent_is_curl(request),
@@ -395,7 +396,7 @@ class RootRedirect(resource.Resource):
self.url = path
def render_GET(self, request):
- return redirectTo(self.url.encode('ascii'), request)
+ return redirectTo(self.url.encode("ascii"), request)
def getChild(self, name, request):
if len(name) == 0:
@@ -403,16 +404,22 @@ class RootRedirect(resource.Resource):
return resource.Resource.getChild(self, name, request)
-def respond_with_json(request, code, json_object, send_cors=False,
- response_code_message=None, pretty_print=False,
- canonical_json=True):
+def respond_with_json(
+ request,
+ code,
+ json_object,
+ send_cors=False,
+ response_code_message=None,
+ pretty_print=False,
+ canonical_json=True,
+):
# could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
if request._disconnected:
logger.warn(
- "Not sending response to request %s, already disconnected.",
- request)
+ "Not sending response to request %s, already disconnected.", request
+ )
return
if pretty_print:
@@ -425,14 +432,17 @@ def respond_with_json(request, code, json_object, send_cors=False,
json_bytes = json.dumps(json_object).encode("utf-8")
return respond_with_json_bytes(
- request, code, json_bytes,
+ request,
+ code,
+ json_bytes,
send_cors=send_cors,
response_code_message=response_code_message,
)
-def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
- response_code_message=None):
+def respond_with_json_bytes(
+ request, code, json_bytes, send_cors=False, response_code_message=None
+):
"""Sends encoded JSON in response to the given request.
Args:
@@ -474,7 +484,7 @@ def set_cors_headers(request):
)
request.setHeader(
b"Access-Control-Allow-Headers",
- b"Origin, X-Requested-With, Content-Type, Accept, Authorization"
+ b"Origin, X-Requested-With, Content-Type, Accept, Authorization",
)
@@ -498,9 +508,7 @@ def finish_request(request):
def _request_user_agent_is_curl(request):
- user_agents = request.requestHeaders.getRawHeaders(
- b"User-Agent", default=[]
- )
+ user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[])
for user_agent in user_agents:
if b"curl" in user_agent:
return True
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 197c652850..cd8415acd5 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -48,7 +48,7 @@ def parse_integer(request, name, default=None, required=False):
def parse_integer_from_args(args, name, default=None, required=False):
if not isinstance(name, bytes):
- name = name.encode('ascii')
+ name = name.encode("ascii")
if name in args:
try:
@@ -89,18 +89,14 @@ def parse_boolean(request, name, default=None, required=False):
def parse_boolean_from_args(args, name, default=None, required=False):
if not isinstance(name, bytes):
- name = name.encode('ascii')
+ name = name.encode("ascii")
if name in args:
try:
- return {
- b"true": True,
- b"false": False,
- }[args[name][0]]
+ return {b"true": True, b"false": False}[args[name][0]]
except Exception:
message = (
- "Boolean query parameter %r must be one of"
- " ['true', 'false']"
+ "Boolean query parameter %r must be one of" " ['true', 'false']"
) % (name,)
raise SynapseError(400, message)
else:
@@ -111,8 +107,15 @@ def parse_boolean_from_args(args, name, default=None, required=False):
return default
-def parse_string(request, name, default=None, required=False,
- allowed_values=None, param_type="string", encoding='ascii'):
+def parse_string(
+ request,
+ name,
+ default=None,
+ required=False,
+ allowed_values=None,
+ param_type="string",
+ encoding="ascii",
+):
"""
Parse a string parameter from the request query string.
@@ -145,11 +148,18 @@ def parse_string(request, name, default=None, required=False,
)
-def parse_string_from_args(args, name, default=None, required=False,
- allowed_values=None, param_type="string", encoding='ascii'):
+def parse_string_from_args(
+ args,
+ name,
+ default=None,
+ required=False,
+ allowed_values=None,
+ param_type="string",
+ encoding="ascii",
+):
if not isinstance(name, bytes):
- name = name.encode('ascii')
+ name = name.encode("ascii")
if name in args:
value = args[name][0]
@@ -159,7 +169,8 @@ def parse_string_from_args(args, name, default=None, required=False,
if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
- name, ", ".join(repr(v) for v in allowed_values)
+ name,
+ ", ".join(repr(v) for v in allowed_values),
)
raise SynapseError(400, message)
else:
@@ -201,7 +212,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
# Decode to Unicode so that simplejson will return Unicode strings on
# Python 2
try:
- content_unicode = content_bytes.decode('utf8')
+ content_unicode = content_bytes.decode("utf8")
except UnicodeDecodeError:
logger.warn("Unable to decode UTF-8")
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
@@ -227,9 +238,7 @@ def parse_json_object_from_request(request, allow_empty_body=False):
SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object.
"""
- content = parse_json_value_from_request(
- request, allow_empty_body=allow_empty_body,
- )
+ content = parse_json_value_from_request(request, allow_empty_body=allow_empty_body)
if allow_empty_body and content is None:
return {}
diff --git a/synapse/http/site.py b/synapse/http/site.py
index e508c0bd4f..93f679ea48 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -46,10 +46,11 @@ class SynapseRequest(Request):
Attributes:
logcontext(LoggingContext) : the log context for this request
"""
+
def __init__(self, site, channel, *args, **kw):
Request.__init__(self, channel, *args, **kw)
self.site = site
- self._channel = channel # this is used by the tests
+ self._channel = channel # this is used by the tests
self.authenticated_entity = None
self.start_time = 0
@@ -72,12 +73,12 @@ class SynapseRequest(Request):
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
- return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
+ return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % (
self.__class__.__name__,
id(self),
self.get_method(),
self.get_redacted_uri(),
- self.clientproto.decode('ascii', errors='replace'),
+ self.clientproto.decode("ascii", errors="replace"),
self.site.site_tag,
)
@@ -87,7 +88,7 @@ class SynapseRequest(Request):
def get_redacted_uri(self):
uri = self.uri
if isinstance(uri, bytes):
- uri = self.uri.decode('ascii')
+ uri = self.uri.decode("ascii")
return redact_uri(uri)
def get_method(self):
@@ -102,7 +103,7 @@ class SynapseRequest(Request):
"""
method = self.method
if isinstance(method, bytes):
- method = self.method.decode('ascii')
+ method = self.method.decode("ascii")
return method
def get_user_agent(self):
@@ -134,8 +135,7 @@ class SynapseRequest(Request):
# dispatching to the handler, so that the handler
# can update the servlet name in the request
# metrics
- requests_counter.labels(self.get_method(),
- self.request_metrics.name).inc()
+ requests_counter.labels(self.get_method(), self.request_metrics.name).inc()
@contextlib.contextmanager
def processing(self):
@@ -200,7 +200,7 @@ class SynapseRequest(Request):
# the client disconnects.
with PreserveLoggingContext(self.logcontext):
logger.warn(
- "Error processing request %r: %s %s", self, reason.type, reason.value,
+ "Error processing request %r: %s %s", self, reason.type, reason.value
)
if not self._is_processing:
@@ -222,7 +222,7 @@ class SynapseRequest(Request):
self.start_time = time.time()
self.request_metrics = RequestMetrics()
self.request_metrics.start(
- self.start_time, name=servlet_name, method=self.get_method(),
+ self.start_time, name=servlet_name, method=self.get_method()
)
self.site.access_logger.info(
@@ -230,7 +230,7 @@ class SynapseRequest(Request):
self.getClientIP(),
self.site.site_tag,
self.get_method(),
- self.get_redacted_uri()
+ self.get_redacted_uri(),
)
def _finished_processing(self):
@@ -282,7 +282,7 @@ class SynapseRequest(Request):
self.site.access_logger.info(
"%s - %s - {%s}"
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
- " %sB %s \"%s %s %s\" \"%s\" [%d dbevts]",
+ ' %sB %s "%s %s %s" "%s" [%d dbevts]',
self.getClientIP(),
self.site.site_tag,
authenticated_entity,
@@ -297,7 +297,7 @@ class SynapseRequest(Request):
code,
self.get_method(),
self.get_redacted_uri(),
- self.clientproto.decode('ascii', errors='replace'),
+ self.clientproto.decode("ascii", errors="replace"),
user_agent,
usage.evt_db_fetch_count,
)
@@ -316,14 +316,19 @@ class XForwardedForRequest(SynapseRequest):
Add a layer on top of another request that only uses the value of an
X-Forwarded-For header as the result of C{getClientIP}.
"""
+
def getClientIP(self):
"""
@return: The client address (the first address) in the value of the
I{X-Forwarded-For header}. If the header is not present, return
C{b"-"}.
"""
- return self.requestHeaders.getRawHeaders(
- b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip().decode('ascii')
+ return (
+ self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0]
+ .split(b",")[0]
+ .strip()
+ .decode("ascii")
+ )
class SynapseRequestFactory(object):
@@ -343,8 +348,17 @@ class SynapseSite(Site):
Subclass of a twisted http Site that does access logging with python's
standard logging
"""
- def __init__(self, logger_name, site_tag, config, resource,
- server_version_string, *args, **kwargs):
+
+ def __init__(
+ self,
+ logger_name,
+ site_tag,
+ config,
+ resource,
+ server_version_string,
+ *args,
+ **kwargs
+ ):
Site.__init__(self, resource, *args, **kwargs)
self.site_tag = site_tag
@@ -352,7 +366,7 @@ class SynapseSite(Site):
proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name)
- self.server_version_string = server_version_string.encode('ascii')
+ self.server_version_string = server_version_string.encode("ascii")
def log(self, request):
pass
|