diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 981fd1f58a..0aaa838d04 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -16,6 +16,7 @@ import logging
import re
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
+from pydantic import Extra, StrictInt, StrictStr
from signedjson.sign import sign_json
from twisted.web.server import Request
@@ -24,9 +25,10 @@ from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
+ parse_and_validate_json_object_from_request,
parse_integer,
- parse_json_object_from_request,
)
+from synapse.rest.models import RequestBodyModel
from synapse.storage.keys import FetchKeyResultForRemote
from synapse.types import JsonDict
from synapse.util import json_decoder
@@ -38,6 +40,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class _KeyQueryCriteriaDataModel(RequestBodyModel):
+ class Config:
+ extra = Extra.allow
+
+ minimum_valid_until_ts: Optional[StrictInt]
+
+
class RemoteKey(RestServlet):
"""HTTP resource for retrieving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
@@ -96,6 +105,9 @@ class RemoteKey(RestServlet):
CATEGORY = "Federation requests"
+ class PostBody(RequestBodyModel):
+ server_keys: Dict[StrictStr, Dict[StrictStr, _KeyQueryCriteriaDataModel]]
+
def __init__(self, hs: "HomeServer"):
self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastores().main
@@ -137,24 +149,29 @@ class RemoteKey(RestServlet):
)
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
- arguments = {}
- if minimum_valid_until_ts is not None:
- arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
- query = {server: {key_id: arguments}}
+ query = {
+ server: {
+ key_id: _KeyQueryCriteriaDataModel(
+ minimum_valid_until_ts=minimum_valid_until_ts
+ )
+ }
+ }
else:
query = {server: {}}
return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
+ content = parse_and_validate_json_object_from_request(request, self.PostBody)
- query = content["server_keys"]
+ query = content.server_keys
return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
async def query_keys(
- self, query: JsonDict, query_remote_on_cache_miss: bool = False
+ self,
+ query: Dict[str, Dict[str, _KeyQueryCriteriaDataModel]],
+ query_remote_on_cache_miss: bool = False,
) -> JsonDict:
logger.info("Handling query for keys %r", query)
@@ -196,8 +213,10 @@ class RemoteKey(RestServlet):
else:
ts_added_ms = key_result.added_ts
ts_valid_until_ms = key_result.valid_until_ts
- req_key = query.get(server_name, {}).get(key_id, {})
- req_valid_until = req_key.get("minimum_valid_until_ts")
+ req_key = query.get(server_name, {}).get(
+ key_id, _KeyQueryCriteriaDataModel(minimum_valid_until_ts=None)
+ )
+ req_valid_until = req_key.minimum_valid_until_ts
if req_valid_until is not None:
if ts_valid_until_ms < req_valid_until:
logger.debug(
|