diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 2e83fa6773..b07aa59c08 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib.parse
-from typing import List, Optional
+from typing import Any, Generator, List, Optional
from netaddr import AddrFormatError, IPAddress, IPSet
from zope.interface import implementer
@@ -116,7 +116,7 @@ class MatrixFederationAgent:
uri: bytes,
headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None,
- ) -> defer.Deferred:
+ ) -> Generator[defer.Deferred, Any, defer.Deferred]:
"""
Args:
method: HTTP method: GET/POST/etc
@@ -177,17 +177,17 @@ class MatrixFederationAgent:
# We need to make sure the host header is set to the netloc of the
# server and that a user-agent is provided.
if headers is None:
- headers = Headers()
+ request_headers = Headers()
else:
- headers = headers.copy()
+ request_headers = headers.copy()
- if not headers.hasHeader(b"host"):
- headers.addRawHeader(b"host", parsed_uri.netloc)
- if not headers.hasHeader(b"user-agent"):
- headers.addRawHeader(b"user-agent", self.user_agent)
+ if not request_headers.hasHeader(b"host"):
+ request_headers.addRawHeader(b"host", parsed_uri.netloc)
+ if not request_headers.hasHeader(b"user-agent"):
+ request_headers.addRawHeader(b"user-agent", self.user_agent)
res = yield make_deferred_yieldable(
- self._agent.request(method, uri, headers, bodyProducer)
+ self._agent.request(method, uri, request_headers, bodyProducer)
)
return res
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index cde42e9f5e..0f107714ea 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -1049,14 +1049,14 @@ def check_content_type_is_json(headers: Headers) -> None:
RequestSendFailed: if the Content-Type header is missing or isn't JSON
"""
- c_type = headers.getRawHeaders(b"Content-Type")
- if c_type is None:
+ content_type_headers = headers.getRawHeaders(b"Content-Type")
+ if content_type_headers is None:
raise RequestSendFailed(
RuntimeError("No Content-Type header received from remote server"),
can_retry=False,
)
- c_type = c_type[0].decode("ascii") # only the first header
+ c_type = content_type_headers[0].decode("ascii") # only the first header
val, options = cgi.parse_header(c_type)
if val != "application/json":
raise RequestSendFailed(
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 845db9b78d..fa89260850 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -21,6 +21,7 @@ import logging
import types
import urllib
from http import HTTPStatus
+from inspect import isawaitable
from io import BytesIO
from typing import (
Any,
@@ -30,6 +31,7 @@ from typing import (
Iterable,
Iterator,
List,
+ Optional,
Pattern,
Tuple,
Union,
@@ -79,10 +81,12 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"""Sends a JSON error response to clients."""
if f.check(SynapseError):
- error_code = f.value.code
- error_dict = f.value.error_dict()
+ # mypy doesn't understand that f.check asserts the type.
+ exc = f.value # type: SynapseError # type: ignore
+ error_code = exc.code
+ error_dict = exc.error_dict()
- logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
+ logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
else:
error_code = 500
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
@@ -91,7 +95,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"Failed handle request via %r: %r",
request.request_metrics.name,
request,
- exc_info=(f.type, f.value, f.getTracebackObject()),
+ exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
# Only respond with an error response if we haven't already started writing,
@@ -128,7 +132,8 @@ def return_html_error(
`{msg}` placeholders), or a jinja2 template
"""
if f.check(CodeMessageException):
- cme = f.value
+ # mypy doesn't understand that f.check asserts the type.
+ cme = f.value # type: CodeMessageException # type: ignore
code = cme.code
msg = cme.msg
@@ -142,7 +147,7 @@ def return_html_error(
logger.error(
"Failed handle request %r",
request,
- exc_info=(f.type, f.value, f.getTracebackObject()),
+ exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
else:
code = HTTPStatus.INTERNAL_SERVER_ERROR
@@ -151,7 +156,7 @@ def return_html_error(
logger.error(
"Failed handle request %r",
request,
- exc_info=(f.type, f.value, f.getTracebackObject()),
+ exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
if isinstance(error_template, str):
@@ -278,7 +283,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
raw_callback_return = method_handler(request)
# Is it synchronous? We'll allow this for now.
- if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
+ if isawaitable(raw_callback_return):
callback_return = await raw_callback_return
else:
callback_return = raw_callback_return # type: ignore
@@ -399,8 +404,10 @@ class JsonResource(DirectServeJsonResource):
A tuple of the callback to use, the name of the servlet, and the
key word arguments to pass to the callback
"""
+ # At this point the path must be bytes.
+ request_path_bytes = request.path # type: bytes # type: ignore
+ request_path = request_path_bytes.decode("ascii")
# Treat HEAD requests as GET requests.
- request_path = request.path.decode("ascii")
request_method = request.method
if request_method == b"HEAD":
request_method = b"GET"
@@ -551,7 +558,7 @@ class _ByteProducer:
request: Request,
iterator: Iterator[bytes],
):
- self._request = request
+ self._request = request # type: Optional[Request]
self._iterator = iterator
self._paused = False
@@ -563,7 +570,7 @@ class _ByteProducer:
"""
Send a list of bytes as a chunk of a response.
"""
- if not data:
+ if not data or not self._request:
return
self._request.write(b"".join(data))
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 30153237e3..47754aff43 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,7 +14,7 @@
import contextlib
import logging
import time
-from typing import Optional, Union
+from typing import Optional, Type, Union
import attr
from zope.interface import implementer
@@ -57,7 +57,7 @@ class SynapseRequest(Request):
def __init__(self, channel, *args, **kw):
Request.__init__(self, channel, *args, **kw)
- self.site = channel.site
+ self.site = channel.site # type: SynapseSite
self._channel = channel # this is used by the tests
self.start_time = 0.0
@@ -96,25 +96,34 @@ class SynapseRequest(Request):
def get_request_id(self):
return "%s-%i" % (self.get_method(), self.request_seq)
- def get_redacted_uri(self):
- uri = self.uri
+ def get_redacted_uri(self) -> str:
+ """Gets the redacted URI associated with the request (or placeholder if the URI
+ has not yet been received).
+
+ Note: This is necessary as the placeholder value in twisted is str
+ rather than bytes, so we need to sanitise `self.uri`.
+
+ Returns:
+ The redacted URI as a string.
+ """
+ uri = self.uri # type: Union[bytes, str]
if isinstance(uri, bytes):
- uri = self.uri.decode("ascii", errors="replace")
+ uri = uri.decode("ascii", errors="replace")
return redact_uri(uri)
- def get_method(self):
- """Gets the method associated with the request (or placeholder if not
- method has yet been received).
+ def get_method(self) -> str:
+ """Gets the method associated with the request (or placeholder if method
+ has not yet been received).
Note: This is necessary as the placeholder value in twisted is str
rather than bytes, so we need to sanitise `self.method`.
Returns:
- str
+ The request method as a string.
"""
- method = self.method
+ method = self.method # type: Union[bytes, str]
if isinstance(method, bytes):
- method = self.method.decode("ascii")
+ return self.method.decode("ascii")
return method
def render(self, resrc):
@@ -432,7 +441,9 @@ class SynapseSite(Site):
assert config.http_options is not None
proxied = config.http_options.x_forwarded
- self.requestFactory = XForwardedForRequest if proxied else SynapseRequest
+ self.requestFactory = (
+ XForwardedForRequest if proxied else SynapseRequest
+ ) # type: Type[Request]
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")
|