diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 69f7085291..a1e4b88e6d 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -29,7 +29,7 @@ def parse_integer(request, name, default=None, required=False):
Args:
request: the twisted HTTP request.
- name (str): the name of the query parameter.
+ name (bytes/unicode): the name of the query parameter.
default (int|None): value to use if the parameter is absent, defaults
to None.
required (bool): whether to raise a 400 SynapseError if the
@@ -46,6 +46,10 @@ def parse_integer(request, name, default=None, required=False):
def parse_integer_from_args(args, name, default=None, required=False):
+
+ if not isinstance(name, bytes):
+ name = name.encode('ascii')
+
if name in args:
try:
return int(args[name][0])
@@ -65,7 +69,7 @@ def parse_boolean(request, name, default=None, required=False):
Args:
request: the twisted HTTP request.
- name (str): the name of the query parameter.
+ name (bytes/unicode): the name of the query parameter.
default (bool|None): value to use if the parameter is absent, defaults
to None.
required (bool): whether to raise a 400 SynapseError if the
@@ -83,11 +87,15 @@ def parse_boolean(request, name, default=None, required=False):
def parse_boolean_from_args(args, name, default=None, required=False):
+
+ if not isinstance(name, bytes):
+ name = name.encode('ascii')
+
if name in args:
try:
return {
- "true": True,
- "false": False,
+ b"true": True,
+ b"false": False,
}[args[name][0]]
except Exception:
message = (
@@ -104,21 +112,29 @@ def parse_boolean_from_args(args, name, default=None, required=False):
def parse_string(request, name, default=None, required=False,
- allowed_values=None, param_type="string"):
- """Parse a string parameter from the request query string.
+ allowed_values=None, param_type="string", encoding='ascii'):
+ """
+ Parse a string parameter from the request query string.
+
+ If encoding is not None, the content of the query param will be
+ decoded to Unicode using the encoding, otherwise it will be encoded
Args:
request: the twisted HTTP request.
- name (str): the name of the query parameter.
- default (str|None): value to use if the parameter is absent, defaults
- to None.
+ name (bytes/unicode): the name of the query parameter.
+ default (bytes/unicode|None): value to use if the parameter is absent,
+ defaults to None. Must be bytes if encoding is None.
required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
- allowed_values (list[str]): List of allowed values for the string,
- or None if any value is allowed, defaults to None
+ allowed_values (list[bytes/unicode]): List of allowed values for the
+ string, or None if any value is allowed, defaults to None. Must be
+ the same type as name, if given.
+ encoding: The encoding to decode the name to, and decode the string
+ content with.
Returns:
- str|None: A string value or the default.
+ bytes/unicode|None: A string value or the default. Unicode if encoding
+ was given, bytes otherwise.
Raises:
SynapseError if the parameter is absent and required, or if the
@@ -126,14 +142,22 @@ def parse_string(request, name, default=None, required=False,
is not one of those allowed values.
"""
return parse_string_from_args(
- request.args, name, default, required, allowed_values, param_type,
+ request.args, name, default, required, allowed_values, param_type, encoding
)
def parse_string_from_args(args, name, default=None, required=False,
- allowed_values=None, param_type="string"):
+ allowed_values=None, param_type="string", encoding='ascii'):
+
+ if not isinstance(name, bytes):
+ name = name.encode('ascii')
+
if name in args:
value = args[name][0]
+
+ if encoding:
+ value = value.decode(encoding)
+
if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values)
@@ -146,6 +170,10 @@ def parse_string_from_args(args, name, default=None, required=False,
message = "Missing %s query parameter %r" % (param_type, name)
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else:
+
+ if encoding and isinstance(default, bytes):
+ return default.decode(encoding)
+
return default
|