diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/crypto/keyring.py | 59 | ||||
-rw-r--r-- | synapse/storage/keys.py | 65 |
2 files changed, 53 insertions, 71 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 5cc98542ce..badb5254ea 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -453,10 +453,11 @@ class Keyring(object): raise_from(KeyLookupError("Remote server returned an error"), e) keys = {} + added_keys = [] - responses = query_response["server_keys"] + time_now_ms = self.clock.time_msec() - for response in responses: + for response in query_response["server_keys"]: if ( u"signatures" not in response or perspective_name not in response[u"signatures"] @@ -492,21 +493,13 @@ class Keyring(object): ) server_name = response["server_name"] + added_keys.extend( + (server_name, key_id, key) for key_id, key in processed_response.items() + ) keys.setdefault(server_name, {}).update(processed_response) - yield logcontext.make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.store_keys, - server_name=server_name, - from_server=perspective_name, - verify_keys=response_keys, - ) - for server_name, response_keys in keys.items() - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + yield self.store.store_server_verify_keys( + perspective_name, time_now_ms, added_keys ) defer.returnValue(keys) @@ -519,6 +512,7 @@ class Keyring(object): if requested_key_id in keys: continue + time_now_ms = self.clock.time_msec() try: response = yield self.client.get_json( destination=server_name, @@ -548,12 +542,13 @@ class Keyring(object): requested_ids=[requested_key_id], response_json=response, ) - + yield self.store.store_server_verify_keys( + server_name, + time_now_ms, + ((server_name, key_id, key) for key_id, key in response_keys.items()), + ) keys.update(response_keys) - yield self.store_keys( - server_name=server_name, from_server=server_name, verify_keys=keys - ) defer.returnValue({server_name: keys}) @defer.inlineCallbacks @@ -650,32 +645,6 @@ class Keyring(object): defer.returnValue(response_keys) - def store_keys(self, server_name, from_server, verify_keys): - """Store a collection of verify keys for a given server - Args: - server_name(str): The name of the server the keys are for. - from_server(str): The server the keys were downloaded from. - verify_keys(dict): A mapping of key_id to VerifyKey. - Returns: - A deferred that completes when the keys are stored. - """ - # TODO(markjh): Store whether the keys have expired. - return logcontext.make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.store.store_server_verify_key, - server_name, - server_name, - key.time_added, - key, - ) - for key_id, key in verify_keys.items() - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) - ) - @defer.inlineCallbacks def _handle_key_deferred(verify_request): diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 7036541792..3c5f52009b 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -84,38 +84,51 @@ class KeyStore(SQLBaseStore): return self.runInteraction("get_server_verify_keys", _txn) - def store_server_verify_key( - self, server_name, from_server, time_now_ms, verify_key - ): - """Stores a NACL verification key for the given server. + def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): + """Stores NACL verification keys for remote servers. Args: - server_name (str): The name of the server. - from_server (str): Where the verification key was looked up - time_now_ms (int): The time now in milliseconds - verify_key (nacl.signing.VerifyKey): The NACL verify key. + from_server (str): Where the verification keys were looked up + ts_added_ms (int): The time to record that the key was added + verify_keys (iterable[tuple[str, str, nacl.signing.VerifyKey]]): + keys to be stored. Each entry is a triplet of + (server_name, key_id, key). """ - key_id = "%s:%s" % (verify_key.alg, verify_key.version) - - # XXX fix this to not need a lock (#3819) - def _txn(txn): - self._simple_upsert_txn( - txn, - table="server_signature_keys", - keyvalues={"server_name": server_name, "key_id": key_id}, - values={ - "from_server": from_server, - "ts_added_ms": time_now_ms, - "verify_key": db_binary_type(verify_key.encode()), - }, + key_values = [] + value_values = [] + invalidations = [] + for server_name, key_id, verify_key in verify_keys: + key_values.append((server_name, key_id)) + value_values.append( + ( + from_server, + ts_added_ms, + db_binary_type(verify_key.encode()), + ) ) # invalidate takes a tuple corresponding to the params of # _get_server_verify_key. _get_server_verify_key only takes one # param, which is itself the 2-tuple (server_name, key_id). - txn.call_after( - self._get_server_verify_key.invalidate, ((server_name, key_id),) - ) - - return self.runInteraction("store_server_verify_key", _txn) + invalidations.append((server_name, key_id)) + + def _invalidate(res): + f = self._get_server_verify_key.invalidate + for i in invalidations: + f((i, )) + return res + + return self.runInteraction( + "store_server_verify_keys", + self._simple_upsert_many_txn, + table="server_signature_keys", + key_names=("server_name", "key_id"), + key_values=key_values, + value_names=( + "from_server", + "ts_added_ms", + "verify_key", + ), + value_values=value_values, + ).addCallback(_invalidate) def store_server_keys_json( self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes |