summary refs log tree commit diff
path: root/synapse/crypto
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/crypto')
-rw-r--r--synapse/crypto/keyring.py112
1 files changed, 68 insertions, 44 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index e94e71bdad..2b6b5913bc 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -60,9 +60,9 @@ logger = logging.getLogger(__name__)
 
 
 @attr.s(slots=True, cmp=False)
-class VerifyKeyRequest(object):
+class VerifyJsonRequest(object):
     """
-    A request for a verify key to verify a JSON object.
+    A request to verify a JSON object.
 
     Attributes:
         server_name(str): The name of the server to verify against.
@@ -85,11 +85,15 @@ class VerifyKeyRequest(object):
     """
 
     server_name = attr.ib()
-    key_ids = attr.ib()
     json_object = attr.ib()
     minimum_valid_until_ts = attr.ib()
+    request_name = attr.ib()
+    key_ids = attr.ib(init=False)
     key_ready = attr.ib(default=attr.Factory(defer.Deferred))
 
+    def __attrs_post_init__(self):
+        self.key_ids = signature_ids(self.json_object, self.server_name)
+
 
 class KeyLookupError(ValueError):
     pass
@@ -114,7 +118,9 @@ class Keyring(object):
         # These are regular, logcontext-agnostic Deferreds.
         self.key_downloads = {}
 
-    def verify_json_for_server(self, server_name, json_object, validity_time):
+    def verify_json_for_server(
+        self, server_name, json_object, validity_time, request_name
+    ):
         """Verify that a JSON object has been signed by a given server
 
         Args:
@@ -125,24 +131,31 @@ class Keyring(object):
             validity_time (int): timestamp at which we require the signing key to
                 be valid. (0 implies we don't care)
 
+            request_name (str): an identifier for this json object (eg, an event id)
+                for logging.
+
         Returns:
             Deferred[None]: completes if the the object was correctly signed, otherwise
                 errbacks with an error
         """
-        req = server_name, json_object, validity_time
-
-        return logcontext.make_deferred_yieldable(
-            self.verify_json_objects_for_server((req,))[0]
-        )
+        req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
+        requests = (req,)
+        return logcontext.make_deferred_yieldable(self._verify_objects(requests)[0])
 
     def verify_json_objects_for_server(self, server_and_json):
         """Bulk verifies signatures of json objects, bulk fetching keys as
         necessary.
 
         Args:
-            server_and_json (iterable[Tuple[str, dict, int]):
-                Iterable of triplets of (server_name, json_object, validity_time)
-                validity_time is a timestamp at which the signing key must be valid.
+            server_and_json (iterable[Tuple[str, dict, int, str]):
+                Iterable of (server_name, json_object, validity_time, request_name)
+                tuples.
+
+                validity_time is a timestamp at which the signing key must be
+                valid.
+
+                request_name is an identifier for this json object (eg, an event id)
+                for logging.
 
         Returns:
             List<Deferred[None]>: for each input triplet, a deferred indicating success
@@ -150,38 +163,54 @@ class Keyring(object):
                 server_name. The deferreds run their callbacks in the sentinel
                 logcontext.
         """
-        # a list of VerifyKeyRequests
-        verify_requests = []
+        return self._verify_objects(
+            VerifyJsonRequest(server_name, json_object, validity_time, request_name)
+            for server_name, json_object, validity_time, request_name in server_and_json
+        )
+
+    def _verify_objects(self, verify_requests):
+        """Does the work of verify_json_[objects_]for_server
+
+
+        Args:
+            verify_requests (iterable[VerifyJsonRequest]):
+                Iterable of verification requests.
+
+        Returns:
+            List<Deferred[None]>: for each input item, a deferred indicating success
+                or failure to verify each json object's signature for the given
+                server_name. The deferreds run their callbacks in the sentinel
+                logcontext.
+        """
+        # a list of VerifyJsonRequests which are awaiting a key lookup
+        key_lookups = []
         handle = preserve_fn(_handle_key_deferred)
 
-        def process(server_name, json_object, validity_time):
+        def process(verify_request):
             """Process an entry in the request list
 
-            Given a (server_name, json_object, validity_time) triplet from the request
-            list, adds a key request to verify_requests, and returns a deferred which
+            Adds a key request to key_lookups, and returns a deferred which
             will complete or fail (in the sentinel context) when verification completes.
             """
-            key_ids = signature_ids(json_object, server_name)
-
-            if not key_ids:
+            if not verify_request.key_ids:
                 return defer.fail(
                     SynapseError(
-                        400, "Not signed by %s" % (server_name,), Codes.UNAUTHORIZED
+                        400,
+                        "Not signed by %s" % (verify_request.server_name,),
+                        Codes.UNAUTHORIZED,
                     )
                 )
 
             logger.debug(
-                "Verifying for %s with key_ids %s, min_validity %i",
-                server_name,
-                key_ids,
-                validity_time,
+                "Verifying %s for %s with key_ids %s, min_validity %i",
+                verify_request.request_name,
+                verify_request.server_name,
+                verify_request.key_ids,
+                verify_request.minimum_valid_until_ts,
             )
 
             # add the key request to the queue, but don't start it off yet.
-            verify_request = VerifyKeyRequest(
-                server_name, key_ids, json_object, validity_time
-            )
-            verify_requests.append(verify_request)
+            key_lookups.append(verify_request)
 
             # now run _handle_key_deferred, which will wait for the key request
             # to complete and then do the verification.
@@ -190,13 +219,10 @@ class Keyring(object):
             # wrap it with preserve_fn (aka run_in_background)
             return handle(verify_request)
 
-        results = [
-            process(server_name, json_object, validity_time)
-            for server_name, json_object, validity_time in server_and_json
-        ]
+        results = [process(r) for r in verify_requests]
 
-        if verify_requests:
-            run_in_background(self._start_key_lookups, verify_requests)
+        if key_lookups:
+            run_in_background(self._start_key_lookups, key_lookups)
 
         return results
 
@@ -207,7 +233,7 @@ class Keyring(object):
         Once each fetch completes, verify_request.key_ready will be resolved.
 
         Args:
-            verify_requests (List[VerifyKeyRequest]):
+            verify_requests (List[VerifyJsonRequest]):
         """
 
         try:
@@ -308,7 +334,7 @@ class Keyring(object):
         with a SynapseError if none of the keys are found.
 
         Args:
-            verify_requests (list[VerifyKeyRequest]): list of verify requests
+            verify_requests (list[VerifyJsonRequest]): list of verify requests
         """
 
         remaining_requests = set(
@@ -357,7 +383,7 @@ class Keyring(object):
 
         Args:
             fetcher (KeyFetcher): fetcher to use to fetch the keys
-            remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
+            remaining_requests (set[VerifyJsonRequest]): outstanding key requests.
                 Any successfully-completed requests will be removed from the list.
         """
         # dict[str, dict[str, int]]: keys to fetch.
@@ -376,7 +402,7 @@ class Keyring(object):
                 # the requests.
                 keys_for_server[key_id] = max(
                     keys_for_server.get(key_id, -1),
-                    verify_request.minimum_valid_until_ts
+                    verify_request.minimum_valid_until_ts,
                 )
 
         results = yield fetcher.get_keys(missing_keys)
@@ -386,7 +412,7 @@ class Keyring(object):
             server_name = verify_request.server_name
 
             # see if any of the keys we got this time are sufficient to
-            # complete this VerifyKeyRequest.
+            # complete this VerifyJsonRequest.
             result_keys = results.get(server_name, {})
             for key_id in verify_request.key_ids:
                 fetch_key_result = result_keys.get(key_id)
@@ -454,9 +480,7 @@ class BaseV2KeyFetcher(object):
         self.config = hs.get_config()
 
     @defer.inlineCallbacks
-    def process_v2_response(
-        self, from_server, response_json, time_added_ms
-    ):
+    def process_v2_response(self, from_server, response_json, time_added_ms):
         """Parse a 'Server Keys' structure from the result of a /key request
 
         This is used to parse either the entirety of the response from
@@ -852,7 +876,7 @@ def _handle_key_deferred(verify_request):
     """Waits for the key to become available, and then performs a verification
 
     Args:
-        verify_request (VerifyKeyRequest):
+        verify_request (VerifyJsonRequest):
 
     Returns:
         Deferred[None]