summary refs log tree commit diff
path: root/synapse/crypto/keyring.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/crypto/keyring.py')
-rw-r--r--synapse/crypto/keyring.py98
1 files changed, 88 insertions, 10 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 5f18ef7748..6fc0712978 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -17,7 +17,7 @@ import abc
 import logging
 import urllib
 from collections import defaultdict
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple
 
 import attr
 from signedjson.key import (
@@ -42,6 +42,8 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.config.key import TrustedKeyServer
+from synapse.events import EventBase
+from synapse.events.utils import prune_event_dict
 from synapse.logging.context import (
     PreserveLoggingContext,
     make_deferred_yieldable,
@@ -69,7 +71,11 @@ class VerifyJsonRequest:
     Attributes:
         server_name: The name of the server to verify against.
 
-        json_object: The JSON object to verify.
+        get_json_object: A callback to fetch the JSON object to verify.
+            A callback is used to allow deferring the creation of the JSON
+            object to verify until needed, e.g. for events we can defer
+            creating the redacted copy. This reduces the memory usage when
+            there are large numbers of in flight requests.
 
         minimum_valid_until_ts: time at which we require the signing key to
             be valid. (0 implies we don't care)
@@ -88,14 +94,50 @@ class VerifyJsonRequest:
     """
 
     server_name = attr.ib(type=str)
-    json_object = attr.ib(type=JsonDict)
+    get_json_object = attr.ib(type=Callable[[], JsonDict])
     minimum_valid_until_ts = attr.ib(type=int)
     request_name = attr.ib(type=str)
-    key_ids = attr.ib(init=False, type=List[str])
+    key_ids = attr.ib(type=List[str])
     key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
 
-    def __attrs_post_init__(self):
-        self.key_ids = signature_ids(self.json_object, self.server_name)
+    @staticmethod
+    def from_json_object(
+        server_name: str,
+        json_object: JsonDict,
+        minimum_valid_until_ms: int,
+        request_name: str,
+    ):
+        """Create a VerifyJsonRequest to verify all signatures on a signed JSON
+        object for the given server.
+        """
+        key_ids = signature_ids(json_object, server_name)
+        return VerifyJsonRequest(
+            server_name,
+            lambda: json_object,
+            minimum_valid_until_ms,
+            request_name=request_name,
+            key_ids=key_ids,
+        )
+
+    @staticmethod
+    def from_event(
+        server_name: str,
+        event: EventBase,
+        minimum_valid_until_ms: int,
+    ):
+        """Create a VerifyJsonRequest to verify all signatures on an event
+        object for the given server.
+        """
+        key_ids = list(event.signatures.get(server_name, []))
+        return VerifyJsonRequest(
+            server_name,
+            # We defer creating the redacted json object, as it uses a lot more
+            # memory than the Event object itself.
+            lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
+            minimum_valid_until_ms,
+            request_name=event.event_id,
+            key_ids=key_ids,
+        )
 
 
 class KeyLookupError(ValueError):
@@ -147,8 +189,13 @@ class Keyring:
             Deferred[None]: completes if the the object was correctly signed, otherwise
                 errbacks with an error
         """
-        req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
-        requests = (req,)
+        request = VerifyJsonRequest.from_json_object(
+            server_name,
+            json_object,
+            validity_time,
+            request_name,
+        )
+        requests = (request,)
         return make_deferred_yieldable(self._verify_objects(requests)[0])
 
     def verify_json_objects_for_server(
@@ -175,10 +222,41 @@ class Keyring:
                 logcontext.
         """
         return self._verify_objects(
-            VerifyJsonRequest(server_name, json_object, validity_time, request_name)
+            VerifyJsonRequest.from_json_object(
+                server_name, json_object, validity_time, request_name
+            )
             for server_name, json_object, validity_time, request_name in server_and_json
         )
 
+    def verify_events_for_server(
+        self, server_and_events: Iterable[Tuple[str, EventBase, int]]
+    ) -> List[defer.Deferred]:
+        """Bulk verification of signatures on events.
+
+        Args:
+            server_and_events:
+                Iterable of `(server_name, event, validity_time)` tuples.
+
+                `server_name` is which server we are verifying the signature for
+                on the event.
+
+                `event` is the event that we'll verify the signatures of for
+                the given `server_name`.
+
+                `validity_time` is a timestamp at which the signing key must be
+                valid.
+
+        Returns:
+            List<Deferred[None]>: for each input triplet, a deferred indicating success
+                or failure to verify each event's signature for the given
+                server_name. The deferreds run their callbacks in the sentinel
+                logcontext.
+        """
+        return self._verify_objects(
+            VerifyJsonRequest.from_event(server_name, event, validity_time)
+            for server_name, event, validity_time in server_and_events
+        )
+
     def _verify_objects(
         self, verify_requests: Iterable[VerifyJsonRequest]
     ) -> List[defer.Deferred]:
@@ -892,7 +970,7 @@ async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
     with PreserveLoggingContext():
         _, key_id, verify_key = await verify_request.key_ready
 
-    json_object = verify_request.json_object
+    json_object = verify_request.get_json_object()
 
     try:
         verify_signed_json(json_object, server_name, verify_key)