summary refs log tree commit diff
path: root/synapse/http/servlet.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/servlet.py')
-rw-r--r--synapse/http/servlet.py56
1 files changed, 42 insertions, 14 deletions
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 69f7085291..a1e4b88e6d 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -29,7 +29,7 @@ def parse_integer(request, name, default=None, required=False):
 
     Args:
         request: the twisted HTTP request.
-        name (str): the name of the query parameter.
+        name (bytes/unicode): the name of the query parameter.
         default (int|None): value to use if the parameter is absent, defaults
             to None.
         required (bool): whether to raise a 400 SynapseError if the
@@ -46,6 +46,10 @@ def parse_integer(request, name, default=None, required=False):
 
 
 def parse_integer_from_args(args, name, default=None, required=False):
+
+    if not isinstance(name, bytes):
+        name = name.encode('ascii')
+
     if name in args:
         try:
             return int(args[name][0])
@@ -65,7 +69,7 @@ def parse_boolean(request, name, default=None, required=False):
 
     Args:
         request: the twisted HTTP request.
-        name (str): the name of the query parameter.
+        name (bytes/unicode): the name of the query parameter.
         default (bool|None): value to use if the parameter is absent, defaults
             to None.
         required (bool): whether to raise a 400 SynapseError if the
@@ -83,11 +87,15 @@ def parse_boolean(request, name, default=None, required=False):
 
 
 def parse_boolean_from_args(args, name, default=None, required=False):
+
+    if not isinstance(name, bytes):
+        name = name.encode('ascii')
+
     if name in args:
         try:
             return {
-                "true": True,
-                "false": False,
+                b"true": True,
+                b"false": False,
             }[args[name][0]]
         except Exception:
             message = (
@@ -104,21 +112,29 @@ def parse_boolean_from_args(args, name, default=None, required=False):
 
 
 def parse_string(request, name, default=None, required=False,
-                 allowed_values=None, param_type="string"):
-    """Parse a string parameter from the request query string.
+                 allowed_values=None, param_type="string", encoding='ascii'):
+    """
+    Parse a string parameter from the request query string.
+
+    If encoding is not None, the content of the query param will be
+    decoded to Unicode using the encoding, otherwise it will be encoded
 
     Args:
         request: the twisted HTTP request.
-        name (str): the name of the query parameter.
-        default (str|None): value to use if the parameter is absent, defaults
-            to None.
+        name (bytes/unicode): the name of the query parameter.
+        default (bytes/unicode|None): value to use if the parameter is absent,
+            defaults to None. Must be bytes if encoding is None.
         required (bool): whether to raise a 400 SynapseError if the
             parameter is absent, defaults to False.
-        allowed_values (list[str]): List of allowed values for the string,
-            or None if any value is allowed, defaults to None
+        allowed_values (list[bytes/unicode]): List of allowed values for the
+            string, or None if any value is allowed, defaults to None. Must be
+            the same type as name, if given.
+        encoding: The encoding to decode the name to, and decode the string
+            content with.
 
     Returns:
-        str|None: A string value or the default.
+        bytes/unicode|None: A string value or the default. Unicode if encoding
+        was given, bytes otherwise.
 
     Raises:
         SynapseError if the parameter is absent and required, or if the
@@ -126,14 +142,22 @@ def parse_string(request, name, default=None, required=False,
             is not one of those allowed values.
     """
     return parse_string_from_args(
-        request.args, name, default, required, allowed_values, param_type,
+        request.args, name, default, required, allowed_values, param_type, encoding
     )
 
 
 def parse_string_from_args(args, name, default=None, required=False,
-                           allowed_values=None, param_type="string"):
+                           allowed_values=None, param_type="string", encoding='ascii'):
+
+    if not isinstance(name, bytes):
+        name = name.encode('ascii')
+
     if name in args:
         value = args[name][0]
+
+        if encoding:
+            value = value.decode(encoding)
+
         if allowed_values is not None and value not in allowed_values:
             message = "Query parameter %r must be one of [%s]" % (
                 name, ", ".join(repr(v) for v in allowed_values)
@@ -146,6 +170,10 @@ def parse_string_from_args(args, name, default=None, required=False,
             message = "Missing %s query parameter %r" % (param_type, name)
             raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
         else:
+
+            if encoding and isinstance(default, bytes):
+                return default.decode(encoding)
+
             return default