summary refs log tree commit diff
path: root/synapse/rest/key/v2/remote_key_resource.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/key/v2/remote_key_resource.py')
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py114
1 files changed, 63 insertions, 51 deletions
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py

index f597157581..19820886f5 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,15 +13,20 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Set +import re +from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple from signedjson.sign import sign_json -from synapse.api.errors import Codes, SynapseError +from twisted.web.server import Request + from synapse.crypto.keyring import ServerKeyFetcher -from synapse.http.server import DirectServeJsonResource, respond_with_json -from synapse.http.servlet import parse_integer, parse_json_object_from_request -from synapse.http.site import SynapseRequest +from synapse.http.server import HttpServer +from synapse.http.servlet import ( + RestServlet, + parse_integer, + parse_json_object_from_request, +) from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import yieldable_gather_results @@ -32,7 +37,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class RemoteKey(DirectServeJsonResource): +class RemoteKey(RestServlet): """HTTP resource for retrieving the TLS certificate and NACL signature verification keys for a collection of servers. Checks that the reported X.509 TLS certificate matches the one used in the HTTPS connection. Checks @@ -88,11 +93,7 @@ class RemoteKey(DirectServeJsonResource): } """ - isLeaf = True - def __init__(self, hs: "HomeServer"): - super().__init__() - self.fetcher = ServerKeyFetcher(hs) self.store = hs.get_datastores().main self.clock = hs.get_clock() @@ -101,47 +102,52 @@ class RemoteKey(DirectServeJsonResource): ) self.config = hs.config - async def _async_render_GET(self, request: SynapseRequest) -> None: - assert request.postpath is not None - if len(request.postpath) == 1: - (server,) = request.postpath - query: dict = {server.decode("ascii"): {}} - elif len(request.postpath) == 2: - server, key_id = request.postpath + def register(self, http_server: HttpServer) -> None: + http_server.register_paths( + "GET", + ( + re.compile( + "^/_matrix/key/v2/query/(?P<server>[^/]*)(/(?P<key_id>[^/]*))?$" + ), + ), + self.on_GET, + self.__class__.__name__, + ) + http_server.register_paths( + "POST", + (re.compile("^/_matrix/key/v2/query$"),), + self.on_POST, + self.__class__.__name__, + ) + + async def on_GET( + self, request: Request, server: str, key_id: Optional[str] = None + ) -> Tuple[int, JsonDict]: + if server and key_id: minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") arguments = {} if minimum_valid_until_ts is not None: arguments["minimum_valid_until_ts"] = minimum_valid_until_ts - query = {server.decode("ascii"): {key_id.decode("ascii"): arguments}} + query = {server: {key_id: arguments}} else: - raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND) + query = {server: {}} - await self.query_keys(request, query, query_remote_on_cache_miss=True) + return 200, await self.query_keys(query, query_remote_on_cache_miss=True) - async def _async_render_POST(self, request: SynapseRequest) -> None: + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) query = content["server_keys"] - await self.query_keys(request, query, query_remote_on_cache_miss=True) + return 200, await self.query_keys(query, query_remote_on_cache_miss=True) async def query_keys( - self, - request: SynapseRequest, - query: JsonDict, - query_remote_on_cache_miss: bool = False, - ) -> None: + self, query: JsonDict, query_remote_on_cache_miss: bool = False + ) -> JsonDict: logger.info("Handling query for keys %r", query) store_queries = [] for server_name, key_ids in query.items(): - if ( - self.federation_domain_whitelist is not None - and server_name not in self.federation_domain_whitelist - ): - logger.debug("Federation denied with %s", server_name) - continue - if not key_ids: key_ids = (None,) for key_id in key_ids: @@ -153,21 +159,28 @@ class RemoteKey(DirectServeJsonResource): time_now_ms = self.clock.time_msec() - # Note that the value is unused. + # Map server_name->key_id->int. Note that the value of the init is unused. + # XXX: why don't we just use a set? cache_misses: Dict[str, Dict[str, int]] = {} for (server_name, key_id, _), key_results in cached.items(): results = [(result["ts_added_ms"], result) for result in key_results] - if not results and key_id is not None: - cache_misses.setdefault(server_name, {})[key_id] = 0 + if key_id is None: + # all keys were requested. Just return what we have without worrying + # about validity + for _, result in results: + # Cast to bytes since postgresql returns a memoryview. + json_results.add(bytes(result["key_json"])) continue - if key_id is not None: + miss = False + if not results: + miss = True + else: ts_added_ms, most_recent_result = max(results) ts_valid_until_ms = most_recent_result["ts_valid_until_ms"] req_key = query.get(server_name, {}).get(key_id, {}) req_valid_until = req_key.get("minimum_valid_until_ts") - miss = False if req_valid_until is not None: if ts_valid_until_ms < req_valid_until: logger.debug( @@ -211,19 +224,20 @@ class RemoteKey(DirectServeJsonResource): ts_valid_until_ms, time_now_ms, ) - - if miss: - cache_misses.setdefault(server_name, {})[key_id] = 0 # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(most_recent_result["key_json"])) - else: - for _, result in results: - # Cast to bytes since postgresql returns a memoryview. - json_results.add(bytes(result["key_json"])) + + if miss and query_remote_on_cache_miss: + # only bother attempting to fetch keys from servers on our whitelist + if ( + self.federation_domain_whitelist is None + or server_name in self.federation_domain_whitelist + ): + cache_misses.setdefault(server_name, {})[key_id] = 0 # If there is a cache miss, request the missing keys, then recurse (and # ensure the result is sent). - if cache_misses and query_remote_on_cache_miss: + if cache_misses: await yieldable_gather_results( lambda t: self.fetcher.get_keys(*t), ( @@ -231,7 +245,7 @@ class RemoteKey(DirectServeJsonResource): for server_name, keys in cache_misses.items() ), ) - await self.query_keys(request, query, query_remote_on_cache_miss=False) + return await self.query_keys(query, query_remote_on_cache_miss=False) else: signed_keys = [] for key_json_raw in json_results: @@ -243,6 +257,4 @@ class RemoteKey(DirectServeJsonResource): signed_keys.append(key_json) - response = {"server_keys": signed_keys} - - respond_with_json(request, 200, response, canonical_json=True) + return {"server_keys": signed_keys}