summary refs log tree commit diff
path: root/synapse/storage/keys.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-06-19 17:20:58 +0100
committerErik Johnston <erik@matrix.org>2015-06-19 17:20:58 +0100
commitb39b294d1f51e80338f2b8992f4c61952cc6462c (patch)
tree81a1948af744046bb740764bdd48956748f584f4 /synapse/storage/keys.py
parentMerge branch 'develop' of github.com:matrix-org/synapse into erikj/persist_ev... (diff)
downloadsynapse-b39b294d1f51e80338f2b8992f4c61952cc6462c.tar.xz
Properly cache get_server_verify_keys
Diffstat (limited to 'synapse/storage/keys.py')
-rw-r--r--synapse/storage/keys.py46
1 files changed, 15 insertions, 31 deletions
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 88a5642924..2902e35181 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -71,24 +71,24 @@ class KeyStore(SQLBaseStore):
             desc="store_server_certificate",
         )
 
-    @cached(num_args=2)
+    @cached()
     @defer.inlineCallbacks
-    def get_server_verify_key(self, server_name, key_id):
-        key_bytes = yield self._simple_select_one_onecol(
+    def get_all_server_verify_keys(self, server_name):
+        rows = yield self._simple_select_list(
             table="server_signature_keys",
             keyvalues={
                 "server_name": server_name,
-                "key_id": key_id
             },
-            retcol="verify_key",
-            desc="get_server_verify_key",
-            allow_none=True,
+            retcols=["key_id", "verify_key"],
+            desc="get_all_server_verify_keys",
         )
 
-        if key_bytes:
-            defer.returnValue(decode_verify_key_bytes(key_id, str(key_bytes)))
-        else:
-            defer.returnValue(None)
+        defer.returnValue({
+            row["key_id"]: decode_verify_key_bytes(
+                row["key_id"], str(row["verify_key"])
+            )
+            for row in rows
+        })
 
     @defer.inlineCallbacks
     def get_server_verify_keys(self, server_name, key_ids):
@@ -100,23 +100,8 @@ class KeyStore(SQLBaseStore):
         Returns:
             (list of VerifyKey): The verification keys.
         """
-        sql = (
-            "SELECT key_id, verify_key FROM server_signature_keys"
-            " WHERE server_name = ?"
-            " AND key_id in (" + ",".join("?" for key_id in key_ids) + ")"
-        )
-
-        rows = yield self._execute_and_decode(
-            "get_server_verify_keys", sql, server_name, *key_ids
-        )
-
-        keys = []
-        for row in rows:
-            key_id = row["key_id"]
-            key_bytes = row["verify_key"]
-            key = decode_verify_key_bytes(key_id, str(key_bytes))
-            keys.append(key)
-        defer.returnValue(keys)
+        keys = yield self.get_all_server_verify_keys(server_name)
+        defer.returnValue([keys[k] for k in key_ids if k in keys])
 
     @defer.inlineCallbacks
     def store_server_verify_key(self, server_name, from_server, time_now_ms,
@@ -129,12 +114,11 @@ class KeyStore(SQLBaseStore):
             ts_now_ms (int): The time now in milliseconds
             verification_key (VerifyKey): The NACL verify key.
         """
-        key_id = "%s:%s" % (verify_key.alg, verify_key.version)
         yield self._simple_upsert(
             table="server_signature_keys",
             keyvalues={
                 "server_name": server_name,
-                "key_id": key_id,
+                "key_id": "%s:%s" % (verify_key.alg, verify_key.version),
             },
             values={
                 "from_server": from_server,
@@ -144,7 +128,7 @@ class KeyStore(SQLBaseStore):
             desc="store_server_verify_key",
         )
 
-        self.get_server_verify_key.invalidate(server_name, key_id)
+        self.get_all_server_verify_keys.invalidate(server_name)
 
     def store_server_keys_json(self, server_name, key_id, from_server,
                                ts_now_ms, ts_expires_ms, key_json_bytes):