summary refs log tree commit diff
diff options
context:
space:
mode:
authorGordan Trevis <GitHub@gordantrevis.me>2024-04-16 21:12:36 +0200
committerGitHub <noreply@github.com>2024-04-16 19:12:36 +0000
commitf0d6f140479d24754993b7fcaeb33e07f26e1c88 (patch)
tree8b861d191f72def055b3796d477aaa3f9b4442a1
parentMerge branch 'master' into develop (diff)
downloadsynapse-f0d6f140479d24754993b7fcaeb33e07f26e1c88.tar.xz
Parse Integer negative value validation (#16920)
-rw-r--r--changelog.d/16920.bugfix1
-rw-r--r--synapse/http/servlet.py90
-rw-r--r--synapse/rest/admin/federation.py38
-rw-r--r--synapse/rest/admin/media.py54
-rw-r--r--synapse/rest/admin/statistics.py34
-rw-r--r--synapse/rest/admin/users.py18
-rw-r--r--synapse/rest/client/room.py2
-rw-r--r--synapse/rest/media/preview_url_resource.py5
-rw-r--r--tests/rest/admin/test_media.py5
9 files changed, 89 insertions, 158 deletions
diff --git a/changelog.d/16920.bugfix b/changelog.d/16920.bugfix
new file mode 100644
index 0000000000..460f4f7160
--- /dev/null
+++ b/changelog.d/16920.bugfix
@@ -0,0 +1 @@
+Adds validation to ensure that the `limit` parameter on `/publicRooms` is non-negative.
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index b73d06f1d3..0ca08038f4 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -19,7 +19,8 @@
 #
 #
 
-""" This module contains base REST classes for constructing REST servlets. """
+"""This module contains base REST classes for constructing REST servlets."""
+
 import enum
 import logging
 from http import HTTPStatus
@@ -65,17 +66,49 @@ def parse_integer(request: Request, name: str, default: int) -> int: ...
 
 
 @overload
-def parse_integer(request: Request, name: str, *, required: Literal[True]) -> int: ...
+def parse_integer(
+    request: Request, name: str, *, default: int, negative: bool
+) -> int: ...
+
+
+@overload
+def parse_integer(
+    request: Request, name: str, *, default: int, negative: bool = False
+) -> int: ...
+
+
+@overload
+def parse_integer(
+    request: Request, name: str, *, required: Literal[True], negative: bool = False
+) -> int: ...
+
+
+@overload
+def parse_integer(
+    request: Request, name: str, *, default: Literal[None], negative: bool = False
+) -> None: ...
+
+
+@overload
+def parse_integer(request: Request, name: str, *, negative: bool) -> Optional[int]: ...
 
 
 @overload
 def parse_integer(
-    request: Request, name: str, default: Optional[int] = None, required: bool = False
+    request: Request,
+    name: str,
+    default: Optional[int] = None,
+    required: bool = False,
+    negative: bool = False,
 ) -> Optional[int]: ...
 
 
 def parse_integer(
-    request: Request, name: str, default: Optional[int] = None, required: bool = False
+    request: Request,
+    name: str,
+    default: Optional[int] = None,
+    required: bool = False,
+    negative: bool = False,
 ) -> Optional[int]:
     """Parse an integer parameter from the request string
 
@@ -85,16 +118,17 @@ def parse_integer(
         default: value to use if the parameter is absent, defaults to None.
         required: whether to raise a 400 SynapseError if the parameter is absent,
             defaults to False.
-
+        negative: whether to allow negative integers, defaults to True.
     Returns:
         An int value or the default.
 
     Raises:
-        SynapseError: if the parameter is absent and required, or if the
-            parameter is present and not an integer.
+        SynapseError: if the parameter is absent and required, if the
+            parameter is present and not an integer, or if the
+            parameter is illegitimate negative.
     """
     args: Mapping[bytes, Sequence[bytes]] = request.args  # type: ignore
-    return parse_integer_from_args(args, name, default, required)
+    return parse_integer_from_args(args, name, default, required, negative)
 
 
 @overload
@@ -120,6 +154,7 @@ def parse_integer_from_args(
     name: str,
     default: Optional[int] = None,
     required: bool = False,
+    negative: bool = False,
 ) -> Optional[int]: ...
 
 
@@ -128,6 +163,7 @@ def parse_integer_from_args(
     name: str,
     default: Optional[int] = None,
     required: bool = False,
+    negative: bool = True,
 ) -> Optional[int]:
     """Parse an integer parameter from the request string
 
@@ -137,33 +173,37 @@ def parse_integer_from_args(
         default: value to use if the parameter is absent, defaults to None.
         required: whether to raise a 400 SynapseError if the parameter is absent,
             defaults to False.
+        negative: whether to allow negative integers, defaults to True.
 
     Returns:
         An int value or the default.
 
     Raises:
-        SynapseError: if the parameter is absent and required, or if the
-            parameter is present and not an integer.
+        SynapseError: if the parameter is absent and required, if the
+            parameter is present and not an integer, or if the
+            parameter is illegitimate negative.
     """
     name_bytes = name.encode("ascii")
 
-    if name_bytes in args:
-        try:
-            return int(args[name_bytes][0])
-        except Exception:
-            message = "Query parameter %r must be an integer" % (name,)
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
-            )
-    else:
-        if required:
-            message = "Missing integer query parameter %r" % (name,)
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
-            )
-        else:
+    if name_bytes not in args:
+        if not required:
             return default
 
+        message = f"Missing required integer query parameter {name}"
+        raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM)
+
+    try:
+        integer = int(args[name_bytes][0])
+    except Exception:
+        message = f"Query parameter {name} must be an integer"
+        raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM)
+
+    if not negative and integer < 0:
+        message = f"Query parameter {name} must be a positive integer."
+        raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM)
+
+    return integer
+
 
 @overload
 def parse_boolean(request: Request, name: str, default: bool) -> bool: ...
diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py
index 045153e0cb..14ab4644cb 100644
--- a/synapse/rest/admin/federation.py
+++ b/synapse/rest/admin/federation.py
@@ -23,7 +23,7 @@ from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.constants import Direction
-from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.api.errors import NotFoundError, SynapseError
 from synapse.federation.transport.server import Authenticator
 from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
@@ -61,22 +61,8 @@ class ListDestinationsRestServlet(RestServlet):
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self._auth, request)
 
-        start = parse_integer(request, "from", default=0)
-        limit = parse_integer(request, "limit", default=100)
-
-        if start < 0:
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Query parameter from must be a string representing a positive integer.",
-                errcode=Codes.INVALID_PARAM,
-            )
-
-        if limit < 0:
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Query parameter limit must be a string representing a positive integer.",
-                errcode=Codes.INVALID_PARAM,
-            )
+        start = parse_integer(request, "from", default=0, negative=False)
+        limit = parse_integer(request, "limit", default=100, negative=False)
 
         destination = parse_string(request, "destination")
 
@@ -195,22 +181,8 @@ class DestinationMembershipRestServlet(RestServlet):
         if not await self._store.is_destination_known(destination):
             raise NotFoundError("Unknown destination")
 
-        start = parse_integer(request, "from", default=0)
-        limit = parse_integer(request, "limit", default=100)
-
-        if start < 0:
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Query parameter from must be a string representing a positive integer.",
-                errcode=Codes.INVALID_PARAM,
-            )
-
-        if limit < 0:
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Query parameter limit must be a string representing a positive integer.",
-                errcode=Codes.INVALID_PARAM,
-            )
+        start = parse_integer(request, "from", default=0, negative=False)
+        limit = parse_integer(request, "limit", default=100, negative=False)
 
         direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
 
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 27f0808658..a05b7252ec 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -311,29 +311,17 @@ class DeleteMediaByDateSize(RestServlet):
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        before_ts = parse_integer(request, "before_ts", required=True)
-        size_gt = parse_integer(request, "size_gt", default=0)
+        before_ts = parse_integer(request, "before_ts", required=True, negative=False)
+        size_gt = parse_integer(request, "size_gt", default=0, negative=False)
         keep_profiles = parse_boolean(request, "keep_profiles", default=True)
 
-        if before_ts < 0:
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Query parameter before_ts must be a positive integer.",
-                errcode=Codes.INVALID_PARAM,
-            )
-        elif before_ts < 30000000000:  # Dec 1970 in milliseconds, Aug 2920 in seconds
+        if before_ts < 30000000000:  # Dec 1970 in milliseconds, Aug 2920 in seconds
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "Query parameter before_ts you provided is from the year 1970. "
                 + "Double check that you are providing a timestamp in milliseconds.",
                 errcode=Codes.INVALID_PARAM,
             )
-        if size_gt < 0:
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Query parameter size_gt must be a string representing a positive integer.",
-                errcode=Codes.INVALID_PARAM,
-            )
 
         # This check is useless, we keep it for the legacy endpoint only.
         if server_name is not None and self.server_name != server_name:
@@ -389,22 +377,8 @@ class UserMediaRestServlet(RestServlet):
         if user is None:
             raise NotFoundError("Unknown user")
 
-        start = parse_integer(request, "from", default=0)
-        limit = parse_integer(request, "limit", default=100)
-
-        if start < 0:
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Query parameter from must be a string representing a positive integer.",
-                errcode=Codes.INVALID_PARAM,
-            )
-
-        if limit < 0:
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Query parameter limit must be a string representing a positive integer.",
-                errcode=Codes.INVALID_PARAM,
-            )
+        start = parse_integer(request, "from", default=0, negative=False)
+        limit = parse_integer(request, "limit", default=100, negative=False)
 
         # If neither `order_by` nor `dir` is set, set the default order
         # to newest media is on top for backward compatibility.
@@ -447,22 +421,8 @@ class UserMediaRestServlet(RestServlet):
         if user is None:
             raise NotFoundError("Unknown user")
 
-        start = parse_integer(request, "from", default=0)
-        limit = parse_integer(request, "limit", default=100)
-
-        if start < 0:
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Query parameter from must be a string representing a positive integer.",
-                errcode=Codes.INVALID_PARAM,
-            )
-
-        if limit < 0:
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Query parameter limit must be a string representing a positive integer.",
-                errcode=Codes.INVALID_PARAM,
-            )
+        start = parse_integer(request, "from", default=0, negative=False)
+        limit = parse_integer(request, "limit", default=100, negative=False)
 
         # If neither `order_by` nor `dir` is set, set the default order
         # to newest media is on top for backward compatibility.
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
index 832f20402e..dc27a41dd9 100644
--- a/synapse/rest/admin/statistics.py
+++ b/synapse/rest/admin/statistics.py
@@ -63,38 +63,12 @@ class UserMediaStatisticsRestServlet(RestServlet):
             ),
         )
 
-        start = parse_integer(request, "from", default=0)
-        if start < 0:
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Query parameter from must be a string representing a positive integer.",
-                errcode=Codes.INVALID_PARAM,
-            )
-
-        limit = parse_integer(request, "limit", default=100)
-        if limit < 0:
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Query parameter limit must be a string representing a positive integer.",
-                errcode=Codes.INVALID_PARAM,
-            )
+        start = parse_integer(request, "from", default=0, negative=False)
+        limit = parse_integer(request, "limit", default=100, negative=False)
+        from_ts = parse_integer(request, "from_ts", default=0, negative=False)
+        until_ts = parse_integer(request, "until_ts", negative=False)
 
-        from_ts = parse_integer(request, "from_ts", default=0)
-        if from_ts < 0:
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Query parameter from_ts must be a string representing a positive integer.",
-                errcode=Codes.INVALID_PARAM,
-            )
-
-        until_ts = parse_integer(request, "until_ts")
         if until_ts is not None:
-            if until_ts < 0:
-                raise SynapseError(
-                    HTTPStatus.BAD_REQUEST,
-                    "Query parameter until_ts must be a string representing a positive integer.",
-                    errcode=Codes.INVALID_PARAM,
-                )
             if until_ts <= from_ts:
                 raise SynapseError(
                     HTTPStatus.BAD_REQUEST,
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 4e34e46512..5bf12c4979 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -90,22 +90,8 @@ class UsersRestServletV2(RestServlet):
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        start = parse_integer(request, "from", default=0)
-        limit = parse_integer(request, "limit", default=100)
-
-        if start < 0:
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Query parameter from must be a string representing a positive integer.",
-                errcode=Codes.INVALID_PARAM,
-            )
-
-        if limit < 0:
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Query parameter limit must be a string representing a positive integer.",
-                errcode=Codes.INVALID_PARAM,
-            )
+        start = parse_integer(request, "from", default=0, negative=False)
+        limit = parse_integer(request, "limit", default=100, negative=False)
 
         user_id = parse_string(request, "user_id")
         name = parse_string(request, "name", encoding="utf-8")
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 65dedb8b92..4eeadf8779 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -499,7 +499,7 @@ class PublicRoomListRestServlet(RestServlet):
             if server:
                 raise e
 
-        limit: Optional[int] = parse_integer(request, "limit", 0)
+        limit: Optional[int] = parse_integer(request, "limit", 0, negative=False)
         since_token = parse_string(request, "since")
 
         if limit == 0:
diff --git a/synapse/rest/media/preview_url_resource.py b/synapse/rest/media/preview_url_resource.py
index 6724986fcc..bfeff2179b 100644
--- a/synapse/rest/media/preview_url_resource.py
+++ b/synapse/rest/media/preview_url_resource.py
@@ -72,9 +72,6 @@ class PreviewUrlResource(RestServlet):
         # XXX: if get_user_by_req fails, what should we do in an async render?
         requester = await self.auth.get_user_by_req(request)
         url = parse_string(request, "url", required=True)
-        ts = parse_integer(request, "ts")
-        if ts is None:
-            ts = self.clock.time_msec()
-
+        ts = parse_integer(request, "ts", default=self.clock.time_msec())
         og = await self.url_previewer.preview(url, requester.user, ts)
         respond_with_json_bytes(request, 200, og, send_cors=True)
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 493e1d1919..f378165513 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -277,7 +277,8 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
         self.assertEqual(
-            "Missing integer query parameter 'before_ts'", channel.json_body["error"]
+            "Missing required integer query parameter before_ts",
+            channel.json_body["error"],
         )
 
     def test_invalid_parameter(self) -> None:
@@ -320,7 +321,7 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
         self.assertEqual(
-            "Query parameter size_gt must be a string representing a positive integer.",
+            "Query parameter size_gt must be a positive integer.",
             channel.json_body["error"],
         )