diff --git a/changelog.d/10018.misc b/changelog.d/10018.misc
new file mode 100644
index 0000000000..eaf9f64867
--- /dev/null
+++ b/changelog.d/10018.misc
@@ -0,0 +1 @@
+Reduce memory usage when verifying signatures on large numbers of events at once.
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 5f18ef7748..6fc0712978 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -17,7 +17,7 @@ import abc
import logging
import urllib
from collections import defaultdict
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple
import attr
from signedjson.key import (
@@ -42,6 +42,8 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.config.key import TrustedKeyServer
+from synapse.events import EventBase
+from synapse.events.utils import prune_event_dict
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
@@ -69,7 +71,11 @@ class VerifyJsonRequest:
Attributes:
server_name: The name of the server to verify against.
- json_object: The JSON object to verify.
+ get_json_object: A callback to fetch the JSON object to verify.
+ A callback is used to allow deferring the creation of the JSON
+ object to verify until needed, e.g. for events we can defer
+ creating the redacted copy. This reduces the memory usage when
+ there are large numbers of in flight requests.
minimum_valid_until_ts: time at which we require the signing key to
be valid. (0 implies we don't care)
@@ -88,14 +94,50 @@ class VerifyJsonRequest:
"""
server_name = attr.ib(type=str)
- json_object = attr.ib(type=JsonDict)
+ get_json_object = attr.ib(type=Callable[[], JsonDict])
minimum_valid_until_ts = attr.ib(type=int)
request_name = attr.ib(type=str)
- key_ids = attr.ib(init=False, type=List[str])
+ key_ids = attr.ib(type=List[str])
key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
- def __attrs_post_init__(self):
- self.key_ids = signature_ids(self.json_object, self.server_name)
+ @staticmethod
+ def from_json_object(
+ server_name: str,
+ json_object: JsonDict,
+ minimum_valid_until_ms: int,
+ request_name: str,
+ ):
+ """Create a VerifyJsonRequest to verify all signatures on a signed JSON
+ object for the given server.
+ """
+ key_ids = signature_ids(json_object, server_name)
+ return VerifyJsonRequest(
+ server_name,
+ lambda: json_object,
+ minimum_valid_until_ms,
+ request_name=request_name,
+ key_ids=key_ids,
+ )
+
+ @staticmethod
+ def from_event(
+ server_name: str,
+ event: EventBase,
+ minimum_valid_until_ms: int,
+ ):
+ """Create a VerifyJsonRequest to verify all signatures on an event
+ object for the given server.
+ """
+ key_ids = list(event.signatures.get(server_name, []))
+ return VerifyJsonRequest(
+ server_name,
+ # We defer creating the redacted json object, as it uses a lot more
+ # memory than the Event object itself.
+ lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
+ minimum_valid_until_ms,
+ request_name=event.event_id,
+ key_ids=key_ids,
+ )
class KeyLookupError(ValueError):
@@ -147,8 +189,13 @@ class Keyring:
Deferred[None]: completes if the the object was correctly signed, otherwise
errbacks with an error
"""
- req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
- requests = (req,)
+ request = VerifyJsonRequest.from_json_object(
+ server_name,
+ json_object,
+ validity_time,
+ request_name,
+ )
+ requests = (request,)
return make_deferred_yieldable(self._verify_objects(requests)[0])
def verify_json_objects_for_server(
@@ -175,10 +222,41 @@ class Keyring:
logcontext.
"""
return self._verify_objects(
- VerifyJsonRequest(server_name, json_object, validity_time, request_name)
+ VerifyJsonRequest.from_json_object(
+ server_name, json_object, validity_time, request_name
+ )
for server_name, json_object, validity_time, request_name in server_and_json
)
+ def verify_events_for_server(
+ self, server_and_events: Iterable[Tuple[str, EventBase, int]]
+ ) -> List[defer.Deferred]:
+ """Bulk verification of signatures on events.
+
+ Args:
+ server_and_events:
+ Iterable of `(server_name, event, validity_time)` tuples.
+
+ `server_name` is which server we are verifying the signature for
+ on the event.
+
+ `event` is the event that we'll verify the signatures of for
+ the given `server_name`.
+
+ `validity_time` is a timestamp at which the signing key must be
+ valid.
+
+ Returns:
+ List<Deferred[None]>: for each input triplet, a deferred indicating success
+ or failure to verify each event's signature for the given
+ server_name. The deferreds run their callbacks in the sentinel
+ logcontext.
+ """
+ return self._verify_objects(
+ VerifyJsonRequest.from_event(server_name, event, validity_time)
+ for server_name, event, validity_time in server_and_events
+ )
+
def _verify_objects(
self, verify_requests: Iterable[VerifyJsonRequest]
) -> List[defer.Deferred]:
@@ -892,7 +970,7 @@ async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
with PreserveLoggingContext():
_, key_id, verify_key = await verify_request.key_ready
- json_object = verify_request.json_object
+ json_object = verify_request.get_json_object()
try:
verify_signed_json(json_object, server_name, verify_key)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 949dcd4614..3fe496dcd3 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -137,11 +137,7 @@ class FederationBase:
return deferreds
-class PduToCheckSig(
- namedtuple(
- "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
- )
-):
+class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
pass
@@ -184,7 +180,6 @@ def _check_sigs_on_pdus(
pdus_to_check = [
PduToCheckSig(
pdu=p,
- redacted_pdu_json=prune_event(p).get_pdu_json(),
sender_domain=get_domain_from_id(p.sender),
deferreds=[],
)
@@ -195,13 +190,12 @@ def _check_sigs_on_pdus(
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
- more_deferreds = keyring.verify_json_objects_for_server(
+ more_deferreds = keyring.verify_events_for_server(
[
(
p.sender_domain,
- p.redacted_pdu_json,
+ p.pdu,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
- p.pdu.event_id,
)
for p in pdus_to_check_sender
]
@@ -230,13 +224,12 @@ def _check_sigs_on_pdus(
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
]
- more_deferreds = keyring.verify_json_objects_for_server(
+ more_deferreds = keyring.verify_events_for_server(
[
(
get_domain_from_id(p.pdu.event_id),
- p.redacted_pdu_json,
+ p.pdu,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
- p.pdu.event_id,
)
for p in pdus_to_check_event_id
]
|