diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index f597157581..19820886f5 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,15 +13,20 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, Set
+import re
+from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
from signedjson.sign import sign_json
-from synapse.api.errors import Codes, SynapseError
+from twisted.web.server import Request
+
from synapse.crypto.keyring import ServerKeyFetcher
-from synapse.http.server import DirectServeJsonResource, respond_with_json
-from synapse.http.servlet import parse_integer, parse_json_object_from_request
-from synapse.http.site import SynapseRequest
+from synapse.http.server import HttpServer
+from synapse.http.servlet import (
+ RestServlet,
+ parse_integer,
+ parse_json_object_from_request,
+)
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
@@ -32,7 +37,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class RemoteKey(DirectServeJsonResource):
+class RemoteKey(RestServlet):
"""HTTP resource for retrieving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
@@ -88,11 +93,7 @@ class RemoteKey(DirectServeJsonResource):
}
"""
- isLeaf = True
-
def __init__(self, hs: "HomeServer"):
- super().__init__()
-
self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
@@ -101,47 +102,52 @@ class RemoteKey(DirectServeJsonResource):
)
self.config = hs.config
- async def _async_render_GET(self, request: SynapseRequest) -> None:
- assert request.postpath is not None
- if len(request.postpath) == 1:
- (server,) = request.postpath
- query: dict = {server.decode("ascii"): {}}
- elif len(request.postpath) == 2:
- server, key_id = request.postpath
+ def register(self, http_server: HttpServer) -> None:
+ http_server.register_paths(
+ "GET",
+ (
+ re.compile(
+ "^/_matrix/key/v2/query/(?P<server>[^/]*)(/(?P<key_id>[^/]*))?$"
+ ),
+ ),
+ self.on_GET,
+ self.__class__.__name__,
+ )
+ http_server.register_paths(
+ "POST",
+ (re.compile("^/_matrix/key/v2/query$"),),
+ self.on_POST,
+ self.__class__.__name__,
+ )
+
+ async def on_GET(
+ self, request: Request, server: str, key_id: Optional[str] = None
+ ) -> Tuple[int, JsonDict]:
+ if server and key_id:
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
arguments = {}
if minimum_valid_until_ts is not None:
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
- query = {server.decode("ascii"): {key_id.decode("ascii"): arguments}}
+ query = {server: {key_id: arguments}}
else:
- raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND)
+ query = {server: {}}
- await self.query_keys(request, query, query_remote_on_cache_miss=True)
+ return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
- async def _async_render_POST(self, request: SynapseRequest) -> None:
+ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
query = content["server_keys"]
- await self.query_keys(request, query, query_remote_on_cache_miss=True)
+ return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
async def query_keys(
- self,
- request: SynapseRequest,
- query: JsonDict,
- query_remote_on_cache_miss: bool = False,
- ) -> None:
+ self, query: JsonDict, query_remote_on_cache_miss: bool = False
+ ) -> JsonDict:
logger.info("Handling query for keys %r", query)
store_queries = []
for server_name, key_ids in query.items():
- if (
- self.federation_domain_whitelist is not None
- and server_name not in self.federation_domain_whitelist
- ):
- logger.debug("Federation denied with %s", server_name)
- continue
-
if not key_ids:
key_ids = (None,)
for key_id in key_ids:
@@ -153,21 +159,28 @@ class RemoteKey(DirectServeJsonResource):
time_now_ms = self.clock.time_msec()
- # Note that the value is unused.
+ # Map server_name->key_id->int. Note that the value of the init is unused.
+ # XXX: why don't we just use a set?
cache_misses: Dict[str, Dict[str, int]] = {}
for (server_name, key_id, _), key_results in cached.items():
results = [(result["ts_added_ms"], result) for result in key_results]
- if not results and key_id is not None:
- cache_misses.setdefault(server_name, {})[key_id] = 0
+ if key_id is None:
+ # all keys were requested. Just return what we have without worrying
+ # about validity
+ for _, result in results:
+ # Cast to bytes since postgresql returns a memoryview.
+ json_results.add(bytes(result["key_json"]))
continue
- if key_id is not None:
+ miss = False
+ if not results:
+ miss = True
+ else:
ts_added_ms, most_recent_result = max(results)
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
req_key = query.get(server_name, {}).get(key_id, {})
req_valid_until = req_key.get("minimum_valid_until_ts")
- miss = False
if req_valid_until is not None:
if ts_valid_until_ms < req_valid_until:
logger.debug(
@@ -211,19 +224,20 @@ class RemoteKey(DirectServeJsonResource):
ts_valid_until_ms,
time_now_ms,
)
-
- if miss:
- cache_misses.setdefault(server_name, {})[key_id] = 0
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(most_recent_result["key_json"]))
- else:
- for _, result in results:
- # Cast to bytes since postgresql returns a memoryview.
- json_results.add(bytes(result["key_json"]))
+
+ if miss and query_remote_on_cache_miss:
+ # only bother attempting to fetch keys from servers on our whitelist
+ if (
+ self.federation_domain_whitelist is None
+ or server_name in self.federation_domain_whitelist
+ ):
+ cache_misses.setdefault(server_name, {})[key_id] = 0
# If there is a cache miss, request the missing keys, then recurse (and
# ensure the result is sent).
- if cache_misses and query_remote_on_cache_miss:
+ if cache_misses:
await yieldable_gather_results(
lambda t: self.fetcher.get_keys(*t),
(
@@ -231,7 +245,7 @@ class RemoteKey(DirectServeJsonResource):
for server_name, keys in cache_misses.items()
),
)
- await self.query_keys(request, query, query_remote_on_cache_miss=False)
+ return await self.query_keys(query, query_remote_on_cache_miss=False)
else:
signed_keys = []
for key_json_raw in json_results:
@@ -243,6 +257,4 @@ class RemoteKey(DirectServeJsonResource):
signed_keys.append(key_json)
- response = {"server_keys": signed_keys}
-
- respond_with_json(request, 200, response, canonical_json=True)
+ return {"server_keys": signed_keys}
|