diff options
author | Erik Johnston <erik@matrix.org> | 2019-04-17 19:44:40 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2019-04-17 19:44:40 +0100 |
commit | ca90336a6935b36b5761244005b0f68b496d5d79 (patch) | |
tree | 6bbce5eafc0db3b24ccc3b59b051da850382ae09 /synapse/storage/keys.py | |
parent | Add management endpoints for account validity (diff) | |
parent | Merge pull request #5047 from matrix-org/babolivier/account_expiration (diff) | |
download | synapse-ca90336a6935b36b5761244005b0f68b496d5d79.tar.xz |
Merge branch 'develop' of github.com:matrix-org/synapse into babolivier/account_expiration
Diffstat (limited to 'synapse/storage/keys.py')
-rw-r--r-- | synapse/storage/keys.py | 154 |
1 files changed, 58 insertions, 96 deletions
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 8af17921e3..7036541792 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,17 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import hashlib +import itertools import logging import six from signedjson.key import decode_verify_key_bytes -import OpenSSL -from twisted.internet import defer - -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util import batch_iter +from synapse.util.caches.descriptors import cached, cachedList from ._base import SQLBaseStore @@ -38,93 +37,56 @@ else: class KeyStore(SQLBaseStore): - """Persistence for signature verification keys and tls X.509 certificates + """Persistence for signature verification keys """ - @defer.inlineCallbacks - def get_server_certificate(self, server_name): - """Retrieve the TLS X.509 certificate for the given server + @cached() + def _get_server_verify_key(self, server_name_and_key_id): + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" + ) + def get_server_verify_keys(self, server_name_and_key_ids): + """ Args: - server_name (bytes): The name of the server. + server_name_and_key_ids (iterable[Tuple[str, str]]): + iterable of (server_name, key-id) tuples to fetch keys for + Returns: - (OpenSSL.crypto.X509): The tls certificate. + Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]: + map from (server_name, key_id) -> VerifyKey, or None if the key is + unknown """ - tls_certificate_bytes, = yield self._simple_select_one( - table="server_tls_certificates", - keyvalues={"server_name": server_name}, - retcols=("tls_certificate",), - desc="get_server_certificate", - ) - tls_certificate = OpenSSL.crypto.load_certificate( - OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes, - ) - defer.returnValue(tls_certificate) + keys = {} - def store_server_certificate(self, server_name, from_server, time_now_ms, - tls_certificate): - """Stores the TLS X.509 certificate for the given server - Args: - server_name (str): The name of the server. - from_server (str): Where the certificate was looked up - time_now_ms (int): The time now in milliseconds - tls_certificate (OpenSSL.crypto.X509): The X.509 certificate. - """ - tls_certificate_bytes = OpenSSL.crypto.dump_certificate( - OpenSSL.crypto.FILETYPE_ASN1, tls_certificate - ) - fingerprint = hashlib.sha256(tls_certificate_bytes).hexdigest() - return self._simple_upsert( - table="server_tls_certificates", - keyvalues={ - "server_name": server_name, - "fingerprint": fingerprint, - }, - values={ - "from_server": from_server, - "ts_added_ms": time_now_ms, - "tls_certificate": db_binary_type(tls_certificate_bytes), - }, - desc="store_server_certificate", - ) + def _get_keys(txn, batch): + """Processes a batch of keys to fetch, and adds the result to `keys`.""" - @cachedInlineCallbacks() - def _get_server_verify_key(self, server_name, key_id): - verify_key_bytes = yield self._simple_select_one_onecol( - table="server_signature_keys", - keyvalues={ - "server_name": server_name, - "key_id": key_id, - }, - retcol="verify_key", - desc="_get_server_verify_key", - allow_none=True, - ) + # batch_iter always returns tuples so it's safe to do len(batch) + sql = ( + "SELECT server_name, key_id, verify_key FROM server_signature_keys " + "WHERE 1=0" + ) + " OR (server_name=? AND key_id=?)" * len(batch) - if verify_key_bytes: - defer.returnValue(decode_verify_key_bytes( - key_id, bytes(verify_key_bytes) - )) + txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) - @defer.inlineCallbacks - def get_server_verify_keys(self, server_name, key_ids): - """Retrieve the NACL verification key for a given server for the given - key_ids - Args: - server_name (str): The name of the server. - key_ids (iterable[str]): key_ids to try and look up. - Returns: - Deferred: resolves to dict[str, VerifyKey]: map from - key_id to verification key. - """ - keys = {} - for key_id in key_ids: - key = yield self._get_server_verify_key(server_name, key_id) - if key: - keys[key_id] = key - defer.returnValue(keys) - - def store_server_verify_key(self, server_name, from_server, time_now_ms, - verify_key): + for row in txn: + server_name, key_id, key_bytes = row + keys[(server_name, key_id)] = decode_verify_key_bytes( + key_id, bytes(key_bytes) + ) + + def _txn(txn): + for batch in batch_iter(server_name_and_key_ids, 50): + _get_keys(txn, batch) + return keys + + return self.runInteraction("get_server_verify_keys", _txn) + + def store_server_verify_key( + self, server_name, from_server, time_now_ms, verify_key + ): """Stores a NACL verification key for the given server. Args: server_name (str): The name of the server. @@ -139,25 +101,25 @@ class KeyStore(SQLBaseStore): self._simple_upsert_txn( txn, table="server_signature_keys", - keyvalues={ - "server_name": server_name, - "key_id": key_id, - }, + keyvalues={"server_name": server_name, "key_id": key_id}, values={ "from_server": from_server, "ts_added_ms": time_now_ms, "verify_key": db_binary_type(verify_key.encode()), }, ) + # invalidate takes a tuple corresponding to the params of + # _get_server_verify_key. _get_server_verify_key only takes one + # param, which is itself the 2-tuple (server_name, key_id). txn.call_after( - self._get_server_verify_key.invalidate, - (server_name, key_id) + self._get_server_verify_key.invalidate, ((server_name, key_id),) ) return self.runInteraction("store_server_verify_key", _txn) - def store_server_keys_json(self, server_name, key_id, from_server, - ts_now_ms, ts_expires_ms, key_json_bytes): + def store_server_keys_json( + self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes + ): """Stores the JSON bytes for a set of keys from a server The JSON should be signed by the originating server, the intermediate server, and by this server. Updates the value for the @@ -197,9 +159,10 @@ class KeyStore(SQLBaseStore): Args: server_keys (list): List of (server_name, key_id, source) triplets. Returns: - Dict mapping (server_name, key_id, source) triplets to dicts with - "ts_valid_until_ms" and "key_json" keys. + Deferred[dict[Tuple[str, str, str|None], list[dict]]]: + Dict mapping (server_name, key_id, source) triplets to lists of dicts """ + def _get_server_keys_json_txn(txn): results = {} for server_name, key_id, from_server in server_keys: @@ -222,6 +185,5 @@ class KeyStore(SQLBaseStore): ) results[(server_name, key_id, from_server)] = rows return results - return self.runInteraction( - "get_server_keys_json", _get_server_keys_json_txn - ) + + return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn) |