summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16923.bugfix1
-rw-r--r--synapse/http/servlet.py82
-rw-r--r--synapse/rest/admin/rooms.py36
-rw-r--r--synapse/rest/client/room.py35
-rw-r--r--tests/rest/admin/test_room.py61
-rw-r--r--tests/rest/client/test_rooms.py52
6 files changed, 220 insertions, 47 deletions
diff --git a/changelog.d/16923.bugfix b/changelog.d/16923.bugfix
new file mode 100644
index 0000000000..bd6f24925e
--- /dev/null
+++ b/changelog.d/16923.bugfix
@@ -0,0 +1 @@
+Return `400 M_NOT_JSON` upon receiving invalid JSON in query parameters across various client and admin endpoints, rather than an internal server error.
\ No newline at end of file
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 0ca08038f4..ab12951da8 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -23,6 +23,7 @@
 
 import enum
 import logging
+import urllib.parse as urlparse
 from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
@@ -450,6 +451,87 @@ def parse_string(
     )
 
 
+def parse_json(
+    request: Request,
+    name: str,
+    default: Optional[dict] = None,
+    required: bool = False,
+    encoding: str = "ascii",
+) -> Optional[JsonDict]:
+    """
+    Parse a JSON parameter from the request query string.
+
+    Args:
+        request: the twisted HTTP request.
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent,
+            defaults to None.
+        required: whether to raise a 400 SynapseError if the
+           parameter is absent, defaults to False.
+        encoding: The encoding to decode the string content with.
+
+    Returns:
+        A JSON value, or `default` if the named query parameter was not found
+        and `required` was False.
+
+    Raises:
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present and not a JSON object.
+    """
+    args: Mapping[bytes, Sequence[bytes]] = request.args  # type: ignore
+    return parse_json_from_args(
+        args,
+        name,
+        default,
+        required=required,
+        encoding=encoding,
+    )
+
+
+def parse_json_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    default: Optional[dict] = None,
+    required: bool = False,
+    encoding: str = "ascii",
+) -> Optional[JsonDict]:
+    """
+    Parse a JSON parameter from the request query string.
+
+    Args:
+        args: a mapping of request args as bytes to a list of bytes (e.g. request.args).
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent,
+            defaults to None.
+        required: whether to raise a 400 SynapseError if the
+            parameter is absent, defaults to False.
+        encoding: the encoding to decode the string content with.
+
+        A JSON value, or `default` if the named query parameter was not found
+        and `required` was False.
+
+    Raises:
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present and not a JSON object.
+    """
+    name_bytes = name.encode("ascii")
+
+    if name_bytes not in args:
+        if not required:
+            return default
+
+        message = f"Missing required integer query parameter {name}"
+        raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM)
+
+    json_str = parse_string_from_args(args, name, required=True, encoding=encoding)
+
+    try:
+        return json_decoder.decode(urlparse.unquote(json_str))
+    except Exception:
+        message = f"Query parameter {name} must be a valid JSON object"
+        raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.NOT_JSON)
+
+
 EnumT = TypeVar("EnumT", bound=enum.Enum)
 
 
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 4252f98a6c..0d86a4e15f 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -21,7 +21,6 @@
 import logging
 from http import HTTPStatus
 from typing import TYPE_CHECKING, List, Optional, Tuple, cast
-from urllib import parse as urlparse
 
 import attr
 
@@ -38,6 +37,7 @@ from synapse.http.servlet import (
     assert_params_in_dict,
     parse_enum,
     parse_integer,
+    parse_json,
     parse_json_object_from_request,
     parse_string,
 )
@@ -51,7 +51,6 @@ from synapse.storage.databases.main.room import RoomSortOrder
 from synapse.streams.config import PaginationConfig
 from synapse.types import JsonDict, RoomID, ScheduledTask, UserID, create_requester
 from synapse.types.state import StateFilter
-from synapse.util import json_decoder
 
 if TYPE_CHECKING:
     from synapse.api.auth import Auth
@@ -776,14 +775,8 @@ class RoomEventContextServlet(RestServlet):
         limit = parse_integer(request, "limit", default=10)
 
         # picking the API shape for symmetry with /messages
-        filter_str = parse_string(request, "filter", encoding="utf-8")
-        if filter_str:
-            filter_json = urlparse.unquote(filter_str)
-            event_filter: Optional[Filter] = Filter(
-                self._hs, json_decoder.decode(filter_json)
-            )
-        else:
-            event_filter = None
+        filter_json = parse_json(request, "filter", encoding="utf-8")
+        event_filter = Filter(self._hs, filter_json) if filter_json else None
 
         event_context = await self.room_context_handler.get_event_context(
             requester,
@@ -914,21 +907,16 @@ class RoomMessagesRestServlet(RestServlet):
         )
         # Twisted will have processed the args by now.
         assert request.args is not None
+
+        filter_json = parse_json(request, "filter", encoding="utf-8")
+        event_filter = Filter(self._hs, filter_json) if filter_json else None
+
         as_client_event = b"raw" not in request.args
-        filter_str = parse_string(request, "filter", encoding="utf-8")
-        if filter_str:
-            filter_json = urlparse.unquote(filter_str)
-            event_filter: Optional[Filter] = Filter(
-                self._hs, json_decoder.decode(filter_json)
-            )
-            if (
-                event_filter
-                and event_filter.filter_json.get("event_format", "client")
-                == "federation"
-            ):
-                as_client_event = False
-        else:
-            event_filter = None
+        if (
+            event_filter
+            and event_filter.filter_json.get("event_format", "client") == "federation"
+        ):
+            as_client_event = False
 
         msgs = await self._pagination_handler.get_messages(
             room_id=room_id,
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 4eeadf8779..e4c7dd1a58 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -52,6 +52,7 @@ from synapse.http.servlet import (
     parse_boolean,
     parse_enum,
     parse_integer,
+    parse_json,
     parse_json_object_from_request,
     parse_string,
     parse_strings_from_args,
@@ -65,7 +66,6 @@ from synapse.rest.client.transactions import HttpTransactionCache
 from synapse.streams.config import PaginationConfig
 from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
 from synapse.types.state import StateFilter
-from synapse.util import json_decoder
 from synapse.util.cancellation import cancellable
 from synapse.util.stringutils import parse_and_validate_server_name, random_string
 
@@ -703,21 +703,16 @@ class RoomMessageListRestServlet(RestServlet):
         )
         # Twisted will have processed the args by now.
         assert request.args is not None
+
+        filter_json = parse_json(request, "filter", encoding="utf-8")
+        event_filter = Filter(self._hs, filter_json) if filter_json else None
+
         as_client_event = b"raw" not in request.args
-        filter_str = parse_string(request, "filter", encoding="utf-8")
-        if filter_str:
-            filter_json = urlparse.unquote(filter_str)
-            event_filter: Optional[Filter] = Filter(
-                self._hs, json_decoder.decode(filter_json)
-            )
-            if (
-                event_filter
-                and event_filter.filter_json.get("event_format", "client")
-                == "federation"
-            ):
-                as_client_event = False
-        else:
-            event_filter = None
+        if (
+            event_filter
+            and event_filter.filter_json.get("event_format", "client") == "federation"
+        ):
+            as_client_event = False
 
         msgs = await self.pagination_handler.get_messages(
             room_id=room_id,
@@ -898,14 +893,8 @@ class RoomEventContextServlet(RestServlet):
         limit = parse_integer(request, "limit", default=10)
 
         # picking the API shape for symmetry with /messages
-        filter_str = parse_string(request, "filter", encoding="utf-8")
-        if filter_str:
-            filter_json = urlparse.unquote(filter_str)
-            event_filter: Optional[Filter] = Filter(
-                self._hs, json_decoder.decode(filter_json)
-            )
-        else:
-            event_filter = None
+        filter_json = parse_json(request, "filter", encoding="utf-8")
+        event_filter = Filter(self._hs, filter_json) if filter_json else None
 
         event_context = await self.room_context_handler.get_event_context(
             requester, room_id, event_id, limit, event_filter
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 0b669b6ee7..7562747260 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -21,6 +21,7 @@
 import json
 import time
 import urllib.parse
+from http import HTTPStatus
 from typing import List, Optional
 from unittest.mock import AsyncMock, Mock
 
@@ -2190,6 +2191,33 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
         chunk = channel.json_body["chunk"]
         self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
 
+    def test_room_message_filter_query_validation(self) -> None:
+        # Test json validation in (filter) query parameter.
+        # Does not test the validity of the filter, only the json validation.
+
+        # Check Get with valid json filter parameter, expect 200.
+        valid_filter_str = '{"types": ["m.room.message"]}'
+        channel = self.make_request(
+            "GET",
+            f"/_synapse/admin/v1/rooms/{self.room_id}/messages?dir=b&filter={valid_filter_str}",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+
+        # Check Get with invalid json filter parameter, expect 400 NOT_JSON.
+        invalid_filter_str = "}}}{}"
+        channel = self.make_request(
+            "GET",
+            f"/_synapse/admin/v1/rooms/{self.room_id}/messages?dir=b&filter={invalid_filter_str}",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
+        self.assertEqual(
+            channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
+        )
+
 
 class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
     servlets = [
@@ -2522,6 +2550,39 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
             else:
                 self.fail("Event %s from events_after not found" % j)
 
+    def test_room_event_context_filter_query_validation(self) -> None:
+        # Test json validation in (filter) query parameter.
+        # Does not test the validity of the filter, only the json validation.
+
+        # Create a user with room and event_id.
+        user_id = self.register_user("test", "test")
+        user_tok = self.login("test", "test")
+        room_id = self.helper.create_room_as(user_id, tok=user_tok)
+        event_id = self.helper.send(room_id, "message 1", tok=user_tok)["event_id"]
+
+        # Check Get with valid json filter parameter, expect 200.
+        valid_filter_str = '{"types": ["m.room.message"]}'
+        channel = self.make_request(
+            "GET",
+            f"/_synapse/admin/v1/rooms/{room_id}/context/{event_id}?filter={valid_filter_str}",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+
+        # Check Get with invalid json filter parameter, expect 400 NOT_JSON.
+        invalid_filter_str = "}}}{}"
+        channel = self.make_request(
+            "GET",
+            f"/_synapse/admin/v1/rooms/{room_id}/context/{event_id}?filter={invalid_filter_str}",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
+        self.assertEqual(
+            channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
+        )
+
 
 class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
     servlets = [
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 1364615085..b796163dcb 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -2175,6 +2175,31 @@ class RoomMessageListTestCase(RoomBase):
         chunk = channel.json_body["chunk"]
         self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
 
+    def test_room_message_filter_query_validation(self) -> None:
+        # Test json validation in (filter) query parameter.
+        # Does not test the validity of the filter, only the json validation.
+
+        # Check Get with valid json filter parameter, expect 200.
+        valid_filter_str = '{"types": ["m.room.message"]}'
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room_id}/messages?access_token=x&dir=b&filter={valid_filter_str}",
+        )
+
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+
+        # Check Get with invalid json filter parameter, expect 400 NOT_JSON.
+        invalid_filter_str = "}}}{}"
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room_id}/messages?access_token=x&dir=b&filter={invalid_filter_str}",
+        )
+
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
+        self.assertEqual(
+            channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
+        )
+
 
 class RoomMessageFilterTestCase(RoomBase):
     """Tests /rooms/$room_id/messages REST events."""
@@ -3213,6 +3238,33 @@ class ContextTestCase(unittest.HomeserverTestCase):
         self.assertDictEqual(events_after[0].get("content"), {}, events_after[0])
         self.assertEqual(events_after[1].get("content"), {}, events_after[1])
 
+    def test_room_event_context_filter_query_validation(self) -> None:
+        # Test json validation in (filter) query parameter.
+        # Does not test the validity of the filter, only the json validation.
+        event_id = self.helper.send(self.room_id, "message 7", tok=self.tok)["event_id"]
+
+        # Check Get with valid json filter parameter, expect 200.
+        valid_filter_str = '{"types": ["m.room.message"]}'
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room_id}/context/{event_id}?filter={valid_filter_str}",
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+
+        # Check Get with invalid json filter parameter, expect 400 NOT_JSON.
+        invalid_filter_str = "}}}{}"
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room_id}/context/{event_id}?filter={invalid_filter_str}",
+            access_token=self.tok,
+        )
+
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body)
+        self.assertEqual(
+            channel.json_body["errcode"], Codes.NOT_JSON, channel.json_body
+        )
+
 
 class RoomAliasListTestCase(unittest.HomeserverTestCase):
     servlets = [