diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 6fc0712978..c840ffca71 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -16,8 +16,7 @@
import abc
import logging
import urllib
-from collections import defaultdict
-from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
import attr
from signedjson.key import (
@@ -44,17 +43,12 @@ from synapse.api.errors import (
from synapse.config.key import TrustedKeyServer
from synapse.events import EventBase
from synapse.events.utils import prune_event_dict
-from synapse.logging.context import (
- PreserveLoggingContext,
- make_deferred_yieldable,
- preserve_fn,
- run_in_background,
-)
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import yieldable_gather_results
-from synapse.util.metrics import Measure
+from synapse.util.batching_queue import BatchingQueue
from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
@@ -80,32 +74,19 @@ class VerifyJsonRequest:
minimum_valid_until_ts: time at which we require the signing key to
be valid. (0 implies we don't care)
- request_name: The name of the request.
-
key_ids: The set of key_ids to that could be used to verify the JSON object
-
- key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
- A deferred (server_name, key_id, verify_key) tuple that resolves when
- a verify key has been fetched. The deferreds' callbacks are run with no
- logcontext.
-
- If we are unable to find a key which satisfies the request, the deferred
- errbacks with an M_UNAUTHORIZED SynapseError.
"""
server_name = attr.ib(type=str)
get_json_object = attr.ib(type=Callable[[], JsonDict])
minimum_valid_until_ts = attr.ib(type=int)
- request_name = attr.ib(type=str)
key_ids = attr.ib(type=List[str])
- key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
@staticmethod
def from_json_object(
server_name: str,
json_object: JsonDict,
minimum_valid_until_ms: int,
- request_name: str,
):
"""Create a VerifyJsonRequest to verify all signatures on a signed JSON
object for the given server.
@@ -115,7 +96,6 @@ class VerifyJsonRequest:
server_name,
lambda: json_object,
minimum_valid_until_ms,
- request_name=request_name,
key_ids=key_ids,
)
@@ -135,16 +115,48 @@ class VerifyJsonRequest:
# memory than the Event object itself.
lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
minimum_valid_until_ms,
- request_name=event.event_id,
key_ids=key_ids,
)
+ def to_fetch_key_request(self) -> "_FetchKeyRequest":
+ """Create a key fetch request for all keys needed to satisfy the
+ verification request.
+ """
+ return _FetchKeyRequest(
+ server_name=self.server_name,
+ minimum_valid_until_ts=self.minimum_valid_until_ts,
+ key_ids=self.key_ids,
+ )
+
class KeyLookupError(ValueError):
pass
+@attr.s(slots=True)
+class _FetchKeyRequest:
+ """A request for keys for a given server.
+
+ We will continue to try and fetch until we have all the keys listed under
+ `key_ids` (with an appropriate `valid_until_ts` property) or we run out of
+ places to fetch keys from.
+
+ Attributes:
+ server_name: The name of the server that owns the keys.
+ minimum_valid_until_ts: The timestamp which the keys must be valid until.
+ key_ids: The IDs of the keys to attempt to fetch
+ """
+
+ server_name = attr.ib(type=str)
+ minimum_valid_until_ts = attr.ib(type=int)
+ key_ids = attr.ib(type=List[str])
+
+
class Keyring:
+ """Handles verifying signed JSON objects and fetching the keys needed to do
+ so.
+ """
+
def __init__(
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
):
@@ -158,22 +170,22 @@ class Keyring:
)
self._key_fetchers = key_fetchers
- # map from server name to Deferred. Has an entry for each server with
- # an ongoing key download; the Deferred completes once the download
- # completes.
- #
- # These are regular, logcontext-agnostic Deferreds.
- self.key_downloads = {} # type: Dict[str, defer.Deferred]
+ self._server_queue = BatchingQueue(
+ "keyring_server",
+ clock=hs.get_clock(),
+ process_batch_callback=self._inner_fetch_key_requests,
+ ) # type: BatchingQueue[_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]]
- def verify_json_for_server(
+ async def verify_json_for_server(
self,
server_name: str,
json_object: JsonDict,
validity_time: int,
- request_name: str,
- ) -> defer.Deferred:
+ ) -> None:
"""Verify that a JSON object has been signed by a given server
+ Completes if the the object was correctly signed, otherwise raises.
+
Args:
server_name: name of the server which must have signed this object
@@ -181,52 +193,45 @@ class Keyring:
validity_time: timestamp at which we require the signing key to
be valid. (0 implies we don't care)
-
- request_name: an identifier for this json object (eg, an event id)
- for logging.
-
- Returns:
- Deferred[None]: completes if the the object was correctly signed, otherwise
- errbacks with an error
"""
request = VerifyJsonRequest.from_json_object(
server_name,
json_object,
validity_time,
- request_name,
)
- requests = (request,)
- return make_deferred_yieldable(self._verify_objects(requests)[0])
+ return await self.process_request(request)
def verify_json_objects_for_server(
- self, server_and_json: Iterable[Tuple[str, dict, int, str]]
+ self, server_and_json: Iterable[Tuple[str, dict, int]]
) -> List[defer.Deferred]:
"""Bulk verifies signatures of json objects, bulk fetching keys as
necessary.
Args:
server_and_json:
- Iterable of (server_name, json_object, validity_time, request_name)
+ Iterable of (server_name, json_object, validity_time)
tuples.
validity_time is a timestamp at which the signing key must be
valid.
- request_name is an identifier for this json object (eg, an event id)
- for logging.
-
Returns:
List<Deferred[None]>: for each input triplet, a deferred indicating success
or failure to verify each json object's signature for the given
server_name. The deferreds run their callbacks in the sentinel
logcontext.
"""
- return self._verify_objects(
- VerifyJsonRequest.from_json_object(
- server_name, json_object, validity_time, request_name
+ return [
+ run_in_background(
+ self.process_request,
+ VerifyJsonRequest.from_json_object(
+ server_name,
+ json_object,
+ validity_time,
+ ),
)
- for server_name, json_object, validity_time, request_name in server_and_json
- )
+ for server_name, json_object, validity_time in server_and_json
+ ]
def verify_events_for_server(
self, server_and_events: Iterable[Tuple[str, EventBase, int]]
@@ -252,321 +257,223 @@ class Keyring:
server_name. The deferreds run their callbacks in the sentinel
logcontext.
"""
- return self._verify_objects(
- VerifyJsonRequest.from_event(server_name, event, validity_time)
+ return [
+ run_in_background(
+ self.process_request,
+ VerifyJsonRequest.from_event(
+ server_name,
+ event,
+ validity_time,
+ ),
+ )
for server_name, event, validity_time in server_and_events
- )
-
- def _verify_objects(
- self, verify_requests: Iterable[VerifyJsonRequest]
- ) -> List[defer.Deferred]:
- """Does the work of verify_json_[objects_]for_server
-
-
- Args:
- verify_requests: Iterable of verification requests.
+ ]
- Returns:
- List<Deferred[None]>: for each input item, a deferred indicating success
- or failure to verify each json object's signature for the given
- server_name. The deferreds run their callbacks in the sentinel
- logcontext.
+ async def process_request(self, verify_request: VerifyJsonRequest) -> None:
+ """Processes the `VerifyJsonRequest`. Raises if the object is not signed
+ by the server, the signatures don't match or we failed to fetch the
+ necessary keys.
"""
- # a list of VerifyJsonRequests which are awaiting a key lookup
- key_lookups = []
- handle = preserve_fn(_handle_key_deferred)
-
- def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
- """Process an entry in the request list
-
- Adds a key request to key_lookups, and returns a deferred which
- will complete or fail (in the sentinel context) when verification completes.
- """
- if not verify_request.key_ids:
- return defer.fail(
- SynapseError(
- 400,
- "Not signed by %s" % (verify_request.server_name,),
- Codes.UNAUTHORIZED,
- )
- )
- logger.debug(
- "Verifying %s for %s with key_ids %s, min_validity %i",
- verify_request.request_name,
- verify_request.server_name,
- verify_request.key_ids,
- verify_request.minimum_valid_until_ts,
+ if not verify_request.key_ids:
+ raise SynapseError(
+ 400,
+ f"Not signed by {verify_request.server_name}",
+ Codes.UNAUTHORIZED,
)
- # add the key request to the queue, but don't start it off yet.
- key_lookups.append(verify_request)
-
- # now run _handle_key_deferred, which will wait for the key request
- # to complete and then do the verification.
- #
- # We want _handle_key_request to log to the right context, so we
- # wrap it with preserve_fn (aka run_in_background)
- return handle(verify_request)
-
- results = [process(r) for r in verify_requests]
-
- if key_lookups:
- run_in_background(self._start_key_lookups, key_lookups)
-
- return results
-
- async def _start_key_lookups(
- self, verify_requests: List[VerifyJsonRequest]
- ) -> None:
- """Sets off the key fetches for each verify request
-
- Once each fetch completes, verify_request.key_ready will be resolved.
-
- Args:
- verify_requests:
- """
-
- try:
- # map from server name to a set of outstanding request ids
- server_to_request_ids = {} # type: Dict[str, Set[int]]
-
- for verify_request in verify_requests:
- server_name = verify_request.server_name
- request_id = id(verify_request)
- server_to_request_ids.setdefault(server_name, set()).add(request_id)
-
- # Wait for any previous lookups to complete before proceeding.
- await self.wait_for_previous_lookups(server_to_request_ids.keys())
-
- # take out a lock on each of the servers by sticking a Deferred in
- # key_downloads
- for server_name in server_to_request_ids.keys():
- self.key_downloads[server_name] = defer.Deferred()
- logger.debug("Got key lookup lock on %s", server_name)
-
- # When we've finished fetching all the keys for a given server_name,
- # drop the lock by resolving the deferred in key_downloads.
- def drop_server_lock(server_name):
- d = self.key_downloads.pop(server_name)
- d.callback(None)
-
- def lookup_done(res, verify_request):
- server_name = verify_request.server_name
- server_requests = server_to_request_ids[server_name]
- server_requests.remove(id(verify_request))
-
- # if there are no more requests for this server, we can drop the lock.
- if not server_requests:
- logger.debug("Releasing key lookup lock on %s", server_name)
- drop_server_lock(server_name)
-
- return res
+ # Add the keys we need to verify to the queue for retrieval. We queue
+ # up requests for the same server so we don't end up with many in flight
+ # requests for the same keys.
+ key_request = verify_request.to_fetch_key_request()
+ found_keys_by_server = await self._server_queue.add_to_queue(
+ key_request, key=verify_request.server_name
+ )
- for verify_request in verify_requests:
- verify_request.key_ready.addBoth(lookup_done, verify_request)
+ # Since we batch up requests the returned set of keys may contain keys
+ # from other servers, so we pull out only the ones we care about.s
+ found_keys = found_keys_by_server.get(verify_request.server_name, {})
- # Actually start fetching keys.
- self._get_server_verify_keys(verify_requests)
- except Exception:
- logger.exception("Error starting key lookups")
+ # Verify each signature we got valid keys for, raising if we can't
+ # verify any of them.
+ verified = False
+ for key_id in verify_request.key_ids:
+ key_result = found_keys.get(key_id)
+ if not key_result:
+ continue
- async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
- """Waits for any previous key lookups for the given servers to finish.
+ if key_result.valid_until_ts < verify_request.minimum_valid_until_ts:
+ continue
- Args:
- server_names: list of servers which we want to look up
+ verify_key = key_result.verify_key
+ json_object = verify_request.get_json_object()
+ try:
+ verify_signed_json(
+ json_object,
+ verify_request.server_name,
+ verify_key,
+ )
+ verified = True
+ except SignatureVerifyException as e:
+ logger.debug(
+ "Error verifying signature for %s:%s:%s with key %s: %s",
+ verify_request.server_name,
+ verify_key.alg,
+ verify_key.version,
+ encode_verify_key_base64(verify_key),
+ str(e),
+ )
+ raise SynapseError(
+ 401,
+ "Invalid signature for server %s with key %s:%s: %s"
+ % (
+ verify_request.server_name,
+ verify_key.alg,
+ verify_key.version,
+ str(e),
+ ),
+ Codes.UNAUTHORIZED,
+ )
- Returns:
- Resolves once all key lookups for the given servers have
- completed. Follows the synapse rules of logcontext preservation.
- """
- loop_count = 1
- while True:
- wait_on = [
- (server_name, self.key_downloads[server_name])
- for server_name in server_names
- if server_name in self.key_downloads
- ]
- if not wait_on:
- break
- logger.info(
- "Waiting for existing lookups for %s to complete [loop %i]",
- [w[0] for w in wait_on],
- loop_count,
+ if not verified:
+ raise SynapseError(
+ 401,
+ f"Failed to find any key to satisfy: {key_request}",
+ Codes.UNAUTHORIZED,
)
- with PreserveLoggingContext():
- await defer.DeferredList((w[1] for w in wait_on))
- loop_count += 1
+ async def _inner_fetch_key_requests(
+ self, requests: List[_FetchKeyRequest]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
+ """Processing function for the queue of `_FetchKeyRequest`."""
+
+ logger.debug("Starting fetch for %s", requests)
+
+ # First we need to deduplicate requests for the same key. We do this by
+ # taking the *maximum* requested `minimum_valid_until_ts` for each pair
+ # of server name/key ID.
+ server_to_key_to_ts = {} # type: Dict[str, Dict[str, int]]
+ for request in requests:
+ by_server = server_to_key_to_ts.setdefault(request.server_name, {})
+ for key_id in request.key_ids:
+ existing_ts = by_server.get(key_id, 0)
+ by_server[key_id] = max(request.minimum_valid_until_ts, existing_ts)
+
+ deduped_requests = [
+ _FetchKeyRequest(server_name, minimum_valid_ts, [key_id])
+ for server_name, by_server in server_to_key_to_ts.items()
+ for key_id, minimum_valid_ts in by_server.items()
+ ]
+
+ logger.debug("Deduplicated key requests to %s", deduped_requests)
+
+ # For each key we call `_inner_verify_request` which will handle
+ # fetching each key. Note these shouldn't throw if we fail to contact
+ # other servers etc.
+ results_per_request = await yieldable_gather_results(
+ self._inner_fetch_key_request,
+ deduped_requests,
+ )
- def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
- """Tries to find at least one key for each verify request
+ # We now convert the returned list of results into a map from server
+ # name to key ID to FetchKeyResult, to return.
+ to_return = {} # type: Dict[str, Dict[str, FetchKeyResult]]
+ for (request, results) in zip(deduped_requests, results_per_request):
+ to_return_by_server = to_return.setdefault(request.server_name, {})
+ for key_id, key_result in results.items():
+ existing = to_return_by_server.get(key_id)
+ if not existing or existing.valid_until_ts < key_result.valid_until_ts:
+ to_return_by_server[key_id] = key_result
- For each verify_request, verify_request.key_ready is called back with
- params (server_name, key_id, VerifyKey) if a key is found, or errbacked
- with a SynapseError if none of the keys are found.
+ return to_return
- Args:
- verify_requests: list of verify requests
+ async def _inner_fetch_key_request(
+ self, verify_request: _FetchKeyRequest
+ ) -> Dict[str, FetchKeyResult]:
+ """Attempt to fetch the given key by calling each key fetcher one by
+ one.
"""
+ logger.debug("Starting fetch for %s", verify_request)
- remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
+ found_keys: Dict[str, FetchKeyResult] = {}
+ missing_key_ids = set(verify_request.key_ids)
- async def do_iterations():
- try:
- with Measure(self.clock, "get_server_verify_keys"):
- for f in self._key_fetchers:
- if not remaining_requests:
- return
- await self._attempt_key_fetches_with_fetcher(
- f, remaining_requests
- )
-
- # look for any requests which weren't satisfied
- while remaining_requests:
- verify_request = remaining_requests.pop()
- rq_str = (
- "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)"
- % (
- verify_request.server_name,
- verify_request.key_ids,
- verify_request.minimum_valid_until_ts,
- )
- )
-
- # If we run the errback immediately, it may cancel our
- # loggingcontext while we are still in it, so instead we
- # schedule it for the next time round the reactor.
- #
- # (this also ensures that we don't get a stack overflow if we
- # has a massive queue of lookups waiting for this server).
- self.clock.call_later(
- 0,
- verify_request.key_ready.errback,
- SynapseError(
- 401,
- "Failed to find any key to satisfy %s" % (rq_str,),
- Codes.UNAUTHORIZED,
- ),
- )
- except Exception as err:
- # we don't really expect to get here, because any errors should already
- # have been caught and logged. But if we do, let's log the error and make
- # sure that all of the deferreds are resolved.
- logger.error("Unexpected error in _get_server_verify_keys: %s", err)
- with PreserveLoggingContext():
- for verify_request in remaining_requests:
- if not verify_request.key_ready.called:
- verify_request.key_ready.errback(err)
-
- run_in_background(do_iterations)
-
- async def _attempt_key_fetches_with_fetcher(
- self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
- ):
- """Use a key fetcher to attempt to satisfy some key requests
+ for fetcher in self._key_fetchers:
+ if not missing_key_ids:
+ break
- Args:
- fetcher: fetcher to use to fetch the keys
- remaining_requests: outstanding key requests.
- Any successfully-completed requests will be removed from the list.
- """
- # The keys to fetch.
- # server_name -> key_id -> min_valid_ts
- missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]]
-
- for verify_request in remaining_requests:
- # any completed requests should already have been removed
- assert not verify_request.key_ready.called
- keys_for_server = missing_keys[verify_request.server_name]
-
- for key_id in verify_request.key_ids:
- # If we have several requests for the same key, then we only need to
- # request that key once, but we should do so with the greatest
- # min_valid_until_ts of the requests, so that we can satisfy all of
- # the requests.
- keys_for_server[key_id] = max(
- keys_for_server.get(key_id, -1),
- verify_request.minimum_valid_until_ts,
- )
+ logger.debug("Getting keys from %s for %s", fetcher, verify_request)
+ keys = await fetcher.get_keys(
+ verify_request.server_name,
+ list(missing_key_ids),
+ verify_request.minimum_valid_until_ts,
+ )
- results = await fetcher.get_keys(missing_keys)
+ for key_id, key in keys.items():
+ if not key:
+ continue
- completed = []
- for verify_request in remaining_requests:
- server_name = verify_request.server_name
+ # If we already have a result for the given key ID we keep the
+ # one with the highest `valid_until_ts`.
+ existing_key = found_keys.get(key_id)
+ if existing_key:
+ if key.valid_until_ts <= existing_key.valid_until_ts:
+ continue
- # see if any of the keys we got this time are sufficient to
- # complete this VerifyJsonRequest.
- result_keys = results.get(server_name, {})
- for key_id in verify_request.key_ids:
- fetch_key_result = result_keys.get(key_id)
- if not fetch_key_result:
- # we didn't get a result for this key
- continue
+ # We always store the returned key even if it doesn't the
+ # `minimum_valid_until_ts` requirement, as some verification
+ # requests may still be able to be satisfied by it.
+ #
+ # We still keep looking for the key from other fetchers in that
+ # case though.
+ found_keys[key_id] = key
- if (
- fetch_key_result.valid_until_ts
- < verify_request.minimum_valid_until_ts
- ):
- # key was not valid at this point
+ if key.valid_until_ts < verify_request.minimum_valid_until_ts:
continue
- # we have a valid key for this request. If we run the callback
- # immediately, it may cancel our loggingcontext while we are still in
- # it, so instead we schedule it for the next time round the reactor.
- #
- # (this also ensures that we don't get a stack overflow if we had
- # a massive queue of lookups waiting for this server).
- logger.debug(
- "Found key %s:%s for %s",
- server_name,
- key_id,
- verify_request.request_name,
- )
- self.clock.call_later(
- 0,
- verify_request.key_ready.callback,
- (server_name, key_id, fetch_key_result.verify_key),
- )
- completed.append(verify_request)
- break
+ missing_key_ids.discard(key_id)
- remaining_requests.difference_update(completed)
+ return found_keys
class KeyFetcher(metaclass=abc.ABCMeta):
- @abc.abstractmethod
+ def __init__(self, hs: "HomeServer"):
+ self._queue = BatchingQueue(
+ self.__class__.__name__, hs.get_clock(), self._fetch_keys
+ )
+
async def get_keys(
- self, keys_to_fetch: Dict[str, Dict[str, int]]
- ) -> Dict[str, Dict[str, FetchKeyResult]]:
- """
- Args:
- keys_to_fetch:
- the keys to be fetched. server_name -> key_id -> min_valid_ts
+ self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+ ) -> Dict[str, FetchKeyResult]:
+ results = await self._queue.add_to_queue(
+ _FetchKeyRequest(
+ server_name=server_name,
+ key_ids=key_ids,
+ minimum_valid_until_ts=minimum_valid_until_ts,
+ )
+ )
+ return results.get(server_name, {})
- Returns:
- Map from server_name -> key_id -> FetchKeyResult
- """
- raise NotImplementedError
+ @abc.abstractmethod
+ async def _fetch_keys(
+ self, keys_to_fetch: List[_FetchKeyRequest]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
+ pass
class StoreKeyFetcher(KeyFetcher):
"""KeyFetcher impl which fetches keys from our data store"""
def __init__(self, hs: "HomeServer"):
- self.store = hs.get_datastore()
+ super().__init__(hs)
- async def get_keys(
- self, keys_to_fetch: Dict[str, Dict[str, int]]
- ) -> Dict[str, Dict[str, FetchKeyResult]]:
- """see KeyFetcher.get_keys"""
+ self.store = hs.get_datastore()
+ async def _fetch_keys(self, keys_to_fetch: List[_FetchKeyRequest]):
key_ids_to_fetch = (
- (server_name, key_id)
- for server_name, keys_for_server in keys_to_fetch.items()
- for key_id in keys_for_server.keys()
+ (queue_value.server_name, key_id)
+ for queue_value in keys_to_fetch
+ for key_id in queue_value.key_ids
)
res = await self.store.get_server_verify_keys(key_ids_to_fetch)
@@ -578,6 +485,8 @@ class StoreKeyFetcher(KeyFetcher):
class BaseV2KeyFetcher(KeyFetcher):
def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
self.store = hs.get_datastore()
self.config = hs.config
@@ -685,10 +594,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
self.client = hs.get_federation_http_client()
self.key_servers = self.config.key_servers
- async def get_keys(
- self, keys_to_fetch: Dict[str, Dict[str, int]]
+ async def _fetch_keys(
+ self, keys_to_fetch: List[_FetchKeyRequest]
) -> Dict[str, Dict[str, FetchKeyResult]]:
- """see KeyFetcher.get_keys"""
+ """see KeyFetcher._fetch_keys"""
async def get_key(key_server: TrustedKeyServer) -> Dict:
try:
@@ -724,12 +633,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return union_of_keys
async def get_server_verify_key_v2_indirect(
- self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
+ self, keys_to_fetch: List[_FetchKeyRequest], key_server: TrustedKeyServer
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
keys_to_fetch:
- the keys to be fetched. server_name -> key_id -> min_valid_ts
+ the keys to be fetched.
key_server: notary server to query for the keys
@@ -743,7 +652,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
perspective_name = key_server.server_name
logger.info(
"Requesting keys %s from notary server %s",
- keys_to_fetch.items(),
+ keys_to_fetch,
perspective_name,
)
@@ -753,11 +662,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
path="/_matrix/key/v2/query",
data={
"server_keys": {
- server_name: {
- key_id: {"minimum_valid_until_ts": min_valid_ts}
- for key_id, min_valid_ts in server_keys.items()
+ queue_value.server_name: {
+ key_id: {
+ "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
+ }
+ for key_id in queue_value.key_ids
}
- for server_name, server_keys in keys_to_fetch.items()
+ for queue_value in keys_to_fetch
}
},
)
@@ -858,7 +769,20 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
self.client = hs.get_federation_http_client()
async def get_keys(
- self, keys_to_fetch: Dict[str, Dict[str, int]]
+ self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+ ) -> Dict[str, FetchKeyResult]:
+ results = await self._queue.add_to_queue(
+ _FetchKeyRequest(
+ server_name=server_name,
+ key_ids=key_ids,
+ minimum_valid_until_ts=minimum_valid_until_ts,
+ ),
+ key=server_name,
+ )
+ return results.get(server_name, {})
+
+ async def _fetch_keys(
+ self, keys_to_fetch: List[_FetchKeyRequest]
) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
@@ -871,8 +795,10 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
results = {}
- async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
- server_name, key_ids = key_to_fetch_item
+ async def get_key(key_to_fetch_item: _FetchKeyRequest) -> None:
+ server_name = key_to_fetch_item.server_name
+ key_ids = key_to_fetch_item.key_ids
+
try:
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
results[server_name] = keys
@@ -883,7 +809,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except Exception:
logger.exception("Error getting keys %s from %s", key_ids, server_name)
- await yieldable_gather_results(get_key, keys_to_fetch.items())
+ await yieldable_gather_results(get_key, keys_to_fetch)
return results
async def get_server_verify_key_v2_direct(
@@ -955,37 +881,3 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
keys.update(response_keys)
return keys
-
-
-async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
- """Waits for the key to become available, and then performs a verification
-
- Args:
- verify_request:
-
- Raises:
- SynapseError if there was a problem performing the verification
- """
- server_name = verify_request.server_name
- with PreserveLoggingContext():
- _, key_id, verify_key = await verify_request.key_ready
-
- json_object = verify_request.get_json_object()
-
- try:
- verify_signed_json(json_object, server_name, verify_key)
- except SignatureVerifyException as e:
- logger.debug(
- "Error verifying signature for %s:%s:%s with key %s: %s",
- server_name,
- verify_key.alg,
- verify_key.version,
- encode_verify_key_base64(verify_key),
- str(e),
- )
- raise SynapseError(
- 401,
- "Invalid signature for server %s with key %s:%s: %s"
- % (server_name, verify_key.alg, verify_key.version, str(e)),
- Codes.UNAUTHORIZED,
- )
|