diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 30e2742102..515ebbc148 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2017 New Vector Ltd.
+# Copyright 2017, 2018 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,7 +16,6 @@
import hashlib
import logging
-import urllib
from collections import namedtuple
from signedjson.key import (
@@ -40,6 +39,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyclient import fetch_server_key
from synapse.util import logcontext, unwrapFirstError
from synapse.util.logcontext import (
+ LoggingContext,
PreserveLoggingContext,
preserve_fn,
run_in_background,
@@ -216,23 +216,34 @@ class Keyring(object):
servers have completed. Follows the synapse rules of logcontext
preservation.
"""
+ loop_count = 1
while True:
wait_on = [
- self.key_downloads[server_name]
+ (server_name, self.key_downloads[server_name])
for server_name in server_names
if server_name in self.key_downloads
]
- if wait_on:
- with PreserveLoggingContext():
- yield defer.DeferredList(wait_on)
- else:
+ if not wait_on:
break
+ logger.info(
+ "Waiting for existing lookups for %s to complete [loop %i]",
+ [w[0] for w in wait_on], loop_count,
+ )
+ with PreserveLoggingContext():
+ yield defer.DeferredList((w[1] for w in wait_on))
+
+ loop_count += 1
+
+ ctx = LoggingContext.current_context()
def rm(r, server_name_):
- self.key_downloads.pop(server_name_, None)
+ with PreserveLoggingContext(ctx):
+ logger.debug("Releasing key lookup lock on %s", server_name_)
+ self.key_downloads.pop(server_name_, None)
return r
for server_name, deferred in server_to_deferred.items():
+ logger.debug("Got key lookup lock on %s", server_name)
self.key_downloads[server_name] = deferred
deferred.addBoth(rm, server_name)
@@ -382,32 +393,13 @@ class Keyring(object):
@defer.inlineCallbacks
def get_keys_from_server(self, server_name_and_key_ids):
- @defer.inlineCallbacks
- def get_key(server_name, key_ids):
- keys = None
- try:
- keys = yield self.get_server_verify_key_v2_direct(
- server_name, key_ids
- )
- except Exception as e:
- logger.info(
- "Unable to get key %r for %r directly: %s %s",
- key_ids, server_name,
- type(e).__name__, str(e),
- )
-
- if not keys:
- keys = yield self.get_server_verify_key_v1_direct(
- server_name, key_ids
- )
-
- keys = {server_name: keys}
-
- defer.returnValue(keys)
-
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- run_in_background(get_key, server_name, key_ids)
+ run_in_background(
+ self.get_server_verify_key_v2_direct,
+ server_name,
+ key_ids,
+ )
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
@@ -432,7 +424,7 @@ class Keyring(object):
# an incoming request.
query_response = yield self.client.post_json(
destination=perspective_name,
- path=b"/_matrix/key/v2/query",
+ path="/_matrix/key/v2/query",
data={
u"server_keys": {
server_name: {
@@ -512,10 +504,7 @@ class Keyring(object):
continue
(response, tls_certificate) = yield fetch_server_key(
- server_name, self.hs.tls_client_options_factory,
- path=(b"/_matrix/key/v2/server/%s" % (
- urllib.quote(requested_key_id),
- )).encode("ascii"),
+ server_name, self.hs.tls_client_options_factory, requested_key_id
)
if (u"signatures" not in response
@@ -644,78 +633,6 @@ class Keyring(object):
defer.returnValue(results)
- @defer.inlineCallbacks
- def get_server_verify_key_v1_direct(self, server_name, key_ids):
- """Finds a verification key for the server with one of the key ids.
- Args:
- server_name (str): The name of the server to fetch a key for.
- keys_ids (list of str): The key_ids to check for.
- """
-
- # Try to fetch the key from the remote server.
-
- (response, tls_certificate) = yield fetch_server_key(
- server_name, self.hs.tls_client_options_factory
- )
-
- # Check the response.
-
- x509_certificate_bytes = crypto.dump_certificate(
- crypto.FILETYPE_ASN1, tls_certificate
- )
-
- if ("signatures" not in response
- or server_name not in response["signatures"]):
- raise KeyLookupError("Key response not signed by remote server")
-
- if "tls_certificate" not in response:
- raise KeyLookupError("Key response missing TLS certificate")
-
- tls_certificate_b64 = response["tls_certificate"]
-
- if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
- raise KeyLookupError("TLS certificate doesn't match")
-
- # Cache the result in the datastore.
-
- time_now_ms = self.clock.time_msec()
-
- verify_keys = {}
- for key_id, key_base64 in response["verify_keys"].items():
- if is_signing_algorithm_supported(key_id):
- key_bytes = decode_base64(key_base64)
- verify_key = decode_verify_key_bytes(key_id, key_bytes)
- verify_key.time_added = time_now_ms
- verify_keys[key_id] = verify_key
-
- for key_id in response["signatures"][server_name]:
- if key_id not in response["verify_keys"]:
- raise KeyLookupError(
- "Key response must include verification keys for all"
- " signatures"
- )
- if key_id in verify_keys:
- verify_signed_json(
- response,
- server_name,
- verify_keys[key_id]
- )
-
- yield self.store.store_server_certificate(
- server_name,
- server_name,
- time_now_ms,
- tls_certificate,
- )
-
- yield self.store_keys(
- server_name=server_name,
- from_server=server_name,
- verify_keys=verify_keys,
- )
-
- defer.returnValue(verify_keys)
-
def store_keys(self, server_name, from_server, verify_keys):
"""Store a collection of verify keys for a given server
Args:
|