diff options
Diffstat (limited to 'synapse/http/server.py')
-rw-r--r-- | synapse/http/server.py | 123 |
1 files changed, 88 insertions, 35 deletions
diff --git a/synapse/http/server.py b/synapse/http/server.py index e464bfe6c7..845db9b78d 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -22,10 +22,22 @@ import types import urllib from http import HTTPStatus from io import BytesIO -from typing import Any, Callable, Dict, Iterator, List, Tuple, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterable, + Iterator, + List, + Pattern, + Tuple, + Union, +) import jinja2 from canonicaljson import iterencode_canonical_json +from typing_extensions import Protocol from zope.interface import implementer from twisted.internet import defer, interfaces @@ -64,8 +76,7 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html> def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: - """Sends a JSON error response to clients. - """ + """Sends a JSON error response to clients.""" if f.check(SynapseError): error_code = f.value.code @@ -94,12 +105,17 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: pass else: respond_with_json( - request, error_code, error_dict, send_cors=True, + request, + error_code, + error_dict, + send_cors=True, ) def return_html_error( - f: failure.Failure, request: Request, error_template: Union[str, jinja2.Template], + f: failure.Failure, + request: Request, + error_template: Union[str, jinja2.Template], ) -> None: """Sends an HTML error page corresponding to the given failure. @@ -168,24 +184,39 @@ def wrap_async_request_handler(h): return preserve_fn(wrapped_async_request_handler) -class HttpServer: - """ Interface for registering callbacks on a HTTP server - """ +# Type of a callback method for processing requests +# it is actually called with a SynapseRequest and a kwargs dict for the params, +# but I can't figure out how to represent that. +ServletCallback = Callable[ + ..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]] +] - def register_paths(self, method, path_patterns, callback): - """ Register a callback that gets fired if we receive a http request + +class HttpServer(Protocol): + """Interface for registering callbacks on a HTTP server""" + + def register_paths( + self, + method: str, + path_patterns: Iterable[Pattern], + callback: ServletCallback, + servlet_classname: str, + ) -> None: + """Register a callback that gets fired if we receive a http request with the given method for a path that matches the given regex. If the regex contains groups these gets passed to the callback via an unpacked tuple. Args: - method (str): The method to listen to. - path_patterns (list<SRE_Pattern>): The regex used to match requests. - callback (function): The function to fire if we receive a matched + method: The HTTP method to listen to. + path_patterns: The regex used to match requests. + callback: The function to fire if we receive a matched request. The first argument will be the request object and subsequent arguments will be any matched groups from the regex. - This should return a tuple of (code, response). + This should return either tuple of (code, response), or None. + servlet_classname (str): The name of the handler to be used in prometheus + and opentracing logs. """ pass @@ -207,8 +238,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): self._extract_context = extract_context def render(self, request): - """ This gets called by twisted every time someone sends us a request. - """ + """This gets called by twisted every time someone sends us a request.""" defer.ensureDeferred(self._async_render_wrapper(request)) return NOT_DONE_YET @@ -259,13 +289,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): @abc.abstractmethod def _send_response( - self, request: SynapseRequest, code: int, response_object: Any, + self, + request: SynapseRequest, + code: int, + response_object: Any, ) -> None: raise NotImplementedError() @abc.abstractmethod def _send_error_response( - self, f: failure.Failure, request: SynapseRequest, + self, + f: failure.Failure, + request: SynapseRequest, ) -> None: raise NotImplementedError() @@ -280,10 +315,12 @@ class DirectServeJsonResource(_AsyncResource): self.canonical_json = canonical_json def _send_response( - self, request: Request, code: int, response_object: Any, + self, + request: Request, + code: int, + response_object: Any, ): - """Implements _AsyncResource._send_response - """ + """Implements _AsyncResource._send_response""" # TODO: Only enable CORS for the requests that need it. respond_with_json( request, @@ -294,15 +331,16 @@ class DirectServeJsonResource(_AsyncResource): ) def _send_error_response( - self, f: failure.Failure, request: SynapseRequest, + self, + f: failure.Failure, + request: SynapseRequest, ) -> None: - """Implements _AsyncResource._send_error_response - """ + """Implements _AsyncResource._send_error_response""" return_json_error(f, request) class JsonResource(DirectServeJsonResource): - """ This implements the HttpServer interface and provides JSON support for + """This implements the HttpServer interface and provides JSON support for Resources. Register callbacks via register_paths() @@ -354,7 +392,7 @@ class JsonResource(DirectServeJsonResource): def _get_handler_for_request( self, request: SynapseRequest - ) -> Tuple[Callable, str, Dict[str, str]]: + ) -> Tuple[ServletCallback, str, Dict[str, str]]: """Finds a callback method to handle the given request. Returns: @@ -415,10 +453,12 @@ class DirectServeHtmlResource(_AsyncResource): ERROR_TEMPLATE = HTML_ERROR_TEMPLATE def _send_response( - self, request: SynapseRequest, code: int, response_object: Any, + self, + request: SynapseRequest, + code: int, + response_object: Any, ): - """Implements _AsyncResource._send_response - """ + """Implements _AsyncResource._send_response""" # We expect to get bytes for us to write assert isinstance(response_object, bytes) html_bytes = response_object @@ -426,10 +466,11 @@ class DirectServeHtmlResource(_AsyncResource): respond_with_html_bytes(request, 200, html_bytes) def _send_error_response( - self, f: failure.Failure, request: SynapseRequest, + self, + f: failure.Failure, + request: SynapseRequest, ) -> None: - """Implements _AsyncResource._send_error_response - """ + """Implements _AsyncResource._send_error_response""" return_html_error(f, request, self.ERROR_TEMPLATE) @@ -506,7 +547,9 @@ class _ByteProducer: min_chunk_size = 1024 def __init__( - self, request: Request, iterator: Iterator[bytes], + self, + request: Request, + iterator: Iterator[bytes], ): self._request = request self._iterator = iterator @@ -626,7 +669,10 @@ def respond_with_json( def respond_with_json_bytes( - request: Request, code: int, json_bytes: bytes, send_cors: bool = False, + request: Request, + code: int, + json_bytes: bytes, + send_cors: bool = False, ): """Sends encoded JSON in response to the given request. @@ -733,8 +779,15 @@ def set_clickjacking_protection_headers(request: Request): request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';") +def respond_with_redirect(request: Request, url: bytes) -> None: + """Write a 302 response to the request, if it is still alive.""" + logger.debug("Redirect to %s", url.decode("utf-8")) + request.redirect(url) + finish_request(request) + + def finish_request(request: Request): - """ Finish writing the response to the request. + """Finish writing the response to the request. Twisted throws a RuntimeException if the connection closed before the response was written but doesn't provide a convenient or reliable way to |