summary refs log tree commit diff
path: root/synapse/http
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http')
-rw-r--r--synapse/http/__init__.py6
-rw-r--r--synapse/http/additional_resource.py12
-rw-r--r--synapse/http/client.py15
-rw-r--r--synapse/http/matrixfederationclient.py3
-rw-r--r--synapse/http/server.py90
-rw-r--r--synapse/http/servlet.py50
-rw-r--r--synapse/http/site.py8
7 files changed, 116 insertions, 68 deletions
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 578fc48ef4..efecb089c1 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -25,7 +25,7 @@ from synapse.api.errors import SynapseError
 class RequestTimedOutError(SynapseError):
     """Exception representing timeout of an outbound request"""
 
-    def __init__(self, msg):
+    def __init__(self, msg: str):
         super().__init__(504, msg)
 
 
@@ -33,7 +33,7 @@ ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$")
 CLIENT_SECRET_RE = re.compile(r"(\?.*client(_|%5[Ff])secret=)[^&]*(.*)$")
 
 
-def redact_uri(uri):
+def redact_uri(uri: str) -> str:
     """Strips sensitive information from the uri replaces with <redacted>"""
     uri = ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri)
     return CLIENT_SECRET_RE.sub(r"\1<redacted>\3", uri)
@@ -46,7 +46,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
     https://twistedmatrix.com/trac/ticket/6528
     """
 
-    def stopProducing(self):
+    def stopProducing(self) -> None:
         try:
             FileBodyProducer.stopProducing(self)
         except task.TaskStopped:
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
index 9a2684aca4..6a9f6635d2 100644
--- a/synapse/http/additional_resource.py
+++ b/synapse/http/additional_resource.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple
 
 from twisted.web.server import Request
 
@@ -32,7 +32,11 @@ class AdditionalResource(DirectServeJsonResource):
     and exception handling.
     """
 
-    def __init__(self, hs: "HomeServer", handler):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        handler: Callable[[Request], Awaitable[Optional[Tuple[int, Any]]]],
+    ):
         """Initialise AdditionalResource
 
         The ``handler`` should return a deferred which completes when it has
@@ -47,7 +51,7 @@ class AdditionalResource(DirectServeJsonResource):
         super().__init__()
         self._handler = handler
 
-    def _async_render(self, request: Request):
+    async def _async_render(self, request: Request) -> Optional[Tuple[int, Any]]:
         # Cheekily pass the result straight through, so we don't need to worry
         # if its an awaitable or not.
-        return self._handler(request)
+        return await self._handler(request)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index b5a2d333a6..fbbeceabeb 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 import logging
 import urllib.parse
+from http import HTTPStatus
 from io import BytesIO
 from typing import (
     TYPE_CHECKING,
@@ -280,7 +281,9 @@ class BlacklistingAgentWrapper(Agent):
                 ip_address, self._ip_whitelist, self._ip_blacklist
             ):
                 logger.info("Blocking access to %s due to blacklist" % (ip_address,))
-                e = SynapseError(403, "IP address blocked by IP blacklist entry")
+                e = SynapseError(
+                    HTTPStatus.FORBIDDEN, "IP address blocked by IP blacklist entry"
+                )
                 return defer.fail(Failure(e))
 
         return self._agent.request(
@@ -719,7 +722,9 @@ class SimpleHttpClient:
 
         if response.code > 299:
             logger.warning("Got %d when downloading %s" % (response.code, url))
-            raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
+            raise SynapseError(
+                HTTPStatus.BAD_GATEWAY, "Got error %d" % (response.code,), Codes.UNKNOWN
+            )
 
         # TODO: if our Content-Type is HTML or something, just read the first
         # N bytes into RAM rather than saving it all to disk only to read it
@@ -731,12 +736,14 @@ class SimpleHttpClient:
             )
         except BodyExceededMaxSize:
             raise SynapseError(
-                502,
+                HTTPStatus.BAD_GATEWAY,
                 "Requested file is too large > %r bytes" % (max_size,),
                 Codes.TOO_LARGE,
             )
         except Exception as e:
-            raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e
+            raise SynapseError(
+                HTTPStatus.BAD_GATEWAY, ("Failed to download remote body: %s" % e)
+            ) from e
 
         return (
             length,
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 203d723d41..deedde0b5b 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -19,6 +19,7 @@ import random
 import sys
 import typing
 import urllib.parse
+from http import HTTPStatus
 from io import BytesIO, StringIO
 from typing import (
     TYPE_CHECKING,
@@ -1154,7 +1155,7 @@ class MatrixFederationHttpClient:
                 request.destination,
                 msg,
             )
-            raise SynapseError(502, msg, Codes.TOO_LARGE)
+            raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
         except defer.TimeoutError as e:
             logger.warning(
                 "{%s} [%s] Timed out reading response - %s %s",
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 91badb0b0a..4fd5660a08 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -30,6 +30,7 @@ from typing import (
     Iterable,
     Iterator,
     List,
+    NoReturn,
     Optional,
     Pattern,
     Tuple,
@@ -170,7 +171,9 @@ def return_html_error(
     respond_with_html(request, code, body)
 
 
-def wrap_async_request_handler(h):
+def wrap_async_request_handler(
+    h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]]
+) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]:
     """Wraps an async request handler so that it calls request.processing.
 
     This helps ensure that work done by the request handler after the request is completed
@@ -183,7 +186,9 @@ def wrap_async_request_handler(h):
     logged until the deferred completes.
     """
 
-    async def wrapped_async_request_handler(self, request):
+    async def wrapped_async_request_handler(
+        self: "_AsyncResource", request: SynapseRequest
+    ) -> None:
         with request.processing():
             await h(self, request)
 
@@ -240,18 +245,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
             context from the request the servlet is handling.
     """
 
-    def __init__(self, extract_context=False):
+    def __init__(self, extract_context: bool = False):
         super().__init__()
 
         self._extract_context = extract_context
 
-    def render(self, request):
+    def render(self, request: SynapseRequest) -> int:
         """This gets called by twisted every time someone sends us a request."""
         defer.ensureDeferred(self._async_render_wrapper(request))
         return NOT_DONE_YET
 
     @wrap_async_request_handler
-    async def _async_render_wrapper(self, request: SynapseRequest):
+    async def _async_render_wrapper(self, request: SynapseRequest) -> None:
         """This is a wrapper that delegates to `_async_render` and handles
         exceptions, return values, metrics, etc.
         """
@@ -271,7 +276,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
             f = failure.Failure()
             self._send_error_response(f, request)
 
-    async def _async_render(self, request: Request):
+    async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]:
         """Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
         no appropriate method exists. Can be overridden in sub classes for
         different routing.
@@ -318,7 +323,7 @@ class DirectServeJsonResource(_AsyncResource):
     formatting responses and errors as JSON.
     """
 
-    def __init__(self, canonical_json=False, extract_context=False):
+    def __init__(self, canonical_json: bool = False, extract_context: bool = False):
         super().__init__(extract_context)
         self.canonical_json = canonical_json
 
@@ -327,7 +332,7 @@ class DirectServeJsonResource(_AsyncResource):
         request: SynapseRequest,
         code: int,
         response_object: Any,
-    ):
+    ) -> None:
         """Implements _AsyncResource._send_response"""
         # TODO: Only enable CORS for the requests that need it.
         respond_with_json(
@@ -368,34 +373,45 @@ class JsonResource(DirectServeJsonResource):
 
     isLeaf = True
 
-    def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        canonical_json: bool = True,
+        extract_context: bool = False,
+    ):
         super().__init__(canonical_json, extract_context)
         self.clock = hs.get_clock()
         self.path_regexs: Dict[bytes, List[_PathEntry]] = {}
         self.hs = hs
 
-    def register_paths(self, method, path_patterns, callback, servlet_classname):
+    def register_paths(
+        self,
+        method: str,
+        path_patterns: Iterable[Pattern],
+        callback: ServletCallback,
+        servlet_classname: str,
+    ) -> None:
         """
         Registers a request handler against a regular expression. Later request URLs are
         checked against these regular expressions in order to identify an appropriate
         handler for that request.
 
         Args:
-            method (str): GET, POST etc
+            method: GET, POST etc
 
-            path_patterns (Iterable[str]): A list of regular expressions to which
-                the request URLs are compared.
+            path_patterns: A list of regular expressions to which the request
+                URLs are compared.
 
-            callback (function): The handler for the request. Usually a Servlet
+            callback: The handler for the request. Usually a Servlet
 
-            servlet_classname (str): The name of the handler to be used in prometheus
+            servlet_classname: The name of the handler to be used in prometheus
                 and opentracing logs.
         """
-        method = method.encode("utf-8")  # method is bytes on py3
+        method_bytes = method.encode("utf-8")
 
         for path_pattern in path_patterns:
             logger.debug("Registering for %s %s", method, path_pattern.pattern)
-            self.path_regexs.setdefault(method, []).append(
+            self.path_regexs.setdefault(method_bytes, []).append(
                 _PathEntry(path_pattern, callback, servlet_classname)
             )
 
@@ -427,7 +443,7 @@ class JsonResource(DirectServeJsonResource):
         # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
         return _unrecognised_request_handler, "unrecognised_request_handler", {}
 
-    async def _async_render(self, request):
+    async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
         callback, servlet_classname, group_dict = self._get_handler_for_request(request)
 
         # Make sure we have an appropriate name for this handler in prometheus
@@ -468,7 +484,7 @@ class DirectServeHtmlResource(_AsyncResource):
         request: SynapseRequest,
         code: int,
         response_object: Any,
-    ):
+    ) -> None:
         """Implements _AsyncResource._send_response"""
         # We expect to get bytes for us to write
         assert isinstance(response_object, bytes)
@@ -492,12 +508,12 @@ class StaticResource(File):
     Differs from the File resource by adding clickjacking protection.
     """
 
-    def render_GET(self, request: Request):
+    def render_GET(self, request: Request) -> bytes:
         set_clickjacking_protection_headers(request)
         return super().render_GET(request)
 
 
-def _unrecognised_request_handler(request):
+def _unrecognised_request_handler(request: Request) -> NoReturn:
     """Request handler for unrecognised requests
 
     This is a request handler suitable for return from
@@ -505,7 +521,7 @@ def _unrecognised_request_handler(request):
     UnrecognizedRequestError.
 
     Args:
-        request (twisted.web.http.Request):
+        request: Unused, but passed in to match the signature of ServletCallback.
     """
     raise UnrecognizedRequestError()
 
@@ -513,14 +529,14 @@ def _unrecognised_request_handler(request):
 class RootRedirect(resource.Resource):
     """Redirects the root '/' path to another path."""
 
-    def __init__(self, path):
+    def __init__(self, path: str):
         resource.Resource.__init__(self)
         self.url = path
 
-    def render_GET(self, request):
+    def render_GET(self, request: Request) -> bytes:
         return redirectTo(self.url.encode("ascii"), request)
 
-    def getChild(self, name, request):
+    def getChild(self, name: str, request: Request) -> resource.Resource:
         if len(name) == 0:
             return self  # select ourselves as the child to render
         return resource.Resource.getChild(self, name, request)
@@ -529,7 +545,7 @@ class RootRedirect(resource.Resource):
 class OptionsResource(resource.Resource):
     """Responds to OPTION requests for itself and all children."""
 
-    def render_OPTIONS(self, request):
+    def render_OPTIONS(self, request: Request) -> bytes:
         request.setResponseCode(204)
         request.setHeader(b"Content-Length", b"0")
 
@@ -537,7 +553,7 @@ class OptionsResource(resource.Resource):
 
         return b""
 
-    def getChildWithDefault(self, path, request):
+    def getChildWithDefault(self, path: str, request: Request) -> resource.Resource:
         if request.method == b"OPTIONS":
             return self  # select ourselves as the child to render
         return resource.Resource.getChildWithDefault(self, path, request)
@@ -649,7 +665,7 @@ def respond_with_json(
     json_object: Any,
     send_cors: bool = False,
     canonical_json: bool = True,
-):
+) -> Optional[int]:
     """Sends encoded JSON in response to the given request.
 
     Args:
@@ -696,7 +712,7 @@ def respond_with_json_bytes(
     code: int,
     json_bytes: bytes,
     send_cors: bool = False,
-):
+) -> Optional[int]:
     """Sends encoded JSON in response to the given request.
 
     Args:
@@ -713,7 +729,7 @@ def respond_with_json_bytes(
         logger.warning(
             "Not sending response to request %s, already disconnected.", request
         )
-        return
+        return None
 
     request.setResponseCode(code)
     request.setHeader(b"Content-Type", b"application/json")
@@ -731,7 +747,7 @@ async def _async_write_json_to_request_in_thread(
     request: SynapseRequest,
     json_encoder: Callable[[Any], bytes],
     json_object: Any,
-):
+) -> None:
     """Encodes the given JSON object on a thread and then writes it to the
     request.
 
@@ -773,7 +789,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None:
     _ByteProducer(request, bytes_generator)
 
 
-def set_cors_headers(request: Request):
+def set_cors_headers(request: Request) -> None:
     """Set the CORS headers so that javascript running in a web browsers can
     use this API
 
@@ -790,14 +806,14 @@ def set_cors_headers(request: Request):
     )
 
 
-def respond_with_html(request: Request, code: int, html: str):
+def respond_with_html(request: Request, code: int, html: str) -> None:
     """
     Wraps `respond_with_html_bytes` by first encoding HTML from a str to UTF-8 bytes.
     """
     respond_with_html_bytes(request, code, html.encode("utf-8"))
 
 
-def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
+def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> None:
     """
     Sends HTML (encoded as UTF-8 bytes) as the response to the given request.
 
@@ -815,7 +831,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
         logger.warning(
             "Not sending response to request %s, already disconnected.", request
         )
-        return
+        return None
 
     request.setResponseCode(code)
     request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
@@ -828,7 +844,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
     finish_request(request)
 
 
-def set_clickjacking_protection_headers(request: Request):
+def set_clickjacking_protection_headers(request: Request) -> None:
     """
     Set headers to guard against clickjacking of embedded content.
 
@@ -850,7 +866,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None:
     finish_request(request)
 
 
-def finish_request(request: Request):
+def finish_request(request: Request) -> None:
     """Finish writing the response to the request.
 
     Twisted throws a RuntimeException if the connection closed before the
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 6dd9b9ad03..4ff840ca0e 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -14,6 +14,7 @@
 
 """ This module contains base REST classes for constructing REST servlets. """
 import logging
+from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
     Iterable,
@@ -30,6 +31,7 @@ from typing_extensions import Literal
 from twisted.web.server import Request
 
 from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
 from synapse.types import JsonDict, RoomAlias, RoomID
 from synapse.util import json_decoder
 
@@ -137,11 +139,15 @@ def parse_integer_from_args(
             return int(args[name_bytes][0])
         except Exception:
             message = "Query parameter %r must be an integer" % (name,)
-            raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
+            )
     else:
         if required:
             message = "Missing integer query parameter %r" % (name,)
-            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
+            )
         else:
             return default
 
@@ -246,11 +252,15 @@ def parse_boolean_from_args(
             message = (
                 "Boolean query parameter %r must be one of ['true', 'false']"
             ) % (name,)
-            raise SynapseError(400, message)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
+            )
     else:
         if required:
             message = "Missing boolean query parameter %r" % (name,)
-            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
+            )
         else:
             return default
 
@@ -313,7 +323,7 @@ def parse_bytes_from_args(
         return args[name_bytes][0]
     elif required:
         message = "Missing string query parameter %s" % (name,)
-        raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+        raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM)
 
     return default
 
@@ -407,14 +417,16 @@ def _parse_string_value(
     try:
         value_str = value.decode(encoding)
     except ValueError:
-        raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
+        raise SynapseError(
+            HTTPStatus.BAD_REQUEST, "Query parameter %r must be %s" % (name, encoding)
+        )
 
     if allowed_values is not None and value_str not in allowed_values:
         message = "Query parameter %r must be one of [%s]" % (
             name,
             ", ".join(repr(v) for v in allowed_values),
         )
-        raise SynapseError(400, message)
+        raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM)
     else:
         return value_str
 
@@ -510,7 +522,9 @@ def parse_strings_from_args(
     else:
         if required:
             message = "Missing string query parameter %r" % (name,)
-            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
+            )
 
         return default
 
@@ -638,7 +652,7 @@ def parse_json_value_from_request(
     try:
         content_bytes = request.content.read()  # type: ignore
     except Exception:
-        raise SynapseError(400, "Error reading JSON content.")
+        raise SynapseError(HTTPStatus.BAD_REQUEST, "Error reading JSON content.")
 
     if not content_bytes and allow_empty_body:
         return None
@@ -647,7 +661,9 @@ def parse_json_value_from_request(
         content = json_decoder.decode(content_bytes.decode("utf-8"))
     except Exception as e:
         logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes)
-        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
+        raise SynapseError(
+            HTTPStatus.BAD_REQUEST, "Content not JSON.", errcode=Codes.NOT_JSON
+        )
 
     return content
 
@@ -673,7 +689,7 @@ def parse_json_object_from_request(
 
     if not isinstance(content, dict):
         message = "Content must be a JSON object."
-        raise SynapseError(400, message, errcode=Codes.BAD_JSON)
+        raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.BAD_JSON)
 
     return content
 
@@ -685,7 +701,9 @@ def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
             absent.append(k)
 
     if len(absent) > 0:
-        raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+        raise SynapseError(
+            HTTPStatus.BAD_REQUEST, "Missing params: %r" % absent, Codes.MISSING_PARAM
+        )
 
 
 class RestServlet:
@@ -709,7 +727,7 @@ class RestServlet:
     into the appropriate HTTP response.
     """
 
-    def register(self, http_server):
+    def register(self, http_server: HttpServer) -> None:
         """Register this servlet with the given HTTP server."""
         patterns = getattr(self, "PATTERNS", None)
         if patterns:
@@ -758,10 +776,12 @@ class ResolveRoomIdMixin:
             resolved_room_id = room_id.to_string()
         else:
             raise SynapseError(
-                400, "%s was not legal room ID or room alias" % (room_identifier,)
+                HTTPStatus.BAD_REQUEST,
+                "%s was not legal room ID or room alias" % (room_identifier,),
             )
         if not resolved_room_id:
             raise SynapseError(
-                400, "Unknown room ID or room alias %s" % room_identifier
+                HTTPStatus.BAD_REQUEST,
+                "Unknown room ID or room alias %s" % room_identifier,
             )
         return resolved_room_id, remote_room_hosts
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 755ad56637..9f68d7e191 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,7 +14,7 @@
 import contextlib
 import logging
 import time
-from typing import Generator, Optional, Tuple, Union
+from typing import Any, Generator, Optional, Tuple, Union
 
 import attr
 from zope.interface import implementer
@@ -66,9 +66,9 @@ class SynapseRequest(Request):
         self,
         channel: HTTPChannel,
         site: "SynapseSite",
-        *args,
+        *args: Any,
         max_request_body_size: int = 1024,
-        **kw,
+        **kw: Any,
     ):
         super().__init__(channel, *args, **kw)
         self._max_request_body_size = max_request_body_size
@@ -557,7 +557,7 @@ class SynapseSite(Site):
         proxied = config.http_options.x_forwarded
         request_class = XForwardedForRequest if proxied else SynapseRequest
 
-        def request_factory(channel, queued: bool) -> Request:
+        def request_factory(channel: HTTPChannel, queued: bool) -> Request:
             return request_class(
                 channel,
                 self,