summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/crypto/keyring.py107
-rw-r--r--synapse/federation/federation_base.py17
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py2
3 files changed, 60 insertions, 66 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py

index 313d577f11..a8c8df2bad 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py
@@ -16,7 +16,7 @@ import abc import logging import urllib -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple import attr from signedjson.key import ( @@ -41,11 +41,9 @@ from synapse.api.errors import ( SynapseError, ) from synapse.config.key import TrustedKeyServer -from synapse.logging.context import ( - PreserveLoggingContext, - make_deferred_yieldable, - run_in_background, -) +from synapse.events import EventBase +from synapse.events.utils import prune_event_dict +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict @@ -72,8 +70,6 @@ class VerifyJsonRequest: minimum_valid_until_ts: time at which we require the signing key to be valid. (0 implies we don't care) - request_name: The name of the request. - key_ids: The set of key_ids to that could be used to verify the JSON object key_ready (Deferred[str, str, nacl.signing.VerifyKey]): @@ -86,14 +82,32 @@ class VerifyJsonRequest: """ server_name = attr.ib(type=str) - json_object = attr.ib(type=JsonDict) + json_object_callback = attr.ib(type=Callable[[], JsonDict]) minimum_valid_until_ts = attr.ib(type=int) - request_name = attr.ib(type=str) - key_ids = attr.ib(init=False, type=List[str]) - key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred) + key_ids = attr.ib(type=List[str]) - def __attrs_post_init__(self): - self.key_ids = signature_ids(self.json_object, self.server_name) + @staticmethod + def from_json_object( + server_name: str, minimum_valid_until_ms: int, json_object: JsonDict + ): + key_ids = signature_ids(json_object, server_name) + return VerifyJsonRequest( + server_name, lambda: json_object, minimum_valid_until_ms, key_ids + ) + + @staticmethod + def from_event( + server_name: str, + minimum_valid_until_ms: int, + event: EventBase, + ): + key_ids = list(event.signatures.get(server_name, [])) + return VerifyJsonRequest( + server_name, + lambda: prune_event_dict(event.room_version, event.get_pdu_json()), + minimum_valid_until_ms, + key_ids, + ) class KeyLookupError(ValueError): @@ -179,8 +193,10 @@ class Keyring: validity_time: int, request_name: str, ) -> defer.Deferred: - request = VerifyJsonRequest( - server_name, json_object, validity_time, request_name + request = VerifyJsonRequest.from_json_object( + server_name, + validity_time, + json_object, ) return defer.ensureDeferred(self._verify_object(request)) @@ -190,14 +206,32 @@ class Keyring: return [ defer.ensureDeferred( self._verify_object( - VerifyJsonRequest( - server_name, json_object, validity_time, request_name + VerifyJsonRequest.from_json_object( + server_name, + validity_time, + json_object, ) ) ) for server_name, json_object, validity_time, request_name in server_and_json ] + def verify_events_for_server( + self, server_and_json: Iterable[Tuple[str, EventBase, int]] + ) -> List[defer.Deferred]: + return [ + defer.ensureDeferred( + self._verify_object( + VerifyJsonRequest.from_event( + server_name, + validity_time, + event, + ) + ) + ) + for server_name, event, validity_time in server_and_json + ] + async def _verify_object(self, verify_request: VerifyJsonRequest): # TODO: Use a batching thing. with (await self._server_queue.queue(verify_request.server_name)): @@ -240,8 +274,9 @@ class Keyring: for key_id in verify_request.key_ids: verify_key = found_keys[key_id].verify_key try: + json_object = verify_request.json_object_callback() verify_signed_json( - verify_request.json_object, + json_object, verify_request.server_name, verify_key, ) @@ -696,37 +731,3 @@ class ServerKeyFetcher(BaseV2KeyFetcher): keys.update(response_keys) return keys - - -async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None: - """Waits for the key to become available, and then performs a verification - - Args: - verify_request: - - Raises: - SynapseError if there was a problem performing the verification - """ - server_name = verify_request.server_name - with PreserveLoggingContext(): - _, key_id, verify_key = await verify_request.key_ready - - json_object = verify_request.json_object - - try: - verify_signed_json(json_object, server_name, verify_key) - except SignatureVerifyException as e: - logger.debug( - "Error verifying signature for %s:%s:%s with key %s: %s", - server_name, - verify_key.alg, - verify_key.version, - encode_verify_key_base64(verify_key), - str(e), - ) - raise SynapseError( - 401, - "Invalid signature for server %s with key %s:%s: %s" - % (server_name, verify_key.alg, verify_key.version, str(e)), - Codes.UNAUTHORIZED, - ) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 949dcd4614..3fe496dcd3 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py
@@ -137,11 +137,7 @@ class FederationBase: return deferreds -class PduToCheckSig( - namedtuple( - "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"] - ) -): +class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])): pass @@ -184,7 +180,6 @@ def _check_sigs_on_pdus( pdus_to_check = [ PduToCheckSig( pdu=p, - redacted_pdu_json=prune_event(p).get_pdu_json(), sender_domain=get_domain_from_id(p.sender), deferreds=[], ) @@ -195,13 +190,12 @@ def _check_sigs_on_pdus( # (except if its a 3pid invite, in which case it may be sent by any server) pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)] - more_deferreds = keyring.verify_json_objects_for_server( + more_deferreds = keyring.verify_events_for_server( [ ( p.sender_domain, - p.redacted_pdu_json, + p.pdu, p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, - p.pdu.event_id, ) for p in pdus_to_check_sender ] @@ -230,13 +224,12 @@ def _check_sigs_on_pdus( if p.sender_domain != get_domain_from_id(p.pdu.event_id) ] - more_deferreds = keyring.verify_json_objects_for_server( + more_deferreds = keyring.verify_events_for_server( [ ( get_domain_from_id(p.pdu.event_id), - p.redacted_pdu_json, + p.pdu, p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, - p.pdu.event_id, ) for p in pdus_to_check_event_id ] diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 24af3c28ea..e19e9ef5c7 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -215,7 +215,7 @@ class RemoteKey(DirectServeJsonResource): # ensure the result is sent). if cache_misses and query_remote_on_cache_miss: await yieldable_gather_results( - self.fetcher.get_keys, + lambda t: self.fetcher.get_keys(*t), ( (server_name, list(keys), 0) for server_name, keys in cache_misses.items()