summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/crypto/keyring.py98
-rw-r--r--synapse/federation/federation_base.py17
2 files changed, 93 insertions, 22 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 5f18ef7748..6fc0712978 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -17,7 +17,7 @@ import abc
 import logging
 import urllib
 from collections import defaultdict
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple
 
 import attr
 from signedjson.key import (
@@ -42,6 +42,8 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.config.key import TrustedKeyServer
+from synapse.events import EventBase
+from synapse.events.utils import prune_event_dict
 from synapse.logging.context import (
     PreserveLoggingContext,
     make_deferred_yieldable,
@@ -69,7 +71,11 @@ class VerifyJsonRequest:
     Attributes:
         server_name: The name of the server to verify against.
 
-        json_object: The JSON object to verify.
+        get_json_object: A callback to fetch the JSON object to verify.
+            A callback is used to allow deferring the creation of the JSON
+            object to verify until needed, e.g. for events we can defer
+            creating the redacted copy. This reduces the memory usage when
+            there are large numbers of in flight requests.
 
         minimum_valid_until_ts: time at which we require the signing key to
             be valid. (0 implies we don't care)
@@ -88,14 +94,50 @@ class VerifyJsonRequest:
     """
 
     server_name = attr.ib(type=str)
-    json_object = attr.ib(type=JsonDict)
+    get_json_object = attr.ib(type=Callable[[], JsonDict])
     minimum_valid_until_ts = attr.ib(type=int)
     request_name = attr.ib(type=str)
-    key_ids = attr.ib(init=False, type=List[str])
+    key_ids = attr.ib(type=List[str])
     key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
 
-    def __attrs_post_init__(self):
-        self.key_ids = signature_ids(self.json_object, self.server_name)
+    @staticmethod
+    def from_json_object(
+        server_name: str,
+        json_object: JsonDict,
+        minimum_valid_until_ms: int,
+        request_name: str,
+    ):
+        """Create a VerifyJsonRequest to verify all signatures on a signed JSON
+        object for the given server.
+        """
+        key_ids = signature_ids(json_object, server_name)
+        return VerifyJsonRequest(
+            server_name,
+            lambda: json_object,
+            minimum_valid_until_ms,
+            request_name=request_name,
+            key_ids=key_ids,
+        )
+
+    @staticmethod
+    def from_event(
+        server_name: str,
+        event: EventBase,
+        minimum_valid_until_ms: int,
+    ):
+        """Create a VerifyJsonRequest to verify all signatures on an event
+        object for the given server.
+        """
+        key_ids = list(event.signatures.get(server_name, []))
+        return VerifyJsonRequest(
+            server_name,
+            # We defer creating the redacted json object, as it uses a lot more
+            # memory than the Event object itself.
+            lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
+            minimum_valid_until_ms,
+            request_name=event.event_id,
+            key_ids=key_ids,
+        )
 
 
 class KeyLookupError(ValueError):
@@ -147,8 +189,13 @@ class Keyring:
             Deferred[None]: completes if the the object was correctly signed, otherwise
                 errbacks with an error
         """
-        req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
-        requests = (req,)
+        request = VerifyJsonRequest.from_json_object(
+            server_name,
+            json_object,
+            validity_time,
+            request_name,
+        )
+        requests = (request,)
         return make_deferred_yieldable(self._verify_objects(requests)[0])
 
     def verify_json_objects_for_server(
@@ -175,10 +222,41 @@ class Keyring:
                 logcontext.
         """
         return self._verify_objects(
-            VerifyJsonRequest(server_name, json_object, validity_time, request_name)
+            VerifyJsonRequest.from_json_object(
+                server_name, json_object, validity_time, request_name
+            )
             for server_name, json_object, validity_time, request_name in server_and_json
         )
 
+    def verify_events_for_server(
+        self, server_and_events: Iterable[Tuple[str, EventBase, int]]
+    ) -> List[defer.Deferred]:
+        """Bulk verification of signatures on events.
+
+        Args:
+            server_and_events:
+                Iterable of `(server_name, event, validity_time)` tuples.
+
+                `server_name` is which server we are verifying the signature for
+                on the event.
+
+                `event` is the event that we'll verify the signatures of for
+                the given `server_name`.
+
+                `validity_time` is a timestamp at which the signing key must be
+                valid.
+
+        Returns:
+            List<Deferred[None]>: for each input triplet, a deferred indicating success
+                or failure to verify each event's signature for the given
+                server_name. The deferreds run their callbacks in the sentinel
+                logcontext.
+        """
+        return self._verify_objects(
+            VerifyJsonRequest.from_event(server_name, event, validity_time)
+            for server_name, event, validity_time in server_and_events
+        )
+
     def _verify_objects(
         self, verify_requests: Iterable[VerifyJsonRequest]
     ) -> List[defer.Deferred]:
@@ -892,7 +970,7 @@ async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
     with PreserveLoggingContext():
         _, key_id, verify_key = await verify_request.key_ready
 
-    json_object = verify_request.json_object
+    json_object = verify_request.get_json_object()
 
     try:
         verify_signed_json(json_object, server_name, verify_key)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 949dcd4614..3fe496dcd3 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -137,11 +137,7 @@ class FederationBase:
         return deferreds
 
 
-class PduToCheckSig(
-    namedtuple(
-        "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
-    )
-):
+class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
     pass
 
 
@@ -184,7 +180,6 @@ def _check_sigs_on_pdus(
     pdus_to_check = [
         PduToCheckSig(
             pdu=p,
-            redacted_pdu_json=prune_event(p).get_pdu_json(),
             sender_domain=get_domain_from_id(p.sender),
             deferreds=[],
         )
@@ -195,13 +190,12 @@ def _check_sigs_on_pdus(
     # (except if its a 3pid invite, in which case it may be sent by any server)
     pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
 
-    more_deferreds = keyring.verify_json_objects_for_server(
+    more_deferreds = keyring.verify_events_for_server(
         [
             (
                 p.sender_domain,
-                p.redacted_pdu_json,
+                p.pdu,
                 p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
-                p.pdu.event_id,
             )
             for p in pdus_to_check_sender
         ]
@@ -230,13 +224,12 @@ def _check_sigs_on_pdus(
             if p.sender_domain != get_domain_from_id(p.pdu.event_id)
         ]
 
-        more_deferreds = keyring.verify_json_objects_for_server(
+        more_deferreds = keyring.verify_events_for_server(
             [
                 (
                     get_domain_from_id(p.pdu.event_id),
-                    p.redacted_pdu_json,
+                    p.pdu,
                     p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
-                    p.pdu.event_id,
                 )
                 for p in pdus_to_check_event_id
             ]