summary refs log tree commit diff
path: root/synapse/http/servlet.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-02-01 16:35:24 -0500
committerGitHub <noreply@github.com>2023-02-01 21:35:24 +0000
commit1182ae50635db94d3c9c47990a0befcbf6306b62 (patch)
tree56bdbf809b884428af3f6fe28e18dc44472077ac /synapse/http/servlet.py
parentAttempt to delete more duplicate rows in receipts_linearized table. (#14915) (diff)
downloadsynapse-1182ae50635db94d3c9c47990a0befcbf6306b62.tar.xz
Add helper to parse an enum from query args & use it. (#14956)
The `parse_enum` helper pulls an enum value from the query string
(by delegating down to the parse_string helper with values generated
from the enum).

This is used to pull out "f" and "b" in most places and then we thread
the resulting Direction enum throughout more code.
Diffstat (limited to 'synapse/http/servlet.py')
-rw-r--r--synapse/http/servlet.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index dead02cd5c..0070bd2940 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 """ This module contains base REST classes for constructing REST servlets. """
+import enum
 import logging
 from http import HTTPStatus
 from typing import (
@@ -362,6 +363,7 @@ def parse_string(
     request: Request,
     name: str,
     *,
+    default: Optional[str] = None,
     required: bool = False,
     allowed_values: Optional[Iterable[str]] = None,
     encoding: str = "ascii",
@@ -413,6 +415,74 @@ def parse_string(
     )
 
 
+EnumT = TypeVar("EnumT", bound=enum.Enum)
+
+
+@overload
+def parse_enum(
+    request: Request,
+    name: str,
+    E: Type[EnumT],
+    default: EnumT,
+) -> EnumT:
+    ...
+
+
+@overload
+def parse_enum(
+    request: Request,
+    name: str,
+    E: Type[EnumT],
+    *,
+    required: Literal[True],
+) -> EnumT:
+    ...
+
+
+def parse_enum(
+    request: Request,
+    name: str,
+    E: Type[EnumT],
+    default: Optional[EnumT] = None,
+    required: bool = False,
+) -> Optional[EnumT]:
+    """
+    Parse an enum parameter from the request query string.
+
+    Note that the enum *must only have string values*.
+
+    Args:
+        request: the twisted HTTP request.
+        name: the name of the query parameter.
+        E: the enum which represents valid values
+        default: enum value to use if the parameter is absent, defaults to None.
+        required: whether to raise a 400 SynapseError if the
+            parameter is absent, defaults to False.
+
+    Returns:
+        An enum value.
+
+    Raises:
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present, must be one of a list of allowed values and
+            is not one of those allowed values.
+    """
+    # Assert the enum values are strings.
+    assert all(
+        isinstance(e.value, str) for e in E
+    ), "parse_enum only works with string values"
+    str_value = parse_string(
+        request,
+        name,
+        default=default.value if default is not None else None,
+        required=required,
+        allowed_values=[e.value for e in E],
+    )
+    if str_value is None:
+        return None
+    return E(str_value)
+
+
 def _parse_string_value(
     value: bytes,
     allowed_values: Optional[Iterable[str]],