diff options
Diffstat (limited to 'synapse/http/servlet.py')
-rw-r--r-- | synapse/http/servlet.py | 179 |
1 files changed, 111 insertions, 68 deletions
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 3f4f2411fc..d61563d39b 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -15,10 +15,12 @@ """ This module contains base REST classes for constructing REST servlets. """ import logging -from typing import Iterable, List, Optional, Union, overload +from typing import Dict, Iterable, List, Optional, overload from typing_extensions import Literal +from twisted.web.server import Request + from synapse.api.errors import Codes, SynapseError from synapse.util import json_decoder @@ -108,13 +110,66 @@ def parse_boolean_from_args(args, name, default=None, required=False): return default +@overload +def parse_bytes_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Literal[None] = None, + required: Literal[True] = True, +) -> bytes: + ... + + +@overload +def parse_bytes_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[bytes] = None, + required: bool = False, +) -> Optional[bytes]: + ... + + +def parse_bytes_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[bytes] = None, + required: bool = False, +) -> Optional[bytes]: + """ + Parse a string parameter as bytes from the request query string. + + Args: + args: A mapping of request args as bytes to a list of bytes (e.g. request.args). + name: the name of the query parameter. + default: value to use if the parameter is absent, + defaults to None. Must be bytes if encoding is None. + required: whether to raise a 400 SynapseError if the + parameter is absent, defaults to False. + Returns: + Bytes or the default value. + + Raises: + SynapseError if the parameter is absent and required. + """ + name_bytes = name.encode("ascii") + + if name_bytes in args: + return args[name_bytes][0] + elif required: + message = "Missing string query parameter %s" % (name,) + raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + + return default + + def parse_string( - request, - name: Union[bytes, str], + request: Request, + name: str, default: Optional[str] = None, required: bool = False, allowed_values: Optional[Iterable[str]] = None, - encoding: Optional[str] = "ascii", + encoding: str = "ascii", ): """ Parse a string parameter from the request query string. @@ -125,66 +180,65 @@ def parse_string( Args: request: the twisted HTTP request. name: the name of the query parameter. - default: value to use if the parameter is absent, - defaults to None. Must be bytes if encoding is None. + default: value to use if the parameter is absent, defaults to None. required: whether to raise a 400 SynapseError if the parameter is absent, defaults to False. allowed_values: List of allowed values for the string, or None if any value is allowed, defaults to None. Must be the same type as name, if given. - encoding : The encoding to decode the string content with. + encoding: The encoding to decode the string content with. + Returns: - A string value or the default. Unicode if encoding - was given, bytes otherwise. + A string value or the default. Raises: SynapseError if the parameter is absent and required, or if the parameter is present, must be one of a list of allowed values and is not one of those allowed values. """ + args = request.args # type: Dict[bytes, List[bytes]] # type: ignore return parse_string_from_args( - request.args, name, default, required, allowed_values, encoding + args, name, default, required, allowed_values, encoding ) def _parse_string_value( - value: Union[str, bytes], + value: bytes, allowed_values: Optional[Iterable[str]], name: str, - encoding: Optional[str], -) -> Union[str, bytes]: - if encoding: - try: - value = value.decode(encoding) - except ValueError: - raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding)) + encoding: str, +) -> str: + try: + value_str = value.decode(encoding) + except ValueError: + raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding)) - if allowed_values is not None and value not in allowed_values: + if allowed_values is not None and value_str not in allowed_values: message = "Query parameter %r must be one of [%s]" % ( name, ", ".join(repr(v) for v in allowed_values), ) raise SynapseError(400, message) else: - return value + return value_str @overload def parse_strings_from_args( - args: List[str], - name: Union[bytes, str], + args: Dict[bytes, List[bytes]], + name: str, default: Optional[List[str]] = None, - required: bool = False, + required: Literal[True] = True, allowed_values: Optional[Iterable[str]] = None, - encoding: Literal[None] = None, -) -> Optional[List[bytes]]: + encoding: str = "ascii", +) -> List[str]: ... @overload def parse_strings_from_args( - args: List[str], - name: Union[bytes, str], + args: Dict[bytes, List[bytes]], + name: str, default: Optional[List[str]] = None, required: bool = False, allowed_values: Optional[Iterable[str]] = None, @@ -194,46 +248,40 @@ def parse_strings_from_args( def parse_strings_from_args( - args: List[str], - name: Union[bytes, str], + args: Dict[bytes, List[bytes]], + name: str, default: Optional[List[str]] = None, required: bool = False, allowed_values: Optional[Iterable[str]] = None, - encoding: Optional[str] = "ascii", -) -> Optional[List[Union[bytes, str]]]: + encoding: str = "ascii", +) -> Optional[List[str]]: """ Parse a string parameter from the request query string list. - If encoding is not None, the content of the query param will be - decoded to Unicode using the encoding, otherwise it will be encoded + The content of the query param will be decoded to Unicode using the encoding. Args: - args: the twisted HTTP request.args list. + args: A mapping of request args as bytes to a list of bytes (e.g. request.args). name: the name of the query parameter. - default: value to use if the parameter is absent, - defaults to None. Must be bytes if encoding is None. - required : whether to raise a 400 SynapseError if the + default: value to use if the parameter is absent, defaults to None. + required: whether to raise a 400 SynapseError if the parameter is absent, defaults to False. - allowed_values (list[bytes|unicode]): List of allowed values for the - string, or None if any value is allowed, defaults to None. Must be - the same type as name, if given. + allowed_values: List of allowed values for the + string, or None if any value is allowed, defaults to None. encoding: The encoding to decode the string content with. Returns: - A string value or the default. Unicode if encoding - was given, bytes otherwise. + A string value or the default. Raises: SynapseError if the parameter is absent and required, or if the parameter is present, must be one of a list of allowed values and is not one of those allowed values. """ + name_bytes = name.encode("ascii") - if not isinstance(name, bytes): - name = name.encode("ascii") - - if name in args: - values = args[name] + if name_bytes in args: + values = args[name_bytes] return [ _parse_string_value(value, allowed_values, name=name, encoding=encoding) @@ -241,36 +289,30 @@ def parse_strings_from_args( ] else: if required: - message = "Missing string query parameter %r" % (name) + message = "Missing string query parameter %r" % (name,) raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) - else: - - if encoding and isinstance(default, bytes): - return default.decode(encoding) - return default + return default def parse_string_from_args( - args: List[str], - name: Union[bytes, str], + args: Dict[bytes, List[bytes]], + name: str, default: Optional[str] = None, required: bool = False, allowed_values: Optional[Iterable[str]] = None, - encoding: Optional[str] = "ascii", -) -> Optional[Union[bytes, str]]: + encoding: str = "ascii", +) -> Optional[str]: """ Parse the string parameter from the request query string list and return the first result. - If encoding is not None, the content of the query param will be - decoded to Unicode using the encoding, otherwise it will be encoded + The content of the query param will be decoded to Unicode using the encoding. Args: - args: the twisted HTTP request.args list. + args: A mapping of request args as bytes to a list of bytes (e.g. request.args). name: the name of the query parameter. - default: value to use if the parameter is absent, - defaults to None. Must be bytes if encoding is None. + default: value to use if the parameter is absent, defaults to None. required: whether to raise a 400 SynapseError if the parameter is absent, defaults to False. allowed_values: List of allowed values for the @@ -279,8 +321,7 @@ def parse_string_from_args( encoding: The encoding to decode the string content with. Returns: - A string value or the default. Unicode if encoding - was given, bytes otherwise. + A string value or the default. Raises: SynapseError if the parameter is absent and required, or if the @@ -291,12 +332,15 @@ def parse_string_from_args( strings = parse_strings_from_args( args, name, - default=[default], + default=[default] if default is not None else None, required=required, allowed_values=allowed_values, encoding=encoding, ) + if strings is None: + return None + return strings[0] @@ -388,9 +432,8 @@ class RestServlet: def register(self, http_server): """ Register this servlet with the given HTTP server. """ - if hasattr(self, "PATTERNS"): - patterns = self.PATTERNS - + patterns = getattr(self, "PATTERNS", None) + if patterns: for method in ("GET", "PUT", "POST", "DELETE"): if hasattr(self, "on_%s" % (method,)): servlet_classname = self.__class__.__name__ |