summary refs log tree commit diff
path: root/synapse/http/servlet.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/servlet.py')
-rw-r--r--synapse/http/servlet.py125
1 files changed, 102 insertions, 23 deletions
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index dead02cd5c..5d79d31579 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -13,11 +13,11 @@
 # limitations under the License.
 
 """ This module contains base REST classes for constructing REST servlets. """
+import enum
 import logging
 from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
-    Iterable,
     List,
     Mapping,
     Optional,
@@ -37,7 +37,7 @@ from twisted.web.server import Request
 from synapse.api.errors import Codes, SynapseError
 from synapse.http import redact_uri
 from synapse.http.server import HttpServer
-from synapse.types import JsonDict, RoomAlias, RoomID
+from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection
 from synapse.util import json_decoder
 
 if TYPE_CHECKING:
@@ -339,7 +339,7 @@ def parse_string(
     name: str,
     default: str,
     *,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> str:
     ...
@@ -351,7 +351,7 @@ def parse_string(
     name: str,
     *,
     required: Literal[True],
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> str:
     ...
@@ -362,8 +362,9 @@ def parse_string(
     request: Request,
     name: str,
     *,
+    default: Optional[str] = None,
     required: bool = False,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> Optional[str]:
     ...
@@ -374,7 +375,7 @@ def parse_string(
     name: str,
     default: Optional[str] = None,
     required: bool = False,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> Optional[str]:
     """
@@ -413,9 +414,77 @@ def parse_string(
     )
 
 
+EnumT = TypeVar("EnumT", bound=enum.Enum)
+
+
+@overload
+def parse_enum(
+    request: Request,
+    name: str,
+    E: Type[EnumT],
+    default: EnumT,
+) -> EnumT:
+    ...
+
+
+@overload
+def parse_enum(
+    request: Request,
+    name: str,
+    E: Type[EnumT],
+    *,
+    required: Literal[True],
+) -> EnumT:
+    ...
+
+
+def parse_enum(
+    request: Request,
+    name: str,
+    E: Type[EnumT],
+    default: Optional[EnumT] = None,
+    required: bool = False,
+) -> Optional[EnumT]:
+    """
+    Parse an enum parameter from the request query string.
+
+    Note that the enum *must only have string values*.
+
+    Args:
+        request: the twisted HTTP request.
+        name: the name of the query parameter.
+        E: the enum which represents valid values
+        default: enum value to use if the parameter is absent, defaults to None.
+        required: whether to raise a 400 SynapseError if the
+            parameter is absent, defaults to False.
+
+    Returns:
+        An enum value.
+
+    Raises:
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present, must be one of a list of allowed values and
+            is not one of those allowed values.
+    """
+    # Assert the enum values are strings.
+    assert all(
+        isinstance(e.value, str) for e in E
+    ), "parse_enum only works with string values"
+    str_value = parse_string(
+        request,
+        name,
+        default=default.value if default is not None else None,
+        required=required,
+        allowed_values=[e.value for e in E],
+    )
+    if str_value is None:
+        return None
+    return E(str_value)
+
+
 def _parse_string_value(
     value: bytes,
-    allowed_values: Optional[Iterable[str]],
+    allowed_values: Optional[StrCollection],
     name: str,
     encoding: str,
 ) -> str:
@@ -441,7 +510,7 @@ def parse_strings_from_args(
     args: Mapping[bytes, Sequence[bytes]],
     name: str,
     *,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> Optional[List[str]]:
     ...
@@ -453,7 +522,7 @@ def parse_strings_from_args(
     name: str,
     default: List[str],
     *,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> List[str]:
     ...
@@ -465,7 +534,7 @@ def parse_strings_from_args(
     name: str,
     *,
     required: Literal[True],
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> List[str]:
     ...
@@ -478,7 +547,7 @@ def parse_strings_from_args(
     default: Optional[List[str]] = None,
     *,
     required: bool = False,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> Optional[List[str]]:
     ...
@@ -489,7 +558,7 @@ def parse_strings_from_args(
     name: str,
     default: Optional[List[str]] = None,
     required: bool = False,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> Optional[List[str]]:
     """
@@ -540,7 +609,7 @@ def parse_string_from_args(
     name: str,
     default: Optional[str] = None,
     *,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> Optional[str]:
     ...
@@ -553,7 +622,7 @@ def parse_string_from_args(
     default: Optional[str] = None,
     *,
     required: Literal[True],
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> str:
     ...
@@ -565,7 +634,7 @@ def parse_string_from_args(
     name: str,
     default: Optional[str] = None,
     required: bool = False,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> Optional[str]:
     ...
@@ -576,7 +645,7 @@ def parse_string_from_args(
     name: str,
     default: Optional[str] = None,
     required: bool = False,
-    allowed_values: Optional[Iterable[str]] = None,
+    allowed_values: Optional[StrCollection] = None,
     encoding: str = "ascii",
 ) -> Optional[str]:
     """
@@ -708,17 +777,13 @@ def parse_json_object_from_request(
 Model = TypeVar("Model", bound=BaseModel)
 
 
-def parse_and_validate_json_object_from_request(
-    request: Request, model_type: Type[Model]
-) -> Model:
-    """Parse a JSON object from the body of a twisted HTTP request, then deserialise and
-    validate using the given pydantic model.
+def validate_json_object(content: JsonDict, model_type: Type[Model]) -> Model:
+    """Validate a deserialized JSON object using the given pydantic model.
 
     Raises:
         SynapseError if the request body couldn't be decoded as JSON or
             if it wasn't a JSON object.
     """
-    content = parse_json_object_from_request(request, allow_empty_body=False)
     try:
         instance = model_type.parse_obj(content)
     except ValidationError as e:
@@ -741,7 +806,21 @@ def parse_and_validate_json_object_from_request(
     return instance
 
 
-def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
+def parse_and_validate_json_object_from_request(
+    request: Request, model_type: Type[Model]
+) -> Model:
+    """Parse a JSON object from the body of a twisted HTTP request, then deserialise and
+    validate using the given pydantic model.
+
+    Raises:
+        SynapseError if the request body couldn't be decoded as JSON or
+            if it wasn't a JSON object.
+    """
+    content = parse_json_object_from_request(request, allow_empty_body=False)
+    return validate_json_object(content, model_type)
+
+
+def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None:
     absent = []
     for k in required:
         if k not in body: