summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/10437.misc1
-rw-r--r--synapse/federation/transport/server.py13
-rw-r--r--synapse/http/servlet.py220
-rw-r--r--synapse/rest/client/v1/room.py2
-rw-r--r--synapse/storage/databases/main/stats.py2
-rw-r--r--tests/rest/admin/test_media.py4
6 files changed, 176 insertions, 66 deletions
diff --git a/changelog.d/10437.misc b/changelog.d/10437.misc
new file mode 100644
index 0000000000..a557578499
--- /dev/null
+++ b/changelog.d/10437.misc
@@ -0,0 +1 @@
+Improve servlet type hints.
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 2974d4d0cc..5e059d6e09 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -984,7 +984,7 @@ class PublicRoomList(BaseFederationServlet):
         limit = parse_integer_from_args(query, "limit", 0)
         since_token = parse_string_from_args(query, "since", None)
         include_all_networks = parse_boolean_from_args(
-            query, "include_all_networks", False
+            query, "include_all_networks", default=False
         )
         third_party_instance_id = parse_string_from_args(
             query, "third_party_instance_id", None
@@ -1908,16 +1908,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
         suggested_only = parse_boolean_from_args(query, "suggested_only", default=False)
         max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space")
 
-        exclude_rooms = []
-        if b"exclude_rooms" in query:
-            try:
-                exclude_rooms = [
-                    room_id.decode("ascii") for room_id in query[b"exclude_rooms"]
-                ]
-            except Exception:
-                raise SynapseError(
-                    400, "Bad query parameter for exclude_rooms", Codes.INVALID_PARAM
-                )
+        exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[])
 
         return 200, await self.handler.federation_space_summary(
             origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index cf45b6623b..732a1e6aeb 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -14,47 +14,86 @@
 
 """ This module contains base REST classes for constructing REST servlets. """
 import logging
-from typing import Dict, Iterable, List, Optional, overload
+from typing import Iterable, List, Mapping, Optional, Sequence, overload
 
 from typing_extensions import Literal
 
 from twisted.web.server import Request
 
 from synapse.api.errors import Codes, SynapseError
+from synapse.types import JsonDict
 from synapse.util import json_decoder
 
 logger = logging.getLogger(__name__)
 
 
-def parse_integer(request, name, default=None, required=False):
+@overload
+def parse_integer(request: Request, name: str, default: int) -> int:
+    ...
+
+
+@overload
+def parse_integer(request: Request, name: str, *, required: Literal[True]) -> int:
+    ...
+
+
+@overload
+def parse_integer(
+    request: Request, name: str, default: Optional[int] = None, required: bool = False
+) -> Optional[int]:
+    ...
+
+
+def parse_integer(
+    request: Request, name: str, default: Optional[int] = None, required: bool = False
+) -> Optional[int]:
     """Parse an integer parameter from the request string
 
     Args:
         request: the twisted HTTP request.
-        name (bytes/unicode): the name of the query parameter.
-        default (int|None): value to use if the parameter is absent, defaults
-            to None.
-        required (bool): whether to raise a 400 SynapseError if the
-            parameter is absent, defaults to False.
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent, defaults to None.
+        required: whether to raise a 400 SynapseError if the parameter is absent,
+            defaults to False.
 
     Returns:
-        int|None: An int value or the default.
+        An int value or the default.
 
     Raises:
         SynapseError: if the parameter is absent and required, or if the
             parameter is present and not an integer.
     """
-    return parse_integer_from_args(request.args, name, default, required)
+    args: Mapping[bytes, Sequence[bytes]] = request.args  # type: ignore
+    return parse_integer_from_args(args, name, default, required)
 
 
-def parse_integer_from_args(args, name, default=None, required=False):
+def parse_integer_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    default: Optional[int] = None,
+    required: bool = False,
+) -> Optional[int]:
+    """Parse an integer parameter from the request string
+
+    Args:
+        args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent, defaults to None.
+        required: whether to raise a 400 SynapseError if the parameter is absent,
+            defaults to False.
+
+    Returns:
+        An int value or the default.
 
-    if not isinstance(name, bytes):
-        name = name.encode("ascii")
+    Raises:
+        SynapseError: if the parameter is absent and required, or if the
+            parameter is present and not an integer.
+    """
+    name_bytes = name.encode("ascii")
 
-    if name in args:
+    if name_bytes in args:
         try:
-            return int(args[name][0])
+            return int(args[name_bytes][0])
         except Exception:
             message = "Query parameter %r must be an integer" % (name,)
             raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
@@ -66,36 +105,102 @@ def parse_integer_from_args(args, name, default=None, required=False):
             return default
 
 
-def parse_boolean(request, name, default=None, required=False):
+@overload
+def parse_boolean(request: Request, name: str, default: bool) -> bool:
+    ...
+
+
+@overload
+def parse_boolean(request: Request, name: str, *, required: Literal[True]) -> bool:
+    ...
+
+
+@overload
+def parse_boolean(
+    request: Request, name: str, default: Optional[bool] = None, required: bool = False
+) -> Optional[bool]:
+    ...
+
+
+def parse_boolean(
+    request: Request, name: str, default: Optional[bool] = None, required: bool = False
+) -> Optional[bool]:
     """Parse a boolean parameter from the request query string
 
     Args:
         request: the twisted HTTP request.
-        name (bytes/unicode): the name of the query parameter.
-        default (bool|None): value to use if the parameter is absent, defaults
-            to None.
-        required (bool): whether to raise a 400 SynapseError if the
-            parameter is absent, defaults to False.
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent, defaults to None.
+        required: whether to raise a 400 SynapseError if the parameter is absent,
+            defaults to False.
 
     Returns:
-        bool|None: A bool value or the default.
+        A bool value or the default.
 
     Raises:
         SynapseError: if the parameter is absent and required, or if the
             parameter is present and not one of "true" or "false".
     """
+    args: Mapping[bytes, Sequence[bytes]] = request.args  # type: ignore
+    return parse_boolean_from_args(args, name, default, required)
 
-    return parse_boolean_from_args(request.args, name, default, required)
 
+@overload
+def parse_boolean_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    default: bool,
+) -> bool:
+    ...
 
-def parse_boolean_from_args(args, name, default=None, required=False):
 
-    if not isinstance(name, bytes):
-        name = name.encode("ascii")
+@overload
+def parse_boolean_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    *,
+    required: Literal[True],
+) -> bool:
+    ...
+
 
-    if name in args:
+@overload
+def parse_boolean_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    default: Optional[bool] = None,
+    required: bool = False,
+) -> Optional[bool]:
+    ...
+
+
+def parse_boolean_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    default: Optional[bool] = None,
+    required: bool = False,
+) -> Optional[bool]:
+    """Parse a boolean parameter from the request query string
+
+    Args:
+        args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent, defaults to None.
+        required: whether to raise a 400 SynapseError if the parameter is absent,
+            defaults to False.
+
+    Returns:
+        A bool value or the default.
+
+    Raises:
+        SynapseError: if the parameter is absent and required, or if the
+            parameter is present and not one of "true" or "false".
+    """
+    name_bytes = name.encode("ascii")
+
+    if name_bytes in args:
         try:
-            return {b"true": True, b"false": False}[args[name][0]]
+            return {b"true": True, b"false": False}[args[name_bytes][0]]
         except Exception:
             message = (
                 "Boolean query parameter %r must be one of ['true', 'false']"
@@ -111,7 +216,7 @@ def parse_boolean_from_args(args, name, default=None, required=False):
 
 @overload
 def parse_bytes_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[bytes] = None,
 ) -> Optional[bytes]:
@@ -120,7 +225,7 @@ def parse_bytes_from_args(
 
 @overload
 def parse_bytes_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Literal[None] = None,
     *,
@@ -131,7 +236,7 @@ def parse_bytes_from_args(
 
 @overload
 def parse_bytes_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[bytes] = None,
     required: bool = False,
@@ -140,7 +245,7 @@ def parse_bytes_from_args(
 
 
 def parse_bytes_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[bytes] = None,
     required: bool = False,
@@ -241,7 +346,7 @@ def parse_string(
             parameter is present, must be one of a list of allowed values and
             is not one of those allowed values.
     """
-    args: Dict[bytes, List[bytes]] = request.args  # type: ignore
+    args: Mapping[bytes, Sequence[bytes]] = request.args  # type: ignore
     return parse_string_from_args(
         args,
         name,
@@ -275,9 +380,8 @@ def _parse_string_value(
 
 @overload
 def parse_strings_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
-    default: Optional[List[str]] = None,
     *,
     allowed_values: Optional[Iterable[str]] = None,
     encoding: str = "ascii",
@@ -287,9 +391,20 @@ def parse_strings_from_args(
 
 @overload
 def parse_strings_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    default: List[str],
+    *,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: str = "ascii",
+) -> List[str]:
+    ...
+
+
+@overload
+def parse_strings_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
-    default: Optional[List[str]] = None,
     *,
     required: Literal[True],
     allowed_values: Optional[Iterable[str]] = None,
@@ -300,7 +415,7 @@ def parse_strings_from_args(
 
 @overload
 def parse_strings_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[List[str]] = None,
     *,
@@ -312,7 +427,7 @@ def parse_strings_from_args(
 
 
 def parse_strings_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[List[str]] = None,
     required: bool = False,
@@ -361,7 +476,7 @@ def parse_strings_from_args(
 
 @overload
 def parse_string_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[str] = None,
     *,
@@ -373,7 +488,7 @@ def parse_string_from_args(
 
 @overload
 def parse_string_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[str] = None,
     *,
@@ -386,7 +501,7 @@ def parse_string_from_args(
 
 @overload
 def parse_string_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[str] = None,
     required: bool = False,
@@ -397,7 +512,7 @@ def parse_string_from_args(
 
 
 def parse_string_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[str] = None,
     required: bool = False,
@@ -445,13 +560,14 @@ def parse_string_from_args(
     return strings[0]
 
 
-def parse_json_value_from_request(request, allow_empty_body=False):
+def parse_json_value_from_request(
+    request: Request, allow_empty_body: bool = False
+) -> Optional[JsonDict]:
     """Parse a JSON value from the body of a twisted HTTP request.
 
     Args:
         request: the twisted HTTP request.
-        allow_empty_body (bool): if True, an empty body will be accepted and
-            turned into None
+        allow_empty_body: if True, an empty body will be accepted and turned into None
 
     Returns:
         The JSON value.
@@ -460,7 +576,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
         SynapseError if the request body couldn't be decoded as JSON.
     """
     try:
-        content_bytes = request.content.read()
+        content_bytes = request.content.read()  # type: ignore
     except Exception:
         raise SynapseError(400, "Error reading JSON content.")
 
@@ -476,13 +592,15 @@ def parse_json_value_from_request(request, allow_empty_body=False):
     return content
 
 
-def parse_json_object_from_request(request, allow_empty_body=False):
+def parse_json_object_from_request(
+    request: Request, allow_empty_body: bool = False
+) -> JsonDict:
     """Parse a JSON object from the body of a twisted HTTP request.
 
     Args:
         request: the twisted HTTP request.
-        allow_empty_body (bool): if True, an empty body will be accepted and
-            turned into an empty dict.
+        allow_empty_body: if True, an empty body will be accepted and turned into
+            an empty dict.
 
     Raises:
         SynapseError if the request body couldn't be decoded as JSON or
@@ -493,14 +611,14 @@ def parse_json_object_from_request(request, allow_empty_body=False):
     if allow_empty_body and content is None:
         return {}
 
-    if type(content) != dict:
+    if not isinstance(content, dict):
         message = "Content must be a JSON object."
         raise SynapseError(400, message, errcode=Codes.BAD_JSON)
 
     return content
 
 
-def assert_params_in_dict(body, required):
+def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
     absent = []
     for k in required:
         if k not in body:
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 5d309a534c..25ba52c624 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -754,7 +754,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
             if server:
                 raise e
 
-        limit = parse_integer(request, "limit", 0)
+        limit: Optional[int] = parse_integer(request, "limit", 0)
         since_token = parse_string(request, "since")
 
         if limit == 0:
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 0f9aa54ca9..889e0d3625 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -650,7 +650,7 @@ class StatsStore(StateDeltasStore):
         order_by: Optional[str] = UserSortOrder.USER_ID.value,
         direction: Optional[str] = "f",
         search_term: Optional[str] = None,
-    ) -> Tuple[List[JsonDict], Dict[str, int]]:
+    ) -> Tuple[List[JsonDict], int]:
         """Function to retrieve a paginated list of users and their uploaded local media
         (size and number). This will return a json list of users and the
         total number of users matching the filter criteria.
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 6fee0f95b6..7198fd293f 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -261,7 +261,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
         self.assertEqual(
-            "Missing integer query parameter b'before_ts'", channel.json_body["error"]
+            "Missing integer query parameter 'before_ts'", channel.json_body["error"]
         )
 
     def test_invalid_parameter(self):
@@ -303,7 +303,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
         self.assertEqual(
-            "Boolean query parameter b'keep_profiles' must be one of ['true', 'false']",
+            "Boolean query parameter 'keep_profiles' must be one of ['true', 'false']",
             channel.json_body["error"],
         )