summary refs log tree commit diff
path: root/synapse/http/servlet.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/servlet.py')
-rw-r--r--synapse/http/servlet.py196
1 files changed, 153 insertions, 43 deletions
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 31897546a9..3f4f2411fc 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -15,6 +15,9 @@
 """ This module contains base REST classes for constructing REST servlets. """
 
 import logging
+from typing import Iterable, List, Optional, Union, overload
+
+from typing_extensions import Literal
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.util import json_decoder
@@ -107,12 +110,11 @@ def parse_boolean_from_args(args, name, default=None, required=False):
 
 def parse_string(
     request,
-    name,
-    default=None,
-    required=False,
-    allowed_values=None,
-    param_type="string",
-    encoding="ascii",
+    name: Union[bytes, str],
+    default: Optional[str] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: Optional[str] = "ascii",
 ):
     """
     Parse a string parameter from the request query string.
@@ -122,18 +124,17 @@ def parse_string(
 
     Args:
         request: the twisted HTTP request.
-        name (bytes|unicode): the name of the query parameter.
-        default (bytes|unicode|None): value to use if the parameter is absent,
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent,
             defaults to None. Must be bytes if encoding is None.
-        required (bool): whether to raise a 400 SynapseError if the
+        required: whether to raise a 400 SynapseError if the
             parameter is absent, defaults to False.
-        allowed_values (list[bytes|unicode]): List of allowed values for the
+        allowed_values: List of allowed values for the
             string, or None if any value is allowed, defaults to None. Must be
             the same type as name, if given.
-        encoding (str|None): The encoding to decode the string content with.
-
+        encoding : The encoding to decode the string content with.
     Returns:
-        bytes/unicode|None: A string value or the default. Unicode if encoding
+        A string value or the default. Unicode if encoding
         was given, bytes otherwise.
 
     Raises:
@@ -142,45 +143,105 @@ def parse_string(
             is not one of those allowed values.
     """
     return parse_string_from_args(
-        request.args, name, default, required, allowed_values, param_type, encoding
+        request.args, name, default, required, allowed_values, encoding
     )
 
 
-def parse_string_from_args(
-    args,
-    name,
-    default=None,
-    required=False,
-    allowed_values=None,
-    param_type="string",
-    encoding="ascii",
-):
+def _parse_string_value(
+    value: Union[str, bytes],
+    allowed_values: Optional[Iterable[str]],
+    name: str,
+    encoding: Optional[str],
+) -> Union[str, bytes]:
+    if encoding:
+        try:
+            value = value.decode(encoding)
+        except ValueError:
+            raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
+
+    if allowed_values is not None and value not in allowed_values:
+        message = "Query parameter %r must be one of [%s]" % (
+            name,
+            ", ".join(repr(v) for v in allowed_values),
+        )
+        raise SynapseError(400, message)
+    else:
+        return value
+
+
+@overload
+def parse_strings_from_args(
+    args: List[str],
+    name: Union[bytes, str],
+    default: Optional[List[str]] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: Literal[None] = None,
+) -> Optional[List[bytes]]:
+    ...
+
+
+@overload
+def parse_strings_from_args(
+    args: List[str],
+    name: Union[bytes, str],
+    default: Optional[List[str]] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: str = "ascii",
+) -> Optional[List[str]]:
+    ...
+
+
+def parse_strings_from_args(
+    args: List[str],
+    name: Union[bytes, str],
+    default: Optional[List[str]] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: Optional[str] = "ascii",
+) -> Optional[List[Union[bytes, str]]]:
+    """
+    Parse a string parameter from the request query string list.
+
+    If encoding is not None, the content of the query param will be
+    decoded to Unicode using the encoding, otherwise it will be encoded
+
+    Args:
+        args: the twisted HTTP request.args list.
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent,
+            defaults to None. Must be bytes if encoding is None.
+        required : whether to raise a 400 SynapseError if the
+            parameter is absent, defaults to False.
+        allowed_values (list[bytes|unicode]): List of allowed values for the
+            string, or None if any value is allowed, defaults to None. Must be
+            the same type as name, if given.
+        encoding: The encoding to decode the string content with.
+
+    Returns:
+        A string value or the default. Unicode if encoding
+        was given, bytes otherwise.
+
+    Raises:
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present, must be one of a list of allowed values and
+            is not one of those allowed values.
+    """
 
     if not isinstance(name, bytes):
         name = name.encode("ascii")
 
     if name in args:
-        value = args[name][0]
-
-        if encoding:
-            try:
-                value = value.decode(encoding)
-            except ValueError:
-                raise SynapseError(
-                    400, "Query parameter %r must be %s" % (name, encoding)
-                )
-
-        if allowed_values is not None and value not in allowed_values:
-            message = "Query parameter %r must be one of [%s]" % (
-                name,
-                ", ".join(repr(v) for v in allowed_values),
-            )
-            raise SynapseError(400, message)
-        else:
-            return value
+        values = args[name]
+
+        return [
+            _parse_string_value(value, allowed_values, name=name, encoding=encoding)
+            for value in values
+        ]
     else:
         if required:
-            message = "Missing %s query parameter %r" % (param_type, name)
+            message = "Missing string query parameter %r" % (name)
             raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
         else:
 
@@ -190,6 +251,55 @@ def parse_string_from_args(
             return default
 
 
+def parse_string_from_args(
+    args: List[str],
+    name: Union[bytes, str],
+    default: Optional[str] = None,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: Optional[str] = "ascii",
+) -> Optional[Union[bytes, str]]:
+    """
+    Parse the string parameter from the request query string list
+    and return the first result.
+
+    If encoding is not None, the content of the query param will be
+    decoded to Unicode using the encoding, otherwise it will be encoded
+
+    Args:
+        args: the twisted HTTP request.args list.
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent,
+            defaults to None. Must be bytes if encoding is None.
+        required: whether to raise a 400 SynapseError if the
+            parameter is absent, defaults to False.
+        allowed_values: List of allowed values for the
+            string, or None if any value is allowed, defaults to None. Must be
+            the same type as name, if given.
+        encoding: The encoding to decode the string content with.
+
+    Returns:
+        A string value or the default. Unicode if encoding
+        was given, bytes otherwise.
+
+    Raises:
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present, must be one of a list of allowed values and
+            is not one of those allowed values.
+    """
+
+    strings = parse_strings_from_args(
+        args,
+        name,
+        default=[default],
+        required=required,
+        allowed_values=allowed_values,
+        encoding=encoding,
+    )
+
+    return strings[0]
+
+
 def parse_json_value_from_request(request, allow_empty_body=False):
     """Parse a JSON value from the body of a twisted HTTP request.
 
@@ -215,7 +325,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
     try:
         content = json_decoder.decode(content_bytes.decode("utf-8"))
     except Exception as e:
-        logger.warning("Unable to parse JSON: %s", e)
+        logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes)
         raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
 
     return content