diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index dead02cd5c..5d79d31579 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -13,11 +13,11 @@
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
+import enum
import logging
from http import HTTPStatus
from typing import (
TYPE_CHECKING,
- Iterable,
List,
Mapping,
Optional,
@@ -37,7 +37,7 @@ from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http import redact_uri
from synapse.http.server import HttpServer
-from synapse.types import JsonDict, RoomAlias, RoomID
+from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -339,7 +339,7 @@ def parse_string(
name: str,
default: str,
*,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> str:
...
@@ -351,7 +351,7 @@ def parse_string(
name: str,
*,
required: Literal[True],
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> str:
...
@@ -362,8 +362,9 @@ def parse_string(
request: Request,
name: str,
*,
+ default: Optional[str] = None,
required: bool = False,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
@@ -374,7 +375,7 @@ def parse_string(
name: str,
default: Optional[str] = None,
required: bool = False,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
"""
@@ -413,9 +414,77 @@ def parse_string(
)
+EnumT = TypeVar("EnumT", bound=enum.Enum)
+
+
+@overload
+def parse_enum(
+ request: Request,
+ name: str,
+ E: Type[EnumT],
+ default: EnumT,
+) -> EnumT:
+ ...
+
+
+@overload
+def parse_enum(
+ request: Request,
+ name: str,
+ E: Type[EnumT],
+ *,
+ required: Literal[True],
+) -> EnumT:
+ ...
+
+
+def parse_enum(
+ request: Request,
+ name: str,
+ E: Type[EnumT],
+ default: Optional[EnumT] = None,
+ required: bool = False,
+) -> Optional[EnumT]:
+ """
+ Parse an enum parameter from the request query string.
+
+ Note that the enum *must only have string values*.
+
+ Args:
+ request: the twisted HTTP request.
+ name: the name of the query parameter.
+ E: the enum which represents valid values
+ default: enum value to use if the parameter is absent, defaults to None.
+ required: whether to raise a 400 SynapseError if the
+ parameter is absent, defaults to False.
+
+ Returns:
+ An enum value.
+
+ Raises:
+ SynapseError if the parameter is absent and required, or if the
+ parameter is present, must be one of a list of allowed values and
+ is not one of those allowed values.
+ """
+ # Assert the enum values are strings.
+ assert all(
+ isinstance(e.value, str) for e in E
+ ), "parse_enum only works with string values"
+ str_value = parse_string(
+ request,
+ name,
+ default=default.value if default is not None else None,
+ required=required,
+ allowed_values=[e.value for e in E],
+ )
+ if str_value is None:
+ return None
+ return E(str_value)
+
+
def _parse_string_value(
value: bytes,
- allowed_values: Optional[Iterable[str]],
+ allowed_values: Optional[StrCollection],
name: str,
encoding: str,
) -> str:
@@ -441,7 +510,7 @@ def parse_strings_from_args(
args: Mapping[bytes, Sequence[bytes]],
name: str,
*,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
...
@@ -453,7 +522,7 @@ def parse_strings_from_args(
name: str,
default: List[str],
*,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> List[str]:
...
@@ -465,7 +534,7 @@ def parse_strings_from_args(
name: str,
*,
required: Literal[True],
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> List[str]:
...
@@ -478,7 +547,7 @@ def parse_strings_from_args(
default: Optional[List[str]] = None,
*,
required: bool = False,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
...
@@ -489,7 +558,7 @@ def parse_strings_from_args(
name: str,
default: Optional[List[str]] = None,
required: bool = False,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
"""
@@ -540,7 +609,7 @@ def parse_string_from_args(
name: str,
default: Optional[str] = None,
*,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
@@ -553,7 +622,7 @@ def parse_string_from_args(
default: Optional[str] = None,
*,
required: Literal[True],
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> str:
...
@@ -565,7 +634,7 @@ def parse_string_from_args(
name: str,
default: Optional[str] = None,
required: bool = False,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
@@ -576,7 +645,7 @@ def parse_string_from_args(
name: str,
default: Optional[str] = None,
required: bool = False,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
"""
@@ -708,17 +777,13 @@ def parse_json_object_from_request(
Model = TypeVar("Model", bound=BaseModel)
-def parse_and_validate_json_object_from_request(
- request: Request, model_type: Type[Model]
-) -> Model:
- """Parse a JSON object from the body of a twisted HTTP request, then deserialise and
- validate using the given pydantic model.
+def validate_json_object(content: JsonDict, model_type: Type[Model]) -> Model:
+ """Validate a deserialized JSON object using the given pydantic model.
Raises:
SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object.
"""
- content = parse_json_object_from_request(request, allow_empty_body=False)
try:
instance = model_type.parse_obj(content)
except ValidationError as e:
@@ -741,7 +806,21 @@ def parse_and_validate_json_object_from_request(
return instance
-def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
+def parse_and_validate_json_object_from_request(
+ request: Request, model_type: Type[Model]
+) -> Model:
+ """Parse a JSON object from the body of a twisted HTTP request, then deserialise and
+ validate using the given pydantic model.
+
+ Raises:
+ SynapseError if the request body couldn't be decoded as JSON or
+ if it wasn't a JSON object.
+ """
+ content = parse_json_object_from_request(request, allow_empty_body=False)
+ return validate_json_object(content, model_type)
+
+
+def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None:
absent = []
for k in required:
if k not in body:
|