summary refs log tree commit diff
path: root/synapse/http/servlet.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/servlet.py')
-rw-r--r--synapse/http/servlet.py258
1 files changed, 206 insertions, 52 deletions
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 04560fb589..732a1e6aeb 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -14,47 +14,86 @@
 
 """ This module contains base REST classes for constructing REST servlets. """
 import logging
-from typing import Dict, Iterable, List, Optional, overload
+from typing import Iterable, List, Mapping, Optional, Sequence, overload
 
 from typing_extensions import Literal
 
 from twisted.web.server import Request
 
 from synapse.api.errors import Codes, SynapseError
+from synapse.types import JsonDict
 from synapse.util import json_decoder
 
 logger = logging.getLogger(__name__)
 
 
-def parse_integer(request, name, default=None, required=False):
+@overload
+def parse_integer(request: Request, name: str, default: int) -> int:
+    ...
+
+
+@overload
+def parse_integer(request: Request, name: str, *, required: Literal[True]) -> int:
+    ...
+
+
+@overload
+def parse_integer(
+    request: Request, name: str, default: Optional[int] = None, required: bool = False
+) -> Optional[int]:
+    ...
+
+
+def parse_integer(
+    request: Request, name: str, default: Optional[int] = None, required: bool = False
+) -> Optional[int]:
     """Parse an integer parameter from the request string
 
     Args:
         request: the twisted HTTP request.
-        name (bytes/unicode): the name of the query parameter.
-        default (int|None): value to use if the parameter is absent, defaults
-            to None.
-        required (bool): whether to raise a 400 SynapseError if the
-            parameter is absent, defaults to False.
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent, defaults to None.
+        required: whether to raise a 400 SynapseError if the parameter is absent,
+            defaults to False.
 
     Returns:
-        int|None: An int value or the default.
+        An int value or the default.
 
     Raises:
         SynapseError: if the parameter is absent and required, or if the
             parameter is present and not an integer.
     """
-    return parse_integer_from_args(request.args, name, default, required)
+    args: Mapping[bytes, Sequence[bytes]] = request.args  # type: ignore
+    return parse_integer_from_args(args, name, default, required)
+
 
+def parse_integer_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    default: Optional[int] = None,
+    required: bool = False,
+) -> Optional[int]:
+    """Parse an integer parameter from the request string
+
+    Args:
+        args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent, defaults to None.
+        required: whether to raise a 400 SynapseError if the parameter is absent,
+            defaults to False.
 
-def parse_integer_from_args(args, name, default=None, required=False):
+    Returns:
+        An int value or the default.
 
-    if not isinstance(name, bytes):
-        name = name.encode("ascii")
+    Raises:
+        SynapseError: if the parameter is absent and required, or if the
+            parameter is present and not an integer.
+    """
+    name_bytes = name.encode("ascii")
 
-    if name in args:
+    if name_bytes in args:
         try:
-            return int(args[name][0])
+            return int(args[name_bytes][0])
         except Exception:
             message = "Query parameter %r must be an integer" % (name,)
             raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
@@ -66,36 +105,102 @@ def parse_integer_from_args(args, name, default=None, required=False):
             return default
 
 
-def parse_boolean(request, name, default=None, required=False):
+@overload
+def parse_boolean(request: Request, name: str, default: bool) -> bool:
+    ...
+
+
+@overload
+def parse_boolean(request: Request, name: str, *, required: Literal[True]) -> bool:
+    ...
+
+
+@overload
+def parse_boolean(
+    request: Request, name: str, default: Optional[bool] = None, required: bool = False
+) -> Optional[bool]:
+    ...
+
+
+def parse_boolean(
+    request: Request, name: str, default: Optional[bool] = None, required: bool = False
+) -> Optional[bool]:
     """Parse a boolean parameter from the request query string
 
     Args:
         request: the twisted HTTP request.
-        name (bytes/unicode): the name of the query parameter.
-        default (bool|None): value to use if the parameter is absent, defaults
-            to None.
-        required (bool): whether to raise a 400 SynapseError if the
-            parameter is absent, defaults to False.
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent, defaults to None.
+        required: whether to raise a 400 SynapseError if the parameter is absent,
+            defaults to False.
 
     Returns:
-        bool|None: A bool value or the default.
+        A bool value or the default.
 
     Raises:
         SynapseError: if the parameter is absent and required, or if the
             parameter is present and not one of "true" or "false".
     """
+    args: Mapping[bytes, Sequence[bytes]] = request.args  # type: ignore
+    return parse_boolean_from_args(args, name, default, required)
+
+
+@overload
+def parse_boolean_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    default: bool,
+) -> bool:
+    ...
+
 
-    return parse_boolean_from_args(request.args, name, default, required)
+@overload
+def parse_boolean_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    *,
+    required: Literal[True],
+) -> bool:
+    ...
 
 
-def parse_boolean_from_args(args, name, default=None, required=False):
+@overload
+def parse_boolean_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    default: Optional[bool] = None,
+    required: bool = False,
+) -> Optional[bool]:
+    ...
 
-    if not isinstance(name, bytes):
-        name = name.encode("ascii")
 
-    if name in args:
+def parse_boolean_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    default: Optional[bool] = None,
+    required: bool = False,
+) -> Optional[bool]:
+    """Parse a boolean parameter from the request query string
+
+    Args:
+        args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent, defaults to None.
+        required: whether to raise a 400 SynapseError if the parameter is absent,
+            defaults to False.
+
+    Returns:
+        A bool value or the default.
+
+    Raises:
+        SynapseError: if the parameter is absent and required, or if the
+            parameter is present and not one of "true" or "false".
+    """
+    name_bytes = name.encode("ascii")
+
+    if name_bytes in args:
         try:
-            return {b"true": True, b"false": False}[args[name][0]]
+            return {b"true": True, b"false": False}[args[name_bytes][0]]
         except Exception:
             message = (
                 "Boolean query parameter %r must be one of ['true', 'false']"
@@ -111,7 +216,7 @@ def parse_boolean_from_args(args, name, default=None, required=False):
 
 @overload
 def parse_bytes_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[bytes] = None,
 ) -> Optional[bytes]:
@@ -120,7 +225,7 @@ def parse_bytes_from_args(
 
 @overload
 def parse_bytes_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Literal[None] = None,
     *,
@@ -131,7 +236,7 @@ def parse_bytes_from_args(
 
 @overload
 def parse_bytes_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[bytes] = None,
     required: bool = False,
@@ -140,7 +245,7 @@ def parse_bytes_from_args(
 
 
 def parse_bytes_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[bytes] = None,
     required: bool = False,
@@ -172,6 +277,42 @@ def parse_bytes_from_args(
     return default
 
 
+@overload
+def parse_string(
+    request: Request,
+    name: str,
+    default: str,
+    *,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: str = "ascii",
+) -> str:
+    ...
+
+
+@overload
+def parse_string(
+    request: Request,
+    name: str,
+    *,
+    required: Literal[True],
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: str = "ascii",
+) -> str:
+    ...
+
+
+@overload
+def parse_string(
+    request: Request,
+    name: str,
+    *,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: str = "ascii",
+) -> Optional[str]:
+    ...
+
+
 def parse_string(
     request: Request,
     name: str,
@@ -179,7 +320,7 @@ def parse_string(
     required: bool = False,
     allowed_values: Optional[Iterable[str]] = None,
     encoding: str = "ascii",
-):
+) -> Optional[str]:
     """
     Parse a string parameter from the request query string.
 
@@ -205,7 +346,7 @@ def parse_string(
             parameter is present, must be one of a list of allowed values and
             is not one of those allowed values.
     """
-    args: Dict[bytes, List[bytes]] = request.args  # type: ignore
+    args: Mapping[bytes, Sequence[bytes]] = request.args  # type: ignore
     return parse_string_from_args(
         args,
         name,
@@ -239,9 +380,8 @@ def _parse_string_value(
 
 @overload
 def parse_strings_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
-    default: Optional[List[str]] = None,
     *,
     allowed_values: Optional[Iterable[str]] = None,
     encoding: str = "ascii",
@@ -251,9 +391,20 @@ def parse_strings_from_args(
 
 @overload
 def parse_strings_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    default: List[str],
+    *,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: str = "ascii",
+) -> List[str]:
+    ...
+
+
+@overload
+def parse_strings_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
-    default: Optional[List[str]] = None,
     *,
     required: Literal[True],
     allowed_values: Optional[Iterable[str]] = None,
@@ -264,7 +415,7 @@ def parse_strings_from_args(
 
 @overload
 def parse_strings_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[List[str]] = None,
     *,
@@ -276,7 +427,7 @@ def parse_strings_from_args(
 
 
 def parse_strings_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[List[str]] = None,
     required: bool = False,
@@ -325,7 +476,7 @@ def parse_strings_from_args(
 
 @overload
 def parse_string_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[str] = None,
     *,
@@ -337,7 +488,7 @@ def parse_string_from_args(
 
 @overload
 def parse_string_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[str] = None,
     *,
@@ -350,7 +501,7 @@ def parse_string_from_args(
 
 @overload
 def parse_string_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[str] = None,
     required: bool = False,
@@ -361,7 +512,7 @@ def parse_string_from_args(
 
 
 def parse_string_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[str] = None,
     required: bool = False,
@@ -409,13 +560,14 @@ def parse_string_from_args(
     return strings[0]
 
 
-def parse_json_value_from_request(request, allow_empty_body=False):
+def parse_json_value_from_request(
+    request: Request, allow_empty_body: bool = False
+) -> Optional[JsonDict]:
     """Parse a JSON value from the body of a twisted HTTP request.
 
     Args:
         request: the twisted HTTP request.
-        allow_empty_body (bool): if True, an empty body will be accepted and
-            turned into None
+        allow_empty_body: if True, an empty body will be accepted and turned into None
 
     Returns:
         The JSON value.
@@ -424,7 +576,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
         SynapseError if the request body couldn't be decoded as JSON.
     """
     try:
-        content_bytes = request.content.read()
+        content_bytes = request.content.read()  # type: ignore
     except Exception:
         raise SynapseError(400, "Error reading JSON content.")
 
@@ -440,13 +592,15 @@ def parse_json_value_from_request(request, allow_empty_body=False):
     return content
 
 
-def parse_json_object_from_request(request, allow_empty_body=False):
+def parse_json_object_from_request(
+    request: Request, allow_empty_body: bool = False
+) -> JsonDict:
     """Parse a JSON object from the body of a twisted HTTP request.
 
     Args:
         request: the twisted HTTP request.
-        allow_empty_body (bool): if True, an empty body will be accepted and
-            turned into an empty dict.
+        allow_empty_body: if True, an empty body will be accepted and turned into
+            an empty dict.
 
     Raises:
         SynapseError if the request body couldn't be decoded as JSON or
@@ -457,14 +611,14 @@ def parse_json_object_from_request(request, allow_empty_body=False):
     if allow_empty_body and content is None:
         return {}
 
-    if type(content) != dict:
+    if not isinstance(content, dict):
         message = "Content must be a JSON object."
         raise SynapseError(400, message, errcode=Codes.BAD_JSON)
 
     return content
 
 
-def assert_params_in_dict(body, required):
+def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
     absent = []
     for k in required:
         if k not in body: