summary refs log tree commit diff
path: root/synapse/http
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http')
-rw-r--r--synapse/http/client.py17
-rw-r--r--synapse/http/matrixfederationclient.py2
-rw-r--r--synapse/http/server.py77
-rw-r--r--synapse/http/servlet.py90
4 files changed, 159 insertions, 27 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 25d319f126..127690e534 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -103,7 +103,7 @@ class SimpleHttpClient(object):
         # TODO: Do we ever want to log message contents?
         logger.debug("post_urlencoded_get_json args: %s", args)
 
-        query_bytes = urllib.urlencode(args, True)
+        query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
 
         response = yield self.request(
             "POST",
@@ -330,7 +330,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
 
     @defer.inlineCallbacks
     def post_urlencoded_get_raw(self, url, args={}):
-        query_bytes = urllib.urlencode(args, True)
+        query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
 
         response = yield self.request(
             "POST",
@@ -350,6 +350,19 @@ class CaptchaServerHttpClient(SimpleHttpClient):
             defer.returnValue(e.response)
 
 
+def encode_urlencode_args(args):
+    return {k: encode_urlencode_arg(v) for k, v in args.items()}
+
+
+def encode_urlencode_arg(arg):
+    if isinstance(arg, unicode):
+        return arg.encode('utf-8')
+    elif isinstance(arg, list):
+        return [encode_urlencode_arg(i) for i in arg]
+    else:
+        return arg
+
+
 def _print_ex(e):
     if hasattr(e, "reasons") and e.reasons:
         for ex in e.reasons:
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index da13e32e78..c3589534f8 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -152,7 +152,7 @@ class MatrixFederationHttpClient(object):
 
                         return self.clock.time_bound_deferred(
                             request_deferred,
-                            time_out=timeout/1000. if timeout else 60,
+                            time_out=timeout / 1000. if timeout else 60,
                         )
 
                     response = yield preserve_context_over_fn(
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 10d1fcd3f6..b17b190ee5 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -41,7 +41,7 @@ metrics = synapse.metrics.get_metrics_for(__name__)
 
 incoming_requests_counter = metrics.register_counter(
     "requests",
-    labels=["method", "servlet"],
+    labels=["method", "servlet", "tag"],
 )
 outgoing_responses_counter = metrics.register_counter(
     "responses",
@@ -50,23 +50,23 @@ outgoing_responses_counter = metrics.register_counter(
 
 response_timer = metrics.register_distribution(
     "response_time",
-    labels=["method", "servlet"]
+    labels=["method", "servlet", "tag"]
 )
 
 response_ru_utime = metrics.register_distribution(
-    "response_ru_utime", labels=["method", "servlet"]
+    "response_ru_utime", labels=["method", "servlet", "tag"]
 )
 
 response_ru_stime = metrics.register_distribution(
-    "response_ru_stime", labels=["method", "servlet"]
+    "response_ru_stime", labels=["method", "servlet", "tag"]
 )
 
 response_db_txn_count = metrics.register_distribution(
-    "response_db_txn_count", labels=["method", "servlet"]
+    "response_db_txn_count", labels=["method", "servlet", "tag"]
 )
 
 response_db_txn_duration = metrics.register_distribution(
-    "response_db_txn_duration", labels=["method", "servlet"]
+    "response_db_txn_duration", labels=["method", "servlet", "tag"]
 )
 
 
@@ -99,9 +99,8 @@ def request_handler(request_handler):
             request_context.request = request_id
             with request.processing():
                 try:
-                    d = request_handler(self, request)
-                    with PreserveLoggingContext():
-                        yield d
+                    with PreserveLoggingContext(request_context):
+                        yield request_handler(self, request)
                 except CodeMessageException as e:
                     code = e.code
                     if isinstance(e, SynapseError):
@@ -208,6 +207,9 @@ class JsonResource(HttpServer, resource.Resource):
         if request.method == "OPTIONS":
             self._send_response(request, 200, {})
             return
+
+        start_context = LoggingContext.current_context()
+
         # Loop through all the registered callbacks to check if the method
         # and path regex match
         for path_entry in self.path_regexs.get(request.method, []):
@@ -226,7 +228,6 @@ class JsonResource(HttpServer, resource.Resource):
                 servlet_classname = servlet_instance.__class__.__name__
             else:
                 servlet_classname = "%r" % callback
-            incoming_requests_counter.inc(request.method, servlet_classname)
 
             args = [
                 urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups()
@@ -237,21 +238,40 @@ class JsonResource(HttpServer, resource.Resource):
                 code, response = callback_return
                 self._send_response(request, code, response)
 
-            response_timer.inc_by(
-                self.clock.time_msec() - start, request.method, servlet_classname
-            )
-
             try:
                 context = LoggingContext.current_context()
+
+                tag = ""
+                if context:
+                    tag = context.tag
+
+                    if context != start_context:
+                        logger.warn(
+                            "Context have unexpectedly changed %r, %r",
+                            context, self.start_context
+                        )
+                        return
+
+                incoming_requests_counter.inc(request.method, servlet_classname, tag)
+
+                response_timer.inc_by(
+                    self.clock.time_msec() - start, request.method,
+                    servlet_classname, tag
+                )
+
                 ru_utime, ru_stime = context.get_resource_usage()
 
-                response_ru_utime.inc_by(ru_utime, request.method, servlet_classname)
-                response_ru_stime.inc_by(ru_stime, request.method, servlet_classname)
+                response_ru_utime.inc_by(
+                    ru_utime, request.method, servlet_classname, tag
+                )
+                response_ru_stime.inc_by(
+                    ru_stime, request.method, servlet_classname, tag
+                )
                 response_db_txn_count.inc_by(
-                    context.db_txn_count, request.method, servlet_classname
+                    context.db_txn_count, request.method, servlet_classname, tag
                 )
                 response_db_txn_duration.inc_by(
-                    context.db_txn_duration, request.method, servlet_classname
+                    context.db_txn_duration, request.method, servlet_classname, tag
                 )
             except:
                 pass
@@ -347,10 +367,29 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
                           "Origin, X-Requested-With, Content-Type, Accept")
 
     request.write(json_bytes)
-    request.finish()
+    finish_request(request)
     return NOT_DONE_YET
 
 
+def finish_request(request):
+    """ Finish writing the response to the request.
+
+    Twisted throws a RuntimeException if the connection closed before the
+    response was written but doesn't provide a convenient or reliable way to
+    determine if the connection was closed. So we catch and log the RuntimeException
+
+    You might think that ``request.notifyFinish`` could be used to tell if the
+    request was finished. However the deferred it returns won't fire if the
+    connection was already closed, meaning we'd have to have called the method
+    right at the start of the request. By the time we want to write the response
+    it will already be too late.
+    """
+    try:
+        request.finish()
+    except RuntimeError as e:
+        logger.info("Connection disconnected before response was written: %r", e)
+
+
 def _request_user_agent_is_curl(request):
     user_agents = request.requestHeaders.getRawHeaders(
         "User-Agent", default=[]
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 7bd87940b4..1c8bd8666f 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -15,14 +15,27 @@
 
 """ This module contains base REST classes for constructing REST servlets. """
 
-from synapse.api.errors import SynapseError
+from synapse.api.errors import SynapseError, Codes
 
 import logging
+import simplejson
 
 logger = logging.getLogger(__name__)
 
 
 def parse_integer(request, name, default=None, required=False):
+    """Parse an integer parameter from the request string
+
+    :param request: the twisted HTTP request.
+    :param name (str): the name of the query parameter.
+    :param default: value to use if the parameter is absent, defaults to None.
+    :param required (bool): whether to raise a 400 SynapseError if the
+        parameter is absent, defaults to False.
+    :return: An int value or the default.
+    :raises
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present and not an integer.
+    """
     if name in request.args:
         try:
             return int(request.args[name][0])
@@ -32,12 +45,25 @@ def parse_integer(request, name, default=None, required=False):
     else:
         if required:
             message = "Missing integer query parameter %r" % (name,)
-            raise SynapseError(400, message)
+            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
         else:
             return default
 
 
 def parse_boolean(request, name, default=None, required=False):
+    """Parse a boolean parameter from the request query string
+
+    :param request: the twisted HTTP request.
+    :param name (str): the name of the query parameter.
+    :param default: value to use if the parameter is absent, defaults to None.
+    :param required (bool): whether to raise a 400 SynapseError if the
+        parameter is absent, defaults to False.
+    :return: A bool value or the default.
+    :raises
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present and not one of "true" or "false".
+    """
+
     if name in request.args:
         try:
             return {
@@ -53,30 +79,84 @@ def parse_boolean(request, name, default=None, required=False):
     else:
         if required:
             message = "Missing boolean query parameter %r" % (name,)
-            raise SynapseError(400, message)
+            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
         else:
             return default
 
 
 def parse_string(request, name, default=None, required=False,
                  allowed_values=None, param_type="string"):
+    """Parse a string parameter from the request query string.
+
+    :param request: the twisted HTTP request.
+    :param name (str): the name of the query parameter.
+    :param default: value to use if the parameter is absent, defaults to None.
+    :param required (bool): whether to raise a 400 SynapseError if the
+        parameter is absent, defaults to False.
+    :param allowed_values (list): List of allowed values for the string,
+        or None if any value is allowed, defaults to None
+    :return: A string value or the default.
+    :raises
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present, must be one of a list of allowed values and
+            is not one of those allowed values.
+    """
+
     if name in request.args:
         value = request.args[name][0]
         if allowed_values is not None and value not in allowed_values:
             message = "Query parameter %r must be one of [%s]" % (
                 name, ", ".join(repr(v) for v in allowed_values)
             )
-            raise SynapseError(message)
+            raise SynapseError(400, message)
         else:
             return value
     else:
         if required:
             message = "Missing %s query parameter %r" % (param_type, name)
-            raise SynapseError(400, message)
+            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
         else:
             return default
 
 
+def parse_json_value_from_request(request):
+    """Parse a JSON value from the body of a twisted HTTP request.
+
+    :param request: the twisted HTTP request.
+    :returns: The JSON value.
+    :raises
+        SynapseError if the request body couldn't be decoded as JSON.
+    """
+    try:
+        content_bytes = request.content.read()
+    except:
+        raise SynapseError(400, "Error reading JSON content.")
+
+    try:
+        content = simplejson.loads(content_bytes)
+    except simplejson.JSONDecodeError:
+        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
+
+    return content
+
+
+def parse_json_object_from_request(request):
+    """Parse a JSON object from the body of a twisted HTTP request.
+
+    :param request: the twisted HTTP request.
+    :raises
+        SynapseError if the request body couldn't be decoded as JSON or
+            if it wasn't a JSON object.
+    """
+    content = parse_json_value_from_request(request)
+
+    if type(content) != dict:
+        message = "Content must be a JSON object."
+        raise SynapseError(400, message, errcode=Codes.BAD_JSON)
+
+    return content
+
+
 class RestServlet(object):
 
     """ A Synapse REST Servlet.