diff options
Diffstat (limited to 'synapse/http')
-rw-r--r-- | synapse/http/__init__.py | 9 | ||||
-rw-r--r-- | synapse/http/additional_resource.py | 1 | ||||
-rw-r--r-- | synapse/http/client.py | 31 | ||||
-rw-r--r-- | synapse/http/endpoint.py | 20 | ||||
-rw-r--r-- | synapse/http/federation/matrix_federation_agent.py | 98 | ||||
-rw-r--r-- | synapse/http/federation/srv_resolver.py | 37 | ||||
-rw-r--r-- | synapse/http/matrixfederationclient.py | 242 | ||||
-rw-r--r-- | synapse/http/server.py | 153 | ||||
-rw-r--r-- | synapse/http/servlet.py | 45 | ||||
-rw-r--r-- | synapse/http/site.py | 48 |
10 files changed, 364 insertions, 320 deletions
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index d36bcd6336..3acf772cd1 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError class RequestTimedOutError(SynapseError): """Exception representing timeout of an outbound request""" + def __init__(self): super(RequestTimedOutError, self).__init__(504, "Timed out") @@ -40,15 +41,12 @@ def cancelled_to_request_timed_out_error(value, timeout): return value -ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') +ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$") def redact_uri(uri): """Strips access tokens from the uri replaces with <redacted>""" - return ACCESS_TOKEN_RE.sub( - r'\1<redacted>\3', - uri - ) + return ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri) class QuieterFileBodyProducer(FileBodyProducer): @@ -57,6 +55,7 @@ class QuieterFileBodyProducer(FileBodyProducer): Workaround for https://github.com/matrix-org/synapse/issues/4003 / https://twistedmatrix.com/trac/ticket/6528 """ + def stopProducing(self): try: FileBodyProducer.stopProducing(self) diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py index 0e10e3f8f7..096619a8c2 100644 --- a/synapse/http/additional_resource.py +++ b/synapse/http/additional_resource.py @@ -28,6 +28,7 @@ class AdditionalResource(Resource): This class is also where we wrap the request handler with logging, metrics, and exception handling. """ + def __init__(self, hs, handler): """Initialise AdditionalResource diff --git a/synapse/http/client.py b/synapse/http/client.py index 77fe68818b..9bc7035c8d 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -17,7 +17,7 @@ import logging from io import BytesIO -from six import text_type +from six import raise_from, text_type from six.moves import urllib import treq @@ -103,8 +103,8 @@ class IPBlacklistingResolver(object): ip_address, self._ip_whitelist, self._ip_blacklist ): logger.info( - "Dropped %s from DNS resolution to %s due to blacklist" % - (ip_address, hostname) + "Dropped %s from DNS resolution to %s due to blacklist" + % (ip_address, hostname) ) has_bad_ip = True @@ -156,7 +156,7 @@ class BlacklistingAgentWrapper(Agent): self._ip_blacklist = ip_blacklist def request(self, method, uri, headers=None, bodyProducer=None): - h = urllib.parse.urlparse(uri.decode('ascii')) + h = urllib.parse.urlparse(uri.decode("ascii")) try: ip_address = IPAddress(h.hostname) @@ -164,10 +164,7 @@ class BlacklistingAgentWrapper(Agent): if check_against_blacklist( ip_address, self._ip_whitelist, self._ip_blacklist ): - logger.info( - "Blocking access to %s due to blacklist" % - (ip_address,) - ) + logger.info("Blocking access to %s due to blacklist" % (ip_address,)) e = SynapseError(403, "IP address blocked by IP blacklist entry") return defer.fail(Failure(e)) except Exception: @@ -206,7 +203,7 @@ class SimpleHttpClient(object): if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix) - self.user_agent = self.user_agent.encode('ascii') + self.user_agent = self.user_agent.encode("ascii") if self._ip_blacklist: real_reactor = hs.get_reactor() @@ -520,8 +517,8 @@ class SimpleHttpClient(object): resp_headers = dict(response.headers.getAllRawHeaders()) if ( - b'Content-Length' in resp_headers - and int(resp_headers[b'Content-Length'][0]) > max_size + b"Content-Length" in resp_headers + and int(resp_headers[b"Content-Length"][0]) > max_size ): logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( @@ -542,17 +539,17 @@ class SimpleHttpClient(object): length = yield make_deferred_yieldable( _readBodyToFile(response, output_stream, max_size) ) + except SynapseError: + # This can happen e.g. because the body is too large. + raise except Exception as e: - logger.exception("Failed to download body") - raise SynapseError( - 502, ("Failed to download remote body: %s" % e), Codes.UNKNOWN - ) + raise_from(SynapseError(502, ("Failed to download remote body: %s" % e)), e) defer.returnValue( ( length, resp_headers, - response.request.absoluteURI.decode('ascii'), + response.request.absoluteURI.decode("ascii"), response.code, ) ) @@ -642,7 +639,7 @@ def encode_urlencode_args(args): def encode_urlencode_arg(arg): if isinstance(arg, text_type): - return arg.encode('utf-8') + return arg.encode("utf-8") elif isinstance(arg, list): return [encode_urlencode_arg(i) for i in arg] else: diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index cd79ebab62..92a5b606c8 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -31,7 +31,7 @@ def parse_server_name(server_name): ValueError if the server name could not be parsed. """ try: - if server_name[-1] == ']': + if server_name[-1] == "]": # ipv6 literal, hopefully return server_name, None @@ -43,9 +43,7 @@ def parse_server_name(server_name): raise ValueError("Invalid server name '%s'" % server_name) -VALID_HOST_REGEX = re.compile( - "\\A[0-9a-zA-Z.-]+\\Z", -) +VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") def parse_and_validate_server_name(server_name): @@ -67,17 +65,15 @@ def parse_and_validate_server_name(server_name): # that nobody is sneaking IP literals in that look like hostnames, etc. # look for ipv6 literals - if host[0] == '[': - if host[-1] != ']': - raise ValueError("Mismatched [...] in server name '%s'" % ( - server_name, - )) + if host[0] == "[": + if host[-1] != "]": + raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) return host, port # otherwise it should only be alphanumerics. if not VALID_HOST_REGEX.match(host): - raise ValueError("Server name '%s' contains invalid characters" % ( - server_name, - )) + raise ValueError( + "Server name '%s' contains invalid characters" % (server_name,) + ) return host, port diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index b4cbe97b41..414cde0777 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -48,7 +48,7 @@ WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600 WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600 logger = logging.getLogger(__name__) -well_known_cache = TTLCache('well-known') +well_known_cache = TTLCache("well-known") @implementer(IAgent) @@ -78,7 +78,9 @@ class MatrixFederationAgent(object): """ def __init__( - self, reactor, tls_client_options_factory, + self, + reactor, + tls_client_options_factory, _well_known_tls_policy=None, _srv_resolver=None, _well_known_cache=well_known_cache, @@ -100,9 +102,9 @@ class MatrixFederationAgent(object): if _well_known_tls_policy is not None: # the param is called 'contextFactory', but actually passing a # contextfactory is deprecated, and it expects an IPolicyForHTTPS. - agent_args['contextFactory'] = _well_known_tls_policy + agent_args["contextFactory"] = _well_known_tls_policy _well_known_agent = RedirectAgent( - Agent(self._reactor, pool=self._pool, **agent_args), + Agent(self._reactor, pool=self._pool, **agent_args) ) self._well_known_agent = _well_known_agent @@ -149,7 +151,7 @@ class MatrixFederationAgent(object): tls_options = None else: tls_options = self._tls_client_options_factory.get_options( - res.tls_server_name.decode("ascii"), + res.tls_server_name.decode("ascii") ) # make sure that the Host header is set correctly @@ -158,14 +160,14 @@ class MatrixFederationAgent(object): else: headers = headers.copy() - if not headers.hasHeader(b'host'): - headers.addRawHeader(b'host', res.host_header) + if not headers.hasHeader(b"host"): + headers.addRawHeader(b"host", res.host_header) class EndpointFactory(object): @staticmethod def endpointForURI(_uri): ep = LoggingHostnameEndpoint( - self._reactor, res.target_host, res.target_port, + self._reactor, res.target_host, res.target_port ) if tls_options is not None: ep = wrapClientTLS(tls_options, ep) @@ -203,21 +205,25 @@ class MatrixFederationAgent(object): port = parsed_uri.port if port == -1: port = 8448 - defer.returnValue(_RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=port, - )) + defer.returnValue( + _RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=parsed_uri.host, + target_port=port, + ) + ) if parsed_uri.port != -1: # there is an explicit port - defer.returnValue(_RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=parsed_uri.port, - )) + defer.returnValue( + _RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=parsed_uri.host, + target_port=parsed_uri.port, + ) + ) if lookup_well_known: # try a .well-known lookup @@ -229,8 +235,8 @@ class MatrixFederationAgent(object): # parse the server name in the .well-known response into host/port. # (This code is lifted from twisted.web.client.URI.fromBytes). - if b':' in well_known_server: - well_known_host, well_known_port = well_known_server.rsplit(b':', 1) + if b":" in well_known_server: + well_known_host, well_known_port = well_known_server.rsplit(b":", 1) try: well_known_port = int(well_known_port) except ValueError: @@ -264,21 +270,27 @@ class MatrixFederationAgent(object): port = 8448 logger.debug( "No SRV record for %s, using %s:%i", - parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port, + parsed_uri.host.decode("ascii"), + target_host.decode("ascii"), + port, ) else: target_host, port = pick_server_from_list(server_list) logger.debug( "Picked %s:%i from SRV records for %s", - target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"), + target_host.decode("ascii"), + port, + parsed_uri.host.decode("ascii"), ) - defer.returnValue(_RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=target_host, - target_port=port, - )) + defer.returnValue( + _RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=target_host, + target_port=port, + ) + ) @defer.inlineCallbacks def _get_well_known(self, server_name): @@ -318,18 +330,18 @@ class MatrixFederationAgent(object): - None if there was no .well-known file. - INVALID_WELL_KNOWN if the .well-known was invalid """ - uri = b"https://%s/.well-known/matrix/server" % (server_name, ) + uri = b"https://%s/.well-known/matrix/server" % (server_name,) uri_str = uri.decode("ascii") logger.info("Fetching %s", uri_str) try: response = yield make_deferred_yieldable( - self._well_known_agent.request(b"GET", uri), + self._well_known_agent.request(b"GET", uri) ) body = yield make_deferred_yieldable(readBody(response)) if response.code != 200: - raise Exception("Non-200 response %s" % (response.code, )) + raise Exception("Non-200 response %s" % (response.code,)) - parsed_body = json.loads(body.decode('utf-8')) + parsed_body = json.loads(body.decode("utf-8")) logger.info("Response from .well-known: %s", parsed_body) if not isinstance(parsed_body, dict): raise Exception("not a dict") @@ -347,8 +359,7 @@ class MatrixFederationAgent(object): result = parsed_body["m.server"].encode("ascii") cache_period = _cache_period_from_headers( - response.headers, - time_now=self._reactor.seconds, + response.headers, time_now=self._reactor.seconds ) if cache_period is None: cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD @@ -364,6 +375,7 @@ class MatrixFederationAgent(object): @implementer(IStreamClientEndpoint) class LoggingHostnameEndpoint(object): """A wrapper for HostnameEndpint which logs when it connects""" + def __init__(self, reactor, host, port, *args, **kwargs): self.host = host self.port = port @@ -377,17 +389,17 @@ class LoggingHostnameEndpoint(object): def _cache_period_from_headers(headers, time_now=time.time): cache_controls = _parse_cache_control(headers) - if b'no-store' in cache_controls: + if b"no-store" in cache_controls: return 0 - if b'max-age' in cache_controls: + if b"max-age" in cache_controls: try: - max_age = int(cache_controls[b'max-age']) + max_age = int(cache_controls[b"max-age"]) return max_age except ValueError: pass - expires = headers.getRawHeaders(b'expires') + expires = headers.getRawHeaders(b"expires") if expires is not None: try: expires_date = stringToDatetime(expires[-1]) @@ -403,9 +415,9 @@ def _cache_period_from_headers(headers, time_now=time.time): def _parse_cache_control(headers): cache_controls = {} - for hdr in headers.getRawHeaders(b'cache-control', []): - for directive in hdr.split(b','): - splits = [x.strip() for x in directive.split(b'=', 1)] + for hdr in headers.getRawHeaders(b"cache-control", []): + for directive in hdr.split(b","): + splits = [x.strip() for x in directive.split(b"=", 1)] k = splits[0].lower() v = splits[1] if len(splits) > 1 else None cache_controls[k] = v diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index 71830c549d..1f22f78a75 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -45,6 +45,7 @@ class Server(object): expires (int): when the cache should expire this record - in *seconds* since the epoch """ + host = attr.ib() port = attr.ib() priority = attr.ib(default=0) @@ -79,9 +80,7 @@ def pick_server_from_list(server_list): return s.host, s.port # this should be impossible. - raise RuntimeError( - "pick_server_from_list got to end of eligible server list.", - ) + raise RuntimeError("pick_server_from_list got to end of eligible server list.") class SrvResolver(object): @@ -95,6 +94,7 @@ class SrvResolver(object): cache (dict): cache object get_time (callable): clock implementation. Should return seconds since the epoch """ + def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time): self._dns_client = dns_client self._cache = cache @@ -124,7 +124,7 @@ class SrvResolver(object): try: answers, _, _ = yield make_deferred_yieldable( - self._dns_client.lookupService(service_name), + self._dns_client.lookupService(service_name) ) except DNSNameError: # TODO: cache this. We can get the SOA out of the exception, and use @@ -136,17 +136,18 @@ class SrvResolver(object): cache_entry = self._cache.get(service_name, None) if cache_entry: logger.warn( - "Failed to resolve %r, falling back to cache. %r", - service_name, e + "Failed to resolve %r, falling back to cache. %r", service_name, e ) defer.returnValue(list(cache_entry)) else: raise e - if (len(answers) == 1 - and answers[0].type == dns.SRV - and answers[0].payload - and answers[0].payload.target == dns.Name(b'.')): + if ( + len(answers) == 1 + and answers[0].type == dns.SRV + and answers[0].payload + and answers[0].payload.target == dns.Name(b".") + ): raise ConnectError("Service %s unavailable" % service_name) servers = [] @@ -157,13 +158,15 @@ class SrvResolver(object): payload = answer.payload - servers.append(Server( - host=payload.target.name, - port=payload.port, - priority=payload.priority, - weight=payload.weight, - expires=now + answer.ttl, - )) + servers.append( + Server( + host=payload.target.name, + port=payload.port, + priority=payload.priority, + weight=payload.weight, + expires=now + answer.ttl, + ) + ) self._cache[service_name] = list(servers) defer.returnValue(servers) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 663ea72a7a..5ef8bb60a3 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -54,10 +54,12 @@ from synapse.util.metrics import Measure logger = logging.getLogger(__name__) -outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests", - "", ["method"]) -incoming_responses_counter = Counter("synapse_http_matrixfederationclient_responses", - "", ["method", "code"]) +outgoing_requests_counter = Counter( + "synapse_http_matrixfederationclient_requests", "", ["method"] +) +incoming_responses_counter = Counter( + "synapse_http_matrixfederationclient_responses", "", ["method", "code"] +) MAX_LONG_RETRIES = 10 @@ -137,11 +139,7 @@ def _handle_json_response(reactor, timeout_sec, request, response): check_content_type_is_json(response.headers) d = treq.json_content(response) - d = timeout_deferred( - d, - timeout=timeout_sec, - reactor=reactor, - ) + d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) body = yield make_deferred_yieldable(d) except Exception as e: @@ -157,7 +155,7 @@ def _handle_json_response(reactor, timeout_sec, request, response): request.txn_id, request.destination, response.code, - response.phrase.decode('ascii', errors='replace'), + response.phrase.decode("ascii", errors="replace"), ) defer.returnValue(body) @@ -181,7 +179,7 @@ class MatrixFederationHttpClient(object): # We need to use a DNS resolver which filters out blacklisted IP # addresses, to prevent DNS rebinding. nameResolver = IPBlacklistingResolver( - real_reactor, None, hs.config.federation_ip_range_blacklist, + real_reactor, None, hs.config.federation_ip_range_blacklist ) @implementer(IReactorPluggableNameResolver) @@ -194,21 +192,19 @@ class MatrixFederationHttpClient(object): self.reactor = Reactor() - self.agent = MatrixFederationAgent( - self.reactor, - tls_client_options_factory, - ) + self.agent = MatrixFederationAgent(self.reactor, tls_client_options_factory) # Use a BlacklistingAgentWrapper to prevent circumventing the IP # blacklist via IP literals in server names self.agent = BlacklistingAgentWrapper( - self.agent, self.reactor, + self.agent, + self.reactor, ip_blacklist=hs.config.federation_ip_range_blacklist, ) self.clock = hs.get_clock() self._store = hs.get_datastore() - self.version_string_bytes = hs.version_string.encode('ascii') + self.version_string_bytes = hs.version_string.encode("ascii") self.default_timeout = 60 def schedule(x): @@ -218,10 +214,7 @@ class MatrixFederationHttpClient(object): @defer.inlineCallbacks def _send_request_with_optional_trailing_slash( - self, - request, - try_trailing_slash_on_400=False, - **send_request_args + self, request, try_trailing_slash_on_400=False, **send_request_args ): """Wrapper for _send_request which can optionally retry the request upon receiving a combination of a 400 HTTP response code and a @@ -244,9 +237,7 @@ class MatrixFederationHttpClient(object): Deferred[Dict]: Parsed JSON response body. """ try: - response = yield self._send_request( - request, **send_request_args - ) + response = yield self._send_request(request, **send_request_args) except HttpResponseException as e: # Received an HTTP error > 300. Check if it meets the requirements # to retry with a trailing slash @@ -262,9 +253,7 @@ class MatrixFederationHttpClient(object): logger.info("Retrying request with trailing slash") request.path += "/" - response = yield self._send_request( - request, **send_request_args - ) + response = yield self._send_request(request, **send_request_args) defer.returnValue(response) @@ -329,8 +318,8 @@ class MatrixFederationHttpClient(object): _sec_timeout = self.default_timeout if ( - self.hs.config.federation_domain_whitelist is not None and - request.destination not in self.hs.config.federation_domain_whitelist + self.hs.config.federation_domain_whitelist is not None + and request.destination not in self.hs.config.federation_domain_whitelist ): raise FederationDeniedError(request.destination) @@ -350,9 +339,7 @@ class MatrixFederationHttpClient(object): else: query_bytes = b"" - headers_dict = { - b"User-Agent": [self.version_string_bytes], - } + headers_dict = {b"User-Agent": [self.version_string_bytes]} with limiter: # XXX: Would be much nicer to retry only at the transaction-layer @@ -362,16 +349,14 @@ class MatrixFederationHttpClient(object): else: retries_left = MAX_SHORT_RETRIES - url_bytes = urllib.parse.urlunparse(( - b"matrix", destination_bytes, - path_bytes, None, query_bytes, b"", - )) - url_str = url_bytes.decode('ascii') + url_bytes = urllib.parse.urlunparse( + (b"matrix", destination_bytes, path_bytes, None, query_bytes, b"") + ) + url_str = url_bytes.decode("ascii") - url_to_sign_bytes = urllib.parse.urlunparse(( - b"", b"", - path_bytes, None, query_bytes, b"", - )) + url_to_sign_bytes = urllib.parse.urlunparse( + (b"", b"", path_bytes, None, query_bytes, b"") + ) while True: try: @@ -379,26 +364,27 @@ class MatrixFederationHttpClient(object): if json: headers_dict[b"Content-Type"] = [b"application/json"] auth_headers = self.build_auth_headers( - destination_bytes, method_bytes, url_to_sign_bytes, - json, + destination_bytes, method_bytes, url_to_sign_bytes, json ) data = encode_canonical_json(json) producer = QuieterFileBodyProducer( - BytesIO(data), - cooperator=self._cooperator, + BytesIO(data), cooperator=self._cooperator ) else: producer = None auth_headers = self.build_auth_headers( - destination_bytes, method_bytes, url_to_sign_bytes, + destination_bytes, method_bytes, url_to_sign_bytes ) headers_dict[b"Authorization"] = auth_headers logger.info( "{%s} [%s] Sending request: %s %s; timeout %fs", - request.txn_id, request.destination, request.method, - url_str, _sec_timeout, + request.txn_id, + request.destination, + request.method, + url_str, + _sec_timeout, ) try: @@ -430,7 +416,7 @@ class MatrixFederationHttpClient(object): request.txn_id, request.destination, response.code, - response.phrase.decode('ascii', errors='replace'), + response.phrase.decode("ascii", errors="replace"), ) if 200 <= response.code < 300: @@ -440,9 +426,7 @@ class MatrixFederationHttpClient(object): # Update transactions table? d = treq.content(response) d = timeout_deferred( - d, - timeout=_sec_timeout, - reactor=self.reactor, + d, timeout=_sec_timeout, reactor=self.reactor ) try: @@ -460,9 +444,7 @@ class MatrixFederationHttpClient(object): ) body = None - e = HttpResponseException( - response.code, response.phrase, body - ) + e = HttpResponseException(response.code, response.phrase, body) # Retry if the error is a 429 (Too Many Requests), # otherwise just raise a standard HttpResponseException @@ -521,7 +503,7 @@ class MatrixFederationHttpClient(object): defer.returnValue(response) def build_auth_headers( - self, destination, method, url_bytes, content=None, destination_is=None, + self, destination, method, url_bytes, content=None, destination_is=None ): """ Builds the Authorization headers for a federation request @@ -538,11 +520,7 @@ class MatrixFederationHttpClient(object): Returns: list[bytes]: a list of headers to be added as "Authorization:" headers """ - request = { - "method": method, - "uri": url_bytes, - "origin": self.server_name, - } + request = {"method": method, "uri": url_bytes, "origin": self.server_name} if destination is not None: request["destination"] = destination @@ -558,20 +536,28 @@ class MatrixFederationHttpClient(object): auth_headers = [] for key, sig in request["signatures"][self.server_name].items(): - auth_headers.append(( - "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( - self.server_name, key, sig, - )).encode('ascii') + auth_headers.append( + ( + 'X-Matrix origin=%s,key="%s",sig="%s"' + % (self.server_name, key, sig) + ).encode("ascii") ) return auth_headers @defer.inlineCallbacks - def put_json(self, destination, path, args={}, data={}, - json_data_callback=None, - long_retries=False, timeout=None, - ignore_backoff=False, - backoff_on_404=False, - try_trailing_slash_on_400=False): + def put_json( + self, + destination, + path, + args={}, + data={}, + json_data_callback=None, + long_retries=False, + timeout=None, + ignore_backoff=False, + backoff_on_404=False, + try_trailing_slash_on_400=False, + ): """ Sends the specifed json data using PUT Args: @@ -635,14 +621,22 @@ class MatrixFederationHttpClient(object): ) body = yield _handle_json_response( - self.reactor, self.default_timeout, request, response, + self.reactor, self.default_timeout, request, response ) defer.returnValue(body) @defer.inlineCallbacks - def post_json(self, destination, path, data={}, long_retries=False, - timeout=None, ignore_backoff=False, args={}): + def post_json( + self, + destination, + path, + data={}, + long_retries=False, + timeout=None, + ignore_backoff=False, + args={}, + ): """ Sends the specifed json data using POST Args: @@ -681,11 +675,7 @@ class MatrixFederationHttpClient(object): """ request = MatrixFederationRequest( - method="POST", - destination=destination, - path=path, - query=args, - json=data, + method="POST", destination=destination, path=path, query=args, json=data ) response = yield self._send_request( @@ -701,14 +691,21 @@ class MatrixFederationHttpClient(object): _sec_timeout = self.default_timeout body = yield _handle_json_response( - self.reactor, _sec_timeout, request, response, + self.reactor, _sec_timeout, request, response ) defer.returnValue(body) @defer.inlineCallbacks - def get_json(self, destination, path, args=None, retry_on_dns_fail=True, - timeout=None, ignore_backoff=False, - try_trailing_slash_on_400=False): + def get_json( + self, + destination, + path, + args=None, + retry_on_dns_fail=True, + timeout=None, + ignore_backoff=False, + try_trailing_slash_on_400=False, + ): """ GETs some json from the given host homeserver and path Args: @@ -745,10 +742,7 @@ class MatrixFederationHttpClient(object): remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( - method="GET", - destination=destination, - path=path, - query=args, + method="GET", destination=destination, path=path, query=args ) response = yield self._send_request_with_optional_trailing_slash( @@ -761,14 +755,21 @@ class MatrixFederationHttpClient(object): ) body = yield _handle_json_response( - self.reactor, self.default_timeout, request, response, + self.reactor, self.default_timeout, request, response ) defer.returnValue(body) @defer.inlineCallbacks - def delete_json(self, destination, path, long_retries=False, - timeout=None, ignore_backoff=False, args={}): + def delete_json( + self, + destination, + path, + long_retries=False, + timeout=None, + ignore_backoff=False, + args={}, + ): """Send a DELETE request to the remote expecting some json response Args: @@ -802,10 +803,7 @@ class MatrixFederationHttpClient(object): remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( - method="DELETE", - destination=destination, - path=path, - query=args, + method="DELETE", destination=destination, path=path, query=args ) response = yield self._send_request( @@ -816,14 +814,21 @@ class MatrixFederationHttpClient(object): ) body = yield _handle_json_response( - self.reactor, self.default_timeout, request, response, + self.reactor, self.default_timeout, request, response ) defer.returnValue(body) @defer.inlineCallbacks - def get_file(self, destination, path, output_stream, args={}, - retry_on_dns_fail=True, max_size=None, - ignore_backoff=False): + def get_file( + self, + destination, + path, + output_stream, + args={}, + retry_on_dns_fail=True, + max_size=None, + ignore_backoff=False, + ): """GETs a file from a given homeserver Args: destination (str): The remote server to send the HTTP request to. @@ -848,16 +853,11 @@ class MatrixFederationHttpClient(object): remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( - method="GET", - destination=destination, - path=path, - query=args, + method="GET", destination=destination, path=path, query=args ) response = yield self._send_request( - request, - retry_on_dns_fail=retry_on_dns_fail, - ignore_backoff=ignore_backoff, + request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff ) headers = dict(response.headers.getAllRawHeaders()) @@ -879,7 +879,7 @@ class MatrixFederationHttpClient(object): request.txn_id, request.destination, response.code, - response.phrase.decode('ascii', errors='replace'), + response.phrase.decode("ascii", errors="replace"), length, ) defer.returnValue((length, headers)) @@ -896,11 +896,13 @@ class _ReadBodyToFileProtocol(protocol.Protocol): self.stream.write(data) self.length += len(data) if self.max_size is not None and self.length >= self.max_size: - self.deferred.errback(SynapseError( - 502, - "Requested file is too large > %r bytes" % (self.max_size,), - Codes.TOO_LARGE, - )) + self.deferred.errback( + SynapseError( + 502, + "Requested file is too large > %r bytes" % (self.max_size,), + Codes.TOO_LARGE, + ) + ) self.deferred = defer.Deferred() self.transport.loseConnection() @@ -920,8 +922,7 @@ def _readBodyToFile(response, stream, max_size): def _flatten_response_never_received(e): if hasattr(e, "reasons"): reasons = ", ".join( - _flatten_response_never_received(f.value) - for f in e.reasons + _flatten_response_never_received(f.value) for f in e.reasons ) return "%s:[%s]" % (type(e).__name__, reasons) @@ -943,16 +944,15 @@ def check_content_type_is_json(headers): """ c_type = headers.getRawHeaders(b"Content-Type") if c_type is None: - raise RequestSendFailed(RuntimeError( - "No Content-Type header" - ), can_retry=False) + raise RequestSendFailed(RuntimeError("No Content-Type header"), can_retry=False) - c_type = c_type[0].decode('ascii') # only the first header + c_type = c_type[0].decode("ascii") # only the first header val, options = cgi.parse_header(c_type) if val != "application/json": - raise RequestSendFailed(RuntimeError( - "Content-Type not application/json: was '%s'" % c_type - ), can_retry=False) + raise RequestSendFailed( + RuntimeError("Content-Type not application/json: was '%s'" % c_type), + can_retry=False, + ) def encode_query_args(args): @@ -967,4 +967,4 @@ def encode_query_args(args): query_bytes = urllib.parse.urlencode(encoded_args, True) - return query_bytes.encode('utf8') + return query_bytes.encode("utf8") diff --git a/synapse/http/server.py b/synapse/http/server.py index 16fb7935da..f067c163c1 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -16,10 +16,11 @@ import cgi import collections +import http.client import logging - -from six import PY3 -from six.moves import http_client, urllib +import types +import urllib +from io import BytesIO from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json @@ -41,11 +42,6 @@ from synapse.api.errors import ( from synapse.util.caches import intern_dict from synapse.util.logcontext import preserve_fn -if PY3: - from io import BytesIO -else: - from cStringIO import StringIO as BytesIO - logger = logging.getLogger(__name__) HTML_ERROR_TEMPLATE = """<!DOCTYPE html> @@ -75,15 +71,12 @@ def wrap_json_request_handler(h): deferred fails with any other type of error we send a 500 reponse. """ - @defer.inlineCallbacks - def wrapped_request_handler(self, request): + async def wrapped_request_handler(self, request): try: - yield h(self, request) + await h(self, request) except SynapseError as e: code = e.code - logger.info( - "%s SynapseError: %s - %s", request, code, e.msg - ) + logger.info("%s SynapseError: %s - %s", request, code, e.msg) # Only respond with an error response if we haven't already started # writing, otherwise lets just kill the connection @@ -96,7 +89,10 @@ def wrap_json_request_handler(h): pass else: respond_with_json( - request, code, e.error_dict(), send_cors=True, + request, + code, + e.error_dict(), + send_cors=True, pretty_print=_request_user_agent_is_curl(request), ) @@ -124,10 +120,7 @@ def wrap_json_request_handler(h): respond_with_json( request, 500, - { - "error": "Internal server error", - "errcode": Codes.UNKNOWN, - }, + {"error": "Internal server error", "errcode": Codes.UNKNOWN}, send_cors=True, pretty_print=_request_user_agent_is_curl(request), ) @@ -143,10 +136,13 @@ def wrap_html_request_handler(h): The handler method must have a signature of "handle_foo(self, request)", where "request" must be a SynapseRequest. """ - def wrapped_request_handler(self, request): - d = defer.maybeDeferred(h, self, request) - d.addErrback(_return_html_error, request) - return d + + async def wrapped_request_handler(self, request): + try: + return await h(self, request) + except Exception: + f = failure.Failure() + return _return_html_error(f, request) return wrap_async_request_handler(wrapped_request_handler) @@ -164,9 +160,7 @@ def _return_html_error(f, request): msg = cme.msg if isinstance(cme, SynapseError): - logger.info( - "%s SynapseError: %s - %s", request, code, msg - ) + logger.info("%s SynapseError: %s - %s", request, code, msg) else: logger.error( "Failed handle request %r", @@ -174,7 +168,7 @@ def _return_html_error(f, request): exc_info=(f.type, f.value, f.getTracebackObject()), ) else: - code = http_client.INTERNAL_SERVER_ERROR + code = http.client.INTERNAL_SERVER_ERROR msg = "Internal server error" logger.error( @@ -183,9 +177,7 @@ def _return_html_error(f, request): exc_info=(f.type, f.value, f.getTracebackObject()), ) - body = HTML_ERROR_TEMPLATE.format( - code=code, msg=cgi.escape(msg), - ).encode("utf-8") + body = HTML_ERROR_TEMPLATE.format(code=code, msg=cgi.escape(msg)).encode("utf-8") request.setResponseCode(code) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") request.setHeader(b"Content-Length", b"%i" % (len(body),)) @@ -205,10 +197,10 @@ def wrap_async_request_handler(h): The handler may return a deferred, in which case the completion of the request isn't logged until the deferred completes. """ - @defer.inlineCallbacks - def wrapped_async_request_handler(self, request): + + async def wrapped_async_request_handler(self, request): with request.processing(): - yield h(self, request) + await h(self, request) # we need to preserve_fn here, because the synchronous render method won't yield for # us (obviously) @@ -274,12 +266,11 @@ class JsonResource(HttpServer, resource.Resource): def render(self, request): """ This gets called by twisted every time someone sends us a request. """ - self._async_render(request) + defer.ensureDeferred(self._async_render(request)) return NOT_DONE_YET @wrap_json_request_handler - @defer.inlineCallbacks - def _async_render(self, request): + async def _async_render(self, request): """ This gets called from render() every time someone sends us a request. This checks if anyone has registered a callback for that method and path. @@ -296,24 +287,19 @@ class JsonResource(HttpServer, resource.Resource): # Now trigger the callback. If it returns a response, we send it # here. If it throws an exception, that is handled by the wrapper # installed by @request_handler. + kwargs = intern_dict( + { + name: urllib.parse.unquote(value) if value else value + for name, value in group_dict.items() + } + ) + + callback_return = callback(request, **kwargs) + + # Is it synchronous? We'll allow this for now. + if isinstance(callback_return, (defer.Deferred, types.CoroutineType)): + callback_return = await callback_return - def _unquote(s): - if PY3: - # On Python 3, unquote is unicode -> unicode - return urllib.parse.unquote(s) - else: - # On Python 2, unquote is bytes -> bytes We need to encode the - # URL again (as it was decoded by _get_handler_for request), as - # ASCII because it's a URL, and then decode it to get the UTF-8 - # characters that were quoted. - return urllib.parse.unquote(s.encode('ascii')).decode('utf8') - - kwargs = intern_dict({ - name: _unquote(value) if value else value - for name, value in group_dict.items() - }) - - callback_return = yield callback(request, **kwargs) if callback_return is not None: code, response = callback_return self._send_response(request, code, response) @@ -339,7 +325,7 @@ class JsonResource(HttpServer, resource.Resource): # Loop through all the registered callbacks to check if the method # and path regex match for path_entry in self.path_regexs.get(request.method, []): - m = path_entry.pattern.match(request.path.decode('ascii')) + m = path_entry.pattern.match(request.path.decode("ascii")) if m: # We found a match! return path_entry.callback, m.groupdict() @@ -347,11 +333,14 @@ class JsonResource(HttpServer, resource.Resource): # Huh. No one wanted to handle that? Fiiiiiine. Send 400. return _unrecognised_request_handler, {} - def _send_response(self, request, code, response_json_object, - response_code_message=None): + def _send_response( + self, request, code, response_json_object, response_code_message=None + ): # TODO: Only enable CORS for the requests that need it. respond_with_json( - request, code, response_json_object, + request, + code, + response_json_object, send_cors=True, response_code_message=response_code_message, pretty_print=_request_user_agent_is_curl(request), @@ -359,6 +348,23 @@ class JsonResource(HttpServer, resource.Resource): ) +class DirectServeResource(resource.Resource): + def render(self, request): + """ + Render the request, using an asynchronous render handler if it exists. + """ + render_callback_name = "_async_render_" + request.method.decode("ascii") + + if hasattr(self, render_callback_name): + # Call the handler + callback = getattr(self, render_callback_name) + defer.ensureDeferred(callback(request)) + + return NOT_DONE_YET + else: + super().render(request) + + def _options_handler(request): """Request handler for OPTIONS requests @@ -395,7 +401,7 @@ class RootRedirect(resource.Resource): self.url = path def render_GET(self, request): - return redirectTo(self.url.encode('ascii'), request) + return redirectTo(self.url.encode("ascii"), request) def getChild(self, name, request): if len(name) == 0: @@ -403,16 +409,22 @@ class RootRedirect(resource.Resource): return resource.Resource.getChild(self, name, request) -def respond_with_json(request, code, json_object, send_cors=False, - response_code_message=None, pretty_print=False, - canonical_json=True): +def respond_with_json( + request, + code, + json_object, + send_cors=False, + response_code_message=None, + pretty_print=False, + canonical_json=True, +): # could alternatively use request.notifyFinish() and flip a flag when # the Deferred fires, but since the flag is RIGHT THERE it seems like # a waste. if request._disconnected: logger.warn( - "Not sending response to request %s, already disconnected.", - request) + "Not sending response to request %s, already disconnected.", request + ) return if pretty_print: @@ -425,14 +437,17 @@ def respond_with_json(request, code, json_object, send_cors=False, json_bytes = json.dumps(json_object).encode("utf-8") return respond_with_json_bytes( - request, code, json_bytes, + request, + code, + json_bytes, send_cors=send_cors, response_code_message=response_code_message, ) -def respond_with_json_bytes(request, code, json_bytes, send_cors=False, - response_code_message=None): +def respond_with_json_bytes( + request, code, json_bytes, send_cors=False, response_code_message=None +): """Sends encoded JSON in response to the given request. Args: @@ -474,7 +489,7 @@ def set_cors_headers(request): ) request.setHeader( b"Access-Control-Allow-Headers", - b"Origin, X-Requested-With, Content-Type, Accept, Authorization" + b"Origin, X-Requested-With, Content-Type, Accept, Authorization", ) @@ -498,9 +513,7 @@ def finish_request(request): def _request_user_agent_is_curl(request): - user_agents = request.requestHeaders.getRawHeaders( - b"User-Agent", default=[] - ) + user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[]) for user_agent in user_agents: if b"curl" in user_agent: return True diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 197c652850..cd8415acd5 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -48,7 +48,7 @@ def parse_integer(request, name, default=None, required=False): def parse_integer_from_args(args, name, default=None, required=False): if not isinstance(name, bytes): - name = name.encode('ascii') + name = name.encode("ascii") if name in args: try: @@ -89,18 +89,14 @@ def parse_boolean(request, name, default=None, required=False): def parse_boolean_from_args(args, name, default=None, required=False): if not isinstance(name, bytes): - name = name.encode('ascii') + name = name.encode("ascii") if name in args: try: - return { - b"true": True, - b"false": False, - }[args[name][0]] + return {b"true": True, b"false": False}[args[name][0]] except Exception: message = ( - "Boolean query parameter %r must be one of" - " ['true', 'false']" + "Boolean query parameter %r must be one of" " ['true', 'false']" ) % (name,) raise SynapseError(400, message) else: @@ -111,8 +107,15 @@ def parse_boolean_from_args(args, name, default=None, required=False): return default -def parse_string(request, name, default=None, required=False, - allowed_values=None, param_type="string", encoding='ascii'): +def parse_string( + request, + name, + default=None, + required=False, + allowed_values=None, + param_type="string", + encoding="ascii", +): """ Parse a string parameter from the request query string. @@ -145,11 +148,18 @@ def parse_string(request, name, default=None, required=False, ) -def parse_string_from_args(args, name, default=None, required=False, - allowed_values=None, param_type="string", encoding='ascii'): +def parse_string_from_args( + args, + name, + default=None, + required=False, + allowed_values=None, + param_type="string", + encoding="ascii", +): if not isinstance(name, bytes): - name = name.encode('ascii') + name = name.encode("ascii") if name in args: value = args[name][0] @@ -159,7 +169,8 @@ def parse_string_from_args(args, name, default=None, required=False, if allowed_values is not None and value not in allowed_values: message = "Query parameter %r must be one of [%s]" % ( - name, ", ".join(repr(v) for v in allowed_values) + name, + ", ".join(repr(v) for v in allowed_values), ) raise SynapseError(400, message) else: @@ -201,7 +212,7 @@ def parse_json_value_from_request(request, allow_empty_body=False): # Decode to Unicode so that simplejson will return Unicode strings on # Python 2 try: - content_unicode = content_bytes.decode('utf8') + content_unicode = content_bytes.decode("utf8") except UnicodeDecodeError: logger.warn("Unable to decode UTF-8") raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) @@ -227,9 +238,7 @@ def parse_json_object_from_request(request, allow_empty_body=False): SynapseError if the request body couldn't be decoded as JSON or if it wasn't a JSON object. """ - content = parse_json_value_from_request( - request, allow_empty_body=allow_empty_body, - ) + content = parse_json_value_from_request(request, allow_empty_body=allow_empty_body) if allow_empty_body and content is None: return {} diff --git a/synapse/http/site.py b/synapse/http/site.py index e508c0bd4f..93f679ea48 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -46,10 +46,11 @@ class SynapseRequest(Request): Attributes: logcontext(LoggingContext) : the log context for this request """ + def __init__(self, site, channel, *args, **kw): Request.__init__(self, channel, *args, **kw) self.site = site - self._channel = channel # this is used by the tests + self._channel = channel # this is used by the tests self.authenticated_entity = None self.start_time = 0 @@ -72,12 +73,12 @@ class SynapseRequest(Request): def __repr__(self): # We overwrite this so that we don't log ``access_token`` - return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % ( + return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % ( self.__class__.__name__, id(self), self.get_method(), self.get_redacted_uri(), - self.clientproto.decode('ascii', errors='replace'), + self.clientproto.decode("ascii", errors="replace"), self.site.site_tag, ) @@ -87,7 +88,7 @@ class SynapseRequest(Request): def get_redacted_uri(self): uri = self.uri if isinstance(uri, bytes): - uri = self.uri.decode('ascii') + uri = self.uri.decode("ascii") return redact_uri(uri) def get_method(self): @@ -102,7 +103,7 @@ class SynapseRequest(Request): """ method = self.method if isinstance(method, bytes): - method = self.method.decode('ascii') + method = self.method.decode("ascii") return method def get_user_agent(self): @@ -134,8 +135,7 @@ class SynapseRequest(Request): # dispatching to the handler, so that the handler # can update the servlet name in the request # metrics - requests_counter.labels(self.get_method(), - self.request_metrics.name).inc() + requests_counter.labels(self.get_method(), self.request_metrics.name).inc() @contextlib.contextmanager def processing(self): @@ -200,7 +200,7 @@ class SynapseRequest(Request): # the client disconnects. with PreserveLoggingContext(self.logcontext): logger.warn( - "Error processing request %r: %s %s", self, reason.type, reason.value, + "Error processing request %r: %s %s", self, reason.type, reason.value ) if not self._is_processing: @@ -222,7 +222,7 @@ class SynapseRequest(Request): self.start_time = time.time() self.request_metrics = RequestMetrics() self.request_metrics.start( - self.start_time, name=servlet_name, method=self.get_method(), + self.start_time, name=servlet_name, method=self.get_method() ) self.site.access_logger.info( @@ -230,7 +230,7 @@ class SynapseRequest(Request): self.getClientIP(), self.site.site_tag, self.get_method(), - self.get_redacted_uri() + self.get_redacted_uri(), ) def _finished_processing(self): @@ -282,7 +282,7 @@ class SynapseRequest(Request): self.site.access_logger.info( "%s - %s - {%s}" " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" - " %sB %s \"%s %s %s\" \"%s\" [%d dbevts]", + ' %sB %s "%s %s %s" "%s" [%d dbevts]', self.getClientIP(), self.site.site_tag, authenticated_entity, @@ -297,7 +297,7 @@ class SynapseRequest(Request): code, self.get_method(), self.get_redacted_uri(), - self.clientproto.decode('ascii', errors='replace'), + self.clientproto.decode("ascii", errors="replace"), user_agent, usage.evt_db_fetch_count, ) @@ -316,14 +316,19 @@ class XForwardedForRequest(SynapseRequest): Add a layer on top of another request that only uses the value of an X-Forwarded-For header as the result of C{getClientIP}. """ + def getClientIP(self): """ @return: The client address (the first address) in the value of the I{X-Forwarded-For header}. If the header is not present, return C{b"-"}. """ - return self.requestHeaders.getRawHeaders( - b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip().decode('ascii') + return ( + self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0] + .split(b",")[0] + .strip() + .decode("ascii") + ) class SynapseRequestFactory(object): @@ -343,8 +348,17 @@ class SynapseSite(Site): Subclass of a twisted http Site that does access logging with python's standard logging """ - def __init__(self, logger_name, site_tag, config, resource, - server_version_string, *args, **kwargs): + + def __init__( + self, + logger_name, + site_tag, + config, + resource, + server_version_string, + *args, + **kwargs + ): Site.__init__(self, resource, *args, **kwargs) self.site_tag = site_tag @@ -352,7 +366,7 @@ class SynapseSite(Site): proxied = config.get("x_forwarded", False) self.requestFactory = SynapseRequestFactory(self, proxied) self.access_logger = logging.getLogger(logger_name) - self.server_version_string = server_version_string.encode('ascii') + self.server_version_string = server_version_string.encode("ascii") def log(self, request): pass |