summary refs log tree commit diff
path: root/synapse/http/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/server.py')
-rw-r--r--synapse/http/server.py157
1 files changed, 76 insertions, 81 deletions
diff --git a/synapse/http/server.py b/synapse/http/server.py
index bc09b8b2be..2d5c23e673 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -13,35 +13,38 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import cgi
-from six.moves import http_client
+import collections
+import logging
 
-from synapse.api.errors import (
-    cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError, Codes
-)
-from synapse.http.request_metrics import (
-    requests_counter,
-)
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
-from synapse.util.caches import intern_dict
-from synapse.util.metrics import Measure
-import synapse.metrics
-import synapse.events
+from six import PY3
+from six.moves import http_client, urllib
 
-from canonicaljson import (
-    encode_canonical_json, encode_pretty_printed_json
-)
+from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
 
 from twisted.internet import defer
 from twisted.python import failure
-from twisted.web import server, resource
+from twisted.web import resource
 from twisted.web.server import NOT_DONE_YET
+from twisted.web.static import NoRangeStaticProducer
 from twisted.web.util import redirectTo
 
-import collections
-import logging
-import urllib
-import simplejson
+import synapse.events
+import synapse.metrics
+from synapse.api.errors import (
+    CodeMessageException,
+    Codes,
+    SynapseError,
+    UnrecognizedRequestError,
+)
+from synapse.util.caches import intern_dict
+from synapse.util.logcontext import preserve_fn
+
+if PY3:
+    from io import BytesIO
+else:
+    from cStringIO import StringIO as BytesIO
 
 logger = logging.getLogger(__name__)
 
@@ -61,11 +64,10 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
 def wrap_json_request_handler(h):
     """Wraps a request handler method with exception handling.
 
-    Also adds logging as per wrap_request_handler_with_logging.
+    Also does the wrapping with request.processing as per wrap_async_request_handler.
 
     The handler method must have a signature of "handle_foo(self, request)",
-    where "self" must have a "clock" attribute (and "request" must be a
-    SynapseRequest).
+    where "request" must be a SynapseRequest.
 
     The handler must return a deferred. If the deferred succeeds we assume that
     a response has been sent. If the deferred fails with a SynapseError we use
@@ -77,16 +79,13 @@ def wrap_json_request_handler(h):
     def wrapped_request_handler(self, request):
         try:
             yield h(self, request)
-        except CodeMessageException as e:
+        except SynapseError as e:
             code = e.code
-            if isinstance(e, SynapseError):
-                logger.info(
-                    "%s SynapseError: %s - %s", request, code, e.msg
-                )
-            else:
-                logger.exception(e)
+            logger.info(
+                "%s SynapseError: %s - %s", request, code, e.msg
+            )
             respond_with_json(
-                request, code, cs_exception(e), send_cors=True,
+                request, code, e.error_dict(), send_cors=True,
                 pretty_print=_request_user_agent_is_curl(request),
             )
 
@@ -112,24 +111,23 @@ def wrap_json_request_handler(h):
                 pretty_print=_request_user_agent_is_curl(request),
             )
 
-    return wrap_request_handler_with_logging(wrapped_request_handler)
+    return wrap_async_request_handler(wrapped_request_handler)
 
 
 def wrap_html_request_handler(h):
     """Wraps a request handler method with exception handling.
 
-    Also adds logging as per wrap_request_handler_with_logging.
+    Also does the wrapping with request.processing as per wrap_async_request_handler.
 
     The handler method must have a signature of "handle_foo(self, request)",
-    where "self" must have a "clock" attribute (and "request" must be a
-    SynapseRequest).
+    where "request" must be a SynapseRequest.
     """
     def wrapped_request_handler(self, request):
         d = defer.maybeDeferred(h, self, request)
         d.addErrback(_return_html_error, request)
         return d
 
-    return wrap_request_handler_with_logging(wrapped_request_handler)
+    return wrap_async_request_handler(wrapped_request_handler)
 
 
 def _return_html_error(f, request):
@@ -174,46 +172,26 @@ def _return_html_error(f, request):
     finish_request(request)
 
 
-def wrap_request_handler_with_logging(h):
-    """Wraps a request handler to provide logging and metrics
+def wrap_async_request_handler(h):
+    """Wraps an async request handler so that it calls request.processing.
+
+    This helps ensure that work done by the request handler after the request is completed
+    is correctly recorded against the request metrics/logs.
 
     The handler method must have a signature of "handle_foo(self, request)",
-    where "self" must have a "clock" attribute (and "request" must be a
-    SynapseRequest).
+    where "request" must be a SynapseRequest.
 
-    As well as calling `request.processing` (which will log the response and
-    duration for this request), the wrapped request handler will insert the
-    request id into the logging context.
+    The handler may return a deferred, in which case the completion of the request isn't
+    logged until the deferred completes.
     """
     @defer.inlineCallbacks
-    def wrapped_request_handler(self, request):
-        """
-        Args:
-            self:
-            request (synapse.http.site.SynapseRequest):
-        """
+    def wrapped_async_request_handler(self, request):
+        with request.processing():
+            yield h(self, request)
 
-        request_id = request.get_request_id()
-        with LoggingContext(request_id) as request_context:
-            request_context.request = request_id
-            with Measure(self.clock, "wrapped_request_handler"):
-                # we start the request metrics timer here with an initial stab
-                # at the servlet name. For most requests that name will be
-                # JsonResource (or a subclass), and JsonResource._async_render
-                # will update it once it picks a servlet.
-                servlet_name = self.__class__.__name__
-                with request.processing(servlet_name):
-                    with PreserveLoggingContext(request_context):
-                        d = defer.maybeDeferred(h, self, request)
-
-                        # record the arrival of the request *after*
-                        # dispatching to the handler, so that the handler
-                        # can update the servlet name in the request
-                        # metrics
-                        requests_counter.labels(request.method,
-                                                request.request_metrics.name).inc()
-                        yield d
-    return wrapped_request_handler
+    # we need to preserve_fn here, because the synchronous render method won't yield for
+    # us (obviously)
+    return preserve_fn(wrapped_async_request_handler)
 
 
 class HttpServer(object):
@@ -265,6 +243,7 @@ class JsonResource(HttpServer, resource.Resource):
         self.hs = hs
 
     def register_paths(self, method, path_patterns, callback):
+        method = method.encode("utf-8")  # method is bytes on py3
         for path_pattern in path_patterns:
             logger.debug("Registering for %s %s", method, path_pattern.pattern)
             self.path_regexs.setdefault(method, []).append(
@@ -275,7 +254,7 @@ class JsonResource(HttpServer, resource.Resource):
         """ This gets called by twisted every time someone sends us a request.
         """
         self._async_render(request)
-        return server.NOT_DONE_YET
+        return NOT_DONE_YET
 
     @wrap_json_request_handler
     @defer.inlineCallbacks
@@ -297,8 +276,19 @@ class JsonResource(HttpServer, resource.Resource):
         # here. If it throws an exception, that is handled by the wrapper
         # installed by @request_handler.
 
+        def _unquote(s):
+            if PY3:
+                # On Python 3, unquote is unicode -> unicode
+                return urllib.parse.unquote(s)
+            else:
+                # On Python 2, unquote is bytes -> bytes We need to encode the
+                # URL again (as it was decoded by _get_handler_for request), as
+                # ASCII because it's a URL, and then decode it to get the UTF-8
+                # characters that were quoted.
+                return urllib.parse.unquote(s.encode('ascii')).decode('utf8')
+
         kwargs = intern_dict({
-            name: urllib.unquote(value).decode("UTF-8") if value else value
+            name: _unquote(value) if value else value
             for name, value in group_dict.items()
         })
 
@@ -314,9 +304,9 @@ class JsonResource(HttpServer, resource.Resource):
             request (twisted.web.http.Request):
 
         Returns:
-            Tuple[Callable, dict[str, str]]: callback method, and the dict
-                mapping keys to path components as specified in the handler's
-                path match regexp.
+            Tuple[Callable, dict[unicode, unicode]]: callback method, and the
+                dict mapping keys to path components as specified in the
+                handler's path match regexp.
 
                 The callback will normally be a method registered via
                 register_paths, so will return (possibly via Deferred) either
@@ -328,7 +318,7 @@ class JsonResource(HttpServer, resource.Resource):
         # Loop through all the registered callbacks to check if the method
         # and path regex match
         for path_entry in self.path_regexs.get(request.method, []):
-            m = path_entry.pattern.match(request.path)
+            m = path_entry.pattern.match(request.path.decode('ascii'))
             if m:
                 # We found a match!
                 return path_entry.callback, m.groupdict()
@@ -384,7 +374,7 @@ class RootRedirect(resource.Resource):
         self.url = path
 
     def render_GET(self, request):
-        return redirectTo(self.url, request)
+        return redirectTo(self.url.encode('ascii'), request)
 
     def getChild(self, name, request):
         if len(name) == 0:
@@ -405,12 +395,13 @@ def respond_with_json(request, code, json_object, send_cors=False,
         return
 
     if pretty_print:
-        json_bytes = encode_pretty_printed_json(json_object) + "\n"
+        json_bytes = encode_pretty_printed_json(json_object) + b"\n"
     else:
         if canonical_json or synapse.events.USE_FROZEN_DICTS:
+            # canonicaljson already encodes to bytes
             json_bytes = encode_canonical_json(json_object)
         else:
-            json_bytes = simplejson.dumps(json_object)
+            json_bytes = json.dumps(json_object).encode("utf-8")
 
     return respond_with_json_bytes(
         request, code, json_bytes,
@@ -440,8 +431,12 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
     if send_cors:
         set_cors_headers(request)
 
-    request.write(json_bytes)
-    finish_request(request)
+    # todo: we can almost certainly avoid this copy and encode the json straight into
+    # the bytesIO, but it would involve faffing around with string->bytes wrappers.
+    bytes_io = BytesIO(json_bytes)
+
+    producer = NoRangeStaticProducer(request, bytes_io)
+    producer.start()
     return NOT_DONE_YET