diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 4cf5549143..aff69c5f83 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -101,10 +101,10 @@ class Keyring(object):
server_name(str): The name of the server to fetch a key for.
keys_ids (list of str): The key_ids to check for.
"""
- cached = yield self.store.get_server_verify_key(server_name, key_ids[0])
+ cached = yield self.store.get_server_verify_keys(server_name, key_ids)
if cached:
- defer.returnValue(cached)
+ defer.returnValue(cached[0])
return
download = self.key_downloads.get(server_name)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 88a5642924..2902e35181 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -71,24 +71,24 @@ class KeyStore(SQLBaseStore):
desc="store_server_certificate",
)
- @cached(num_args=2)
+ @cached()
@defer.inlineCallbacks
- def get_server_verify_key(self, server_name, key_id):
- key_bytes = yield self._simple_select_one_onecol(
+ def get_all_server_verify_keys(self, server_name):
+ rows = yield self._simple_select_list(
table="server_signature_keys",
keyvalues={
"server_name": server_name,
- "key_id": key_id
},
- retcol="verify_key",
- desc="get_server_verify_key",
- allow_none=True,
+ retcols=["key_id", "verify_key"],
+ desc="get_all_server_verify_keys",
)
- if key_bytes:
- defer.returnValue(decode_verify_key_bytes(key_id, str(key_bytes)))
- else:
- defer.returnValue(None)
+ defer.returnValue({
+ row["key_id"]: decode_verify_key_bytes(
+ row["key_id"], str(row["verify_key"])
+ )
+ for row in rows
+ })
@defer.inlineCallbacks
def get_server_verify_keys(self, server_name, key_ids):
@@ -100,23 +100,8 @@ class KeyStore(SQLBaseStore):
Returns:
(list of VerifyKey): The verification keys.
"""
- sql = (
- "SELECT key_id, verify_key FROM server_signature_keys"
- " WHERE server_name = ?"
- " AND key_id in (" + ",".join("?" for key_id in key_ids) + ")"
- )
-
- rows = yield self._execute_and_decode(
- "get_server_verify_keys", sql, server_name, *key_ids
- )
-
- keys = []
- for row in rows:
- key_id = row["key_id"]
- key_bytes = row["verify_key"]
- key = decode_verify_key_bytes(key_id, str(key_bytes))
- keys.append(key)
- defer.returnValue(keys)
+ keys = yield self.get_all_server_verify_keys(server_name)
+ defer.returnValue([keys[k] for k in key_ids if k in keys])
@defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms,
@@ -129,12 +114,11 @@ class KeyStore(SQLBaseStore):
ts_now_ms (int): The time now in milliseconds
verification_key (VerifyKey): The NACL verify key.
"""
- key_id = "%s:%s" % (verify_key.alg, verify_key.version)
yield self._simple_upsert(
table="server_signature_keys",
keyvalues={
"server_name": server_name,
- "key_id": key_id,
+ "key_id": "%s:%s" % (verify_key.alg, verify_key.version),
},
values={
"from_server": from_server,
@@ -144,7 +128,7 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key",
)
- self.get_server_verify_key.invalidate(server_name, key_id)
+ self.get_all_server_verify_keys.invalidate(server_name)
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):
|