diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index dead02cd5c..0070bd2940 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -13,6 +13,7 @@
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
+import enum
import logging
from http import HTTPStatus
from typing import (
@@ -362,6 +363,7 @@ def parse_string(
request: Request,
name: str,
*,
+ default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
@@ -413,6 +415,74 @@ def parse_string(
)
+EnumT = TypeVar("EnumT", bound=enum.Enum)
+
+
+@overload
+def parse_enum(
+ request: Request,
+ name: str,
+ E: Type[EnumT],
+ default: EnumT,
+) -> EnumT:
+ ...
+
+
+@overload
+def parse_enum(
+ request: Request,
+ name: str,
+ E: Type[EnumT],
+ *,
+ required: Literal[True],
+) -> EnumT:
+ ...
+
+
+def parse_enum(
+ request: Request,
+ name: str,
+ E: Type[EnumT],
+ default: Optional[EnumT] = None,
+ required: bool = False,
+) -> Optional[EnumT]:
+ """
+ Parse an enum parameter from the request query string.
+
+ Note that the enum *must only have string values*.
+
+ Args:
+ request: the twisted HTTP request.
+ name: the name of the query parameter.
+ E: the enum which represents valid values
+ default: enum value to use if the parameter is absent, defaults to None.
+ required: whether to raise a 400 SynapseError if the
+ parameter is absent, defaults to False.
+
+ Returns:
+ An enum value.
+
+ Raises:
+ SynapseError if the parameter is absent and required, or if the
+ parameter is present, must be one of a list of allowed values and
+ is not one of those allowed values.
+ """
+ # Assert the enum values are strings.
+ assert all(
+ isinstance(e.value, str) for e in E
+ ), "parse_enum only works with string values"
+ str_value = parse_string(
+ request,
+ name,
+ default=default.value if default is not None else None,
+ required=required,
+ allowed_values=[e.value for e in E],
+ )
+ if str_value is None:
+ return None
+ return E(str_value)
+
+
def _parse_string_value(
value: bytes,
allowed_values: Optional[Iterable[str]],
|