summary refs log tree commit diff
path: root/synapse/crypto/keyring.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/crypto/keyring.py')
-rw-r--r--synapse/crypto/keyring.py286
1 files changed, 137 insertions, 149 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index d8ba870cca..5cc98542ce 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -56,9 +56,9 @@ from synapse.util.retryutils import NotRetryingDestination
 logger = logging.getLogger(__name__)
 
 
-VerifyKeyRequest = namedtuple("VerifyRequest", (
-    "server_name", "key_ids", "json_object", "deferred"
-))
+VerifyKeyRequest = namedtuple(
+    "VerifyRequest", ("server_name", "key_ids", "json_object", "deferred")
+)
 """
 A request for a verify key to verify a JSON object.
 
@@ -96,9 +96,7 @@ class Keyring(object):
 
     def verify_json_for_server(self, server_name, json_object):
         return logcontext.make_deferred_yieldable(
-            self.verify_json_objects_for_server(
-                [(server_name, json_object)]
-            )[0]
+            self.verify_json_objects_for_server([(server_name, json_object)])[0]
         )
 
     def verify_json_objects_for_server(self, server_and_json):
@@ -130,18 +128,15 @@ class Keyring(object):
             if not key_ids:
                 return defer.fail(
                     SynapseError(
-                        400,
-                        "Not signed by %s" % (server_name,),
-                        Codes.UNAUTHORIZED,
+                        400, "Not signed by %s" % (server_name,), Codes.UNAUTHORIZED
                     )
                 )
 
-            logger.debug("Verifying for %s with key_ids %s",
-                         server_name, key_ids)
+            logger.debug("Verifying for %s with key_ids %s", server_name, key_ids)
 
             # add the key request to the queue, but don't start it off yet.
             verify_request = VerifyKeyRequest(
-                server_name, key_ids, json_object, defer.Deferred(),
+                server_name, key_ids, json_object, defer.Deferred()
             )
             verify_requests.append(verify_request)
 
@@ -179,15 +174,13 @@ class Keyring(object):
             # any other lookups until we have finished.
             # The deferreds are called with no logcontext.
             server_to_deferred = {
-                rq.server_name: defer.Deferred()
-                for rq in verify_requests
+                rq.server_name: defer.Deferred() for rq in verify_requests
             }
 
             # We want to wait for any previous lookups to complete before
             # proceeding.
             yield self.wait_for_previous_lookups(
-                [rq.server_name for rq in verify_requests],
-                server_to_deferred,
+                [rq.server_name for rq in verify_requests], server_to_deferred
             )
 
             # Actually start fetching keys.
@@ -216,9 +209,7 @@ class Keyring(object):
                 return res
 
             for verify_request in verify_requests:
-                verify_request.deferred.addBoth(
-                    remove_deferreds, verify_request,
-                )
+                verify_request.deferred.addBoth(remove_deferreds, verify_request)
         except Exception:
             logger.exception("Error starting key lookups")
 
@@ -248,7 +239,8 @@ class Keyring(object):
                 break
             logger.info(
                 "Waiting for existing lookups for %s to complete [loop %i]",
-                [w[0] for w in wait_on], loop_count,
+                [w[0] for w in wait_on],
+                loop_count,
             )
             with PreserveLoggingContext():
                 yield defer.DeferredList((w[1] for w in wait_on))
@@ -335,13 +327,14 @@ class Keyring(object):
 
                 with PreserveLoggingContext():
                     for verify_request in requests_missing_keys:
-                        verify_request.deferred.errback(SynapseError(
-                            401,
-                            "No key for %s with id %s" % (
-                                verify_request.server_name, verify_request.key_ids,
-                            ),
-                            Codes.UNAUTHORIZED,
-                        ))
+                        verify_request.deferred.errback(
+                            SynapseError(
+                                401,
+                                "No key for %s with id %s"
+                                % (verify_request.server_name, verify_request.key_ids),
+                                Codes.UNAUTHORIZED,
+                            )
+                        )
 
         def on_err(err):
             with PreserveLoggingContext():
@@ -383,25 +376,26 @@ class Keyring(object):
                 )
                 defer.returnValue(result)
             except KeyLookupError as e:
-                logger.warning(
-                    "Key lookup failed from %r: %s", perspective_name, e,
-                )
+                logger.warning("Key lookup failed from %r: %s", perspective_name, e)
             except Exception as e:
                 logger.exception(
                     "Unable to get key from %r: %s %s",
                     perspective_name,
-                    type(e).__name__, str(e),
+                    type(e).__name__,
+                    str(e),
                 )
 
             defer.returnValue({})
 
-        results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
-            [
-                run_in_background(get_key, p_name, p_keys)
-                for p_name, p_keys in self.perspective_servers.items()
-            ],
-            consumeErrors=True,
-        ).addErrback(unwrapFirstError))
+        results = yield logcontext.make_deferred_yieldable(
+            defer.gatherResults(
+                [
+                    run_in_background(get_key, p_name, p_keys)
+                    for p_name, p_keys in self.perspective_servers.items()
+                ],
+                consumeErrors=True,
+            ).addErrback(unwrapFirstError)
+        )
 
         union_of_keys = {}
         for result in results:
@@ -412,32 +406,30 @@ class Keyring(object):
 
     @defer.inlineCallbacks
     def get_keys_from_server(self, server_name_and_key_ids):
-        results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
-            [
-                run_in_background(
-                    self.get_server_verify_key_v2_direct,
-                    server_name,
-                    key_ids,
-                )
-                for server_name, key_ids in server_name_and_key_ids
-            ],
-            consumeErrors=True,
-        ).addErrback(unwrapFirstError))
+        results = yield logcontext.make_deferred_yieldable(
+            defer.gatherResults(
+                [
+                    run_in_background(
+                        self.get_server_verify_key_v2_direct, server_name, key_ids
+                    )
+                    for server_name, key_ids in server_name_and_key_ids
+                ],
+                consumeErrors=True,
+            ).addErrback(unwrapFirstError)
+        )
 
         merged = {}
         for result in results:
             merged.update(result)
 
-        defer.returnValue({
-            server_name: keys
-            for server_name, keys in merged.items()
-            if keys
-        })
+        defer.returnValue(
+            {server_name: keys for server_name, keys in merged.items() if keys}
+        )
 
     @defer.inlineCallbacks
-    def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
-                                          perspective_name,
-                                          perspective_keys):
+    def get_server_verify_key_v2_indirect(
+        self, server_names_and_key_ids, perspective_name, perspective_keys
+    ):
         # TODO(mark): Set the minimum_valid_until_ts to that needed by
         # the events being validated or the current time if validating
         # an incoming request.
@@ -448,9 +440,7 @@ class Keyring(object):
                 data={
                     u"server_keys": {
                         server_name: {
-                            key_id: {
-                                u"minimum_valid_until_ts": 0
-                            } for key_id in key_ids
+                            key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids
                         }
                         for server_name, key_ids in server_names_and_key_ids
                     }
@@ -458,21 +448,19 @@ class Keyring(object):
                 long_retries=True,
             )
         except (NotRetryingDestination, RequestSendFailed) as e:
-            raise_from(
-                KeyLookupError("Failed to connect to remote server"), e,
-            )
+            raise_from(KeyLookupError("Failed to connect to remote server"), e)
         except HttpResponseException as e:
-            raise_from(
-                KeyLookupError("Remote server returned an error"), e,
-            )
+            raise_from(KeyLookupError("Remote server returned an error"), e)
 
         keys = {}
 
         responses = query_response["server_keys"]
 
         for response in responses:
-            if (u"signatures" not in response
-                    or perspective_name not in response[u"signatures"]):
+            if (
+                u"signatures" not in response
+                or perspective_name not in response[u"signatures"]
+            ):
                 raise KeyLookupError(
                     "Key response not signed by perspective server"
                     " %r" % (perspective_name,)
@@ -482,9 +470,7 @@ class Keyring(object):
             for key_id in response[u"signatures"][perspective_name]:
                 if key_id in perspective_keys:
                     verify_signed_json(
-                        response,
-                        perspective_name,
-                        perspective_keys[key_id]
+                        response, perspective_name, perspective_keys[key_id]
                     )
                     verified = True
 
@@ -494,7 +480,7 @@ class Keyring(object):
                     " known key, signed with: %r, known keys: %r",
                     perspective_name,
                     list(response[u"signatures"][perspective_name]),
-                    list(perspective_keys)
+                    list(perspective_keys),
                 )
                 raise KeyLookupError(
                     "Response not signed with a known key for perspective"
@@ -508,18 +494,20 @@ class Keyring(object):
 
             keys.setdefault(server_name, {}).update(processed_response)
 
-        yield logcontext.make_deferred_yieldable(defer.gatherResults(
-            [
-                run_in_background(
-                    self.store_keys,
-                    server_name=server_name,
-                    from_server=perspective_name,
-                    verify_keys=response_keys,
-                )
-                for server_name, response_keys in keys.items()
-            ],
-            consumeErrors=True
-        ).addErrback(unwrapFirstError))
+        yield logcontext.make_deferred_yieldable(
+            defer.gatherResults(
+                [
+                    run_in_background(
+                        self.store_keys,
+                        server_name=server_name,
+                        from_server=perspective_name,
+                        verify_keys=response_keys,
+                    )
+                    for server_name, response_keys in keys.items()
+                ],
+                consumeErrors=True,
+            ).addErrback(unwrapFirstError)
+        )
 
         defer.returnValue(keys)
 
@@ -534,26 +522,26 @@ class Keyring(object):
             try:
                 response = yield self.client.get_json(
                     destination=server_name,
-                    path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id),
+                    path="/_matrix/key/v2/server/"
+                    + urllib.parse.quote(requested_key_id),
                     ignore_backoff=True,
                 )
             except (NotRetryingDestination, RequestSendFailed) as e:
-                raise_from(
-                    KeyLookupError("Failed to connect to remote server"), e,
-                )
+                raise_from(KeyLookupError("Failed to connect to remote server"), e)
             except HttpResponseException as e:
-                raise_from(
-                    KeyLookupError("Remote server returned an error"), e,
-                )
+                raise_from(KeyLookupError("Remote server returned an error"), e)
 
-            if (u"signatures" not in response
-                    or server_name not in response[u"signatures"]):
+            if (
+                u"signatures" not in response
+                or server_name not in response[u"signatures"]
+            ):
                 raise KeyLookupError("Key response not signed by remote server")
 
             if response["server_name"] != server_name:
-                raise KeyLookupError("Expected a response for server %r not %r" % (
-                    server_name, response["server_name"]
-                ))
+                raise KeyLookupError(
+                    "Expected a response for server %r not %r"
+                    % (server_name, response["server_name"])
+                )
 
             response_keys = yield self.process_v2_response(
                 from_server=server_name,
@@ -564,16 +552,12 @@ class Keyring(object):
             keys.update(response_keys)
 
         yield self.store_keys(
-            server_name=server_name,
-            from_server=server_name,
-            verify_keys=keys,
+            server_name=server_name, from_server=server_name, verify_keys=keys
         )
         defer.returnValue({server_name: keys})
 
     @defer.inlineCallbacks
-    def process_v2_response(
-        self, from_server, response_json, requested_ids=[],
-    ):
+    def process_v2_response(self, from_server, response_json, requested_ids=[]):
         """Parse a 'Server Keys' structure from the result of a /key request
 
         This is used to parse either the entirety of the response from
@@ -627,20 +611,13 @@ class Keyring(object):
         for key_id in response_json["signatures"].get(server_name, {}):
             if key_id not in response_json["verify_keys"]:
                 raise KeyLookupError(
-                    "Key response must include verification keys for all"
-                    " signatures"
+                    "Key response must include verification keys for all" " signatures"
                 )
             if key_id in verify_keys:
-                verify_signed_json(
-                    response_json,
-                    server_name,
-                    verify_keys[key_id]
-                )
+                verify_signed_json(response_json, server_name, verify_keys[key_id])
 
         signed_key_json = sign_json(
-            response_json,
-            self.config.server_name,
-            self.config.signing_key[0],
+            response_json, self.config.server_name, self.config.signing_key[0]
         )
 
         signed_key_json_bytes = encode_canonical_json(signed_key_json)
@@ -653,21 +630,23 @@ class Keyring(object):
         response_keys.update(verify_keys)
         response_keys.update(old_verify_keys)
 
-        yield logcontext.make_deferred_yieldable(defer.gatherResults(
-            [
-                run_in_background(
-                    self.store.store_server_keys_json,
-                    server_name=server_name,
-                    key_id=key_id,
-                    from_server=from_server,
-                    ts_now_ms=time_now_ms,
-                    ts_expires_ms=ts_valid_until_ms,
-                    key_json_bytes=signed_key_json_bytes,
-                )
-                for key_id in updated_key_ids
-            ],
-            consumeErrors=True,
-        ).addErrback(unwrapFirstError))
+        yield logcontext.make_deferred_yieldable(
+            defer.gatherResults(
+                [
+                    run_in_background(
+                        self.store.store_server_keys_json,
+                        server_name=server_name,
+                        key_id=key_id,
+                        from_server=from_server,
+                        ts_now_ms=time_now_ms,
+                        ts_expires_ms=ts_valid_until_ms,
+                        key_json_bytes=signed_key_json_bytes,
+                    )
+                    for key_id in updated_key_ids
+                ],
+                consumeErrors=True,
+            ).addErrback(unwrapFirstError)
+        )
 
         defer.returnValue(response_keys)
 
@@ -681,16 +660,21 @@ class Keyring(object):
             A deferred that completes when the keys are stored.
         """
         # TODO(markjh): Store whether the keys have expired.
-        return logcontext.make_deferred_yieldable(defer.gatherResults(
-            [
-                run_in_background(
-                    self.store.store_server_verify_key,
-                    server_name, server_name, key.time_added, key
-                )
-                for key_id, key in verify_keys.items()
-            ],
-            consumeErrors=True,
-        ).addErrback(unwrapFirstError))
+        return logcontext.make_deferred_yieldable(
+            defer.gatherResults(
+                [
+                    run_in_background(
+                        self.store.store_server_verify_key,
+                        server_name,
+                        server_name,
+                        key.time_added,
+                        key,
+                    )
+                    for key_id, key in verify_keys.items()
+                ],
+                consumeErrors=True,
+            ).addErrback(unwrapFirstError)
+        )
 
 
 @defer.inlineCallbacks
@@ -713,17 +697,19 @@ def _handle_key_deferred(verify_request):
     except KeyLookupError as e:
         logger.warn(
             "Failed to download keys for %s: %s %s",
-            server_name, type(e).__name__, str(e),
+            server_name,
+            type(e).__name__,
+            str(e),
         )
         raise SynapseError(
-            502,
-            "Error downloading keys for %s" % (server_name,),
-            Codes.UNAUTHORIZED,
+            502, "Error downloading keys for %s" % (server_name,), Codes.UNAUTHORIZED
         )
     except Exception as e:
         logger.exception(
             "Got Exception when downloading keys for %s: %s %s",
-            server_name, type(e).__name__, str(e),
+            server_name,
+            type(e).__name__,
+            str(e),
         )
         raise SynapseError(
             401,
@@ -733,22 +719,24 @@ def _handle_key_deferred(verify_request):
 
     json_object = verify_request.json_object
 
-    logger.debug("Got key %s %s:%s for server %s, verifying" % (
-        key_id, verify_key.alg, verify_key.version, server_name,
-    ))
+    logger.debug(
+        "Got key %s %s:%s for server %s, verifying"
+        % (key_id, verify_key.alg, verify_key.version, server_name)
+    )
     try:
         verify_signed_json(json_object, server_name, verify_key)
     except SignatureVerifyException as e:
         logger.debug(
             "Error verifying signature for %s:%s:%s with key %s: %s",
-            server_name, verify_key.alg, verify_key.version,
+            server_name,
+            verify_key.alg,
+            verify_key.version,
             encode_verify_key_base64(verify_key),
             str(e),
         )
         raise SynapseError(
             401,
-            "Invalid signature for server %s with key %s:%s: %s" % (
-                server_name, verify_key.alg, verify_key.version, str(e),
-            ),
+            "Invalid signature for server %s with key %s:%s: %s"
+            % (server_name, verify_key.alg, verify_key.version, str(e)),
             Codes.UNAUTHORIZED,
         )