summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/errors.py8
-rw-r--r--synapse/rest/client/keys.py16
2 files changed, 23 insertions, 1 deletions
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index dc662bca83..9480f448d7 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -147,6 +147,14 @@ class SynapseError(CodeMessageException):
         return cs_error(self.msg, self.errcode)
 
 
+class InvalidAPICallError(SynapseError):
+    """You called an existing API endpoint, but fed that endpoint
+    invalid or incomplete data."""
+
+    def __init__(self, msg: str):
+        super().__init__(HTTPStatus.BAD_REQUEST, msg, Codes.BAD_JSON)
+
+
 class ProxiedRequestError(SynapseError):
     """An error from a general matrix endpoint, eg. from a proxied Matrix API call.
 
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index d0d9d30d40..012491f597 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -15,8 +15,9 @@
 # limitations under the License.
 
 import logging
+from typing import Any
 
-from synapse.api.errors import SynapseError
+from synapse.api.errors import InvalidAPICallError, SynapseError
 from synapse.http.servlet import (
     RestServlet,
     parse_integer,
@@ -163,6 +164,19 @@ class KeyQueryServlet(RestServlet):
         device_id = requester.device_id
         timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
+
+        device_keys = body.get("device_keys")
+        if not isinstance(device_keys, dict):
+            raise InvalidAPICallError("'device_keys' must be a JSON object")
+
+        def is_list_of_strings(values: Any) -> bool:
+            return isinstance(values, list) and all(isinstance(v, str) for v in values)
+
+        if any(not is_list_of_strings(keys) for keys in device_keys.values()):
+            raise InvalidAPICallError(
+                "'device_keys' values must be a list of strings",
+            )
+
         result = await self.e2e_keys_handler.query_devices(
             body, timeout, user_id, device_id
         )