diff --git a/synapse/http/server.py b/synapse/http/server.py
index e464bfe6c7..845db9b78d 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -22,10 +22,22 @@ import types
import urllib
from http import HTTPStatus
from io import BytesIO
-from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
+from typing import (
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Pattern,
+ Tuple,
+ Union,
+)
import jinja2
from canonicaljson import iterencode_canonical_json
+from typing_extensions import Protocol
from zope.interface import implementer
from twisted.internet import defer, interfaces
@@ -64,8 +76,7 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
- """Sends a JSON error response to clients.
- """
+ """Sends a JSON error response to clients."""
if f.check(SynapseError):
error_code = f.value.code
@@ -94,12 +105,17 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
pass
else:
respond_with_json(
- request, error_code, error_dict, send_cors=True,
+ request,
+ error_code,
+ error_dict,
+ send_cors=True,
)
def return_html_error(
- f: failure.Failure, request: Request, error_template: Union[str, jinja2.Template],
+ f: failure.Failure,
+ request: Request,
+ error_template: Union[str, jinja2.Template],
) -> None:
"""Sends an HTML error page corresponding to the given failure.
@@ -168,24 +184,39 @@ def wrap_async_request_handler(h):
return preserve_fn(wrapped_async_request_handler)
-class HttpServer:
- """ Interface for registering callbacks on a HTTP server
- """
+# Type of a callback method for processing requests
+# it is actually called with a SynapseRequest and a kwargs dict for the params,
+# but I can't figure out how to represent that.
+ServletCallback = Callable[
+ ..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]]
+]
- def register_paths(self, method, path_patterns, callback):
- """ Register a callback that gets fired if we receive a http request
+
+class HttpServer(Protocol):
+ """Interface for registering callbacks on a HTTP server"""
+
+ def register_paths(
+ self,
+ method: str,
+ path_patterns: Iterable[Pattern],
+ callback: ServletCallback,
+ servlet_classname: str,
+ ) -> None:
+ """Register a callback that gets fired if we receive a http request
with the given method for a path that matches the given regex.
If the regex contains groups these gets passed to the callback via
an unpacked tuple.
Args:
- method (str): The method to listen to.
- path_patterns (list<SRE_Pattern>): The regex used to match requests.
- callback (function): The function to fire if we receive a matched
+ method: The HTTP method to listen to.
+ path_patterns: The regex used to match requests.
+ callback: The function to fire if we receive a matched
request. The first argument will be the request object and
subsequent arguments will be any matched groups from the regex.
- This should return a tuple of (code, response).
+ This should return either tuple of (code, response), or None.
+ servlet_classname (str): The name of the handler to be used in prometheus
+ and opentracing logs.
"""
pass
@@ -207,8 +238,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
self._extract_context = extract_context
def render(self, request):
- """ This gets called by twisted every time someone sends us a request.
- """
+ """This gets called by twisted every time someone sends us a request."""
defer.ensureDeferred(self._async_render_wrapper(request))
return NOT_DONE_YET
@@ -259,13 +289,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
@abc.abstractmethod
def _send_response(
- self, request: SynapseRequest, code: int, response_object: Any,
+ self,
+ request: SynapseRequest,
+ code: int,
+ response_object: Any,
) -> None:
raise NotImplementedError()
@abc.abstractmethod
def _send_error_response(
- self, f: failure.Failure, request: SynapseRequest,
+ self,
+ f: failure.Failure,
+ request: SynapseRequest,
) -> None:
raise NotImplementedError()
@@ -280,10 +315,12 @@ class DirectServeJsonResource(_AsyncResource):
self.canonical_json = canonical_json
def _send_response(
- self, request: Request, code: int, response_object: Any,
+ self,
+ request: Request,
+ code: int,
+ response_object: Any,
):
- """Implements _AsyncResource._send_response
- """
+ """Implements _AsyncResource._send_response"""
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
request,
@@ -294,15 +331,16 @@ class DirectServeJsonResource(_AsyncResource):
)
def _send_error_response(
- self, f: failure.Failure, request: SynapseRequest,
+ self,
+ f: failure.Failure,
+ request: SynapseRequest,
) -> None:
- """Implements _AsyncResource._send_error_response
- """
+ """Implements _AsyncResource._send_error_response"""
return_json_error(f, request)
class JsonResource(DirectServeJsonResource):
- """ This implements the HttpServer interface and provides JSON support for
+ """This implements the HttpServer interface and provides JSON support for
Resources.
Register callbacks via register_paths()
@@ -354,7 +392,7 @@ class JsonResource(DirectServeJsonResource):
def _get_handler_for_request(
self, request: SynapseRequest
- ) -> Tuple[Callable, str, Dict[str, str]]:
+ ) -> Tuple[ServletCallback, str, Dict[str, str]]:
"""Finds a callback method to handle the given request.
Returns:
@@ -415,10 +453,12 @@ class DirectServeHtmlResource(_AsyncResource):
ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
def _send_response(
- self, request: SynapseRequest, code: int, response_object: Any,
+ self,
+ request: SynapseRequest,
+ code: int,
+ response_object: Any,
):
- """Implements _AsyncResource._send_response
- """
+ """Implements _AsyncResource._send_response"""
# We expect to get bytes for us to write
assert isinstance(response_object, bytes)
html_bytes = response_object
@@ -426,10 +466,11 @@ class DirectServeHtmlResource(_AsyncResource):
respond_with_html_bytes(request, 200, html_bytes)
def _send_error_response(
- self, f: failure.Failure, request: SynapseRequest,
+ self,
+ f: failure.Failure,
+ request: SynapseRequest,
) -> None:
- """Implements _AsyncResource._send_error_response
- """
+ """Implements _AsyncResource._send_error_response"""
return_html_error(f, request, self.ERROR_TEMPLATE)
@@ -506,7 +547,9 @@ class _ByteProducer:
min_chunk_size = 1024
def __init__(
- self, request: Request, iterator: Iterator[bytes],
+ self,
+ request: Request,
+ iterator: Iterator[bytes],
):
self._request = request
self._iterator = iterator
@@ -626,7 +669,10 @@ def respond_with_json(
def respond_with_json_bytes(
- request: Request, code: int, json_bytes: bytes, send_cors: bool = False,
+ request: Request,
+ code: int,
+ json_bytes: bytes,
+ send_cors: bool = False,
):
"""Sends encoded JSON in response to the given request.
@@ -733,8 +779,15 @@ def set_clickjacking_protection_headers(request: Request):
request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';")
+def respond_with_redirect(request: Request, url: bytes) -> None:
+ """Write a 302 response to the request, if it is still alive."""
+ logger.debug("Redirect to %s", url.decode("utf-8"))
+ request.redirect(url)
+ finish_request(request)
+
+
def finish_request(request: Request):
- """ Finish writing the response to the request.
+ """Finish writing the response to the request.
Twisted throws a RuntimeException if the connection closed before the
response was written but doesn't provide a convenient or reliable way to
|