diff options
Diffstat (limited to 'synapse/rest')
-rw-r--r-- | synapse/rest/key/v2/remote_key_resource.py | 39 |
1 files changed, 29 insertions, 10 deletions
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 981fd1f58a..0aaa838d04 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -16,6 +16,7 @@ import logging import re from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple +from pydantic import Extra, StrictInt, StrictStr from signedjson.sign import sign_json from twisted.web.server import Request @@ -24,9 +25,10 @@ from synapse.crypto.keyring import ServerKeyFetcher from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, + parse_and_validate_json_object_from_request, parse_integer, - parse_json_object_from_request, ) +from synapse.rest.models import RequestBodyModel from synapse.storage.keys import FetchKeyResultForRemote from synapse.types import JsonDict from synapse.util import json_decoder @@ -38,6 +40,13 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class _KeyQueryCriteriaDataModel(RequestBodyModel): + class Config: + extra = Extra.allow + + minimum_valid_until_ts: Optional[StrictInt] + + class RemoteKey(RestServlet): """HTTP resource for retrieving the TLS certificate and NACL signature verification keys for a collection of servers. Checks that the reported @@ -96,6 +105,9 @@ class RemoteKey(RestServlet): CATEGORY = "Federation requests" + class PostBody(RequestBodyModel): + server_keys: Dict[StrictStr, Dict[StrictStr, _KeyQueryCriteriaDataModel]] + def __init__(self, hs: "HomeServer"): self.fetcher = ServerKeyFetcher(hs) self.store = hs.get_datastores().main @@ -137,24 +149,29 @@ class RemoteKey(RestServlet): ) minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") - arguments = {} - if minimum_valid_until_ts is not None: - arguments["minimum_valid_until_ts"] = minimum_valid_until_ts - query = {server: {key_id: arguments}} + query = { + server: { + key_id: _KeyQueryCriteriaDataModel( + minimum_valid_until_ts=minimum_valid_until_ts + ) + } + } else: query = {server: {}} return 200, await self.query_keys(query, query_remote_on_cache_miss=True) async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) + content = parse_and_validate_json_object_from_request(request, self.PostBody) - query = content["server_keys"] + query = content.server_keys return 200, await self.query_keys(query, query_remote_on_cache_miss=True) async def query_keys( - self, query: JsonDict, query_remote_on_cache_miss: bool = False + self, + query: Dict[str, Dict[str, _KeyQueryCriteriaDataModel]], + query_remote_on_cache_miss: bool = False, ) -> JsonDict: logger.info("Handling query for keys %r", query) @@ -196,8 +213,10 @@ class RemoteKey(RestServlet): else: ts_added_ms = key_result.added_ts ts_valid_until_ms = key_result.valid_until_ts - req_key = query.get(server_name, {}).get(key_id, {}) - req_valid_until = req_key.get("minimum_valid_until_ts") + req_key = query.get(server_name, {}).get( + key_id, _KeyQueryCriteriaDataModel(minimum_valid_until_ts=None) + ) + req_valid_until = req_key.minimum_valid_until_ts if req_valid_until is not None: if ts_valid_until_ms < req_valid_until: logger.debug( |