summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/replication/slave/storage/keys.py16
-rw-r--r--synapse/storage/keys.py34
2 files changed, 26 insertions, 24 deletions
diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py
index c1c895439d..dd2ae49e48 100644
--- a/synapse/replication/slave/storage/keys.py
+++ b/synapse/replication/slave/storage/keys.py
@@ -13,17 +13,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
 from ._base import BaseSlavedStore
 from synapse.storage import DataStore
 from synapse.storage.keys import KeyStore
 
 
 class SlavedKeyStore(BaseSlavedStore):
-    # TODO: use the cached version and invalidate deleted tokens
-    get_all_server_verify_keys = defer.inlineCallbacks(KeyStore.__dict__[
-        "get_all_server_verify_keys"
-    ].orig)
+    _get_server_verify_key = KeyStore.__dict__[
+        "_get_server_verify_key"
+    ]
 
     get_server_verify_keys = DataStore.get_server_verify_keys.__func__
+    store_server_verify_key = DataStore.store_server_verify_key.__func__
+
+    get_server_certificate = DataStore.get_server_certificate.__func__
+    store_server_certificate = DataStore.store_server_certificate.__func__
+
+    get_server_keys_json = DataStore.get_server_keys_json.__func__
+    store_server_keys_json = DataStore.store_server_keys_json.__func__
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 1195efec08..86b37b9ddd 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -78,22 +78,22 @@ class KeyStore(SQLBaseStore):
         )
 
     @cachedInlineCallbacks()
-    def get_all_server_verify_keys(self, server_name):
-        rows = yield self._simple_select_list(
+    def _get_server_verify_key(self, server_name, key_id):
+        verify_key_bytes = yield self._simple_select_one_onecol(
             table="server_signature_keys",
             keyvalues={
                 "server_name": server_name,
+                "key_id": key_id,
             },
-            retcols=["key_id", "verify_key"],
-            desc="get_all_server_verify_keys",
+            retcol="verify_key",
+            desc="_get_server_verify_key",
+            allow_none=True,
         )
 
-        defer.returnValue({
-            row["key_id"]: decode_verify_key_bytes(
-                row["key_id"], str(row["verify_key"])
-            )
-            for row in rows
-        })
+        if verify_key_bytes:
+            defer.returnValue(decode_verify_key_bytes(
+                key_id, str(verify_key_bytes)
+            ))
 
     @defer.inlineCallbacks
     def get_server_verify_keys(self, server_name, key_ids):
@@ -105,12 +105,12 @@ class KeyStore(SQLBaseStore):
         Returns:
             (list of VerifyKey): The verification keys.
         """
-        keys = yield self.get_all_server_verify_keys(server_name)
-        defer.returnValue({
-            k: keys[k]
-            for k in key_ids
-            if k in keys and keys[k]
-        })
+        keys = {}
+        for key_id in key_ids:
+            key = yield self._get_server_verify_key(server_name, key_id)
+            if key:
+                keys[key_id] = key
+        defer.returnValue(keys)
 
     @defer.inlineCallbacks
     def store_server_verify_key(self, server_name, from_server, time_now_ms,
@@ -137,8 +137,6 @@ class KeyStore(SQLBaseStore):
             desc="store_server_verify_key",
         )
 
-        self.get_all_server_verify_keys.invalidate((server_name,))
-
     def store_server_keys_json(self, server_name, key_id, from_server,
                                ts_now_ms, ts_expires_ms, key_json_bytes):
         """Stores the JSON bytes for a set of keys from a server