diff options
Diffstat (limited to 'synapse/http')
-rw-r--r-- | synapse/http/federation/matrix_federation_agent.py | 18 | ||||
-rw-r--r-- | synapse/http/matrixfederationclient.py | 6 | ||||
-rw-r--r-- | synapse/http/server.py | 29 | ||||
-rw-r--r-- | synapse/http/site.py | 35 |
4 files changed, 53 insertions, 35 deletions
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 2e83fa6773..b07aa59c08 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib.parse -from typing import List, Optional +from typing import Any, Generator, List, Optional from netaddr import AddrFormatError, IPAddress, IPSet from zope.interface import implementer @@ -116,7 +116,7 @@ class MatrixFederationAgent: uri: bytes, headers: Optional[Headers] = None, bodyProducer: Optional[IBodyProducer] = None, - ) -> defer.Deferred: + ) -> Generator[defer.Deferred, Any, defer.Deferred]: """ Args: method: HTTP method: GET/POST/etc @@ -177,17 +177,17 @@ class MatrixFederationAgent: # We need to make sure the host header is set to the netloc of the # server and that a user-agent is provided. if headers is None: - headers = Headers() + request_headers = Headers() else: - headers = headers.copy() + request_headers = headers.copy() - if not headers.hasHeader(b"host"): - headers.addRawHeader(b"host", parsed_uri.netloc) - if not headers.hasHeader(b"user-agent"): - headers.addRawHeader(b"user-agent", self.user_agent) + if not request_headers.hasHeader(b"host"): + request_headers.addRawHeader(b"host", parsed_uri.netloc) + if not request_headers.hasHeader(b"user-agent"): + request_headers.addRawHeader(b"user-agent", self.user_agent) res = yield make_deferred_yieldable( - self._agent.request(method, uri, headers, bodyProducer) + self._agent.request(method, uri, request_headers, bodyProducer) ) return res diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index cde42e9f5e..0f107714ea 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -1049,14 +1049,14 @@ def check_content_type_is_json(headers: Headers) -> None: RequestSendFailed: if the Content-Type header is missing or isn't JSON """ - c_type = headers.getRawHeaders(b"Content-Type") - if c_type is None: + content_type_headers = headers.getRawHeaders(b"Content-Type") + if content_type_headers is None: raise RequestSendFailed( RuntimeError("No Content-Type header received from remote server"), can_retry=False, ) - c_type = c_type[0].decode("ascii") # only the first header + c_type = content_type_headers[0].decode("ascii") # only the first header val, options = cgi.parse_header(c_type) if val != "application/json": raise RequestSendFailed( diff --git a/synapse/http/server.py b/synapse/http/server.py index 845db9b78d..fa89260850 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -21,6 +21,7 @@ import logging import types import urllib from http import HTTPStatus +from inspect import isawaitable from io import BytesIO from typing import ( Any, @@ -30,6 +31,7 @@ from typing import ( Iterable, Iterator, List, + Optional, Pattern, Tuple, Union, @@ -79,10 +81,12 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: """Sends a JSON error response to clients.""" if f.check(SynapseError): - error_code = f.value.code - error_dict = f.value.error_dict() + # mypy doesn't understand that f.check asserts the type. + exc = f.value # type: SynapseError # type: ignore + error_code = exc.code + error_dict = exc.error_dict() - logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg) + logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) else: error_code = 500 error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN} @@ -91,7 +95,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: "Failed handle request via %r: %r", request.request_metrics.name, request, - exc_info=(f.type, f.value, f.getTracebackObject()), + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore ) # Only respond with an error response if we haven't already started writing, @@ -128,7 +132,8 @@ def return_html_error( `{msg}` placeholders), or a jinja2 template """ if f.check(CodeMessageException): - cme = f.value + # mypy doesn't understand that f.check asserts the type. + cme = f.value # type: CodeMessageException # type: ignore code = cme.code msg = cme.msg @@ -142,7 +147,7 @@ def return_html_error( logger.error( "Failed handle request %r", request, - exc_info=(f.type, f.value, f.getTracebackObject()), + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore ) else: code = HTTPStatus.INTERNAL_SERVER_ERROR @@ -151,7 +156,7 @@ def return_html_error( logger.error( "Failed handle request %r", request, - exc_info=(f.type, f.value, f.getTracebackObject()), + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore ) if isinstance(error_template, str): @@ -278,7 +283,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): raw_callback_return = method_handler(request) # Is it synchronous? We'll allow this for now. - if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): + if isawaitable(raw_callback_return): callback_return = await raw_callback_return else: callback_return = raw_callback_return # type: ignore @@ -399,8 +404,10 @@ class JsonResource(DirectServeJsonResource): A tuple of the callback to use, the name of the servlet, and the key word arguments to pass to the callback """ + # At this point the path must be bytes. + request_path_bytes = request.path # type: bytes # type: ignore + request_path = request_path_bytes.decode("ascii") # Treat HEAD requests as GET requests. - request_path = request.path.decode("ascii") request_method = request.method if request_method == b"HEAD": request_method = b"GET" @@ -551,7 +558,7 @@ class _ByteProducer: request: Request, iterator: Iterator[bytes], ): - self._request = request + self._request = request # type: Optional[Request] self._iterator = iterator self._paused = False @@ -563,7 +570,7 @@ class _ByteProducer: """ Send a list of bytes as a chunk of a response. """ - if not data: + if not data or not self._request: return self._request.write(b"".join(data)) diff --git a/synapse/http/site.py b/synapse/http/site.py index 30153237e3..47754aff43 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -14,7 +14,7 @@ import contextlib import logging import time -from typing import Optional, Union +from typing import Optional, Type, Union import attr from zope.interface import implementer @@ -57,7 +57,7 @@ class SynapseRequest(Request): def __init__(self, channel, *args, **kw): Request.__init__(self, channel, *args, **kw) - self.site = channel.site + self.site = channel.site # type: SynapseSite self._channel = channel # this is used by the tests self.start_time = 0.0 @@ -96,25 +96,34 @@ class SynapseRequest(Request): def get_request_id(self): return "%s-%i" % (self.get_method(), self.request_seq) - def get_redacted_uri(self): - uri = self.uri + def get_redacted_uri(self) -> str: + """Gets the redacted URI associated with the request (or placeholder if the URI + has not yet been received). + + Note: This is necessary as the placeholder value in twisted is str + rather than bytes, so we need to sanitise `self.uri`. + + Returns: + The redacted URI as a string. + """ + uri = self.uri # type: Union[bytes, str] if isinstance(uri, bytes): - uri = self.uri.decode("ascii", errors="replace") + uri = uri.decode("ascii", errors="replace") return redact_uri(uri) - def get_method(self): - """Gets the method associated with the request (or placeholder if not - method has yet been received). + def get_method(self) -> str: + """Gets the method associated with the request (or placeholder if method + has not yet been received). Note: This is necessary as the placeholder value in twisted is str rather than bytes, so we need to sanitise `self.method`. Returns: - str + The request method as a string. """ - method = self.method + method = self.method # type: Union[bytes, str] if isinstance(method, bytes): - method = self.method.decode("ascii") + return self.method.decode("ascii") return method def render(self, resrc): @@ -432,7 +441,9 @@ class SynapseSite(Site): assert config.http_options is not None proxied = config.http_options.x_forwarded - self.requestFactory = XForwardedForRequest if proxied else SynapseRequest + self.requestFactory = ( + XForwardedForRequest if proxied else SynapseRequest + ) # type: Type[Request] self.access_logger = logging.getLogger(logger_name) self.server_version_string = server_version_string.encode("ascii") |