diff options
Diffstat (limited to 'synapse/storage/keys.py')
-rw-r--r-- | synapse/storage/keys.py | 49 |
1 files changed, 30 insertions, 19 deletions
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 5bdf497b93..49b8e37cfd 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore +from _base import SQLBaseStore, cachedInlineCallbacks from twisted.internet import defer @@ -71,6 +71,24 @@ class KeyStore(SQLBaseStore): desc="store_server_certificate", ) + @cachedInlineCallbacks() + def get_all_server_verify_keys(self, server_name): + rows = yield self._simple_select_list( + table="server_signature_keys", + keyvalues={ + "server_name": server_name, + }, + retcols=["key_id", "verify_key"], + desc="get_all_server_verify_keys", + ) + + defer.returnValue({ + row["key_id"]: decode_verify_key_bytes( + row["key_id"], str(row["verify_key"]) + ) + for row in rows + }) + @defer.inlineCallbacks def get_server_verify_keys(self, server_name, key_ids): """Retrieve the NACL verification key for a given server for the given @@ -81,24 +99,14 @@ class KeyStore(SQLBaseStore): Returns: (list of VerifyKey): The verification keys. """ - sql = ( - "SELECT key_id, verify_key FROM server_signature_keys" - " WHERE server_name = ?" - " AND key_id in (" + ",".join("?" for key_id in key_ids) + ")" - ) - - rows = yield self._execute_and_decode( - "get_server_verify_keys", sql, server_name, *key_ids - ) - - keys = [] - for row in rows: - key_id = row["key_id"] - key_bytes = row["verify_key"] - key = decode_verify_key_bytes(key_id, str(key_bytes)) - keys.append(key) - defer.returnValue(keys) + keys = yield self.get_all_server_verify_keys(server_name) + defer.returnValue({ + k: keys[k] + for k in key_ids + if k in keys and keys[k] + }) + @defer.inlineCallbacks def store_server_verify_key(self, server_name, from_server, time_now_ms, verify_key): """Stores a NACL verification key for the given server. @@ -109,7 +117,7 @@ class KeyStore(SQLBaseStore): ts_now_ms (int): The time now in milliseconds verification_key (VerifyKey): The NACL verify key. """ - return self._simple_upsert( + yield self._simple_upsert( table="server_signature_keys", keyvalues={ "server_name": server_name, @@ -123,6 +131,8 @@ class KeyStore(SQLBaseStore): desc="store_server_verify_key", ) + self.get_all_server_verify_keys.invalidate((server_name,)) + def store_server_keys_json(self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes): """Stores the JSON bytes for a set of keys from a server @@ -152,6 +162,7 @@ class KeyStore(SQLBaseStore): "ts_valid_until_ms": ts_expires_ms, "key_json": buffer(key_json_bytes), }, + desc="store_server_keys_json", ) def get_server_keys_json(self, server_keys): |