summary refs log tree commit diff
path: root/synapse/crypto
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/crypto')
-rw-r--r--synapse/crypto/keyclient.py17
-rw-r--r--synapse/crypto/keyring.py223
2 files changed, 131 insertions, 109 deletions
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index c2bd64d6c2..f1fd488b90 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -13,14 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
+from synapse.util import logcontext
 from twisted.web.http import HTTPClient
 from twisted.internet.protocol import Factory
 from twisted.internet import defer, reactor
 from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util.logcontext import (
-    preserve_context_over_fn, preserve_context_over_deferred
-)
 import simplejson as json
 import logging
 
@@ -43,14 +40,10 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
 
     for i in range(5):
         try:
-            protocol = yield preserve_context_over_fn(
-                endpoint.connect, factory
-            )
-            server_response, server_certificate = yield preserve_context_over_deferred(
-                protocol.remote_key
-            )
-            defer.returnValue((server_response, server_certificate))
-            return
+            with logcontext.PreserveLoggingContext():
+                protocol = yield endpoint.connect(factory)
+                server_response, server_certificate = yield protocol.remote_key
+                defer.returnValue((server_response, server_certificate))
         except SynapseKeyClientError as e:
             logger.exception("Error getting key for %r" % (server_name,))
             if e.status.startswith("4"):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 51851d04e5..054bac456d 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -18,7 +18,7 @@ from synapse.crypto.keyclient import fetch_server_key
 from synapse.api.errors import SynapseError, Codes
 from synapse.util import unwrapFirstError, logcontext
 from synapse.util.logcontext import (
-    preserve_context_over_fn, PreserveLoggingContext,
+    PreserveLoggingContext,
     preserve_fn
 )
 from synapse.util.metrics import Measure
@@ -57,7 +57,8 @@ Attributes:
     json_object(dict): The JSON object to verify.
     deferred(twisted.internet.defer.Deferred):
         A deferred (server_name, key_id, verify_key) tuple that resolves when
-        a verify key has been fetched
+        a verify key has been fetched. The deferreds' callbacks are run with no
+        logcontext.
 """
 
 
@@ -82,9 +83,11 @@ class Keyring(object):
         self.key_downloads = {}
 
     def verify_json_for_server(self, server_name, json_object):
-        return self.verify_json_objects_for_server(
-            [(server_name, json_object)]
-        )[0]
+        return logcontext.make_deferred_yieldable(
+            self.verify_json_objects_for_server(
+                [(server_name, json_object)]
+            )[0]
+        )
 
     def verify_json_objects_for_server(self, server_and_json):
         """Bulk verifies signatures of json objects, bulk fetching keys as
@@ -94,8 +97,10 @@ class Keyring(object):
             server_and_json (list): List of pairs of (server_name, json_object)
 
         Returns:
-            list of deferreds indicating success or failure to verify each
-            json object's signature for the given server_name.
+            List<Deferred>: for each input pair, a deferred indicating success
+                or failure to verify each json object's signature for the given
+                server_name. The deferreds run their callbacks in the sentinel
+                logcontext.
         """
         verify_requests = []
 
@@ -122,93 +127,71 @@ class Keyring(object):
 
             verify_requests.append(verify_request)
 
-        @defer.inlineCallbacks
-        def handle_key_deferred(verify_request):
-            server_name = verify_request.server_name
-            try:
-                _, key_id, verify_key = yield verify_request.deferred
-            except IOError as e:
-                logger.warn(
-                    "Got IOError when downloading keys for %s: %s %s",
-                    server_name, type(e).__name__, str(e.message),
-                )
-                raise SynapseError(
-                    502,
-                    "Error downloading keys for %s" % (server_name,),
-                    Codes.UNAUTHORIZED,
-                )
-            except Exception as e:
-                logger.exception(
-                    "Got Exception when downloading keys for %s: %s %s",
-                    server_name, type(e).__name__, str(e.message),
-                )
-                raise SynapseError(
-                    401,
-                    "No key for %s with id %s" % (server_name, key_ids),
-                    Codes.UNAUTHORIZED,
-                )
+        preserve_fn(self._start_key_lookups)(verify_requests)
 
-            json_object = verify_request.json_object
+        # Pass those keys to handle_key_deferred so that the json object
+        # signatures can be verified
+        handle = preserve_fn(_handle_key_deferred)
+        return [
+            handle(rq) for rq in verify_requests
+        ]
 
-            logger.debug("Got key %s %s:%s for server %s, verifying" % (
-                key_id, verify_key.alg, verify_key.version, server_name,
-            ))
-            try:
-                verify_signed_json(json_object, server_name, verify_key)
-            except:
-                raise SynapseError(
-                    401,
-                    "Invalid signature for server %s with key %s:%s" % (
-                        server_name, verify_key.alg, verify_key.version
-                    ),
-                    Codes.UNAUTHORIZED,
-                )
+    @defer.inlineCallbacks
+    def _start_key_lookups(self, verify_requests):
+        """Sets off the key fetches for each verify request
+
+        Once each fetch completes, verify_request.deferred will be resolved.
+
+        Args:
+            verify_requests (List[VerifyKeyRequest]):
+        """
 
+        # create a deferred for each server we're going to look up the keys
+        # for; we'll resolve them once we have completed our lookups.
+        # These will be passed into wait_for_previous_lookups to block
+        # any other lookups until we have finished.
+        # The deferreds are called with no logcontext.
         server_to_deferred = {
-            server_name: defer.Deferred()
-            for server_name, _ in server_and_json
+            rq.server_name: defer.Deferred()
+            for rq in verify_requests
         }
 
-        with PreserveLoggingContext():
+        # We want to wait for any previous lookups to complete before
+        # proceeding.
+        yield self.wait_for_previous_lookups(
+            [rq.server_name for rq in verify_requests],
+            server_to_deferred,
+        )
 
-            # We want to wait for any previous lookups to complete before
-            # proceeding.
-            wait_on_deferred = self.wait_for_previous_lookups(
-                [server_name for server_name, _ in server_and_json],
-                server_to_deferred,
-            )
+        # Actually start fetching keys.
+        self._get_server_verify_keys(verify_requests)
 
-            # Actually start fetching keys.
-            wait_on_deferred.addBoth(
-                lambda _: self.get_server_verify_keys(verify_requests)
-            )
+        # When we've finished fetching all the keys for a given server_name,
+        # resolve the deferred passed to `wait_for_previous_lookups` so that
+        # any lookups waiting will proceed.
+        #
+        # map from server name to a set of request ids
+        server_to_request_ids = {}
 
-            # When we've finished fetching all the keys for a given server_name,
-            # resolve the deferred passed to `wait_for_previous_lookups` so that
-            # any lookups waiting will proceed.
-            server_to_request_ids = {}
-
-            def remove_deferreds(res, server_name, verify_request):
-                request_id = id(verify_request)
-                server_to_request_ids[server_name].discard(request_id)
-                if not server_to_request_ids[server_name]:
-                    d = server_to_deferred.pop(server_name, None)
-                    if d:
-                        d.callback(None)
-                return res
-
-            for verify_request in verify_requests:
-                server_name = verify_request.server_name
-                request_id = id(verify_request)
-                server_to_request_ids.setdefault(server_name, set()).add(request_id)
-                deferred.addBoth(remove_deferreds, server_name, verify_request)
+        for verify_request in verify_requests:
+            server_name = verify_request.server_name
+            request_id = id(verify_request)
+            server_to_request_ids.setdefault(server_name, set()).add(request_id)
 
-        # Pass those keys to handle_key_deferred so that the json object
-        # signatures can be verified
-        return [
-            preserve_context_over_fn(handle_key_deferred, verify_request)
-            for verify_request in verify_requests
-        ]
+        def remove_deferreds(res, verify_request):
+            server_name = verify_request.server_name
+            request_id = id(verify_request)
+            server_to_request_ids[server_name].discard(request_id)
+            if not server_to_request_ids[server_name]:
+                d = server_to_deferred.pop(server_name, None)
+                if d:
+                    d.callback(None)
+            return res
+
+        for verify_request in verify_requests:
+            verify_request.deferred.addBoth(
+                remove_deferreds, verify_request,
+            )
 
     @defer.inlineCallbacks
     def wait_for_previous_lookups(self, server_names, server_to_deferred):
@@ -245,7 +228,7 @@ class Keyring(object):
             self.key_downloads[server_name] = deferred
             deferred.addBoth(rm, server_name)
 
-    def get_server_verify_keys(self, verify_requests):
+    def _get_server_verify_keys(self, verify_requests):
         """Tries to find at least one key for each verify request
 
         For each verify_request, verify_request.deferred is called back with
@@ -314,21 +297,23 @@ class Keyring(object):
                     if not missing_keys:
                         break
 
-                for verify_request in requests_missing_keys.values():
-                    verify_request.deferred.errback(SynapseError(
-                        401,
-                        "No key for %s with id %s" % (
-                            verify_request.server_name, verify_request.key_ids,
-                        ),
-                        Codes.UNAUTHORIZED,
-                    ))
+                with PreserveLoggingContext():
+                    for verify_request in requests_missing_keys:
+                        verify_request.deferred.errback(SynapseError(
+                            401,
+                            "No key for %s with id %s" % (
+                                verify_request.server_name, verify_request.key_ids,
+                            ),
+                            Codes.UNAUTHORIZED,
+                        ))
 
         def on_err(err):
-            for verify_request in verify_requests:
-                if not verify_request.deferred.called:
-                    verify_request.deferred.errback(err)
+            with PreserveLoggingContext():
+                for verify_request in verify_requests:
+                    if not verify_request.deferred.called:
+                        verify_request.deferred.errback(err)
 
-        do_iterations().addErrback(on_err)
+        preserve_fn(do_iterations)().addErrback(on_err)
 
     @defer.inlineCallbacks
     def get_keys_from_store(self, server_name_and_key_ids):
@@ -738,3 +723,47 @@ class Keyring(object):
             ],
             consumeErrors=True,
         ).addErrback(unwrapFirstError))
+
+
+@defer.inlineCallbacks
+def _handle_key_deferred(verify_request):
+    server_name = verify_request.server_name
+    try:
+        with PreserveLoggingContext():
+            _, key_id, verify_key = yield verify_request.deferred
+    except IOError as e:
+        logger.warn(
+            "Got IOError when downloading keys for %s: %s %s",
+            server_name, type(e).__name__, str(e.message),
+        )
+        raise SynapseError(
+            502,
+            "Error downloading keys for %s" % (server_name,),
+            Codes.UNAUTHORIZED,
+        )
+    except Exception as e:
+        logger.exception(
+            "Got Exception when downloading keys for %s: %s %s",
+            server_name, type(e).__name__, str(e.message),
+        )
+        raise SynapseError(
+            401,
+            "No key for %s with id %s" % (server_name, verify_request.key_ids),
+            Codes.UNAUTHORIZED,
+        )
+
+    json_object = verify_request.json_object
+
+    logger.debug("Got key %s %s:%s for server %s, verifying" % (
+        key_id, verify_key.alg, verify_key.version, server_name,
+    ))
+    try:
+        verify_signed_json(json_object, server_name, verify_key)
+    except:
+        raise SynapseError(
+            401,
+            "Invalid signature for server %s with key %s:%s" % (
+                server_name, verify_key.alg, verify_key.version
+            ),
+            Codes.UNAUTHORIZED,
+        )