summary refs log tree commit diff
path: root/synapse/http/servlet.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/servlet.py')
-rw-r--r--synapse/http/servlet.py56
1 files changed, 25 insertions, 31 deletions
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py

index 89991e7127..07eb4f439b 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py
@@ -200,36 +200,6 @@ def parse_string( args, name, default, required, allowed_values, encoding ) -def parse_list_from_args( - args: Dict[bytes, List[bytes]], - name: Union[bytes, str], - encoding: Optional[str] = "ascii", -): - """Parse and optionally decode a list of values from request query parameters. - - Args: - args: A dictionary of query parameters from a request. - name: The name of the query parameter to extract values from. If given as bytes, - will be decoded as "ascii". - encoding: An optional encoding that is used to decode each parameter value with. - - Raises: - KeyError: If the given `name` does not exist in `args`. - SynapseError: If an argument was not encoded with the specified `encoding`. - """ - if not isinstance(name, bytes): - name = name.encode("ascii") - args_list = args[name] - - if encoding: - # Decode each argument value - try: - args_list = [value.decode(encoding) for value in args_list] - except ValueError: - raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding)) - - return args_list - def _parse_string_value( value: bytes, @@ -324,6 +294,30 @@ def parse_strings_from_args( return default +@overload +def parse_string_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[str] = None, + required: Literal[True] = True, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> str: + ... + + +@overload +def parse_string_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[str] = None, + required: bool = False, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> Optional[str]: + ... + + def parse_string_from_args( args: Dict[bytes, List[bytes]], name: str, @@ -460,7 +454,7 @@ class RestServlet: """ def register(self, http_server): - """ Register this servlet with the given HTTP server. """ + """Register this servlet with the given HTTP server.""" patterns = getattr(self, "PATTERNS", None) if patterns: for method in ("GET", "PUT", "POST", "DELETE"):