diff options
Diffstat (limited to 'synapse/http/server.py')
-rw-r--r-- | synapse/http/server.py | 153 |
1 files changed, 83 insertions, 70 deletions
diff --git a/synapse/http/server.py b/synapse/http/server.py index 16fb7935da..f067c163c1 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -16,10 +16,11 @@ import cgi import collections +import http.client import logging - -from six import PY3 -from six.moves import http_client, urllib +import types +import urllib +from io import BytesIO from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json @@ -41,11 +42,6 @@ from synapse.api.errors import ( from synapse.util.caches import intern_dict from synapse.util.logcontext import preserve_fn -if PY3: - from io import BytesIO -else: - from cStringIO import StringIO as BytesIO - logger = logging.getLogger(__name__) HTML_ERROR_TEMPLATE = """<!DOCTYPE html> @@ -75,15 +71,12 @@ def wrap_json_request_handler(h): deferred fails with any other type of error we send a 500 reponse. """ - @defer.inlineCallbacks - def wrapped_request_handler(self, request): + async def wrapped_request_handler(self, request): try: - yield h(self, request) + await h(self, request) except SynapseError as e: code = e.code - logger.info( - "%s SynapseError: %s - %s", request, code, e.msg - ) + logger.info("%s SynapseError: %s - %s", request, code, e.msg) # Only respond with an error response if we haven't already started # writing, otherwise lets just kill the connection @@ -96,7 +89,10 @@ def wrap_json_request_handler(h): pass else: respond_with_json( - request, code, e.error_dict(), send_cors=True, + request, + code, + e.error_dict(), + send_cors=True, pretty_print=_request_user_agent_is_curl(request), ) @@ -124,10 +120,7 @@ def wrap_json_request_handler(h): respond_with_json( request, 500, - { - "error": "Internal server error", - "errcode": Codes.UNKNOWN, - }, + {"error": "Internal server error", "errcode": Codes.UNKNOWN}, send_cors=True, pretty_print=_request_user_agent_is_curl(request), ) @@ -143,10 +136,13 @@ def wrap_html_request_handler(h): The handler method must have a signature of "handle_foo(self, request)", where "request" must be a SynapseRequest. """ - def wrapped_request_handler(self, request): - d = defer.maybeDeferred(h, self, request) - d.addErrback(_return_html_error, request) - return d + + async def wrapped_request_handler(self, request): + try: + return await h(self, request) + except Exception: + f = failure.Failure() + return _return_html_error(f, request) return wrap_async_request_handler(wrapped_request_handler) @@ -164,9 +160,7 @@ def _return_html_error(f, request): msg = cme.msg if isinstance(cme, SynapseError): - logger.info( - "%s SynapseError: %s - %s", request, code, msg - ) + logger.info("%s SynapseError: %s - %s", request, code, msg) else: logger.error( "Failed handle request %r", @@ -174,7 +168,7 @@ def _return_html_error(f, request): exc_info=(f.type, f.value, f.getTracebackObject()), ) else: - code = http_client.INTERNAL_SERVER_ERROR + code = http.client.INTERNAL_SERVER_ERROR msg = "Internal server error" logger.error( @@ -183,9 +177,7 @@ def _return_html_error(f, request): exc_info=(f.type, f.value, f.getTracebackObject()), ) - body = HTML_ERROR_TEMPLATE.format( - code=code, msg=cgi.escape(msg), - ).encode("utf-8") + body = HTML_ERROR_TEMPLATE.format(code=code, msg=cgi.escape(msg)).encode("utf-8") request.setResponseCode(code) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") request.setHeader(b"Content-Length", b"%i" % (len(body),)) @@ -205,10 +197,10 @@ def wrap_async_request_handler(h): The handler may return a deferred, in which case the completion of the request isn't logged until the deferred completes. """ - @defer.inlineCallbacks - def wrapped_async_request_handler(self, request): + + async def wrapped_async_request_handler(self, request): with request.processing(): - yield h(self, request) + await h(self, request) # we need to preserve_fn here, because the synchronous render method won't yield for # us (obviously) @@ -274,12 +266,11 @@ class JsonResource(HttpServer, resource.Resource): def render(self, request): """ This gets called by twisted every time someone sends us a request. """ - self._async_render(request) + defer.ensureDeferred(self._async_render(request)) return NOT_DONE_YET @wrap_json_request_handler - @defer.inlineCallbacks - def _async_render(self, request): + async def _async_render(self, request): """ This gets called from render() every time someone sends us a request. This checks if anyone has registered a callback for that method and path. @@ -296,24 +287,19 @@ class JsonResource(HttpServer, resource.Resource): # Now trigger the callback. If it returns a response, we send it # here. If it throws an exception, that is handled by the wrapper # installed by @request_handler. + kwargs = intern_dict( + { + name: urllib.parse.unquote(value) if value else value + for name, value in group_dict.items() + } + ) + + callback_return = callback(request, **kwargs) + + # Is it synchronous? We'll allow this for now. + if isinstance(callback_return, (defer.Deferred, types.CoroutineType)): + callback_return = await callback_return - def _unquote(s): - if PY3: - # On Python 3, unquote is unicode -> unicode - return urllib.parse.unquote(s) - else: - # On Python 2, unquote is bytes -> bytes We need to encode the - # URL again (as it was decoded by _get_handler_for request), as - # ASCII because it's a URL, and then decode it to get the UTF-8 - # characters that were quoted. - return urllib.parse.unquote(s.encode('ascii')).decode('utf8') - - kwargs = intern_dict({ - name: _unquote(value) if value else value - for name, value in group_dict.items() - }) - - callback_return = yield callback(request, **kwargs) if callback_return is not None: code, response = callback_return self._send_response(request, code, response) @@ -339,7 +325,7 @@ class JsonResource(HttpServer, resource.Resource): # Loop through all the registered callbacks to check if the method # and path regex match for path_entry in self.path_regexs.get(request.method, []): - m = path_entry.pattern.match(request.path.decode('ascii')) + m = path_entry.pattern.match(request.path.decode("ascii")) if m: # We found a match! return path_entry.callback, m.groupdict() @@ -347,11 +333,14 @@ class JsonResource(HttpServer, resource.Resource): # Huh. No one wanted to handle that? Fiiiiiine. Send 400. return _unrecognised_request_handler, {} - def _send_response(self, request, code, response_json_object, - response_code_message=None): + def _send_response( + self, request, code, response_json_object, response_code_message=None + ): # TODO: Only enable CORS for the requests that need it. respond_with_json( - request, code, response_json_object, + request, + code, + response_json_object, send_cors=True, response_code_message=response_code_message, pretty_print=_request_user_agent_is_curl(request), @@ -359,6 +348,23 @@ class JsonResource(HttpServer, resource.Resource): ) +class DirectServeResource(resource.Resource): + def render(self, request): + """ + Render the request, using an asynchronous render handler if it exists. + """ + render_callback_name = "_async_render_" + request.method.decode("ascii") + + if hasattr(self, render_callback_name): + # Call the handler + callback = getattr(self, render_callback_name) + defer.ensureDeferred(callback(request)) + + return NOT_DONE_YET + else: + super().render(request) + + def _options_handler(request): """Request handler for OPTIONS requests @@ -395,7 +401,7 @@ class RootRedirect(resource.Resource): self.url = path def render_GET(self, request): - return redirectTo(self.url.encode('ascii'), request) + return redirectTo(self.url.encode("ascii"), request) def getChild(self, name, request): if len(name) == 0: @@ -403,16 +409,22 @@ class RootRedirect(resource.Resource): return resource.Resource.getChild(self, name, request) -def respond_with_json(request, code, json_object, send_cors=False, - response_code_message=None, pretty_print=False, - canonical_json=True): +def respond_with_json( + request, + code, + json_object, + send_cors=False, + response_code_message=None, + pretty_print=False, + canonical_json=True, +): # could alternatively use request.notifyFinish() and flip a flag when # the Deferred fires, but since the flag is RIGHT THERE it seems like # a waste. if request._disconnected: logger.warn( - "Not sending response to request %s, already disconnected.", - request) + "Not sending response to request %s, already disconnected.", request + ) return if pretty_print: @@ -425,14 +437,17 @@ def respond_with_json(request, code, json_object, send_cors=False, json_bytes = json.dumps(json_object).encode("utf-8") return respond_with_json_bytes( - request, code, json_bytes, + request, + code, + json_bytes, send_cors=send_cors, response_code_message=response_code_message, ) -def respond_with_json_bytes(request, code, json_bytes, send_cors=False, - response_code_message=None): +def respond_with_json_bytes( + request, code, json_bytes, send_cors=False, response_code_message=None +): """Sends encoded JSON in response to the given request. Args: @@ -474,7 +489,7 @@ def set_cors_headers(request): ) request.setHeader( b"Access-Control-Allow-Headers", - b"Origin, X-Requested-With, Content-Type, Accept, Authorization" + b"Origin, X-Requested-With, Content-Type, Accept, Authorization", ) @@ -498,9 +513,7 @@ def finish_request(request): def _request_user_agent_is_curl(request): - user_agents = request.requestHeaders.getRawHeaders( - b"User-Agent", default=[] - ) + user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[]) for user_agent in user_agents: if b"curl" in user_agent: return True |