summary refs log tree commit diff
path: root/synapse/crypto
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/crypto')
-rw-r--r--synapse/crypto/keyclient.py6
-rw-r--r--synapse/crypto/keyring.py75
2 files changed, 43 insertions, 38 deletions
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index 2452c7a26e..4911f0896b 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -26,7 +26,7 @@ import logging
 logger = logging.getLogger(__name__)
 
 KEY_API_V1 = b"/_matrix/key/v1/"
-KEY_API_V2 = b"/_matrix/key/v2/local"
+
 
 @defer.inlineCallbacks
 def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
@@ -94,8 +94,8 @@ class SynapseKeyClientProtocol(HTTPClient):
         if status != b"200":
             # logger.info("Non-200 response from %s: %s %s",
             #            self.transport.getHost(), status, message)
-            error = SynapseKeyClientError("Non-200 response %r from %r" %
-                (status, self.host)
+            error = SynapseKeyClientError(
+                "Non-200 response %r from %r" % (status, self.host)
             )
             error.status = status
             self.errback(error)
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 5528d0a280..17ac66731c 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -15,7 +15,9 @@
 
 from synapse.crypto.keyclient import fetch_server_key
 from twisted.internet import defer
-from syutil.crypto.jsonsign import verify_signed_json, signature_ids
+from syutil.crypto.jsonsign import (
+    verify_signed_json, signature_ids, sign_json, encode_canonical_json
+)
 from syutil.crypto.signing_key import (
     is_signing_algorithm_supported, decode_verify_key_bytes
 )
@@ -26,6 +28,8 @@ from synapse.util.retryutils import get_retry_limiter
 
 from OpenSSL import crypto
 
+import urllib
+import hashlib
 import logging
 
 
@@ -37,6 +41,7 @@ class Keyring(object):
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
         self.client = hs.get_http_client()
+        self.config = hs.get_config()
         self.perspective_servers = {}
         self.hs = hs
 
@@ -127,7 +132,6 @@ class Keyring(object):
                 server_name, key_ids
             )
 
-
         for key_id in key_ids:
             if key_id in keys:
                 defer.returnValue(keys[key_id])
@@ -142,17 +146,18 @@ class Keyring(object):
             perspective_name, self.clock, self.store
         )
 
-        responses = yield self.client.post_json(
-            destination=perspective_name,
-            path=b"/_matrix/key/v2/query",
-            data={u"server_keys": {server_name: list(key_ids)}},
-        )
+        with limiter:
+            responses = yield self.client.post_json(
+                destination=perspective_name,
+                path=b"/_matrix/key/v2/query",
+                data={u"server_keys": {server_name: list(key_ids)}},
+            )
 
-        keys = dict()
+        keys = {}
 
         for response in responses:
             if (u"signatures" not in response
-                or perspective_name not in response[u"signatures"]):
+                    or perspective_name not in response[u"signatures"]):
                 raise ValueError(
                     "Key response not signed by perspective server"
                     " %r" % (perspective_name,)
@@ -181,7 +186,9 @@ class Keyring(object):
                     " server %r" % (perspective_name,)
                 )
 
-            response_keys = process_v2_response(self, server_name, key_ids)
+            response_keys = yield self.process_v2_response(
+                server_name, perspective_name, response
+            )
 
             keys.update(response_keys)
 
@@ -202,15 +209,15 @@ class Keyring(object):
             if requested_key_id in keys:
                 continue
 
-            (response_json, tls_certificate) = yield fetch_server_key(
+            (response, tls_certificate) = yield fetch_server_key(
                 server_name, self.hs.tls_context_factory,
-                path="/_matrix/key/v2/server/%s" % (
+                path=(b"/_matrix/key/v2/server/%s" % (
                     urllib.quote(requested_key_id),
-                ),
+                )).encode("ascii"),
             )
 
             if (u"signatures" not in response
-                or server_name not in response[u"signatures"]):
+                    or server_name not in response[u"signatures"]):
                 raise ValueError("Key response not signed by remote server")
 
             if "tls_fingerprints" not in response:
@@ -223,17 +230,18 @@ class Keyring(object):
             sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
 
             response_sha256_fingerprints = set()
-            for fingerprint in response_json[u"tls_fingerprints"]:
+            for fingerprint in response[u"tls_fingerprints"]:
                 if u"sha256" in fingerprint:
                     response_sha256_fingerprints.add(fingerprint[u"sha256"])
 
-            if sha256_fingerprint not in response_sha256_fingerprints:
+            if sha256_fingerprint_b64 not in response_sha256_fingerprints:
                 raise ValueError("TLS certificate not allowed by fingerprints")
 
             response_keys = yield self.process_v2_response(
                 server_name=server_name,
                 from_server=server_name,
-                response_json=response_json,
+                requested_id=requested_key_id,
+                response_json=response,
             )
 
             keys.update(response_keys)
@@ -244,19 +252,15 @@ class Keyring(object):
             verify_keys=keys,
         )
 
-        for key_id in key_ids:
-            if key_id in verify_keys:
-                defer.returnValue(verify_keys[key_id])
-                return
-
-        raise ValueError("No verification key found for given key ids")
+        defer.returnValue(keys)
 
     @defer.inlineCallbacks
-    def process_v2_response(self, server_name, from_server, json_response):
-        time_now_ms = clock.time_msec()
+    def process_v2_response(self, server_name, from_server, response_json,
+                            requested_id=None):
+        time_now_ms = self.clock.time_msec()
         response_keys = {}
         verify_keys = {}
-        for key_id, key_data in response["verify_keys"].items():
+        for key_id, key_data in response_json["verify_keys"].items():
             if is_signing_algorithm_supported(key_id):
                 key_base64 = key_data["key"]
                 key_bytes = decode_base64(key_base64)
@@ -264,7 +268,7 @@ class Keyring(object):
                 verify_keys[key_id] = verify_key
 
         old_verify_keys = {}
-        for key_id, key_data in response["verify_keys"].items():
+        for key_id, key_data in response_json["old_verify_keys"].items():
             if is_signing_algorithm_supported(key_id):
                 key_base64 = key_data["key"]
                 key_bytes = decode_base64(key_base64)
@@ -273,21 +277,21 @@ class Keyring(object):
                 verify_key.time_added = time_now_ms
                 old_verify_keys[key_id] = verify_key
 
-        for key_id in response["signatures"][server_name]:
-            if key_id not in response["verify_keys"]:
+        for key_id in response_json["signatures"][server_name]:
+            if key_id not in response_json["verify_keys"]:
                 raise ValueError(
                     "Key response must include verification keys for all"
                     " signatures"
                 )
             if key_id in verify_keys:
                 verify_signed_json(
-                    response,
+                    response_json,
                     server_name,
                     verify_keys[key_id]
                 )
 
         signed_key_json = sign_json(
-            response,
+            response_json,
             self.config.server_name,
             self.config.signing_key[0],
         )
@@ -295,7 +299,9 @@ class Keyring(object):
         signed_key_json_bytes = encode_canonical_json(signed_key_json)
         ts_valid_until_ms = signed_key_json[u"valid_until"]
 
-        updated_key_ids = set([requested_key_id])
+        updated_key_ids = set()
+        if requested_id is not None:
+            updated_key_ids.add(requested_id)
         updated_key_ids.update(verify_keys)
         updated_key_ids.update(old_verify_keys)
 
@@ -307,8 +313,8 @@ class Keyring(object):
                 server_name=server_name,
                 key_id=key_id,
                 from_server=server_name,
-                ts_now_ms=ts_now_ms,
-                ts_valid_until_ms=valid_until,
+                ts_now_ms=time_now_ms,
+                ts_expires_ms=ts_valid_until_ms,
                 key_json_bytes=signed_key_json_bytes,
             )
 
@@ -373,7 +379,6 @@ class Keyring(object):
                     verify_keys[key_id]
                 )
 
-
         yield self.store.store_server_certificate(
             server_name,
             server_name,