summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16183.misc1
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py39
2 files changed, 30 insertions, 10 deletions
diff --git a/changelog.d/16183.misc b/changelog.d/16183.misc
new file mode 100644
index 0000000000..305d5baa6e
--- /dev/null
+++ b/changelog.d/16183.misc
@@ -0,0 +1 @@
+Improve error reporting of invalid data passed to `/_matrix/key/v2/query`.
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 981fd1f58a..0aaa838d04 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -16,6 +16,7 @@ import logging
 import re
 from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
 
+from pydantic import Extra, StrictInt, StrictStr
 from signedjson.sign import sign_json
 
 from twisted.web.server import Request
@@ -24,9 +25,10 @@ from synapse.crypto.keyring import ServerKeyFetcher
 from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
+    parse_and_validate_json_object_from_request,
     parse_integer,
-    parse_json_object_from_request,
 )
+from synapse.rest.models import RequestBodyModel
 from synapse.storage.keys import FetchKeyResultForRemote
 from synapse.types import JsonDict
 from synapse.util import json_decoder
@@ -38,6 +40,13 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class _KeyQueryCriteriaDataModel(RequestBodyModel):
+    class Config:
+        extra = Extra.allow
+
+    minimum_valid_until_ts: Optional[StrictInt]
+
+
 class RemoteKey(RestServlet):
     """HTTP resource for retrieving the TLS certificate and NACL signature
     verification keys for a collection of servers. Checks that the reported
@@ -96,6 +105,9 @@ class RemoteKey(RestServlet):
 
     CATEGORY = "Federation requests"
 
+    class PostBody(RequestBodyModel):
+        server_keys: Dict[StrictStr, Dict[StrictStr, _KeyQueryCriteriaDataModel]]
+
     def __init__(self, hs: "HomeServer"):
         self.fetcher = ServerKeyFetcher(hs)
         self.store = hs.get_datastores().main
@@ -137,24 +149,29 @@ class RemoteKey(RestServlet):
             )
 
             minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
-            arguments = {}
-            if minimum_valid_until_ts is not None:
-                arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
-            query = {server: {key_id: arguments}}
+            query = {
+                server: {
+                    key_id: _KeyQueryCriteriaDataModel(
+                        minimum_valid_until_ts=minimum_valid_until_ts
+                    )
+                }
+            }
         else:
             query = {server: {}}
 
         return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
 
     async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
-        content = parse_json_object_from_request(request)
+        content = parse_and_validate_json_object_from_request(request, self.PostBody)
 
-        query = content["server_keys"]
+        query = content.server_keys
 
         return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
 
     async def query_keys(
-        self, query: JsonDict, query_remote_on_cache_miss: bool = False
+        self,
+        query: Dict[str, Dict[str, _KeyQueryCriteriaDataModel]],
+        query_remote_on_cache_miss: bool = False,
     ) -> JsonDict:
         logger.info("Handling query for keys %r", query)
 
@@ -196,8 +213,10 @@ class RemoteKey(RestServlet):
             else:
                 ts_added_ms = key_result.added_ts
                 ts_valid_until_ms = key_result.valid_until_ts
-                req_key = query.get(server_name, {}).get(key_id, {})
-                req_valid_until = req_key.get("minimum_valid_until_ts")
+                req_key = query.get(server_name, {}).get(
+                    key_id, _KeyQueryCriteriaDataModel(minimum_valid_until_ts=None)
+                )
+                req_valid_until = req_key.minimum_valid_until_ts
                 if req_valid_until is not None:
                     if ts_valid_until_ms < req_valid_until:
                         logger.debug(