summary refs log tree commit diff
path: root/synapse/http/servlet.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/servlet.py')
-rw-r--r--synapse/http/servlet.py179
1 files changed, 111 insertions, 68 deletions
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 3f4f2411fc..d61563d39b 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -15,10 +15,12 @@
 """ This module contains base REST classes for constructing REST servlets. """
 
 import logging
-from typing import Iterable, List, Optional, Union, overload
+from typing import Dict, Iterable, List, Optional, overload
 
 from typing_extensions import Literal
 
+from twisted.web.server import Request
+
 from synapse.api.errors import Codes, SynapseError
 from synapse.util import json_decoder
 
@@ -108,13 +110,66 @@ def parse_boolean_from_args(args, name, default=None, required=False):
             return default
 
 
+@overload
+def parse_bytes_from_args(
+    args: Dict[bytes, List[bytes]],
+    name: str,
+    default: Literal[None] = None,
+    required: Literal[True] = True,
+) -> bytes:
+    ...
+
+
+@overload
+def parse_bytes_from_args(
+    args: Dict[bytes, List[bytes]],
+    name: str,
+    default: Optional[bytes] = None,
+    required: bool = False,
+) -> Optional[bytes]:
+    ...
+
+
+def parse_bytes_from_args(
+    args: Dict[bytes, List[bytes]],
+    name: str,
+    default: Optional[bytes] = None,
+    required: bool = False,
+) -> Optional[bytes]:
+    """
+    Parse a string parameter as bytes from the request query string.
+
+    Args:
+        args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent,
+            defaults to None. Must be bytes if encoding is None.
+        required: whether to raise a 400 SynapseError if the
+            parameter is absent, defaults to False.
+    Returns:
+        Bytes or the default value.
+
+    Raises:
+        SynapseError if the parameter is absent and required.
+    """
+    name_bytes = name.encode("ascii")
+
+    if name_bytes in args:
+        return args[name_bytes][0]
+    elif required:
+        message = "Missing string query parameter %s" % (name,)
+        raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+
+    return default
+
+
 def parse_string(
-    request,
-    name: Union[bytes, str],
+    request: Request,
+    name: str,
     default: Optional[str] = None,
     required: bool = False,
     allowed_values: Optional[Iterable[str]] = None,
-    encoding: Optional[str] = "ascii",
+    encoding: str = "ascii",
 ):
     """
     Parse a string parameter from the request query string.
@@ -125,66 +180,65 @@ def parse_string(
     Args:
         request: the twisted HTTP request.
         name: the name of the query parameter.
-        default: value to use if the parameter is absent,
-            defaults to None. Must be bytes if encoding is None.
+        default: value to use if the parameter is absent, defaults to None.
         required: whether to raise a 400 SynapseError if the
             parameter is absent, defaults to False.
         allowed_values: List of allowed values for the
             string, or None if any value is allowed, defaults to None. Must be
             the same type as name, if given.
-        encoding : The encoding to decode the string content with.
+        encoding: The encoding to decode the string content with.
+
     Returns:
-        A string value or the default. Unicode if encoding
-        was given, bytes otherwise.
+        A string value or the default.
 
     Raises:
         SynapseError if the parameter is absent and required, or if the
             parameter is present, must be one of a list of allowed values and
             is not one of those allowed values.
     """
+    args = request.args  # type: Dict[bytes, List[bytes]]  # type: ignore
     return parse_string_from_args(
-        request.args, name, default, required, allowed_values, encoding
+        args, name, default, required, allowed_values, encoding
     )
 
 
 def _parse_string_value(
-    value: Union[str, bytes],
+    value: bytes,
     allowed_values: Optional[Iterable[str]],
     name: str,
-    encoding: Optional[str],
-) -> Union[str, bytes]:
-    if encoding:
-        try:
-            value = value.decode(encoding)
-        except ValueError:
-            raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
+    encoding: str,
+) -> str:
+    try:
+        value_str = value.decode(encoding)
+    except ValueError:
+        raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
 
-    if allowed_values is not None and value not in allowed_values:
+    if allowed_values is not None and value_str not in allowed_values:
         message = "Query parameter %r must be one of [%s]" % (
             name,
             ", ".join(repr(v) for v in allowed_values),
         )
         raise SynapseError(400, message)
     else:
-        return value
+        return value_str
 
 
 @overload
 def parse_strings_from_args(
-    args: List[str],
-    name: Union[bytes, str],
+    args: Dict[bytes, List[bytes]],
+    name: str,
     default: Optional[List[str]] = None,
-    required: bool = False,
+    required: Literal[True] = True,
     allowed_values: Optional[Iterable[str]] = None,
-    encoding: Literal[None] = None,
-) -> Optional[List[bytes]]:
+    encoding: str = "ascii",
+) -> List[str]:
     ...
 
 
 @overload
 def parse_strings_from_args(
-    args: List[str],
-    name: Union[bytes, str],
+    args: Dict[bytes, List[bytes]],
+    name: str,
     default: Optional[List[str]] = None,
     required: bool = False,
     allowed_values: Optional[Iterable[str]] = None,
@@ -194,46 +248,40 @@ def parse_strings_from_args(
 
 
 def parse_strings_from_args(
-    args: List[str],
-    name: Union[bytes, str],
+    args: Dict[bytes, List[bytes]],
+    name: str,
     default: Optional[List[str]] = None,
     required: bool = False,
     allowed_values: Optional[Iterable[str]] = None,
-    encoding: Optional[str] = "ascii",
-) -> Optional[List[Union[bytes, str]]]:
+    encoding: str = "ascii",
+) -> Optional[List[str]]:
     """
     Parse a string parameter from the request query string list.
 
-    If encoding is not None, the content of the query param will be
-    decoded to Unicode using the encoding, otherwise it will be encoded
+    The content of the query param will be decoded to Unicode using the encoding.
 
     Args:
-        args: the twisted HTTP request.args list.
+        args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
         name: the name of the query parameter.
-        default: value to use if the parameter is absent,
-            defaults to None. Must be bytes if encoding is None.
-        required : whether to raise a 400 SynapseError if the
+        default: value to use if the parameter is absent, defaults to None.
+        required: whether to raise a 400 SynapseError if the
             parameter is absent, defaults to False.
-        allowed_values (list[bytes|unicode]): List of allowed values for the
-            string, or None if any value is allowed, defaults to None. Must be
-            the same type as name, if given.
+        allowed_values: List of allowed values for the
+            string, or None if any value is allowed, defaults to None.
         encoding: The encoding to decode the string content with.
 
     Returns:
-        A string value or the default. Unicode if encoding
-        was given, bytes otherwise.
+        A string value or the default.
 
     Raises:
         SynapseError if the parameter is absent and required, or if the
             parameter is present, must be one of a list of allowed values and
             is not one of those allowed values.
     """
+    name_bytes = name.encode("ascii")
 
-    if not isinstance(name, bytes):
-        name = name.encode("ascii")
-
-    if name in args:
-        values = args[name]
+    if name_bytes in args:
+        values = args[name_bytes]
 
         return [
             _parse_string_value(value, allowed_values, name=name, encoding=encoding)
@@ -241,36 +289,30 @@ def parse_strings_from_args(
         ]
     else:
         if required:
-            message = "Missing string query parameter %r" % (name)
+            message = "Missing string query parameter %r" % (name,)
             raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
-        else:
-
-            if encoding and isinstance(default, bytes):
-                return default.decode(encoding)
 
-            return default
+        return default
 
 
 def parse_string_from_args(
-    args: List[str],
-    name: Union[bytes, str],
+    args: Dict[bytes, List[bytes]],
+    name: str,
     default: Optional[str] = None,
     required: bool = False,
     allowed_values: Optional[Iterable[str]] = None,
-    encoding: Optional[str] = "ascii",
-) -> Optional[Union[bytes, str]]:
+    encoding: str = "ascii",
+) -> Optional[str]:
     """
     Parse the string parameter from the request query string list
     and return the first result.
 
-    If encoding is not None, the content of the query param will be
-    decoded to Unicode using the encoding, otherwise it will be encoded
+    The content of the query param will be decoded to Unicode using the encoding.
 
     Args:
-        args: the twisted HTTP request.args list.
+        args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
         name: the name of the query parameter.
-        default: value to use if the parameter is absent,
-            defaults to None. Must be bytes if encoding is None.
+        default: value to use if the parameter is absent, defaults to None.
         required: whether to raise a 400 SynapseError if the
             parameter is absent, defaults to False.
         allowed_values: List of allowed values for the
@@ -279,8 +321,7 @@ def parse_string_from_args(
         encoding: The encoding to decode the string content with.
 
     Returns:
-        A string value or the default. Unicode if encoding
-        was given, bytes otherwise.
+        A string value or the default.
 
     Raises:
         SynapseError if the parameter is absent and required, or if the
@@ -291,12 +332,15 @@ def parse_string_from_args(
     strings = parse_strings_from_args(
         args,
         name,
-        default=[default],
+        default=[default] if default is not None else None,
         required=required,
         allowed_values=allowed_values,
         encoding=encoding,
     )
 
+    if strings is None:
+        return None
+
     return strings[0]
 
 
@@ -388,9 +432,8 @@ class RestServlet:
 
     def register(self, http_server):
         """ Register this servlet with the given HTTP server. """
-        if hasattr(self, "PATTERNS"):
-            patterns = self.PATTERNS
-
+        patterns = getattr(self, "PATTERNS", None)
+        if patterns:
             for method in ("GET", "PUT", "POST", "DELETE"):
                 if hasattr(self, "on_%s" % (method,)):
                     servlet_classname = self.__class__.__name__