summary refs log tree commit diff
path: root/synapse/http
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-03-03 15:47:38 -0500
committerGitHub <noreply@github.com>2021-03-03 15:47:38 -0500
commit33a02f0f52a5cd793c0d123463d8a7a1e3f4f198 (patch)
tree0d53722ffb90376b2998bdf2f206e3b9d76da9d3 /synapse/http
parentSet X-Forwarded-Proto header when frontend-proxy proxies a request (#9539) (diff)
downloadsynapse-33a02f0f52a5cd793c0d123463d8a7a1e3f4f198.tar.xz
Fix additional type hints from Twisted upgrade. (#9518)
Diffstat (limited to 'synapse/http')
-rw-r--r--synapse/http/federation/matrix_federation_agent.py18
-rw-r--r--synapse/http/matrixfederationclient.py6
-rw-r--r--synapse/http/server.py29
-rw-r--r--synapse/http/site.py35
4 files changed, 53 insertions, 35 deletions
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 2e83fa6773..b07aa59c08 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import urllib.parse
-from typing import List, Optional
+from typing import Any, Generator, List, Optional
 
 from netaddr import AddrFormatError, IPAddress, IPSet
 from zope.interface import implementer
@@ -116,7 +116,7 @@ class MatrixFederationAgent:
         uri: bytes,
         headers: Optional[Headers] = None,
         bodyProducer: Optional[IBodyProducer] = None,
-    ) -> defer.Deferred:
+    ) -> Generator[defer.Deferred, Any, defer.Deferred]:
         """
         Args:
             method: HTTP method: GET/POST/etc
@@ -177,17 +177,17 @@ class MatrixFederationAgent:
         # We need to make sure the host header is set to the netloc of the
         # server and that a user-agent is provided.
         if headers is None:
-            headers = Headers()
+            request_headers = Headers()
         else:
-            headers = headers.copy()
+            request_headers = headers.copy()
 
-        if not headers.hasHeader(b"host"):
-            headers.addRawHeader(b"host", parsed_uri.netloc)
-        if not headers.hasHeader(b"user-agent"):
-            headers.addRawHeader(b"user-agent", self.user_agent)
+        if not request_headers.hasHeader(b"host"):
+            request_headers.addRawHeader(b"host", parsed_uri.netloc)
+        if not request_headers.hasHeader(b"user-agent"):
+            request_headers.addRawHeader(b"user-agent", self.user_agent)
 
         res = yield make_deferred_yieldable(
-            self._agent.request(method, uri, headers, bodyProducer)
+            self._agent.request(method, uri, request_headers, bodyProducer)
         )
 
         return res
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index cde42e9f5e..0f107714ea 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -1049,14 +1049,14 @@ def check_content_type_is_json(headers: Headers) -> None:
         RequestSendFailed: if the Content-Type header is missing or isn't JSON
 
     """
-    c_type = headers.getRawHeaders(b"Content-Type")
-    if c_type is None:
+    content_type_headers = headers.getRawHeaders(b"Content-Type")
+    if content_type_headers is None:
         raise RequestSendFailed(
             RuntimeError("No Content-Type header received from remote server"),
             can_retry=False,
         )
 
-    c_type = c_type[0].decode("ascii")  # only the first header
+    c_type = content_type_headers[0].decode("ascii")  # only the first header
     val, options = cgi.parse_header(c_type)
     if val != "application/json":
         raise RequestSendFailed(
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 845db9b78d..fa89260850 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -21,6 +21,7 @@ import logging
 import types
 import urllib
 from http import HTTPStatus
+from inspect import isawaitable
 from io import BytesIO
 from typing import (
     Any,
@@ -30,6 +31,7 @@ from typing import (
     Iterable,
     Iterator,
     List,
+    Optional,
     Pattern,
     Tuple,
     Union,
@@ -79,10 +81,12 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
     """Sends a JSON error response to clients."""
 
     if f.check(SynapseError):
-        error_code = f.value.code
-        error_dict = f.value.error_dict()
+        # mypy doesn't understand that f.check asserts the type.
+        exc = f.value  # type: SynapseError  # type: ignore
+        error_code = exc.code
+        error_dict = exc.error_dict()
 
-        logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
+        logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
     else:
         error_code = 500
         error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
@@ -91,7 +95,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
             "Failed handle request via %r: %r",
             request.request_metrics.name,
             request,
-            exc_info=(f.type, f.value, f.getTracebackObject()),
+            exc_info=(f.type, f.value, f.getTracebackObject()),  # type: ignore
         )
 
     # Only respond with an error response if we haven't already started writing,
@@ -128,7 +132,8 @@ def return_html_error(
             `{msg}` placeholders), or a jinja2 template
     """
     if f.check(CodeMessageException):
-        cme = f.value
+        # mypy doesn't understand that f.check asserts the type.
+        cme = f.value  # type: CodeMessageException  # type: ignore
         code = cme.code
         msg = cme.msg
 
@@ -142,7 +147,7 @@ def return_html_error(
             logger.error(
                 "Failed handle request %r",
                 request,
-                exc_info=(f.type, f.value, f.getTracebackObject()),
+                exc_info=(f.type, f.value, f.getTracebackObject()),  # type: ignore
             )
     else:
         code = HTTPStatus.INTERNAL_SERVER_ERROR
@@ -151,7 +156,7 @@ def return_html_error(
         logger.error(
             "Failed handle request %r",
             request,
-            exc_info=(f.type, f.value, f.getTracebackObject()),
+            exc_info=(f.type, f.value, f.getTracebackObject()),  # type: ignore
         )
 
     if isinstance(error_template, str):
@@ -278,7 +283,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
             raw_callback_return = method_handler(request)
 
             # Is it synchronous? We'll allow this for now.
-            if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
+            if isawaitable(raw_callback_return):
                 callback_return = await raw_callback_return
             else:
                 callback_return = raw_callback_return  # type: ignore
@@ -399,8 +404,10 @@ class JsonResource(DirectServeJsonResource):
             A tuple of the callback to use, the name of the servlet, and the
             key word arguments to pass to the callback
         """
+        # At this point the path must be bytes.
+        request_path_bytes = request.path  # type: bytes  # type: ignore
+        request_path = request_path_bytes.decode("ascii")
         # Treat HEAD requests as GET requests.
-        request_path = request.path.decode("ascii")
         request_method = request.method
         if request_method == b"HEAD":
             request_method = b"GET"
@@ -551,7 +558,7 @@ class _ByteProducer:
         request: Request,
         iterator: Iterator[bytes],
     ):
-        self._request = request
+        self._request = request  # type: Optional[Request]
         self._iterator = iterator
         self._paused = False
 
@@ -563,7 +570,7 @@ class _ByteProducer:
         """
         Send a list of bytes as a chunk of a response.
         """
-        if not data:
+        if not data or not self._request:
             return
         self._request.write(b"".join(data))
 
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 30153237e3..47754aff43 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,7 +14,7 @@
 import contextlib
 import logging
 import time
-from typing import Optional, Union
+from typing import Optional, Type, Union
 
 import attr
 from zope.interface import implementer
@@ -57,7 +57,7 @@ class SynapseRequest(Request):
 
     def __init__(self, channel, *args, **kw):
         Request.__init__(self, channel, *args, **kw)
-        self.site = channel.site
+        self.site = channel.site  # type: SynapseSite
         self._channel = channel  # this is used by the tests
         self.start_time = 0.0
 
@@ -96,25 +96,34 @@ class SynapseRequest(Request):
     def get_request_id(self):
         return "%s-%i" % (self.get_method(), self.request_seq)
 
-    def get_redacted_uri(self):
-        uri = self.uri
+    def get_redacted_uri(self) -> str:
+        """Gets the redacted URI associated with the request (or placeholder if the URI
+        has not yet been received).
+
+        Note: This is necessary as the placeholder value in twisted is str
+        rather than bytes, so we need to sanitise `self.uri`.
+
+        Returns:
+            The redacted URI as a string.
+        """
+        uri = self.uri  # type: Union[bytes, str]
         if isinstance(uri, bytes):
-            uri = self.uri.decode("ascii", errors="replace")
+            uri = uri.decode("ascii", errors="replace")
         return redact_uri(uri)
 
-    def get_method(self):
-        """Gets the method associated with the request (or placeholder if not
-        method has yet been received).
+    def get_method(self) -> str:
+        """Gets the method associated with the request (or placeholder if method
+        has not yet been received).
 
         Note: This is necessary as the placeholder value in twisted is str
         rather than bytes, so we need to sanitise `self.method`.
 
         Returns:
-            str
+            The request method as a string.
         """
-        method = self.method
+        method = self.method  # type: Union[bytes, str]
         if isinstance(method, bytes):
-            method = self.method.decode("ascii")
+            return self.method.decode("ascii")
         return method
 
     def render(self, resrc):
@@ -432,7 +441,9 @@ class SynapseSite(Site):
 
         assert config.http_options is not None
         proxied = config.http_options.x_forwarded
-        self.requestFactory = XForwardedForRequest if proxied else SynapseRequest
+        self.requestFactory = (
+            XForwardedForRequest if proxied else SynapseRequest
+        )  # type: Type[Request]
         self.access_logger = logging.getLogger(logger_name)
         self.server_version_string = server_version_string.encode("ascii")