summary refs log tree commit diff
path: root/synapse/http
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http')
-rw-r--r--synapse/http/request_metrics.py13
-rw-r--r--synapse/http/servlet.py56
-rw-r--r--synapse/http/site.py12
3 files changed, 61 insertions, 20 deletions
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 588e280571..72c2654678 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+import threading
 
 from prometheus_client.core import Counter, Histogram
 
@@ -111,6 +112,9 @@ in_flight_requests_db_sched_duration = Counter(
 # The set of all in flight requests, set[RequestMetrics]
 _in_flight_requests = set()
 
+# Protects the _in_flight_requests set from concurrent accesss
+_in_flight_requests_lock = threading.Lock()
+
 
 def _get_in_flight_counts():
     """Returns a count of all in flight requests by (method, server_name)
@@ -120,7 +124,8 @@ def _get_in_flight_counts():
     """
     # Cast to a list to prevent it changing while the Prometheus
     # thread is collecting metrics
-    reqs = list(_in_flight_requests)
+    with _in_flight_requests_lock:
+        reqs = list(_in_flight_requests)
 
     for rm in reqs:
         rm.update_metrics()
@@ -154,10 +159,12 @@ class RequestMetrics(object):
         # to the "in flight" metrics.
         self._request_stats = self.start_context.get_resource_usage()
 
-        _in_flight_requests.add(self)
+        with _in_flight_requests_lock:
+            _in_flight_requests.add(self)
 
     def stop(self, time_sec, request):
-        _in_flight_requests.discard(self)
+        with _in_flight_requests_lock:
+            _in_flight_requests.discard(self)
 
         context = LoggingContext.current_context()
 
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 69f7085291..a1e4b88e6d 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -29,7 +29,7 @@ def parse_integer(request, name, default=None, required=False):
 
     Args:
         request: the twisted HTTP request.
-        name (str): the name of the query parameter.
+        name (bytes/unicode): the name of the query parameter.
         default (int|None): value to use if the parameter is absent, defaults
             to None.
         required (bool): whether to raise a 400 SynapseError if the
@@ -46,6 +46,10 @@ def parse_integer(request, name, default=None, required=False):
 
 
 def parse_integer_from_args(args, name, default=None, required=False):
+
+    if not isinstance(name, bytes):
+        name = name.encode('ascii')
+
     if name in args:
         try:
             return int(args[name][0])
@@ -65,7 +69,7 @@ def parse_boolean(request, name, default=None, required=False):
 
     Args:
         request: the twisted HTTP request.
-        name (str): the name of the query parameter.
+        name (bytes/unicode): the name of the query parameter.
         default (bool|None): value to use if the parameter is absent, defaults
             to None.
         required (bool): whether to raise a 400 SynapseError if the
@@ -83,11 +87,15 @@ def parse_boolean(request, name, default=None, required=False):
 
 
 def parse_boolean_from_args(args, name, default=None, required=False):
+
+    if not isinstance(name, bytes):
+        name = name.encode('ascii')
+
     if name in args:
         try:
             return {
-                "true": True,
-                "false": False,
+                b"true": True,
+                b"false": False,
             }[args[name][0]]
         except Exception:
             message = (
@@ -104,21 +112,29 @@ def parse_boolean_from_args(args, name, default=None, required=False):
 
 
 def parse_string(request, name, default=None, required=False,
-                 allowed_values=None, param_type="string"):
-    """Parse a string parameter from the request query string.
+                 allowed_values=None, param_type="string", encoding='ascii'):
+    """
+    Parse a string parameter from the request query string.
+
+    If encoding is not None, the content of the query param will be
+    decoded to Unicode using the encoding, otherwise it will be encoded
 
     Args:
         request: the twisted HTTP request.
-        name (str): the name of the query parameter.
-        default (str|None): value to use if the parameter is absent, defaults
-            to None.
+        name (bytes/unicode): the name of the query parameter.
+        default (bytes/unicode|None): value to use if the parameter is absent,
+            defaults to None. Must be bytes if encoding is None.
         required (bool): whether to raise a 400 SynapseError if the
             parameter is absent, defaults to False.
-        allowed_values (list[str]): List of allowed values for the string,
-            or None if any value is allowed, defaults to None
+        allowed_values (list[bytes/unicode]): List of allowed values for the
+            string, or None if any value is allowed, defaults to None. Must be
+            the same type as name, if given.
+        encoding: The encoding to decode the name to, and decode the string
+            content with.
 
     Returns:
-        str|None: A string value or the default.
+        bytes/unicode|None: A string value or the default. Unicode if encoding
+        was given, bytes otherwise.
 
     Raises:
         SynapseError if the parameter is absent and required, or if the
@@ -126,14 +142,22 @@ def parse_string(request, name, default=None, required=False,
             is not one of those allowed values.
     """
     return parse_string_from_args(
-        request.args, name, default, required, allowed_values, param_type,
+        request.args, name, default, required, allowed_values, param_type, encoding
     )
 
 
 def parse_string_from_args(args, name, default=None, required=False,
-                           allowed_values=None, param_type="string"):
+                           allowed_values=None, param_type="string", encoding='ascii'):
+
+    if not isinstance(name, bytes):
+        name = name.encode('ascii')
+
     if name in args:
         value = args[name][0]
+
+        if encoding:
+            value = value.decode(encoding)
+
         if allowed_values is not None and value not in allowed_values:
             message = "Query parameter %r must be one of [%s]" % (
                 name, ", ".join(repr(v) for v in allowed_values)
@@ -146,6 +170,10 @@ def parse_string_from_args(args, name, default=None, required=False,
             message = "Missing %s query parameter %r" % (param_type, name)
             raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
         else:
+
+            if encoding and isinstance(default, bytes):
+                return default.decode(encoding)
+
             return default
 
 
diff --git a/synapse/http/site.py b/synapse/http/site.py
index f5a8f78406..88ed3714f9 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -182,7 +182,7 @@ class SynapseRequest(Request):
         # the client disconnects.
         with PreserveLoggingContext(self.logcontext):
             logger.warn(
-                "Error processing request: %s %s", reason.type, reason.value,
+                "Error processing request %r: %s %s", self, reason.type, reason.value,
             )
 
             if not self._is_processing:
@@ -219,6 +219,12 @@ class SynapseRequest(Request):
         """Log the completion of this request and update the metrics
         """
 
+        if self.logcontext is None:
+            # this can happen if the connection closed before we read the
+            # headers (so render was never called). In that case we'll already
+            # have logged a warning, so just bail out.
+            return
+
         usage = self.logcontext.get_resource_usage()
 
         if self._processing_finished_time is None:
@@ -235,7 +241,7 @@ class SynapseRequest(Request):
         # need to decode as it could be raw utf-8 bytes
         # from a IDN servname in an auth header
         authenticated_entity = self.authenticated_entity
-        if authenticated_entity is not None:
+        if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
             authenticated_entity = authenticated_entity.decode("utf-8", "replace")
 
         # ...or could be raw utf-8 bytes in the User-Agent header.
@@ -328,7 +334,7 @@ class SynapseSite(Site):
         proxied = config.get("x_forwarded", False)
         self.requestFactory = SynapseRequestFactory(self, proxied)
         self.access_logger = logging.getLogger(logger_name)
-        self.server_version_string = server_version_string
+        self.server_version_string = server_version_string.encode('ascii')
 
     def log(self, request):
         pass