diff options
Diffstat (limited to 'synapse/http')
-rw-r--r-- | synapse/http/__init__.py | 6 | ||||
-rw-r--r-- | synapse/http/additional_resource.py | 12 | ||||
-rw-r--r-- | synapse/http/client.py | 15 | ||||
-rw-r--r-- | synapse/http/matrixfederationclient.py | 3 | ||||
-rw-r--r-- | synapse/http/server.py | 90 | ||||
-rw-r--r-- | synapse/http/servlet.py | 50 | ||||
-rw-r--r-- | synapse/http/site.py | 8 |
7 files changed, 116 insertions, 68 deletions
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index 578fc48ef4..efecb089c1 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -25,7 +25,7 @@ from synapse.api.errors import SynapseError class RequestTimedOutError(SynapseError): """Exception representing timeout of an outbound request""" - def __init__(self, msg): + def __init__(self, msg: str): super().__init__(504, msg) @@ -33,7 +33,7 @@ ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$") CLIENT_SECRET_RE = re.compile(r"(\?.*client(_|%5[Ff])secret=)[^&]*(.*)$") -def redact_uri(uri): +def redact_uri(uri: str) -> str: """Strips sensitive information from the uri replaces with <redacted>""" uri = ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri) return CLIENT_SECRET_RE.sub(r"\1<redacted>\3", uri) @@ -46,7 +46,7 @@ class QuieterFileBodyProducer(FileBodyProducer): https://twistedmatrix.com/trac/ticket/6528 """ - def stopProducing(self): + def stopProducing(self) -> None: try: FileBodyProducer.stopProducing(self) except task.TaskStopped: diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py index 9a2684aca4..6a9f6635d2 100644 --- a/synapse/http/additional_resource.py +++ b/synapse/http/additional_resource.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple from twisted.web.server import Request @@ -32,7 +32,11 @@ class AdditionalResource(DirectServeJsonResource): and exception handling. """ - def __init__(self, hs: "HomeServer", handler): + def __init__( + self, + hs: "HomeServer", + handler: Callable[[Request], Awaitable[Optional[Tuple[int, Any]]]], + ): """Initialise AdditionalResource The ``handler`` should return a deferred which completes when it has @@ -47,7 +51,7 @@ class AdditionalResource(DirectServeJsonResource): super().__init__() self._handler = handler - def _async_render(self, request: Request): + async def _async_render(self, request: Request) -> Optional[Tuple[int, Any]]: # Cheekily pass the result straight through, so we don't need to worry # if its an awaitable or not. - return self._handler(request) + return await self._handler(request) diff --git a/synapse/http/client.py b/synapse/http/client.py index b5a2d333a6..fbbeceabeb 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -14,6 +14,7 @@ # limitations under the License. import logging import urllib.parse +from http import HTTPStatus from io import BytesIO from typing import ( TYPE_CHECKING, @@ -280,7 +281,9 @@ class BlacklistingAgentWrapper(Agent): ip_address, self._ip_whitelist, self._ip_blacklist ): logger.info("Blocking access to %s due to blacklist" % (ip_address,)) - e = SynapseError(403, "IP address blocked by IP blacklist entry") + e = SynapseError( + HTTPStatus.FORBIDDEN, "IP address blocked by IP blacklist entry" + ) return defer.fail(Failure(e)) return self._agent.request( @@ -719,7 +722,9 @@ class SimpleHttpClient: if response.code > 299: logger.warning("Got %d when downloading %s" % (response.code, url)) - raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN) + raise SynapseError( + HTTPStatus.BAD_GATEWAY, "Got error %d" % (response.code,), Codes.UNKNOWN + ) # TODO: if our Content-Type is HTML or something, just read the first # N bytes into RAM rather than saving it all to disk only to read it @@ -731,12 +736,14 @@ class SimpleHttpClient: ) except BodyExceededMaxSize: raise SynapseError( - 502, + HTTPStatus.BAD_GATEWAY, "Requested file is too large > %r bytes" % (max_size,), Codes.TOO_LARGE, ) except Exception as e: - raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e + raise SynapseError( + HTTPStatus.BAD_GATEWAY, ("Failed to download remote body: %s" % e) + ) from e return ( length, diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 203d723d41..deedde0b5b 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -19,6 +19,7 @@ import random import sys import typing import urllib.parse +from http import HTTPStatus from io import BytesIO, StringIO from typing import ( TYPE_CHECKING, @@ -1154,7 +1155,7 @@ class MatrixFederationHttpClient: request.destination, msg, ) - raise SynapseError(502, msg, Codes.TOO_LARGE) + raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE) except defer.TimeoutError as e: logger.warning( "{%s} [%s] Timed out reading response - %s %s", diff --git a/synapse/http/server.py b/synapse/http/server.py index 91badb0b0a..4fd5660a08 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -30,6 +30,7 @@ from typing import ( Iterable, Iterator, List, + NoReturn, Optional, Pattern, Tuple, @@ -170,7 +171,9 @@ def return_html_error( respond_with_html(request, code, body) -def wrap_async_request_handler(h): +def wrap_async_request_handler( + h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]] +) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]: """Wraps an async request handler so that it calls request.processing. This helps ensure that work done by the request handler after the request is completed @@ -183,7 +186,9 @@ def wrap_async_request_handler(h): logged until the deferred completes. """ - async def wrapped_async_request_handler(self, request): + async def wrapped_async_request_handler( + self: "_AsyncResource", request: SynapseRequest + ) -> None: with request.processing(): await h(self, request) @@ -240,18 +245,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): context from the request the servlet is handling. """ - def __init__(self, extract_context=False): + def __init__(self, extract_context: bool = False): super().__init__() self._extract_context = extract_context - def render(self, request): + def render(self, request: SynapseRequest) -> int: """This gets called by twisted every time someone sends us a request.""" defer.ensureDeferred(self._async_render_wrapper(request)) return NOT_DONE_YET @wrap_async_request_handler - async def _async_render_wrapper(self, request: SynapseRequest): + async def _async_render_wrapper(self, request: SynapseRequest) -> None: """This is a wrapper that delegates to `_async_render` and handles exceptions, return values, metrics, etc. """ @@ -271,7 +276,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): f = failure.Failure() self._send_error_response(f, request) - async def _async_render(self, request: Request): + async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]: """Delegates to `_async_render_<METHOD>` methods, or returns a 400 if no appropriate method exists. Can be overridden in sub classes for different routing. @@ -318,7 +323,7 @@ class DirectServeJsonResource(_AsyncResource): formatting responses and errors as JSON. """ - def __init__(self, canonical_json=False, extract_context=False): + def __init__(self, canonical_json: bool = False, extract_context: bool = False): super().__init__(extract_context) self.canonical_json = canonical_json @@ -327,7 +332,7 @@ class DirectServeJsonResource(_AsyncResource): request: SynapseRequest, code: int, response_object: Any, - ): + ) -> None: """Implements _AsyncResource._send_response""" # TODO: Only enable CORS for the requests that need it. respond_with_json( @@ -368,34 +373,45 @@ class JsonResource(DirectServeJsonResource): isLeaf = True - def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False): + def __init__( + self, + hs: "HomeServer", + canonical_json: bool = True, + extract_context: bool = False, + ): super().__init__(canonical_json, extract_context) self.clock = hs.get_clock() self.path_regexs: Dict[bytes, List[_PathEntry]] = {} self.hs = hs - def register_paths(self, method, path_patterns, callback, servlet_classname): + def register_paths( + self, + method: str, + path_patterns: Iterable[Pattern], + callback: ServletCallback, + servlet_classname: str, + ) -> None: """ Registers a request handler against a regular expression. Later request URLs are checked against these regular expressions in order to identify an appropriate handler for that request. Args: - method (str): GET, POST etc + method: GET, POST etc - path_patterns (Iterable[str]): A list of regular expressions to which - the request URLs are compared. + path_patterns: A list of regular expressions to which the request + URLs are compared. - callback (function): The handler for the request. Usually a Servlet + callback: The handler for the request. Usually a Servlet - servlet_classname (str): The name of the handler to be used in prometheus + servlet_classname: The name of the handler to be used in prometheus and opentracing logs. """ - method = method.encode("utf-8") # method is bytes on py3 + method_bytes = method.encode("utf-8") for path_pattern in path_patterns: logger.debug("Registering for %s %s", method, path_pattern.pattern) - self.path_regexs.setdefault(method, []).append( + self.path_regexs.setdefault(method_bytes, []).append( _PathEntry(path_pattern, callback, servlet_classname) ) @@ -427,7 +443,7 @@ class JsonResource(DirectServeJsonResource): # Huh. No one wanted to handle that? Fiiiiiine. Send 400. return _unrecognised_request_handler, "unrecognised_request_handler", {} - async def _async_render(self, request): + async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]: callback, servlet_classname, group_dict = self._get_handler_for_request(request) # Make sure we have an appropriate name for this handler in prometheus @@ -468,7 +484,7 @@ class DirectServeHtmlResource(_AsyncResource): request: SynapseRequest, code: int, response_object: Any, - ): + ) -> None: """Implements _AsyncResource._send_response""" # We expect to get bytes for us to write assert isinstance(response_object, bytes) @@ -492,12 +508,12 @@ class StaticResource(File): Differs from the File resource by adding clickjacking protection. """ - def render_GET(self, request: Request): + def render_GET(self, request: Request) -> bytes: set_clickjacking_protection_headers(request) return super().render_GET(request) -def _unrecognised_request_handler(request): +def _unrecognised_request_handler(request: Request) -> NoReturn: """Request handler for unrecognised requests This is a request handler suitable for return from @@ -505,7 +521,7 @@ def _unrecognised_request_handler(request): UnrecognizedRequestError. Args: - request (twisted.web.http.Request): + request: Unused, but passed in to match the signature of ServletCallback. """ raise UnrecognizedRequestError() @@ -513,14 +529,14 @@ def _unrecognised_request_handler(request): class RootRedirect(resource.Resource): """Redirects the root '/' path to another path.""" - def __init__(self, path): + def __init__(self, path: str): resource.Resource.__init__(self) self.url = path - def render_GET(self, request): + def render_GET(self, request: Request) -> bytes: return redirectTo(self.url.encode("ascii"), request) - def getChild(self, name, request): + def getChild(self, name: str, request: Request) -> resource.Resource: if len(name) == 0: return self # select ourselves as the child to render return resource.Resource.getChild(self, name, request) @@ -529,7 +545,7 @@ class RootRedirect(resource.Resource): class OptionsResource(resource.Resource): """Responds to OPTION requests for itself and all children.""" - def render_OPTIONS(self, request): + def render_OPTIONS(self, request: Request) -> bytes: request.setResponseCode(204) request.setHeader(b"Content-Length", b"0") @@ -537,7 +553,7 @@ class OptionsResource(resource.Resource): return b"" - def getChildWithDefault(self, path, request): + def getChildWithDefault(self, path: str, request: Request) -> resource.Resource: if request.method == b"OPTIONS": return self # select ourselves as the child to render return resource.Resource.getChildWithDefault(self, path, request) @@ -649,7 +665,7 @@ def respond_with_json( json_object: Any, send_cors: bool = False, canonical_json: bool = True, -): +) -> Optional[int]: """Sends encoded JSON in response to the given request. Args: @@ -696,7 +712,7 @@ def respond_with_json_bytes( code: int, json_bytes: bytes, send_cors: bool = False, -): +) -> Optional[int]: """Sends encoded JSON in response to the given request. Args: @@ -713,7 +729,7 @@ def respond_with_json_bytes( logger.warning( "Not sending response to request %s, already disconnected.", request ) - return + return None request.setResponseCode(code) request.setHeader(b"Content-Type", b"application/json") @@ -731,7 +747,7 @@ async def _async_write_json_to_request_in_thread( request: SynapseRequest, json_encoder: Callable[[Any], bytes], json_object: Any, -): +) -> None: """Encodes the given JSON object on a thread and then writes it to the request. @@ -773,7 +789,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None: _ByteProducer(request, bytes_generator) -def set_cors_headers(request: Request): +def set_cors_headers(request: Request) -> None: """Set the CORS headers so that javascript running in a web browsers can use this API @@ -790,14 +806,14 @@ def set_cors_headers(request: Request): ) -def respond_with_html(request: Request, code: int, html: str): +def respond_with_html(request: Request, code: int, html: str) -> None: """ Wraps `respond_with_html_bytes` by first encoding HTML from a str to UTF-8 bytes. """ respond_with_html_bytes(request, code, html.encode("utf-8")) -def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes): +def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> None: """ Sends HTML (encoded as UTF-8 bytes) as the response to the given request. @@ -815,7 +831,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes): logger.warning( "Not sending response to request %s, already disconnected.", request ) - return + return None request.setResponseCode(code) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") @@ -828,7 +844,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes): finish_request(request) -def set_clickjacking_protection_headers(request: Request): +def set_clickjacking_protection_headers(request: Request) -> None: """ Set headers to guard against clickjacking of embedded content. @@ -850,7 +866,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None: finish_request(request) -def finish_request(request: Request): +def finish_request(request: Request) -> None: """Finish writing the response to the request. Twisted throws a RuntimeException if the connection closed before the diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 6dd9b9ad03..4ff840ca0e 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -14,6 +14,7 @@ """ This module contains base REST classes for constructing REST servlets. """ import logging +from http import HTTPStatus from typing import ( TYPE_CHECKING, Iterable, @@ -30,6 +31,7 @@ from typing_extensions import Literal from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.types import JsonDict, RoomAlias, RoomID from synapse.util import json_decoder @@ -137,11 +139,15 @@ def parse_integer_from_args( return int(args[name_bytes][0]) except Exception: message = "Query parameter %r must be an integer" % (name,) - raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) else: if required: message = "Missing integer query parameter %r" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM + ) else: return default @@ -246,11 +252,15 @@ def parse_boolean_from_args( message = ( "Boolean query parameter %r must be one of ['true', 'false']" ) % (name,) - raise SynapseError(400, message) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) else: if required: message = "Missing boolean query parameter %r" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM + ) else: return default @@ -313,7 +323,7 @@ def parse_bytes_from_args( return args[name_bytes][0] elif required: message = "Missing string query parameter %s" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM) return default @@ -407,14 +417,16 @@ def _parse_string_value( try: value_str = value.decode(encoding) except ValueError: - raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding)) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Query parameter %r must be %s" % (name, encoding) + ) if allowed_values is not None and value_str not in allowed_values: message = "Query parameter %r must be one of [%s]" % ( name, ", ".join(repr(v) for v in allowed_values), ) - raise SynapseError(400, message) + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM) else: return value_str @@ -510,7 +522,9 @@ def parse_strings_from_args( else: if required: message = "Missing string query parameter %r" % (name,) - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM + ) return default @@ -638,7 +652,7 @@ def parse_json_value_from_request( try: content_bytes = request.content.read() # type: ignore except Exception: - raise SynapseError(400, "Error reading JSON content.") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Error reading JSON content.") if not content_bytes and allow_empty_body: return None @@ -647,7 +661,9 @@ def parse_json_value_from_request( content = json_decoder.decode(content_bytes.decode("utf-8")) except Exception as e: logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes) - raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Content not JSON.", errcode=Codes.NOT_JSON + ) return content @@ -673,7 +689,7 @@ def parse_json_object_from_request( if not isinstance(content, dict): message = "Content must be a JSON object." - raise SynapseError(400, message, errcode=Codes.BAD_JSON) + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.BAD_JSON) return content @@ -685,7 +701,9 @@ def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None: absent.append(k) if len(absent) > 0: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Missing params: %r" % absent, Codes.MISSING_PARAM + ) class RestServlet: @@ -709,7 +727,7 @@ class RestServlet: into the appropriate HTTP response. """ - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: """Register this servlet with the given HTTP server.""" patterns = getattr(self, "PATTERNS", None) if patterns: @@ -758,10 +776,12 @@ class ResolveRoomIdMixin: resolved_room_id = room_id.to_string() else: raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) + HTTPStatus.BAD_REQUEST, + "%s was not legal room ID or room alias" % (room_identifier,), ) if not resolved_room_id: raise SynapseError( - 400, "Unknown room ID or room alias %s" % room_identifier + HTTPStatus.BAD_REQUEST, + "Unknown room ID or room alias %s" % room_identifier, ) return resolved_room_id, remote_room_hosts diff --git a/synapse/http/site.py b/synapse/http/site.py index 755ad56637..9f68d7e191 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -14,7 +14,7 @@ import contextlib import logging import time -from typing import Generator, Optional, Tuple, Union +from typing import Any, Generator, Optional, Tuple, Union import attr from zope.interface import implementer @@ -66,9 +66,9 @@ class SynapseRequest(Request): self, channel: HTTPChannel, site: "SynapseSite", - *args, + *args: Any, max_request_body_size: int = 1024, - **kw, + **kw: Any, ): super().__init__(channel, *args, **kw) self._max_request_body_size = max_request_body_size @@ -557,7 +557,7 @@ class SynapseSite(Site): proxied = config.http_options.x_forwarded request_class = XForwardedForRequest if proxied else SynapseRequest - def request_factory(channel, queued: bool) -> Request: + def request_factory(channel: HTTPChannel, queued: bool) -> Request: return request_class( channel, self, |