summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/config/server.py4
-rw-r--r--synapse/crypto/keyclient.py6
-rw-r--r--synapse/crypto/keyring.py75
-rw-r--r--synapse/rest/key/v2/__init__.py10
-rw-r--r--synapse/rest/key/v2/local_key_resource.py9
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py174
-rw-r--r--synapse/storage/keys.py35
7 files changed, 252 insertions, 61 deletions
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 050ab90403..a26fb115f2 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -62,7 +62,7 @@ class ServerConfig(Config):
         server_group.add_argument("--old-signing-key-path",
                                   help="The old signing keys")
         server_group.add_argument("--key-refresh-interval",
-                                  default=24 * 60 * 60 * 1000, # 1 Day
+                                  default=24 * 60 * 60 * 1000,  # 1 Day
                                   help="How long a key response is valid for."
                                        " Used to set the exipiry in /key/v2/."
                                        " Controls how frequently servers will"
@@ -156,5 +156,5 @@ class ServerConfig(Config):
             args.old_signing_key_path = base_key_name + ".old.signing.keys"
 
         if not os.path.exists(args.old_signing_key_path):
-            with open(args.old_signing_key_path, "w") as old_signing_key_file:
+            with open(args.old_signing_key_path, "w"):
                 pass
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index 2452c7a26e..4911f0896b 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -26,7 +26,7 @@ import logging
 logger = logging.getLogger(__name__)
 
 KEY_API_V1 = b"/_matrix/key/v1/"
-KEY_API_V2 = b"/_matrix/key/v2/local"
+
 
 @defer.inlineCallbacks
 def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
@@ -94,8 +94,8 @@ class SynapseKeyClientProtocol(HTTPClient):
         if status != b"200":
             # logger.info("Non-200 response from %s: %s %s",
             #            self.transport.getHost(), status, message)
-            error = SynapseKeyClientError("Non-200 response %r from %r" %
-                (status, self.host)
+            error = SynapseKeyClientError(
+                "Non-200 response %r from %r" % (status, self.host)
             )
             error.status = status
             self.errback(error)
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 5528d0a280..17ac66731c 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -15,7 +15,9 @@
 
 from synapse.crypto.keyclient import fetch_server_key
 from twisted.internet import defer
-from syutil.crypto.jsonsign import verify_signed_json, signature_ids
+from syutil.crypto.jsonsign import (
+    verify_signed_json, signature_ids, sign_json, encode_canonical_json
+)
 from syutil.crypto.signing_key import (
     is_signing_algorithm_supported, decode_verify_key_bytes
 )
@@ -26,6 +28,8 @@ from synapse.util.retryutils import get_retry_limiter
 
 from OpenSSL import crypto
 
+import urllib
+import hashlib
 import logging
 
 
@@ -37,6 +41,7 @@ class Keyring(object):
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
         self.client = hs.get_http_client()
+        self.config = hs.get_config()
         self.perspective_servers = {}
         self.hs = hs
 
@@ -127,7 +132,6 @@ class Keyring(object):
                 server_name, key_ids
             )
 
-
         for key_id in key_ids:
             if key_id in keys:
                 defer.returnValue(keys[key_id])
@@ -142,17 +146,18 @@ class Keyring(object):
             perspective_name, self.clock, self.store
         )
 
-        responses = yield self.client.post_json(
-            destination=perspective_name,
-            path=b"/_matrix/key/v2/query",
-            data={u"server_keys": {server_name: list(key_ids)}},
-        )
+        with limiter:
+            responses = yield self.client.post_json(
+                destination=perspective_name,
+                path=b"/_matrix/key/v2/query",
+                data={u"server_keys": {server_name: list(key_ids)}},
+            )
 
-        keys = dict()
+        keys = {}
 
         for response in responses:
             if (u"signatures" not in response
-                or perspective_name not in response[u"signatures"]):
+                    or perspective_name not in response[u"signatures"]):
                 raise ValueError(
                     "Key response not signed by perspective server"
                     " %r" % (perspective_name,)
@@ -181,7 +186,9 @@ class Keyring(object):
                     " server %r" % (perspective_name,)
                 )
 
-            response_keys = process_v2_response(self, server_name, key_ids)
+            response_keys = yield self.process_v2_response(
+                server_name, perspective_name, response
+            )
 
             keys.update(response_keys)
 
@@ -202,15 +209,15 @@ class Keyring(object):
             if requested_key_id in keys:
                 continue
 
-            (response_json, tls_certificate) = yield fetch_server_key(
+            (response, tls_certificate) = yield fetch_server_key(
                 server_name, self.hs.tls_context_factory,
-                path="/_matrix/key/v2/server/%s" % (
+                path=(b"/_matrix/key/v2/server/%s" % (
                     urllib.quote(requested_key_id),
-                ),
+                )).encode("ascii"),
             )
 
             if (u"signatures" not in response
-                or server_name not in response[u"signatures"]):
+                    or server_name not in response[u"signatures"]):
                 raise ValueError("Key response not signed by remote server")
 
             if "tls_fingerprints" not in response:
@@ -223,17 +230,18 @@ class Keyring(object):
             sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
 
             response_sha256_fingerprints = set()
-            for fingerprint in response_json[u"tls_fingerprints"]:
+            for fingerprint in response[u"tls_fingerprints"]:
                 if u"sha256" in fingerprint:
                     response_sha256_fingerprints.add(fingerprint[u"sha256"])
 
-            if sha256_fingerprint not in response_sha256_fingerprints:
+            if sha256_fingerprint_b64 not in response_sha256_fingerprints:
                 raise ValueError("TLS certificate not allowed by fingerprints")
 
             response_keys = yield self.process_v2_response(
                 server_name=server_name,
                 from_server=server_name,
-                response_json=response_json,
+                requested_id=requested_key_id,
+                response_json=response,
             )
 
             keys.update(response_keys)
@@ -244,19 +252,15 @@ class Keyring(object):
             verify_keys=keys,
         )
 
-        for key_id in key_ids:
-            if key_id in verify_keys:
-                defer.returnValue(verify_keys[key_id])
-                return
-
-        raise ValueError("No verification key found for given key ids")
+        defer.returnValue(keys)
 
     @defer.inlineCallbacks
-    def process_v2_response(self, server_name, from_server, json_response):
-        time_now_ms = clock.time_msec()
+    def process_v2_response(self, server_name, from_server, response_json,
+                            requested_id=None):
+        time_now_ms = self.clock.time_msec()
         response_keys = {}
         verify_keys = {}
-        for key_id, key_data in response["verify_keys"].items():
+        for key_id, key_data in response_json["verify_keys"].items():
             if is_signing_algorithm_supported(key_id):
                 key_base64 = key_data["key"]
                 key_bytes = decode_base64(key_base64)
@@ -264,7 +268,7 @@ class Keyring(object):
                 verify_keys[key_id] = verify_key
 
         old_verify_keys = {}
-        for key_id, key_data in response["verify_keys"].items():
+        for key_id, key_data in response_json["old_verify_keys"].items():
             if is_signing_algorithm_supported(key_id):
                 key_base64 = key_data["key"]
                 key_bytes = decode_base64(key_base64)
@@ -273,21 +277,21 @@ class Keyring(object):
                 verify_key.time_added = time_now_ms
                 old_verify_keys[key_id] = verify_key
 
-        for key_id in response["signatures"][server_name]:
-            if key_id not in response["verify_keys"]:
+        for key_id in response_json["signatures"][server_name]:
+            if key_id not in response_json["verify_keys"]:
                 raise ValueError(
                     "Key response must include verification keys for all"
                     " signatures"
                 )
             if key_id in verify_keys:
                 verify_signed_json(
-                    response,
+                    response_json,
                     server_name,
                     verify_keys[key_id]
                 )
 
         signed_key_json = sign_json(
-            response,
+            response_json,
             self.config.server_name,
             self.config.signing_key[0],
         )
@@ -295,7 +299,9 @@ class Keyring(object):
         signed_key_json_bytes = encode_canonical_json(signed_key_json)
         ts_valid_until_ms = signed_key_json[u"valid_until"]
 
-        updated_key_ids = set([requested_key_id])
+        updated_key_ids = set()
+        if requested_id is not None:
+            updated_key_ids.add(requested_id)
         updated_key_ids.update(verify_keys)
         updated_key_ids.update(old_verify_keys)
 
@@ -307,8 +313,8 @@ class Keyring(object):
                 server_name=server_name,
                 key_id=key_id,
                 from_server=server_name,
-                ts_now_ms=ts_now_ms,
-                ts_valid_until_ms=valid_until,
+                ts_now_ms=time_now_ms,
+                ts_expires_ms=ts_valid_until_ms,
                 key_json_bytes=signed_key_json_bytes,
             )
 
@@ -373,7 +379,6 @@ class Keyring(object):
                     verify_keys[key_id]
                 )
 
-
         yield self.store.store_server_certificate(
             server_name,
             server_name,
diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py
index b79ed02590..1c14791b09 100644
--- a/synapse/rest/key/v2/__init__.py
+++ b/synapse/rest/key/v2/__init__.py
@@ -13,7 +13,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.web.resource import Resource
 from .local_key_resource import LocalKey
+from .remote_key_resource import RemoteKey
 
-class KeyApiV2Resource(LocalKey):
-    pass
+
+class KeyApiV2Resource(Resource):
+    def __init__(self, hs):
+        Resource.__init__(self)
+        self.putChild("server", LocalKey(hs))
+        self.putChild("query", RemoteKey(hs))
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index 1c0e0717c1..982a460962 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -31,7 +31,7 @@ class LocalKey(Resource):
     """HTTP resource containing encoding the TLS X.509 certificate and NACL
     signature verification keys for this server::
 
-        GET /_matrix/key/v2/ HTTP/1.1
+        GET /_matrix/key/v2/server/a.key.id HTTP/1.1
 
         HTTP/1.1 200 OK
         Content-Type: application/json
@@ -56,6 +56,8 @@ class LocalKey(Resource):
         }
     """
 
+    isLeaf = True
+
     def __init__(self, hs):
         self.version_string = hs.version_string
         self.config = hs.config
@@ -68,7 +70,6 @@ class LocalKey(Resource):
         self.expires = int(time_now_msec + refresh_interval)
         self.response_body = encode_canonical_json(self.response_json_object())
 
-
     def response_json_object(self):
         verify_keys = {}
         for key in self.config.signing_key:
@@ -120,7 +121,3 @@ class LocalKey(Resource):
             request, 200, self.response_body,
             version_string=self.version_string
         )
-
-    def getChild(self, name, request):
-        if name == '':
-            return self
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
new file mode 100644
index 0000000000..cf6f2c2e73
--- /dev/null
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -0,0 +1,174 @@
+from synapse.http.server import request_handler, respond_with_json_bytes
+from synapse.api.errors import SynapseError, Codes
+
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+from twisted.internet import defer
+
+
+from io import BytesIO
+import json
+import logging
+logger = logging.getLogger(__name__)
+
+
+class RemoteKey(Resource):
+    """HTTP resource for retreiving the TLS certificate and NACL signature
+    verification keys for a collection of servers. Checks that the reported
+    X.509 TLS certificate matches the one used in the HTTPS connection. Checks
+    that the NACL signature for the remote server is valid. Returns a dict of
+    JSON signed by both the remote server and by this server.
+
+    Supports individual GET APIs and a bulk query POST API.
+
+    Requsts:
+
+    GET /_matrix/key/v2/query/remote.server.example.com HTTP/1.1
+
+    GET /_matrix/key/v2/query/remote.server.example.com/a.key.id HTTP/1.1
+
+    POST /_matrix/v2/query HTTP/1.1
+    Content-Type: application/json
+    {
+        "server_keys": { "remote.server.example.com": ["a.key.id"] }
+    }
+
+    Response:
+
+    HTTP/1.1 200 OK
+    Content-Type: application/json
+    {
+        "server_keys": [
+            {
+                "server_name": "remote.server.example.com"
+                "valid_until": # posix timestamp
+                "verify_keys": {
+                    "a.key.id": { # The identifier for a key.
+                        key: "" # base64 encoded verification key.
+                    }
+                }
+                "old_verify_keys": {
+                    "an.old.key.id": { # The identifier for an old key.
+                        key: "", # base64 encoded key
+                        expired: 0, # when th e
+                    }
+                }
+                "tls_fingerprints": [
+                    { "sha256": # fingerprint }
+                ]
+                "signatures": {
+                    "remote.server.example.com": {...}
+                    "this.server.example.com": {...}
+                }
+            }
+        ]
+    }
+    """
+
+    isLeaf = True
+
+    def __init__(self, hs):
+        self.keyring = hs.get_keyring()
+        self.store = hs.get_datastore()
+        self.version_string = hs.version_string
+        self.clock = hs.get_clock()
+
+    def render_GET(self, request):
+        self.async_render_GET(request)
+        return NOT_DONE_YET
+
+    @request_handler
+    @defer.inlineCallbacks
+    def async_render_GET(self, request):
+        if len(request.postpath) == 1:
+            server, = request.postpath
+            query = {server: [None]}
+        elif len(request.postpath) == 2:
+            server, key_id = request.postpath
+            query = {server: [key_id]}
+        else:
+            raise SynapseError(
+                404, "Not found %r" % request.postpath, Codes.NOT_FOUND
+            )
+        yield self.query_keys(request, query, query_remote_on_cache_miss=True)
+
+    def render_POST(self, request):
+        self.async_render_POST(request)
+        return NOT_DONE_YET
+
+    @request_handler
+    @defer.inlineCallbacks
+    def async_render_POST(self, request):
+        try:
+            content = json.loads(request.content.read())
+            if type(content) != dict:
+                raise ValueError()
+        except ValueError:
+            raise SynapseError(
+                400, "Content must be JSON object.", errcode=Codes.NOT_JSON
+            )
+
+        query = content["server_keys"]
+
+        yield self.query_keys(request, query, query_remote_on_cache_miss=True)
+
+    @defer.inlineCallbacks
+    def query_keys(self, request, query, query_remote_on_cache_miss=False):
+        store_queries = []
+        for server_name, key_ids in query.items():
+            for key_id in key_ids:
+                store_queries.append((server_name, key_id, None))
+
+        cached = yield self.store.get_server_keys_json(store_queries)
+
+        json_results = []
+
+        time_now_ms = self.clock.time_msec()
+
+        cache_misses = dict()
+        for (server_name, key_id, from_server), results in cached.items():
+            results = [
+                (result["ts_added_ms"], result) for result in results
+                if result["ts_valid_until_ms"] > time_now_ms
+            ]
+
+            if not results:
+                if key_id is not None:
+                    cache_misses.setdefault(server_name, set()).add(key_id)
+                continue
+
+            if key_id is not None:
+                most_recent_result = max(results)
+                json_results.append(most_recent_result[-1]["key_json"])
+            else:
+                for result in results:
+                    json_results.append(result[-1]["key_json"])
+
+        if cache_misses and query_remote_on_cache_miss:
+            for server_name, key_ids in cache_misses.items():
+                try:
+                    yield self.keyring.get_server_verify_key_v2_direct(
+                        server_name, key_ids
+                    )
+                except:
+                    logger.exception("Failed to get key for %r", server_name)
+                    pass
+            yield self.query_keys(
+                request, query, query_remote_on_cache_miss=False
+            )
+        else:
+            result_io = BytesIO()
+            result_io.write(b"{\"server_keys\":")
+            sep = b"["
+            for json_bytes in json_results:
+                result_io.write(sep)
+                result_io.write(json_bytes)
+                sep = b","
+            if sep == b"[":
+                result_io.write(sep)
+            result_io.write(b"]}")
+
+            respond_with_json_bytes(
+                request, 200, result_io.getvalue(),
+                version_string=self.version_string
+            )
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 8b08d42859..22b158d71e 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -140,8 +140,8 @@ class KeyStore(SQLBaseStore):
                 "key_id": key_id,
                 "from_server": from_server,
                 "ts_added_ms": ts_now_ms,
-                "ts_valid_until_ms": ts_valid_until_ms,
-                "key_json": key_json_bytes,
+                "ts_valid_until_ms": ts_expires_ms,
+                "key_json": buffer(key_json_bytes),
             },
             or_replace=True,
         )
@@ -149,9 +149,9 @@ class KeyStore(SQLBaseStore):
     def get_server_keys_json(self, server_keys):
         """Retrive the key json for a list of server_keys and key ids.
         If no keys are found for a given server, key_id and source then
-        that server, key_id, and source triplet will be missing from the
-        returned dictionary. The JSON is returned as a byte array so that it
-        can be efficiently used in an HTTP response.
+        that server, key_id, and source triplet entry will be an empty list.
+        The JSON is returned as a byte array so that it can be efficiently
+        used in an HTTP response.
         Args:
             server_keys (list): List of (server_name, key_id, source) triplets.
         Returns:
@@ -161,16 +161,25 @@ class KeyStore(SQLBaseStore):
         def _get_server_keys_json_txn(txn):
             results = {}
             for server_name, key_id, from_server in server_keys:
-                rows = _simple_select_list_txn(
-                    keyvalues={
-                        "server_name": server_name,
-                        "key_id": key_id,
-                        "from_server": from_server,
-                    },
-                    retcols=("ts_valid_until_ms", "key_json"),
+                keyvalues = {"server_name": server_name}
+                if key_id is not None:
+                    keyvalues["key_id"] = key_id
+                if from_server is not None:
+                    keyvalues["from_server"] = from_server
+                rows = self._simple_select_list_txn(
+                    txn,
+                    "server_keys_json",
+                    keyvalues=keyvalues,
+                    retcols=(
+                        "key_id",
+                        "from_server",
+                        "ts_added_ms",
+                        "ts_valid_until_ms",
+                        "key_json",
+                    ),
                 )
                 results[(server_name, key_id, from_server)] = rows
             return results
-        return runInteraction(
+        return self.runInteraction(
             "get_server_keys_json", _get_server_keys_json_txn
         )