summary refs log tree commit diff
path: root/synapse/crypto
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/crypto')
-rw-r--r--synapse/crypto/keyclient.py3
-rw-r--r--synapse/crypto/keyring.py83
2 files changed, 49 insertions, 37 deletions
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index 784d02f122..54b83da9d8 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -36,6 +36,7 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
 
     factory = SynapseKeyClientFactory()
     factory.path = path
+    factory.host = server_name
     endpoint = matrix_federation_endpoint(
         reactor, server_name, ssl_context_factory, timeout=30
     )
@@ -81,6 +82,8 @@ class SynapseKeyClientProtocol(HTTPClient):
         self.host = self.transport.getHost()
         logger.debug("Connected to %s", self.host)
         self.sendCommand(b"GET", self.path)
+        if self.host:
+            self.sendHeader(b"Host", self.host)
         self.endHeaders()
         self.timer = reactor.callLater(
             self.timeout,
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index cddec0b2bc..d08ee0aa91 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -18,6 +18,10 @@ from synapse.api.errors import SynapseError, Codes
 from synapse.util.retryutils import get_retry_limiter
 from synapse.util import unwrapFirstError
 from synapse.util.async import ObservableDeferred
+from synapse.util.logcontext import (
+    preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
+    preserve_fn
+)
 
 from twisted.internet import defer
 
@@ -142,40 +146,43 @@ class Keyring(object):
             for server_name, _ in server_and_json
         }
 
-        # We want to wait for any previous lookups to complete before
-        # proceeding.
-        wait_on_deferred = self.wait_for_previous_lookups(
-            [server_name for server_name, _ in server_and_json],
-            server_to_deferred,
-        )
+        with PreserveLoggingContext():
 
-        # Actually start fetching keys.
-        wait_on_deferred.addBoth(
-            lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
-        )
+            # We want to wait for any previous lookups to complete before
+            # proceeding.
+            wait_on_deferred = self.wait_for_previous_lookups(
+                [server_name for server_name, _ in server_and_json],
+                server_to_deferred,
+            )
 
-        # When we've finished fetching all the keys for a given server_name,
-        # resolve the deferred passed to `wait_for_previous_lookups` so that
-        # any lookups waiting will proceed.
-        server_to_gids = {}
+            # Actually start fetching keys.
+            wait_on_deferred.addBoth(
+                lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
+            )
+
+            # When we've finished fetching all the keys for a given server_name,
+            # resolve the deferred passed to `wait_for_previous_lookups` so that
+            # any lookups waiting will proceed.
+            server_to_gids = {}
 
-        def remove_deferreds(res, server_name, group_id):
-            server_to_gids[server_name].discard(group_id)
-            if not server_to_gids[server_name]:
-                d = server_to_deferred.pop(server_name, None)
-                if d:
-                    d.callback(None)
-            return res
+            def remove_deferreds(res, server_name, group_id):
+                server_to_gids[server_name].discard(group_id)
+                if not server_to_gids[server_name]:
+                    d = server_to_deferred.pop(server_name, None)
+                    if d:
+                        d.callback(None)
+                return res
 
-        for g_id, deferred in deferreds.items():
-            server_name = group_id_to_group[g_id].server_name
-            server_to_gids.setdefault(server_name, set()).add(g_id)
-            deferred.addBoth(remove_deferreds, server_name, g_id)
+            for g_id, deferred in deferreds.items():
+                server_name = group_id_to_group[g_id].server_name
+                server_to_gids.setdefault(server_name, set()).add(g_id)
+                deferred.addBoth(remove_deferreds, server_name, g_id)
 
         # Pass those keys to handle_key_deferred so that the json object
         # signatures can be verified
         return [
-            handle_key_deferred(
+            preserve_context_over_fn(
+                handle_key_deferred,
                 group_id_to_group[g_id],
                 deferreds[g_id],
             )
@@ -198,12 +205,13 @@ class Keyring(object):
                 if server_name in self.key_downloads
             ]
             if wait_on:
-                yield defer.DeferredList(wait_on)
+                with PreserveLoggingContext():
+                    yield defer.DeferredList(wait_on)
             else:
                 break
 
         for server_name, deferred in server_to_deferred.items():
-            d = ObservableDeferred(deferred)
+            d = ObservableDeferred(preserve_context_over_deferred(deferred))
             self.key_downloads[server_name] = d
 
             def rm(r, server_name):
@@ -244,12 +252,13 @@ class Keyring(object):
                 for group in group_id_to_group.values():
                     for key_id in group.key_ids:
                         if key_id in merged_results[group.server_name]:
-                            group_id_to_deferred[group.group_id].callback((
-                                group.group_id,
-                                group.server_name,
-                                key_id,
-                                merged_results[group.server_name][key_id],
-                            ))
+                            with PreserveLoggingContext():
+                                group_id_to_deferred[group.group_id].callback((
+                                    group.group_id,
+                                    group.server_name,
+                                    key_id,
+                                    merged_results[group.server_name][key_id],
+                                ))
                             break
                     else:
                         missing_groups.setdefault(
@@ -504,7 +513,7 @@ class Keyring(object):
 
         yield defer.gatherResults(
             [
-                self.store_keys(
+                preserve_fn(self.store_keys)(
                     server_name=key_server_name,
                     from_server=server_name,
                     verify_keys=verify_keys,
@@ -573,7 +582,7 @@ class Keyring(object):
 
         yield defer.gatherResults(
             [
-                self.store.store_server_keys_json(
+                preserve_fn(self.store.store_server_keys_json)(
                     server_name=server_name,
                     key_id=key_id,
                     from_server=server_name,
@@ -675,7 +684,7 @@ class Keyring(object):
         # TODO(markjh): Store whether the keys have expired.
         yield defer.gatherResults(
             [
-                self.store.store_server_verify_key(
+                preserve_fn(self.store.store_server_verify_key)(
                     server_name, server_name, key.time_added, key
                 )
                 for key_id, key in verify_keys.items()