summary refs log tree commit diff
path: root/synapse/crypto/keyring.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/crypto/keyring.py')
-rw-r--r--synapse/crypto/keyring.py343
1 files changed, 204 insertions, 139 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index d7211ee9b3..22ee0fc93f 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,12 +16,11 @@
 
 from synapse.crypto.keyclient import fetch_server_key
 from synapse.api.errors import SynapseError, Codes
-from synapse.util.retryutils import get_retry_limiter
-from synapse.util import unwrapFirstError
-from synapse.util.async import ObservableDeferred
+from synapse.util import unwrapFirstError, logcontext
 from synapse.util.logcontext import (
-    preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
-    preserve_fn
+    PreserveLoggingContext,
+    preserve_fn,
+    run_in_background,
 )
 from synapse.util.metrics import Measure
 
@@ -58,7 +58,8 @@ Attributes:
     json_object(dict): The JSON object to verify.
     deferred(twisted.internet.defer.Deferred):
         A deferred (server_name, key_id, verify_key) tuple that resolves when
-        a verify key has been fetched
+        a verify key has been fetched. The deferreds' callbacks are run with no
+        logcontext.
 """
 
 
@@ -75,31 +76,41 @@ class Keyring(object):
         self.perspective_servers = self.config.perspectives
         self.hs = hs
 
+        # map from server name to Deferred. Has an entry for each server with
+        # an ongoing key download; the Deferred completes once the download
+        # completes.
+        #
+        # These are regular, logcontext-agnostic Deferreds.
         self.key_downloads = {}
 
     def verify_json_for_server(self, server_name, json_object):
-        return self.verify_json_objects_for_server(
-            [(server_name, json_object)]
-        )[0]
+        return logcontext.make_deferred_yieldable(
+            self.verify_json_objects_for_server(
+                [(server_name, json_object)]
+            )[0]
+        )
 
     def verify_json_objects_for_server(self, server_and_json):
-        """Bulk verfies signatures of json objects, bulk fetching keys as
+        """Bulk verifies signatures of json objects, bulk fetching keys as
         necessary.
 
         Args:
             server_and_json (list): List of pairs of (server_name, json_object)
 
         Returns:
-            list of deferreds indicating success or failure to verify each
-            json object's signature for the given server_name.
+            List<Deferred>: for each input pair, a deferred indicating success
+                or failure to verify each json object's signature for the given
+                server_name. The deferreds run their callbacks in the sentinel
+                logcontext.
         """
         verify_requests = []
 
         for server_name, json_object in server_and_json:
-            logger.debug("Verifying for %s", server_name)
 
             key_ids = signature_ids(json_object, server_name)
             if not key_ids:
+                logger.warn("Request from %s: no supported signature keys",
+                            server_name)
                 deferred = defer.fail(SynapseError(
                     400,
                     "Not signed with a supported algorithm",
@@ -108,76 +119,69 @@ class Keyring(object):
             else:
                 deferred = defer.Deferred()
 
+            logger.debug("Verifying for %s with key_ids %s",
+                         server_name, key_ids)
+
             verify_request = VerifyKeyRequest(
                 server_name, key_ids, json_object, deferred
             )
 
             verify_requests.append(verify_request)
 
-        @defer.inlineCallbacks
-        def handle_key_deferred(verify_request):
-            server_name = verify_request.server_name
-            try:
-                _, key_id, verify_key = yield verify_request.deferred
-            except IOError as e:
-                logger.warn(
-                    "Got IOError when downloading keys for %s: %s %s",
-                    server_name, type(e).__name__, str(e.message),
-                )
-                raise SynapseError(
-                    502,
-                    "Error downloading keys for %s" % (server_name,),
-                    Codes.UNAUTHORIZED,
-                )
-            except Exception as e:
-                logger.exception(
-                    "Got Exception when downloading keys for %s: %s %s",
-                    server_name, type(e).__name__, str(e.message),
-                )
-                raise SynapseError(
-                    401,
-                    "No key for %s with id %s" % (server_name, key_ids),
-                    Codes.UNAUTHORIZED,
-                )
+        run_in_background(self._start_key_lookups, verify_requests)
 
-            json_object = verify_request.json_object
+        # Pass those keys to handle_key_deferred so that the json object
+        # signatures can be verified
+        handle = preserve_fn(_handle_key_deferred)
+        return [
+            handle(rq) for rq in verify_requests
+        ]
 
-            try:
-                verify_signed_json(json_object, server_name, verify_key)
-            except:
-                raise SynapseError(
-                    401,
-                    "Invalid signature for server %s with key %s:%s" % (
-                        server_name, verify_key.alg, verify_key.version
-                    ),
-                    Codes.UNAUTHORIZED,
-                )
+    @defer.inlineCallbacks
+    def _start_key_lookups(self, verify_requests):
+        """Sets off the key fetches for each verify request
 
-        server_to_deferred = {
-            server_name: defer.Deferred()
-            for server_name, _ in server_and_json
-        }
+        Once each fetch completes, verify_request.deferred will be resolved.
 
-        with PreserveLoggingContext():
+        Args:
+            verify_requests (List[VerifyKeyRequest]):
+        """
+
+        try:
+            # create a deferred for each server we're going to look up the keys
+            # for; we'll resolve them once we have completed our lookups.
+            # These will be passed into wait_for_previous_lookups to block
+            # any other lookups until we have finished.
+            # The deferreds are called with no logcontext.
+            server_to_deferred = {
+                rq.server_name: defer.Deferred()
+                for rq in verify_requests
+            }
 
             # We want to wait for any previous lookups to complete before
             # proceeding.
-            wait_on_deferred = self.wait_for_previous_lookups(
-                [server_name for server_name, _ in server_and_json],
+            yield self.wait_for_previous_lookups(
+                [rq.server_name for rq in verify_requests],
                 server_to_deferred,
             )
 
             # Actually start fetching keys.
-            wait_on_deferred.addBoth(
-                lambda _: self.get_server_verify_keys(verify_requests)
-            )
+            self._get_server_verify_keys(verify_requests)
 
             # When we've finished fetching all the keys for a given server_name,
             # resolve the deferred passed to `wait_for_previous_lookups` so that
             # any lookups waiting will proceed.
+            #
+            # map from server name to a set of request ids
             server_to_request_ids = {}
 
-            def remove_deferreds(res, server_name, verify_request):
+            for verify_request in verify_requests:
+                server_name = verify_request.server_name
+                request_id = id(verify_request)
+                server_to_request_ids.setdefault(server_name, set()).add(request_id)
+
+            def remove_deferreds(res, verify_request):
+                server_name = verify_request.server_name
                 request_id = id(verify_request)
                 server_to_request_ids[server_name].discard(request_id)
                 if not server_to_request_ids[server_name]:
@@ -187,17 +191,11 @@ class Keyring(object):
                 return res
 
             for verify_request in verify_requests:
-                server_name = verify_request.server_name
-                request_id = id(verify_request)
-                server_to_request_ids.setdefault(server_name, set()).add(request_id)
-                deferred.addBoth(remove_deferreds, server_name, verify_request)
-
-        # Pass those keys to handle_key_deferred so that the json object
-        # signatures can be verified
-        return [
-            preserve_context_over_fn(handle_key_deferred, verify_request)
-            for verify_request in verify_requests
-        ]
+                verify_request.deferred.addBoth(
+                    remove_deferreds, verify_request,
+                )
+        except Exception:
+            logger.exception("Error starting key lookups")
 
     @defer.inlineCallbacks
     def wait_for_previous_lookups(self, server_names, server_to_deferred):
@@ -206,7 +204,13 @@ class Keyring(object):
         Args:
             server_names (list): list of server_names we want to lookup
             server_to_deferred (dict): server_name to deferred which gets
-                resolved once we've finished looking up keys for that server
+                resolved once we've finished looking up keys for that server.
+                The Deferreds should be regular twisted ones which call their
+                callbacks with no logcontext.
+
+        Returns: a Deferred which resolves once all key lookups for the given
+            servers have completed. Follows the synapse rules of logcontext
+            preservation.
         """
         while True:
             wait_on = [
@@ -220,19 +224,23 @@ class Keyring(object):
             else:
                 break
 
+        def rm(r, server_name_):
+            self.key_downloads.pop(server_name_, None)
+            return r
+
         for server_name, deferred in server_to_deferred.items():
-            d = ObservableDeferred(preserve_context_over_deferred(deferred))
-            self.key_downloads[server_name] = d
+            self.key_downloads[server_name] = deferred
+            deferred.addBoth(rm, server_name)
 
-            def rm(r, server_name):
-                self.key_downloads.pop(server_name, None)
-                return r
+    def _get_server_verify_keys(self, verify_requests):
+        """Tries to find at least one key for each verify request
 
-            d.addBoth(rm, server_name)
+        For each verify_request, verify_request.deferred is called back with
+        params (server_name, key_id, VerifyKey) if a key is found, or errbacked
+        with a SynapseError if none of the keys are found.
 
-    def get_server_verify_keys(self, verify_requests):
-        """Takes a dict of KeyGroups and tries to find at least one key for
-        each group.
+        Args:
+            verify_requests (list[VerifyKeyRequest]): list of verify requests
         """
 
         # These are functions that produce keys given a list of key ids
@@ -245,8 +253,11 @@ class Keyring(object):
         @defer.inlineCallbacks
         def do_iterations():
             with Measure(self.clock, "get_server_verify_keys"):
+                # dict[str, dict[str, VerifyKey]]: results so far.
+                # map server_name -> key_id -> VerifyKey
                 merged_results = {}
 
+                # dict[str, set(str)]: keys to fetch for each server
                 missing_keys = {}
                 for verify_request in verify_requests:
                     missing_keys.setdefault(verify_request.server_name, set()).update(
@@ -290,33 +301,46 @@ class Keyring(object):
                     if not missing_keys:
                         break
 
-                for verify_request in requests_missing_keys.values():
-                    verify_request.deferred.errback(SynapseError(
-                        401,
-                        "No key for %s with id %s" % (
-                            verify_request.server_name, verify_request.key_ids,
-                        ),
-                        Codes.UNAUTHORIZED,
-                    ))
+                with PreserveLoggingContext():
+                    for verify_request in requests_missing_keys:
+                        verify_request.deferred.errback(SynapseError(
+                            401,
+                            "No key for %s with id %s" % (
+                                verify_request.server_name, verify_request.key_ids,
+                            ),
+                            Codes.UNAUTHORIZED,
+                        ))
 
         def on_err(err):
-            for verify_request in verify_requests:
-                if not verify_request.deferred.called:
-                    verify_request.deferred.errback(err)
+            with PreserveLoggingContext():
+                for verify_request in verify_requests:
+                    if not verify_request.deferred.called:
+                        verify_request.deferred.errback(err)
 
-        do_iterations().addErrback(on_err)
+        run_in_background(do_iterations).addErrback(on_err)
 
     @defer.inlineCallbacks
     def get_keys_from_store(self, server_name_and_key_ids):
-        res = yield preserve_context_over_deferred(defer.gatherResults(
+        """
+
+        Args:
+            server_name_and_key_ids (list[(str, iterable[str])]):
+                list of (server_name, iterable[key_id]) tuples to fetch keys for
+
+        Returns:
+            Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
+                server_name -> key_id -> VerifyKey
+        """
+        res = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                preserve_fn(self.store.get_server_verify_keys)(
-                    server_name, key_ids
+                run_in_background(
+                    self.store.get_server_verify_keys,
+                    server_name, key_ids,
                 ).addCallback(lambda ks, server: (server, ks), server_name)
                 for server_name, key_ids in server_name_and_key_ids
             ],
             consumeErrors=True,
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         defer.returnValue(dict(res))
 
@@ -333,17 +357,17 @@ class Keyring(object):
                 logger.exception(
                     "Unable to get key from %r: %s %s",
                     perspective_name,
-                    type(e).__name__, str(e.message),
+                    type(e).__name__, str(e),
                 )
                 defer.returnValue({})
 
-        results = yield preserve_context_over_deferred(defer.gatherResults(
+        results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                preserve_fn(get_key)(p_name, p_keys)
+                run_in_background(get_key, p_name, p_keys)
                 for p_name, p_keys in self.perspective_servers.items()
             ],
             consumeErrors=True,
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         union_of_keys = {}
         for result in results:
@@ -356,40 +380,34 @@ class Keyring(object):
     def get_keys_from_server(self, server_name_and_key_ids):
         @defer.inlineCallbacks
         def get_key(server_name, key_ids):
-            limiter = yield get_retry_limiter(
-                server_name,
-                self.clock,
-                self.store,
-            )
-            with limiter:
-                keys = None
-                try:
-                    keys = yield self.get_server_verify_key_v2_direct(
-                        server_name, key_ids
-                    )
-                except Exception as e:
-                    logger.info(
-                        "Unable to get key %r for %r directly: %s %s",
-                        key_ids, server_name,
-                        type(e).__name__, str(e.message),
-                    )
+            keys = None
+            try:
+                keys = yield self.get_server_verify_key_v2_direct(
+                    server_name, key_ids
+                )
+            except Exception as e:
+                logger.info(
+                    "Unable to get key %r for %r directly: %s %s",
+                    key_ids, server_name,
+                    type(e).__name__, str(e),
+                )
 
-                if not keys:
-                    keys = yield self.get_server_verify_key_v1_direct(
-                        server_name, key_ids
-                    )
+            if not keys:
+                keys = yield self.get_server_verify_key_v1_direct(
+                    server_name, key_ids
+                )
 
-                    keys = {server_name: keys}
+                keys = {server_name: keys}
 
             defer.returnValue(keys)
 
-        results = yield preserve_context_over_deferred(defer.gatherResults(
+        results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                preserve_fn(get_key)(server_name, key_ids)
+                run_in_background(get_key, server_name, key_ids)
                 for server_name, key_ids in server_name_and_key_ids
             ],
             consumeErrors=True,
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         merged = {}
         for result in results:
@@ -466,9 +484,10 @@ class Keyring(object):
             for server_name, response_keys in processed_response.items():
                 keys.setdefault(server_name, {}).update(response_keys)
 
-        yield preserve_context_over_deferred(defer.gatherResults(
+        yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                preserve_fn(self.store_keys)(
+                run_in_background(
+                    self.store_keys,
                     server_name=server_name,
                     from_server=perspective_name,
                     verify_keys=response_keys,
@@ -476,7 +495,7 @@ class Keyring(object):
                 for server_name, response_keys in keys.items()
             ],
             consumeErrors=True
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         defer.returnValue(keys)
 
@@ -524,9 +543,10 @@ class Keyring(object):
 
             keys.update(response_keys)
 
-        yield preserve_context_over_deferred(defer.gatherResults(
+        yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                preserve_fn(self.store_keys)(
+                run_in_background(
+                    self.store_keys,
                     server_name=key_server_name,
                     from_server=server_name,
                     verify_keys=verify_keys,
@@ -534,7 +554,7 @@ class Keyring(object):
                 for key_server_name, verify_keys in keys.items()
             ],
             consumeErrors=True
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         defer.returnValue(keys)
 
@@ -600,9 +620,10 @@ class Keyring(object):
         response_keys.update(verify_keys)
         response_keys.update(old_verify_keys)
 
-        yield preserve_context_over_deferred(defer.gatherResults(
+        yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                preserve_fn(self.store.store_server_keys_json)(
+                run_in_background(
+                    self.store.store_server_keys_json,
                     server_name=server_name,
                     key_id=key_id,
                     from_server=server_name,
@@ -613,7 +634,7 @@ class Keyring(object):
                 for key_id in updated_key_ids
             ],
             consumeErrors=True,
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         results[server_name] = response_keys
 
@@ -691,7 +712,6 @@ class Keyring(object):
 
         defer.returnValue(verify_keys)
 
-    @defer.inlineCallbacks
     def store_keys(self, server_name, from_server, verify_keys):
         """Store a collection of verify keys for a given server
         Args:
@@ -702,12 +722,57 @@ class Keyring(object):
             A deferred that completes when the keys are stored.
         """
         # TODO(markjh): Store whether the keys have expired.
-        yield preserve_context_over_deferred(defer.gatherResults(
+        return logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                preserve_fn(self.store.store_server_verify_key)(
+                run_in_background(
+                    self.store.store_server_verify_key,
                     server_name, server_name, key.time_added, key
                 )
                 for key_id, key in verify_keys.items()
             ],
             consumeErrors=True,
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
+
+
+@defer.inlineCallbacks
+def _handle_key_deferred(verify_request):
+    server_name = verify_request.server_name
+    try:
+        with PreserveLoggingContext():
+            _, key_id, verify_key = yield verify_request.deferred
+    except IOError as e:
+        logger.warn(
+            "Got IOError when downloading keys for %s: %s %s",
+            server_name, type(e).__name__, str(e),
+        )
+        raise SynapseError(
+            502,
+            "Error downloading keys for %s" % (server_name,),
+            Codes.UNAUTHORIZED,
+        )
+    except Exception as e:
+        logger.exception(
+            "Got Exception when downloading keys for %s: %s %s",
+            server_name, type(e).__name__, str(e),
+        )
+        raise SynapseError(
+            401,
+            "No key for %s with id %s" % (server_name, verify_request.key_ids),
+            Codes.UNAUTHORIZED,
+        )
+
+    json_object = verify_request.json_object
+
+    logger.debug("Got key %s %s:%s for server %s, verifying" % (
+        key_id, verify_key.alg, verify_key.version, server_name,
+    ))
+    try:
+        verify_signed_json(json_object, server_name, verify_key)
+    except Exception:
+        raise SynapseError(
+            401,
+            "Invalid signature for server %s with key %s:%s" % (
+                server_name, verify_key.alg, verify_key.version
+            ),
+            Codes.UNAUTHORIZED,
+        )