summary refs log tree commit diff
path: root/synapse/crypto
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/crypto')
-rw-r--r--synapse/crypto/keyring.py166
1 files changed, 86 insertions, 80 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 5012c10ee8..d7211ee9b3 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -22,6 +22,7 @@ from synapse.util.logcontext import (
     preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
     preserve_fn
 )
+from synapse.util.metrics import Measure
 
 from twisted.internet import defer
 
@@ -61,6 +62,10 @@ Attributes:
 """
 
 
+class KeyLookupError(ValueError):
+    pass
+
+
 class Keyring(object):
     def __init__(self, hs):
         self.store = hs.get_datastore()
@@ -239,59 +244,60 @@ class Keyring(object):
 
         @defer.inlineCallbacks
         def do_iterations():
-            merged_results = {}
+            with Measure(self.clock, "get_server_verify_keys"):
+                merged_results = {}
 
-            missing_keys = {}
-            for verify_request in verify_requests:
-                missing_keys.setdefault(verify_request.server_name, set()).update(
-                    verify_request.key_ids
-                )
-
-            for fn in key_fetch_fns:
-                results = yield fn(missing_keys.items())
-                merged_results.update(results)
-
-                # We now need to figure out which verify requests we have keys
-                # for and which we don't
                 missing_keys = {}
-                requests_missing_keys = []
                 for verify_request in verify_requests:
-                    server_name = verify_request.server_name
-                    result_keys = merged_results[server_name]
-
-                    if verify_request.deferred.called:
-                        # We've already called this deferred, which probably
-                        # means that we've already found a key for it.
-                        continue
-
-                    for key_id in verify_request.key_ids:
-                        if key_id in result_keys:
-                            with PreserveLoggingContext():
-                                verify_request.deferred.callback((
-                                    server_name,
-                                    key_id,
-                                    result_keys[key_id],
-                                ))
-                            break
-                    else:
-                        # The else block is only reached if the loop above
-                        # doesn't break.
-                        missing_keys.setdefault(server_name, set()).update(
-                            verify_request.key_ids
-                        )
-                        requests_missing_keys.append(verify_request)
-
-                if not missing_keys:
-                    break
-
-            for verify_request in requests_missing_keys.values():
-                verify_request.deferred.errback(SynapseError(
-                    401,
-                    "No key for %s with id %s" % (
-                        verify_request.server_name, verify_request.key_ids,
-                    ),
-                    Codes.UNAUTHORIZED,
-                ))
+                    missing_keys.setdefault(verify_request.server_name, set()).update(
+                        verify_request.key_ids
+                    )
+
+                for fn in key_fetch_fns:
+                    results = yield fn(missing_keys.items())
+                    merged_results.update(results)
+
+                    # We now need to figure out which verify requests we have keys
+                    # for and which we don't
+                    missing_keys = {}
+                    requests_missing_keys = []
+                    for verify_request in verify_requests:
+                        server_name = verify_request.server_name
+                        result_keys = merged_results[server_name]
+
+                        if verify_request.deferred.called:
+                            # We've already called this deferred, which probably
+                            # means that we've already found a key for it.
+                            continue
+
+                        for key_id in verify_request.key_ids:
+                            if key_id in result_keys:
+                                with PreserveLoggingContext():
+                                    verify_request.deferred.callback((
+                                        server_name,
+                                        key_id,
+                                        result_keys[key_id],
+                                    ))
+                                break
+                        else:
+                            # The else block is only reached if the loop above
+                            # doesn't break.
+                            missing_keys.setdefault(server_name, set()).update(
+                                verify_request.key_ids
+                            )
+                            requests_missing_keys.append(verify_request)
+
+                    if not missing_keys:
+                        break
+
+                for verify_request in requests_missing_keys.values():
+                    verify_request.deferred.errback(SynapseError(
+                        401,
+                        "No key for %s with id %s" % (
+                            verify_request.server_name, verify_request.key_ids,
+                        ),
+                        Codes.UNAUTHORIZED,
+                    ))
 
         def on_err(err):
             for verify_request in verify_requests:
@@ -302,15 +308,15 @@ class Keyring(object):
 
     @defer.inlineCallbacks
     def get_keys_from_store(self, server_name_and_key_ids):
-        res = yield defer.gatherResults(
+        res = yield preserve_context_over_deferred(defer.gatherResults(
             [
-                self.store.get_server_verify_keys(
+                preserve_fn(self.store.get_server_verify_keys)(
                     server_name, key_ids
                 ).addCallback(lambda ks, server: (server, ks), server_name)
                 for server_name, key_ids in server_name_and_key_ids
             ],
             consumeErrors=True,
-        ).addErrback(unwrapFirstError)
+        )).addErrback(unwrapFirstError)
 
         defer.returnValue(dict(res))
 
@@ -331,13 +337,13 @@ class Keyring(object):
                 )
                 defer.returnValue({})
 
-        results = yield defer.gatherResults(
+        results = yield preserve_context_over_deferred(defer.gatherResults(
             [
-                get_key(p_name, p_keys)
+                preserve_fn(get_key)(p_name, p_keys)
                 for p_name, p_keys in self.perspective_servers.items()
             ],
             consumeErrors=True,
-        ).addErrback(unwrapFirstError)
+        )).addErrback(unwrapFirstError)
 
         union_of_keys = {}
         for result in results:
@@ -363,7 +369,7 @@ class Keyring(object):
                     )
                 except Exception as e:
                     logger.info(
-                        "Unable to getting key %r for %r directly: %s %s",
+                        "Unable to get key %r for %r directly: %s %s",
                         key_ids, server_name,
                         type(e).__name__, str(e.message),
                     )
@@ -377,13 +383,13 @@ class Keyring(object):
 
             defer.returnValue(keys)
 
-        results = yield defer.gatherResults(
+        results = yield preserve_context_over_deferred(defer.gatherResults(
             [
-                get_key(server_name, key_ids)
+                preserve_fn(get_key)(server_name, key_ids)
                 for server_name, key_ids in server_name_and_key_ids
             ],
             consumeErrors=True,
-        ).addErrback(unwrapFirstError)
+        )).addErrback(unwrapFirstError)
 
         merged = {}
         for result in results:
@@ -425,7 +431,7 @@ class Keyring(object):
         for response in responses:
             if (u"signatures" not in response
                     or perspective_name not in response[u"signatures"]):
-                raise ValueError(
+                raise KeyLookupError(
                     "Key response not signed by perspective server"
                     " %r" % (perspective_name,)
                 )
@@ -448,7 +454,7 @@ class Keyring(object):
                     list(response[u"signatures"][perspective_name]),
                     list(perspective_keys)
                 )
-                raise ValueError(
+                raise KeyLookupError(
                     "Response not signed with a known key for perspective"
                     " server %r" % (perspective_name,)
                 )
@@ -460,9 +466,9 @@ class Keyring(object):
             for server_name, response_keys in processed_response.items():
                 keys.setdefault(server_name, {}).update(response_keys)
 
-        yield defer.gatherResults(
+        yield preserve_context_over_deferred(defer.gatherResults(
             [
-                self.store_keys(
+                preserve_fn(self.store_keys)(
                     server_name=server_name,
                     from_server=perspective_name,
                     verify_keys=response_keys,
@@ -470,7 +476,7 @@ class Keyring(object):
                 for server_name, response_keys in keys.items()
             ],
             consumeErrors=True
-        ).addErrback(unwrapFirstError)
+        )).addErrback(unwrapFirstError)
 
         defer.returnValue(keys)
 
@@ -491,10 +497,10 @@ class Keyring(object):
 
             if (u"signatures" not in response
                     or server_name not in response[u"signatures"]):
-                raise ValueError("Key response not signed by remote server")
+                raise KeyLookupError("Key response not signed by remote server")
 
             if "tls_fingerprints" not in response:
-                raise ValueError("Key response missing TLS fingerprints")
+                raise KeyLookupError("Key response missing TLS fingerprints")
 
             certificate_bytes = crypto.dump_certificate(
                 crypto.FILETYPE_ASN1, tls_certificate
@@ -508,7 +514,7 @@ class Keyring(object):
                     response_sha256_fingerprints.add(fingerprint[u"sha256"])
 
             if sha256_fingerprint_b64 not in response_sha256_fingerprints:
-                raise ValueError("TLS certificate not allowed by fingerprints")
+                raise KeyLookupError("TLS certificate not allowed by fingerprints")
 
             response_keys = yield self.process_v2_response(
                 from_server=server_name,
@@ -518,7 +524,7 @@ class Keyring(object):
 
             keys.update(response_keys)
 
-        yield defer.gatherResults(
+        yield preserve_context_over_deferred(defer.gatherResults(
             [
                 preserve_fn(self.store_keys)(
                     server_name=key_server_name,
@@ -528,7 +534,7 @@ class Keyring(object):
                 for key_server_name, verify_keys in keys.items()
             ],
             consumeErrors=True
-        ).addErrback(unwrapFirstError)
+        )).addErrback(unwrapFirstError)
 
         defer.returnValue(keys)
 
@@ -560,14 +566,14 @@ class Keyring(object):
         server_name = response_json["server_name"]
         if only_from_server:
             if server_name != from_server:
-                raise ValueError(
+                raise KeyLookupError(
                     "Expected a response for server %r not %r" % (
                         from_server, server_name
                     )
                 )
         for key_id in response_json["signatures"].get(server_name, {}):
             if key_id not in response_json["verify_keys"]:
-                raise ValueError(
+                raise KeyLookupError(
                     "Key response must include verification keys for all"
                     " signatures"
                 )
@@ -594,7 +600,7 @@ class Keyring(object):
         response_keys.update(verify_keys)
         response_keys.update(old_verify_keys)
 
-        yield defer.gatherResults(
+        yield preserve_context_over_deferred(defer.gatherResults(
             [
                 preserve_fn(self.store.store_server_keys_json)(
                     server_name=server_name,
@@ -607,7 +613,7 @@ class Keyring(object):
                 for key_id in updated_key_ids
             ],
             consumeErrors=True,
-        ).addErrback(unwrapFirstError)
+        )).addErrback(unwrapFirstError)
 
         results[server_name] = response_keys
 
@@ -635,15 +641,15 @@ class Keyring(object):
 
         if ("signatures" not in response
                 or server_name not in response["signatures"]):
-            raise ValueError("Key response not signed by remote server")
+            raise KeyLookupError("Key response not signed by remote server")
 
         if "tls_certificate" not in response:
-            raise ValueError("Key response missing TLS certificate")
+            raise KeyLookupError("Key response missing TLS certificate")
 
         tls_certificate_b64 = response["tls_certificate"]
 
         if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
-            raise ValueError("TLS certificate doesn't match")
+            raise KeyLookupError("TLS certificate doesn't match")
 
         # Cache the result in the datastore.
 
@@ -659,7 +665,7 @@ class Keyring(object):
 
         for key_id in response["signatures"][server_name]:
             if key_id not in response["verify_keys"]:
-                raise ValueError(
+                raise KeyLookupError(
                     "Key response must include verification keys for all"
                     " signatures"
                 )
@@ -696,7 +702,7 @@ class Keyring(object):
             A deferred that completes when the keys are stored.
         """
         # TODO(markjh): Store whether the keys have expired.
-        yield defer.gatherResults(
+        yield preserve_context_over_deferred(defer.gatherResults(
             [
                 preserve_fn(self.store.store_server_verify_key)(
                     server_name, server_name, key.time_added, key
@@ -704,4 +710,4 @@ class Keyring(object):
                 for key_id, key in verify_keys.items()
             ],
             consumeErrors=True,
-        ).addErrback(unwrapFirstError)
+        )).addErrback(unwrapFirstError)