diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 89991e7127..07eb4f439b 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -200,36 +200,6 @@ def parse_string(
args, name, default, required, allowed_values, encoding
)
-def parse_list_from_args(
- args: Dict[bytes, List[bytes]],
- name: Union[bytes, str],
- encoding: Optional[str] = "ascii",
-):
- """Parse and optionally decode a list of values from request query parameters.
-
- Args:
- args: A dictionary of query parameters from a request.
- name: The name of the query parameter to extract values from. If given as bytes,
- will be decoded as "ascii".
- encoding: An optional encoding that is used to decode each parameter value with.
-
- Raises:
- KeyError: If the given `name` does not exist in `args`.
- SynapseError: If an argument was not encoded with the specified `encoding`.
- """
- if not isinstance(name, bytes):
- name = name.encode("ascii")
- args_list = args[name]
-
- if encoding:
- # Decode each argument value
- try:
- args_list = [value.decode(encoding) for value in args_list]
- except ValueError:
- raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
-
- return args_list
-
def _parse_string_value(
value: bytes,
@@ -324,6 +294,30 @@ def parse_strings_from_args(
return default
+@overload
+def parse_string_from_args(
+ args: Dict[bytes, List[bytes]],
+ name: str,
+ default: Optional[str] = None,
+ required: Literal[True] = True,
+ allowed_values: Optional[Iterable[str]] = None,
+ encoding: str = "ascii",
+) -> str:
+ ...
+
+
+@overload
+def parse_string_from_args(
+ args: Dict[bytes, List[bytes]],
+ name: str,
+ default: Optional[str] = None,
+ required: bool = False,
+ allowed_values: Optional[Iterable[str]] = None,
+ encoding: str = "ascii",
+) -> Optional[str]:
+ ...
+
+
def parse_string_from_args(
args: Dict[bytes, List[bytes]],
name: str,
@@ -460,7 +454,7 @@ class RestServlet:
"""
def register(self, http_server):
- """ Register this servlet with the given HTTP server. """
+ """Register this servlet with the given HTTP server."""
patterns = getattr(self, "PATTERNS", None)
if patterns:
for method in ("GET", "PUT", "POST", "DELETE"):
|