summary refs log tree commit diff
path: root/synapse/http/servlet.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/servlet.py')
-rw-r--r--synapse/http/servlet.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index dead02cd5c..0070bd2940 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 """ This module contains base REST classes for constructing REST servlets. """
+import enum
 import logging
 from http import HTTPStatus
 from typing import (
@@ -362,6 +363,7 @@ def parse_string(
     request: Request,
     name: str,
     *,
+    default: Optional[str] = None,
     required: bool = False,
     allowed_values: Optional[Iterable[str]] = None,
     encoding: str = "ascii",
@@ -413,6 +415,74 @@ def parse_string(
     )
 
 
+EnumT = TypeVar("EnumT", bound=enum.Enum)
+
+
+@overload
+def parse_enum(
+    request: Request,
+    name: str,
+    E: Type[EnumT],
+    default: EnumT,
+) -> EnumT:
+    ...
+
+
+@overload
+def parse_enum(
+    request: Request,
+    name: str,
+    E: Type[EnumT],
+    *,
+    required: Literal[True],
+) -> EnumT:
+    ...
+
+
+def parse_enum(
+    request: Request,
+    name: str,
+    E: Type[EnumT],
+    default: Optional[EnumT] = None,
+    required: bool = False,
+) -> Optional[EnumT]:
+    """
+    Parse an enum parameter from the request query string.
+
+    Note that the enum *must only have string values*.
+
+    Args:
+        request: the twisted HTTP request.
+        name: the name of the query parameter.
+        E: the enum which represents valid values
+        default: enum value to use if the parameter is absent, defaults to None.
+        required: whether to raise a 400 SynapseError if the
+            parameter is absent, defaults to False.
+
+    Returns:
+        An enum value.
+
+    Raises:
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present, must be one of a list of allowed values and
+            is not one of those allowed values.
+    """
+    # Assert the enum values are strings.
+    assert all(
+        isinstance(e.value, str) for e in E
+    ), "parse_enum only works with string values"
+    str_value = parse_string(
+        request,
+        name,
+        default=default.value if default is not None else None,
+        required=required,
+        allowed_values=[e.value for e in E],
+    )
+    if str_value is None:
+        return None
+    return E(str_value)
+
+
 def _parse_string_value(
     value: bytes,
     allowed_values: Optional[Iterable[str]],