summary refs log tree commit diff
path: root/synapse/http/server.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/http/server.py29
1 files changed, 18 insertions, 11 deletions
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 845db9b78d..fa89260850 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -21,6 +21,7 @@ import logging
 import types
 import urllib
 from http import HTTPStatus
+from inspect import isawaitable
 from io import BytesIO
 from typing import (
     Any,
@@ -30,6 +31,7 @@ from typing import (
     Iterable,
     Iterator,
     List,
+    Optional,
     Pattern,
     Tuple,
     Union,
@@ -79,10 +81,12 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
     """Sends a JSON error response to clients."""
 
     if f.check(SynapseError):
-        error_code = f.value.code
-        error_dict = f.value.error_dict()
+        # mypy doesn't understand that f.check asserts the type.
+        exc = f.value  # type: SynapseError  # type: ignore
+        error_code = exc.code
+        error_dict = exc.error_dict()
 
-        logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
+        logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
     else:
         error_code = 500
         error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
@@ -91,7 +95,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
             "Failed handle request via %r: %r",
             request.request_metrics.name,
             request,
-            exc_info=(f.type, f.value, f.getTracebackObject()),
+            exc_info=(f.type, f.value, f.getTracebackObject()),  # type: ignore
         )
 
     # Only respond with an error response if we haven't already started writing,
@@ -128,7 +132,8 @@ def return_html_error(
             `{msg}` placeholders), or a jinja2 template
     """
     if f.check(CodeMessageException):
-        cme = f.value
+        # mypy doesn't understand that f.check asserts the type.
+        cme = f.value  # type: CodeMessageException  # type: ignore
         code = cme.code
         msg = cme.msg
 
@@ -142,7 +147,7 @@ def return_html_error(
             logger.error(
                 "Failed handle request %r",
                 request,
-                exc_info=(f.type, f.value, f.getTracebackObject()),
+                exc_info=(f.type, f.value, f.getTracebackObject()),  # type: ignore
             )
     else:
         code = HTTPStatus.INTERNAL_SERVER_ERROR
@@ -151,7 +156,7 @@ def return_html_error(
         logger.error(
             "Failed handle request %r",
             request,
-            exc_info=(f.type, f.value, f.getTracebackObject()),
+            exc_info=(f.type, f.value, f.getTracebackObject()),  # type: ignore
         )
 
     if isinstance(error_template, str):
@@ -278,7 +283,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
             raw_callback_return = method_handler(request)
 
             # Is it synchronous? We'll allow this for now.
-            if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
+            if isawaitable(raw_callback_return):
                 callback_return = await raw_callback_return
             else:
                 callback_return = raw_callback_return  # type: ignore
@@ -399,8 +404,10 @@ class JsonResource(DirectServeJsonResource):
             A tuple of the callback to use, the name of the servlet, and the
             key word arguments to pass to the callback
         """
+        # At this point the path must be bytes.
+        request_path_bytes = request.path  # type: bytes  # type: ignore
+        request_path = request_path_bytes.decode("ascii")
         # Treat HEAD requests as GET requests.
-        request_path = request.path.decode("ascii")
         request_method = request.method
         if request_method == b"HEAD":
             request_method = b"GET"
@@ -551,7 +558,7 @@ class _ByteProducer:
         request: Request,
         iterator: Iterator[bytes],
     ):
-        self._request = request
+        self._request = request  # type: Optional[Request]
         self._iterator = iterator
         self._paused = False
 
@@ -563,7 +570,7 @@ class _ByteProducer:
         """
         Send a list of bytes as a chunk of a response.
         """
-        if not data:
+        if not data or not self._request:
             return
         self._request.write(b"".join(data))