diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py
index b79ed02590..1c14791b09 100644
--- a/synapse/rest/key/v2/__init__.py
+++ b/synapse/rest/key/v2/__init__.py
@@ -13,7 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.web.resource import Resource
from .local_key_resource import LocalKey
+from .remote_key_resource import RemoteKey
-class KeyApiV2Resource(LocalKey):
- pass
+
+class KeyApiV2Resource(Resource):
+ def __init__(self, hs):
+ Resource.__init__(self)
+ self.putChild("server", LocalKey(hs))
+ self.putChild("query", RemoteKey(hs))
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index 1c0e0717c1..982a460962 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -31,7 +31,7 @@ class LocalKey(Resource):
"""HTTP resource containing encoding the TLS X.509 certificate and NACL
signature verification keys for this server::
- GET /_matrix/key/v2/ HTTP/1.1
+ GET /_matrix/key/v2/server/a.key.id HTTP/1.1
HTTP/1.1 200 OK
Content-Type: application/json
@@ -56,6 +56,8 @@ class LocalKey(Resource):
}
"""
+ isLeaf = True
+
def __init__(self, hs):
self.version_string = hs.version_string
self.config = hs.config
@@ -68,7 +70,6 @@ class LocalKey(Resource):
self.expires = int(time_now_msec + refresh_interval)
self.response_body = encode_canonical_json(self.response_json_object())
-
def response_json_object(self):
verify_keys = {}
for key in self.config.signing_key:
@@ -120,7 +121,3 @@ class LocalKey(Resource):
request, 200, self.response_body,
version_string=self.version_string
)
-
- def getChild(self, name, request):
- if name == '':
- return self
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
new file mode 100644
index 0000000000..cf6f2c2e73
--- /dev/null
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -0,0 +1,174 @@
+from synapse.http.server import request_handler, respond_with_json_bytes
+from synapse.api.errors import SynapseError, Codes
+
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+from twisted.internet import defer
+
+
+from io import BytesIO
+import json
+import logging
+logger = logging.getLogger(__name__)
+
+
+class RemoteKey(Resource):
+ """HTTP resource for retreiving the TLS certificate and NACL signature
+ verification keys for a collection of servers. Checks that the reported
+ X.509 TLS certificate matches the one used in the HTTPS connection. Checks
+ that the NACL signature for the remote server is valid. Returns a dict of
+ JSON signed by both the remote server and by this server.
+
+ Supports individual GET APIs and a bulk query POST API.
+
+ Requsts:
+
+ GET /_matrix/key/v2/query/remote.server.example.com HTTP/1.1
+
+ GET /_matrix/key/v2/query/remote.server.example.com/a.key.id HTTP/1.1
+
+ POST /_matrix/v2/query HTTP/1.1
+ Content-Type: application/json
+ {
+ "server_keys": { "remote.server.example.com": ["a.key.id"] }
+ }
+
+ Response:
+
+ HTTP/1.1 200 OK
+ Content-Type: application/json
+ {
+ "server_keys": [
+ {
+ "server_name": "remote.server.example.com"
+ "valid_until": # posix timestamp
+ "verify_keys": {
+ "a.key.id": { # The identifier for a key.
+ key: "" # base64 encoded verification key.
+ }
+ }
+ "old_verify_keys": {
+ "an.old.key.id": { # The identifier for an old key.
+ key: "", # base64 encoded key
+ expired: 0, # when th e
+ }
+ }
+ "tls_fingerprints": [
+ { "sha256": # fingerprint }
+ ]
+ "signatures": {
+ "remote.server.example.com": {...}
+ "this.server.example.com": {...}
+ }
+ }
+ ]
+ }
+ """
+
+ isLeaf = True
+
+ def __init__(self, hs):
+ self.keyring = hs.get_keyring()
+ self.store = hs.get_datastore()
+ self.version_string = hs.version_string
+ self.clock = hs.get_clock()
+
+ def render_GET(self, request):
+ self.async_render_GET(request)
+ return NOT_DONE_YET
+
+ @request_handler
+ @defer.inlineCallbacks
+ def async_render_GET(self, request):
+ if len(request.postpath) == 1:
+ server, = request.postpath
+ query = {server: [None]}
+ elif len(request.postpath) == 2:
+ server, key_id = request.postpath
+ query = {server: [key_id]}
+ else:
+ raise SynapseError(
+ 404, "Not found %r" % request.postpath, Codes.NOT_FOUND
+ )
+ yield self.query_keys(request, query, query_remote_on_cache_miss=True)
+
+ def render_POST(self, request):
+ self.async_render_POST(request)
+ return NOT_DONE_YET
+
+ @request_handler
+ @defer.inlineCallbacks
+ def async_render_POST(self, request):
+ try:
+ content = json.loads(request.content.read())
+ if type(content) != dict:
+ raise ValueError()
+ except ValueError:
+ raise SynapseError(
+ 400, "Content must be JSON object.", errcode=Codes.NOT_JSON
+ )
+
+ query = content["server_keys"]
+
+ yield self.query_keys(request, query, query_remote_on_cache_miss=True)
+
+ @defer.inlineCallbacks
+ def query_keys(self, request, query, query_remote_on_cache_miss=False):
+ store_queries = []
+ for server_name, key_ids in query.items():
+ for key_id in key_ids:
+ store_queries.append((server_name, key_id, None))
+
+ cached = yield self.store.get_server_keys_json(store_queries)
+
+ json_results = []
+
+ time_now_ms = self.clock.time_msec()
+
+ cache_misses = dict()
+ for (server_name, key_id, from_server), results in cached.items():
+ results = [
+ (result["ts_added_ms"], result) for result in results
+ if result["ts_valid_until_ms"] > time_now_ms
+ ]
+
+ if not results:
+ if key_id is not None:
+ cache_misses.setdefault(server_name, set()).add(key_id)
+ continue
+
+ if key_id is not None:
+ most_recent_result = max(results)
+ json_results.append(most_recent_result[-1]["key_json"])
+ else:
+ for result in results:
+ json_results.append(result[-1]["key_json"])
+
+ if cache_misses and query_remote_on_cache_miss:
+ for server_name, key_ids in cache_misses.items():
+ try:
+ yield self.keyring.get_server_verify_key_v2_direct(
+ server_name, key_ids
+ )
+ except:
+ logger.exception("Failed to get key for %r", server_name)
+ pass
+ yield self.query_keys(
+ request, query, query_remote_on_cache_miss=False
+ )
+ else:
+ result_io = BytesIO()
+ result_io.write(b"{\"server_keys\":")
+ sep = b"["
+ for json_bytes in json_results:
+ result_io.write(sep)
+ result_io.write(json_bytes)
+ sep = b","
+ if sep == b"[":
+ result_io.write(sep)
+ result_io.write(b"]}")
+
+ respond_with_json_bytes(
+ request, 200, result_io.getvalue(),
+ version_string=self.version_string
+ )
|