summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/http/servlet.py179
-rw-r--r--synapse/rest/admin/rooms.py2
-rw-r--r--synapse/rest/client/v1/login.py8
-rw-r--r--synapse/rest/client/v1/room.py4
-rw-r--r--synapse/rest/consent/consent_resource.py9
-rw-r--r--synapse/rest/media/v1/upload_resource.py11
6 files changed, 130 insertions, 83 deletions
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 3f4f2411fc..d61563d39b 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -15,10 +15,12 @@
 """ This module contains base REST classes for constructing REST servlets. """
 
 import logging
-from typing import Iterable, List, Optional, Union, overload
+from typing import Dict, Iterable, List, Optional, overload
 
 from typing_extensions import Literal
 
+from twisted.web.server import Request
+
 from synapse.api.errors import Codes, SynapseError
 from synapse.util import json_decoder
 
@@ -108,13 +110,66 @@ def parse_boolean_from_args(args, name, default=None, required=False):
             return default
 
 
+@overload
+def parse_bytes_from_args(
+    args: Dict[bytes, List[bytes]],
+    name: str,
+    default: Literal[None] = None,
+    required: Literal[True] = True,
+) -> bytes:
+    ...
+
+
+@overload
+def parse_bytes_from_args(
+    args: Dict[bytes, List[bytes]],
+    name: str,
+    default: Optional[bytes] = None,
+    required: bool = False,
+) -> Optional[bytes]:
+    ...
+
+
+def parse_bytes_from_args(
+    args: Dict[bytes, List[bytes]],
+    name: str,
+    default: Optional[bytes] = None,
+    required: bool = False,
+) -> Optional[bytes]:
+    """
+    Parse a string parameter as bytes from the request query string.
+
+    Args:
+        args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent,
+            defaults to None. Must be bytes if encoding is None.
+        required: whether to raise a 400 SynapseError if the
+            parameter is absent, defaults to False.
+    Returns:
+        Bytes or the default value.
+
+    Raises:
+        SynapseError if the parameter is absent and required.
+    """
+    name_bytes = name.encode("ascii")
+
+    if name_bytes in args:
+        return args[name_bytes][0]
+    elif required:
+        message = "Missing string query parameter %s" % (name,)
+        raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+
+    return default
+
+
 def parse_string(
-    request,
-    name: Union[bytes, str],
+    request: Request,
+    name: str,
     default: Optional[str] = None,
     required: bool = False,
     allowed_values: Optional[Iterable[str]] = None,
-    encoding: Optional[str] = "ascii",
+    encoding: str = "ascii",
 ):
     """
     Parse a string parameter from the request query string.
@@ -125,66 +180,65 @@ def parse_string(
     Args:
         request: the twisted HTTP request.
         name: the name of the query parameter.
-        default: value to use if the parameter is absent,
-            defaults to None. Must be bytes if encoding is None.
+        default: value to use if the parameter is absent, defaults to None.
         required: whether to raise a 400 SynapseError if the
             parameter is absent, defaults to False.
         allowed_values: List of allowed values for the
             string, or None if any value is allowed, defaults to None. Must be
             the same type as name, if given.
-        encoding : The encoding to decode the string content with.
+        encoding: The encoding to decode the string content with.
+
     Returns:
-        A string value or the default. Unicode if encoding
-        was given, bytes otherwise.
+        A string value or the default.
 
     Raises:
         SynapseError if the parameter is absent and required, or if the
             parameter is present, must be one of a list of allowed values and
             is not one of those allowed values.
     """
+    args = request.args  # type: Dict[bytes, List[bytes]]  # type: ignore
     return parse_string_from_args(
-        request.args, name, default, required, allowed_values, encoding
+        args, name, default, required, allowed_values, encoding
     )
 
 
 def _parse_string_value(
-    value: Union[str, bytes],
+    value: bytes,
     allowed_values: Optional[Iterable[str]],
     name: str,
-    encoding: Optional[str],
-) -> Union[str, bytes]:
-    if encoding:
-        try:
-            value = value.decode(encoding)
-        except ValueError:
-            raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
+    encoding: str,
+) -> str:
+    try:
+        value_str = value.decode(encoding)
+    except ValueError:
+        raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
 
-    if allowed_values is not None and value not in allowed_values:
+    if allowed_values is not None and value_str not in allowed_values:
         message = "Query parameter %r must be one of [%s]" % (
             name,
             ", ".join(repr(v) for v in allowed_values),
         )
         raise SynapseError(400, message)
     else:
-        return value
+        return value_str
 
 
 @overload
 def parse_strings_from_args(
-    args: List[str],
-    name: Union[bytes, str],
+    args: Dict[bytes, List[bytes]],
+    name: str,
     default: Optional[List[str]] = None,
-    required: bool = False,
+    required: Literal[True] = True,
     allowed_values: Optional[Iterable[str]] = None,
-    encoding: Literal[None] = None,
-) -> Optional[List[bytes]]:
+    encoding: str = "ascii",
+) -> List[str]:
     ...
 
 
 @overload
 def parse_strings_from_args(
-    args: List[str],
-    name: Union[bytes, str],
+    args: Dict[bytes, List[bytes]],
+    name: str,
     default: Optional[List[str]] = None,
     required: bool = False,
     allowed_values: Optional[Iterable[str]] = None,
@@ -194,46 +248,40 @@ def parse_strings_from_args(
 
 
 def parse_strings_from_args(
-    args: List[str],
-    name: Union[bytes, str],
+    args: Dict[bytes, List[bytes]],
+    name: str,
     default: Optional[List[str]] = None,
     required: bool = False,
     allowed_values: Optional[Iterable[str]] = None,
-    encoding: Optional[str] = "ascii",
-) -> Optional[List[Union[bytes, str]]]:
+    encoding: str = "ascii",
+) -> Optional[List[str]]:
     """
     Parse a string parameter from the request query string list.
 
-    If encoding is not None, the content of the query param will be
-    decoded to Unicode using the encoding, otherwise it will be encoded
+    The content of the query param will be decoded to Unicode using the encoding.
 
     Args:
-        args: the twisted HTTP request.args list.
+        args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
         name: the name of the query parameter.
-        default: value to use if the parameter is absent,
-            defaults to None. Must be bytes if encoding is None.
-        required : whether to raise a 400 SynapseError if the
+        default: value to use if the parameter is absent, defaults to None.
+        required: whether to raise a 400 SynapseError if the
             parameter is absent, defaults to False.
-        allowed_values (list[bytes|unicode]): List of allowed values for the
-            string, or None if any value is allowed, defaults to None. Must be
-            the same type as name, if given.
+        allowed_values: List of allowed values for the
+            string, or None if any value is allowed, defaults to None.
         encoding: The encoding to decode the string content with.
 
     Returns:
-        A string value or the default. Unicode if encoding
-        was given, bytes otherwise.
+        A string value or the default.
 
     Raises:
         SynapseError if the parameter is absent and required, or if the
             parameter is present, must be one of a list of allowed values and
             is not one of those allowed values.
     """
+    name_bytes = name.encode("ascii")
 
-    if not isinstance(name, bytes):
-        name = name.encode("ascii")
-
-    if name in args:
-        values = args[name]
+    if name_bytes in args:
+        values = args[name_bytes]
 
         return [
             _parse_string_value(value, allowed_values, name=name, encoding=encoding)
@@ -241,36 +289,30 @@ def parse_strings_from_args(
         ]
     else:
         if required:
-            message = "Missing string query parameter %r" % (name)
+            message = "Missing string query parameter %r" % (name,)
             raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
-        else:
-
-            if encoding and isinstance(default, bytes):
-                return default.decode(encoding)
 
-            return default
+        return default
 
 
 def parse_string_from_args(
-    args: List[str],
-    name: Union[bytes, str],
+    args: Dict[bytes, List[bytes]],
+    name: str,
     default: Optional[str] = None,
     required: bool = False,
     allowed_values: Optional[Iterable[str]] = None,
-    encoding: Optional[str] = "ascii",
-) -> Optional[Union[bytes, str]]:
+    encoding: str = "ascii",
+) -> Optional[str]:
     """
     Parse the string parameter from the request query string list
     and return the first result.
 
-    If encoding is not None, the content of the query param will be
-    decoded to Unicode using the encoding, otherwise it will be encoded
+    The content of the query param will be decoded to Unicode using the encoding.
 
     Args:
-        args: the twisted HTTP request.args list.
+        args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
         name: the name of the query parameter.
-        default: value to use if the parameter is absent,
-            defaults to None. Must be bytes if encoding is None.
+        default: value to use if the parameter is absent, defaults to None.
         required: whether to raise a 400 SynapseError if the
             parameter is absent, defaults to False.
         allowed_values: List of allowed values for the
@@ -279,8 +321,7 @@ def parse_string_from_args(
         encoding: The encoding to decode the string content with.
 
     Returns:
-        A string value or the default. Unicode if encoding
-        was given, bytes otherwise.
+        A string value or the default.
 
     Raises:
         SynapseError if the parameter is absent and required, or if the
@@ -291,12 +332,15 @@ def parse_string_from_args(
     strings = parse_strings_from_args(
         args,
         name,
-        default=[default],
+        default=[default] if default is not None else None,
         required=required,
         allowed_values=allowed_values,
         encoding=encoding,
     )
 
+    if strings is None:
+        return None
+
     return strings[0]
 
 
@@ -388,9 +432,8 @@ class RestServlet:
 
     def register(self, http_server):
         """ Register this servlet with the given HTTP server. """
-        if hasattr(self, "PATTERNS"):
-            patterns = self.PATTERNS
-
+        patterns = getattr(self, "PATTERNS", None)
+        if patterns:
             for method in ("GET", "PUT", "POST", "DELETE"):
                 if hasattr(self, "on_%s" % (method,)):
                     servlet_classname = self.__class__.__name__
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index f289ffe3d0..f0cddd2d2c 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -649,7 +649,7 @@ class RoomEventContextServlet(RestServlet):
         limit = parse_integer(request, "limit", default=10)
 
         # picking the API shape for symmetry with /messages
-        filter_str = parse_string(request, b"filter", encoding="utf-8")
+        filter_str = parse_string(request, "filter", encoding="utf-8")
         if filter_str:
             filter_json = urlparse.unquote(filter_str)
             event_filter = Filter(
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 42e709ec14..f6be5f1020 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,7 +14,7 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
@@ -25,6 +25,7 @@ from synapse.http import get_request_uri
 from synapse.http.server import HttpServer, finish_request
 from synapse.http.servlet import (
     RestServlet,
+    parse_bytes_from_args,
     parse_json_object_from_request,
     parse_string,
 )
@@ -437,9 +438,8 @@ class SsoRedirectServlet(RestServlet):
             finish_request(request)
             return
 
-        client_redirect_url = parse_string(
-            request, "redirectUrl", required=True, encoding=None
-        )
+        args = request.args  # type: Dict[bytes, List[bytes]]  # type: ignore
+        client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True)
         sso_url = await self._sso_handler.handle_redirect_request(
             request,
             client_redirect_url,
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 5a9c27f75f..122105854a 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -537,7 +537,7 @@ class RoomMessageListRestServlet(RestServlet):
             self.store, request, default_limit=10
         )
         as_client_event = b"raw" not in request.args
-        filter_str = parse_string(request, b"filter", encoding="utf-8")
+        filter_str = parse_string(request, "filter", encoding="utf-8")
         if filter_str:
             filter_json = urlparse.unquote(filter_str)
             event_filter = Filter(
@@ -652,7 +652,7 @@ class RoomEventContextServlet(RestServlet):
         limit = parse_integer(request, "limit", default=10)
 
         # picking the API shape for symmetry with /messages
-        filter_str = parse_string(request, b"filter", encoding="utf-8")
+        filter_str = parse_string(request, "filter", encoding="utf-8")
         if filter_str:
             filter_json = urlparse.unquote(filter_str)
             event_filter = Filter(
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index b19cd8afc5..e52570cd8e 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -17,6 +17,7 @@ import logging
 from hashlib import sha256
 from http import HTTPStatus
 from os import path
+from typing import Dict, List
 
 import jinja2
 from jinja2 import TemplateNotFound
@@ -24,7 +25,7 @@ from jinja2 import TemplateNotFound
 from synapse.api.errors import NotFoundError, StoreError, SynapseError
 from synapse.config import ConfigError
 from synapse.http.server import DirectServeHtmlResource, respond_with_html
-from synapse.http.servlet import parse_string
+from synapse.http.servlet import parse_bytes_from_args, parse_string
 from synapse.types import UserID
 
 # language to use for the templates. TODO: figure this out from Accept-Language
@@ -116,7 +117,8 @@ class ConsentResource(DirectServeHtmlResource):
         has_consented = False
         public_version = username == ""
         if not public_version:
-            userhmac_bytes = parse_string(request, "h", required=True, encoding=None)
+            args = request.args  # type: Dict[bytes, List[bytes]]
+            userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
 
             self._check_hash(username, userhmac_bytes)
 
@@ -152,7 +154,8 @@ class ConsentResource(DirectServeHtmlResource):
         """
         version = parse_string(request, "v", required=True)
         username = parse_string(request, "u", required=True)
-        userhmac = parse_string(request, "h", required=True, encoding=None)
+        args = request.args  # type: Dict[bytes, List[bytes]]
+        userhmac = parse_bytes_from_args(args, "h", required=True)
 
         self._check_hash(username, userhmac)
 
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 024a105bf2..62dc4aae2d 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -14,13 +14,13 @@
 # limitations under the License.
 
 import logging
-from typing import IO, TYPE_CHECKING
+from typing import IO, TYPE_CHECKING, Dict, List, Optional
 
 from twisted.web.server import Request
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import DirectServeJsonResource, respond_with_json
-from synapse.http.servlet import parse_string
+from synapse.http.servlet import parse_bytes_from_args
 from synapse.http.site import SynapseRequest
 from synapse.rest.media.v1.media_storage import SpamMediaException
 
@@ -61,10 +61,11 @@ class UploadResource(DirectServeJsonResource):
                 errcode=Codes.TOO_LARGE,
             )
 
-        upload_name = parse_string(request, b"filename", encoding=None)
-        if upload_name:
+        args = request.args  # type: Dict[bytes, List[bytes]]  # type: ignore
+        upload_name_bytes = parse_bytes_from_args(args, "filename")
+        if upload_name_bytes:
             try:
-                upload_name = upload_name.decode("utf8")
+                upload_name = upload_name_bytes.decode("utf8")  # type: Optional[str]
             except UnicodeDecodeError:
                 raise SynapseError(
                     msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400