diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 2b4faee4c1..8709394b97 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -15,7 +15,9 @@
from synapse.crypto.keyclient import fetch_server_key
from twisted.internet import defer
-from syutil.crypto.jsonsign import verify_signed_json, signature_ids
+from syutil.crypto.jsonsign import (
+ verify_signed_json, signature_ids, sign_json, encode_canonical_json
+)
from syutil.crypto.signing_key import (
is_signing_algorithm_supported, decode_verify_key_bytes
)
@@ -28,6 +30,8 @@ from synapse.util.async import create_observer
from OpenSSL import crypto
+import urllib
+import hashlib
import logging
@@ -38,6 +42,9 @@ class Keyring(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
+ self.client = hs.get_http_client()
+ self.config = hs.get_config()
+ self.perspective_servers = self.config.perspectives
self.hs = hs
self.key_downloads = {}
@@ -89,12 +96,11 @@ class Keyring(object):
@defer.inlineCallbacks
def get_server_verify_key(self, server_name, key_ids):
"""Finds a verification key for the server with one of the key ids.
+ Trys to fetch the key from a trusted perspective server first.
Args:
- server_name (str): The name of the server to fetch a key for.
+ server_name(str): The name of the server to fetch a key for.
keys_ids (list of str): The key_ids to check for.
"""
-
- # Check the datastore to see if we have one cached.
cached = yield self.store.get_server_verify_keys(server_name, key_ids)
if cached:
@@ -117,7 +123,29 @@ class Keyring(object):
@defer.inlineCallbacks
def _get_server_verify_key_impl(self, server_name, key_ids):
- # Try to fetch the key from the remote server.
+ keys = None
+
+ perspective_results = []
+ for perspective_name, perspective_keys in self.perspective_servers.items():
+ @defer.inlineCallbacks
+ def get_key():
+ try:
+ result = yield self.get_server_verify_key_v2_indirect(
+ server_name, key_ids, perspective_name, perspective_keys
+ )
+ defer.returnValue(result)
+ except:
+ logging.info(
+ "Unable to getting key %r for %r from %r",
+ key_ids, server_name, perspective_name,
+ )
+ perspective_results.append(get_key())
+
+ perspective_results = yield defer.gatherResults(perspective_results)
+
+ for results in perspective_results:
+ if results is not None:
+ keys = results
limiter = yield get_retry_limiter(
server_name,
@@ -126,10 +154,234 @@ class Keyring(object):
)
with limiter:
+ if keys is None:
+ try:
+ keys = yield self.get_server_verify_key_v2_direct(
+ server_name, key_ids
+ )
+ except:
+ pass
+
+ keys = yield self.get_server_verify_key_v1_direct(
+ server_name, key_ids
+ )
+
+ for key_id in key_ids:
+ if key_id in keys:
+ defer.returnValue(keys[key_id])
+ return
+ raise ValueError("No verification key found for given key ids")
+
+ @defer.inlineCallbacks
+ def get_server_verify_key_v2_indirect(self, server_name, key_ids,
+ perspective_name,
+ perspective_keys):
+ limiter = yield get_retry_limiter(
+ perspective_name, self.clock, self.store
+ )
+
+ with limiter:
+ # TODO(mark): Set the minimum_valid_until_ts to that needed by
+ # the events being validated or the current time if validating
+ # an incoming request.
+ responses = yield self.client.post_json(
+ destination=perspective_name,
+ path=b"/_matrix/key/v2/query",
+ data={
+ u"server_keys": {
+ server_name: {
+ key_id: {
+ u"minimum_valid_until_ts": 0
+ } for key_id in key_ids
+ }
+ }
+ },
+ )
+
+ keys = {}
+
+ for response in responses:
+ if (u"signatures" not in response
+ or perspective_name not in response[u"signatures"]):
+ raise ValueError(
+ "Key response not signed by perspective server"
+ " %r" % (perspective_name,)
+ )
+
+ verified = False
+ for key_id in response[u"signatures"][perspective_name]:
+ if key_id in perspective_keys:
+ verify_signed_json(
+ response,
+ perspective_name,
+ perspective_keys[key_id]
+ )
+ verified = True
+
+ if not verified:
+ logging.info(
+ "Response from perspective server %r not signed with a"
+ " known key, signed with: %r, known keys: %r",
+ perspective_name,
+ list(response[u"signatures"][perspective_name]),
+ list(perspective_keys)
+ )
+ raise ValueError(
+ "Response not signed with a known key for perspective"
+ " server %r" % (perspective_name,)
+ )
+
+ response_keys = yield self.process_v2_response(
+ server_name, perspective_name, response
+ )
+
+ keys.update(response_keys)
+
+ yield self.store_keys(
+ server_name=server_name,
+ from_server=perspective_name,
+ verify_keys=keys,
+ )
+
+ defer.returnValue(keys)
+
+ @defer.inlineCallbacks
+ def get_server_verify_key_v2_direct(self, server_name, key_ids):
+
+ keys = {}
+
+ for requested_key_id in key_ids:
+ if requested_key_id in keys:
+ continue
+
(response, tls_certificate) = yield fetch_server_key(
- server_name, self.hs.tls_context_factory
+ server_name, self.hs.tls_context_factory,
+ path=(b"/_matrix/key/v2/server/%s" % (
+ urllib.quote(requested_key_id),
+ )).encode("ascii"),
+ )
+
+ if (u"signatures" not in response
+ or server_name not in response[u"signatures"]):
+ raise ValueError("Key response not signed by remote server")
+
+ if "tls_fingerprints" not in response:
+ raise ValueError("Key response missing TLS fingerprints")
+
+ certificate_bytes = crypto.dump_certificate(
+ crypto.FILETYPE_ASN1, tls_certificate
+ )
+ sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
+ sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
+
+ response_sha256_fingerprints = set()
+ for fingerprint in response[u"tls_fingerprints"]:
+ if u"sha256" in fingerprint:
+ response_sha256_fingerprints.add(fingerprint[u"sha256"])
+
+ if sha256_fingerprint_b64 not in response_sha256_fingerprints:
+ raise ValueError("TLS certificate not allowed by fingerprints")
+
+ response_keys = yield self.process_v2_response(
+ server_name=server_name,
+ from_server=server_name,
+ requested_id=requested_key_id,
+ response_json=response,
+ )
+
+ keys.update(response_keys)
+
+ yield self.store_keys(
+ server_name=server_name,
+ from_server=server_name,
+ verify_keys=keys,
+ )
+
+ defer.returnValue(keys)
+
+ @defer.inlineCallbacks
+ def process_v2_response(self, server_name, from_server, response_json,
+ requested_id=None):
+ time_now_ms = self.clock.time_msec()
+ response_keys = {}
+ verify_keys = {}
+ for key_id, key_data in response_json["verify_keys"].items():
+ if is_signing_algorithm_supported(key_id):
+ key_base64 = key_data["key"]
+ key_bytes = decode_base64(key_base64)
+ verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_key.time_added = time_now_ms
+ verify_keys[key_id] = verify_key
+
+ old_verify_keys = {}
+ for key_id, key_data in response_json["old_verify_keys"].items():
+ if is_signing_algorithm_supported(key_id):
+ key_base64 = key_data["key"]
+ key_bytes = decode_base64(key_base64)
+ verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_key.expired = key_data["expired_ts"]
+ verify_key.time_added = time_now_ms
+ old_verify_keys[key_id] = verify_key
+
+ for key_id in response_json["signatures"][server_name]:
+ if key_id not in response_json["verify_keys"]:
+ raise ValueError(
+ "Key response must include verification keys for all"
+ " signatures"
+ )
+ if key_id in verify_keys:
+ verify_signed_json(
+ response_json,
+ server_name,
+ verify_keys[key_id]
+ )
+
+ signed_key_json = sign_json(
+ response_json,
+ self.config.server_name,
+ self.config.signing_key[0],
+ )
+
+ signed_key_json_bytes = encode_canonical_json(signed_key_json)
+ ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
+
+ updated_key_ids = set()
+ if requested_id is not None:
+ updated_key_ids.add(requested_id)
+ updated_key_ids.update(verify_keys)
+ updated_key_ids.update(old_verify_keys)
+
+ response_keys.update(verify_keys)
+ response_keys.update(old_verify_keys)
+
+ for key_id in updated_key_ids:
+ yield self.store.store_server_keys_json(
+ server_name=server_name,
+ key_id=key_id,
+ from_server=server_name,
+ ts_now_ms=time_now_ms,
+ ts_expires_ms=ts_valid_until_ms,
+ key_json_bytes=signed_key_json_bytes,
)
+ defer.returnValue(response_keys)
+
+ raise ValueError("No verification key found for given key ids")
+
+ @defer.inlineCallbacks
+ def get_server_verify_key_v1_direct(self, server_name, key_ids):
+ """Finds a verification key for the server with one of the key ids.
+ Args:
+ server_name (str): The name of the server to fetch a key for.
+ keys_ids (list of str): The key_ids to check for.
+ """
+
+ # Try to fetch the key from the remote server.
+
+ (response, tls_certificate) = yield fetch_server_key(
+ server_name, self.hs.tls_context_factory
+ )
+
# Check the response.
x509_certificate_bytes = crypto.dump_certificate(
@@ -148,11 +400,16 @@ class Keyring(object):
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
raise ValueError("TLS certificate doesn't match")
+ # Cache the result in the datastore.
+
+ time_now_ms = self.clock.time_msec()
+
verify_keys = {}
for key_id, key_base64 in response["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_key.time_added = time_now_ms
verify_keys[key_id] = verify_key
for key_id in response["signatures"][server_name]:
@@ -168,10 +425,6 @@ class Keyring(object):
verify_keys[key_id]
)
- # Cache the result in the datastore.
-
- time_now_ms = self.clock.time_msec()
-
yield self.store.store_server_certificate(
server_name,
server_name,
@@ -179,14 +432,26 @@ class Keyring(object):
tls_certificate,
)
+ yield self.store_keys(
+ server_name=server_name,
+ from_server=server_name,
+ verify_keys=verify_keys,
+ )
+
+ defer.returnValue(verify_keys)
+
+ @defer.inlineCallbacks
+ def store_keys(self, server_name, from_server, verify_keys):
+ """Store a collection of verify keys for a given server
+ Args:
+ server_name(str): The name of the server the keys are for.
+ from_server(str): The server the keys were downloaded from.
+ verify_keys(dict): A mapping of key_id to VerifyKey.
+ Returns:
+ A deferred that completes when the keys are stored.
+ """
for key_id, key in verify_keys.items():
+ # TODO(markjh): Store whether the keys have expired.
yield self.store.store_server_verify_key(
- server_name, server_name, time_now_ms, key
+ server_name, server_name, key.time_added, key
)
-
- for key_id in key_ids:
- if key_id in verify_keys:
- defer.returnValue(verify_keys[key_id])
- return
-
- raise ValueError("No verification key found for given key ids")
|