diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 313d577f11..a8c8df2bad 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -16,7 +16,7 @@
import abc
import logging
import urllib
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
import attr
from signedjson.key import (
@@ -41,11 +41,9 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.config.key import TrustedKeyServer
-from synapse.logging.context import (
- PreserveLoggingContext,
- make_deferred_yieldable,
- run_in_background,
-)
+from synapse.events import EventBase
+from synapse.events.utils import prune_event_dict
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict
@@ -72,8 +70,6 @@ class VerifyJsonRequest:
minimum_valid_until_ts: time at which we require the signing key to
be valid. (0 implies we don't care)
- request_name: The name of the request.
-
key_ids: The set of key_ids to that could be used to verify the JSON object
key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
@@ -86,14 +82,32 @@ class VerifyJsonRequest:
"""
server_name = attr.ib(type=str)
- json_object = attr.ib(type=JsonDict)
+ json_object_callback = attr.ib(type=Callable[[], JsonDict])
minimum_valid_until_ts = attr.ib(type=int)
- request_name = attr.ib(type=str)
- key_ids = attr.ib(init=False, type=List[str])
- key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
+ key_ids = attr.ib(type=List[str])
- def __attrs_post_init__(self):
- self.key_ids = signature_ids(self.json_object, self.server_name)
+ @staticmethod
+ def from_json_object(
+ server_name: str, minimum_valid_until_ms: int, json_object: JsonDict
+ ):
+ key_ids = signature_ids(json_object, server_name)
+ return VerifyJsonRequest(
+ server_name, lambda: json_object, minimum_valid_until_ms, key_ids
+ )
+
+ @staticmethod
+ def from_event(
+ server_name: str,
+ minimum_valid_until_ms: int,
+ event: EventBase,
+ ):
+ key_ids = list(event.signatures.get(server_name, []))
+ return VerifyJsonRequest(
+ server_name,
+ lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
+ minimum_valid_until_ms,
+ key_ids,
+ )
class KeyLookupError(ValueError):
@@ -179,8 +193,10 @@ class Keyring:
validity_time: int,
request_name: str,
) -> defer.Deferred:
- request = VerifyJsonRequest(
- server_name, json_object, validity_time, request_name
+ request = VerifyJsonRequest.from_json_object(
+ server_name,
+ validity_time,
+ json_object,
)
return defer.ensureDeferred(self._verify_object(request))
@@ -190,14 +206,32 @@ class Keyring:
return [
defer.ensureDeferred(
self._verify_object(
- VerifyJsonRequest(
- server_name, json_object, validity_time, request_name
+ VerifyJsonRequest.from_json_object(
+ server_name,
+ validity_time,
+ json_object,
)
)
)
for server_name, json_object, validity_time, request_name in server_and_json
]
+ def verify_events_for_server(
+ self, server_and_json: Iterable[Tuple[str, EventBase, int]]
+ ) -> List[defer.Deferred]:
+ return [
+ defer.ensureDeferred(
+ self._verify_object(
+ VerifyJsonRequest.from_event(
+ server_name,
+ validity_time,
+ event,
+ )
+ )
+ )
+ for server_name, event, validity_time in server_and_json
+ ]
+
async def _verify_object(self, verify_request: VerifyJsonRequest):
# TODO: Use a batching thing.
with (await self._server_queue.queue(verify_request.server_name)):
@@ -240,8 +274,9 @@ class Keyring:
for key_id in verify_request.key_ids:
verify_key = found_keys[key_id].verify_key
try:
+ json_object = verify_request.json_object_callback()
verify_signed_json(
- verify_request.json_object,
+ json_object,
verify_request.server_name,
verify_key,
)
@@ -696,37 +731,3 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
keys.update(response_keys)
return keys
-
-
-async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
- """Waits for the key to become available, and then performs a verification
-
- Args:
- verify_request:
-
- Raises:
- SynapseError if there was a problem performing the verification
- """
- server_name = verify_request.server_name
- with PreserveLoggingContext():
- _, key_id, verify_key = await verify_request.key_ready
-
- json_object = verify_request.json_object
-
- try:
- verify_signed_json(json_object, server_name, verify_key)
- except SignatureVerifyException as e:
- logger.debug(
- "Error verifying signature for %s:%s:%s with key %s: %s",
- server_name,
- verify_key.alg,
- verify_key.version,
- encode_verify_key_base64(verify_key),
- str(e),
- )
- raise SynapseError(
- 401,
- "Invalid signature for server %s with key %s:%s: %s"
- % (server_name, verify_key.alg, verify_key.version, str(e)),
- Codes.UNAUTHORIZED,
- )
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 949dcd4614..3fe496dcd3 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -137,11 +137,7 @@ class FederationBase:
return deferreds
-class PduToCheckSig(
- namedtuple(
- "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
- )
-):
+class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
pass
@@ -184,7 +180,6 @@ def _check_sigs_on_pdus(
pdus_to_check = [
PduToCheckSig(
pdu=p,
- redacted_pdu_json=prune_event(p).get_pdu_json(),
sender_domain=get_domain_from_id(p.sender),
deferreds=[],
)
@@ -195,13 +190,12 @@ def _check_sigs_on_pdus(
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
- more_deferreds = keyring.verify_json_objects_for_server(
+ more_deferreds = keyring.verify_events_for_server(
[
(
p.sender_domain,
- p.redacted_pdu_json,
+ p.pdu,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
- p.pdu.event_id,
)
for p in pdus_to_check_sender
]
@@ -230,13 +224,12 @@ def _check_sigs_on_pdus(
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
]
- more_deferreds = keyring.verify_json_objects_for_server(
+ more_deferreds = keyring.verify_events_for_server(
[
(
get_domain_from_id(p.pdu.event_id),
- p.redacted_pdu_json,
+ p.pdu,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
- p.pdu.event_id,
)
for p in pdus_to_check_event_id
]
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 24af3c28ea..e19e9ef5c7 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -215,7 +215,7 @@ class RemoteKey(DirectServeJsonResource):
# ensure the result is sent).
if cache_misses and query_remote_on_cache_miss:
await yieldable_gather_results(
- self.fetcher.get_keys,
+ lambda t: self.fetcher.get_keys(*t),
(
(server_name, list(keys), 0)
for server_name, keys in cache_misses.items()
|