summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/keys.py41
1 files changed, 24 insertions, 17 deletions
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 3b5e0a4fb9..87aeaf71d6 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -113,30 +113,37 @@ class KeyStore(SQLBaseStore):
                 keys[key_id] = key
         defer.returnValue(keys)
 
-    @defer.inlineCallbacks
     def store_server_verify_key(self, server_name, from_server, time_now_ms,
                                 verify_key):
         """Stores a NACL verification key for the given server.
         Args:
             server_name (str): The name of the server.
-            key_id (str): The version of the key for the server.
             from_server (str): Where the verification key was looked up
-            ts_now_ms (int): The time now in milliseconds
-            verification_key (VerifyKey): The NACL verify key.
+            time_now_ms (int): The time now in milliseconds
+            verify_key (nacl.signing.VerifyKey): The NACL verify key.
         """
-        yield self._simple_upsert(
-            table="server_signature_keys",
-            keyvalues={
-                "server_name": server_name,
-                "key_id": "%s:%s" % (verify_key.alg, verify_key.version),
-            },
-            values={
-                "from_server": from_server,
-                "ts_added_ms": time_now_ms,
-                "verify_key": buffer(verify_key.encode()),
-            },
-            desc="store_server_verify_key",
-        )
+        key_id = "%s:%s" % (verify_key.alg, verify_key.version)
+
+        def _txn(txn):
+            self._simple_upsert_txn(
+                txn,
+                table="server_signature_keys",
+                keyvalues={
+                    "server_name": server_name,
+                    "key_id": key_id,
+                },
+                values={
+                    "from_server": from_server,
+                    "ts_added_ms": time_now_ms,
+                    "verify_key": buffer(verify_key.encode()),
+                },
+            )
+            txn.call_after(
+                self._get_server_verify_key.invalidate,
+                (server_name, key_id)
+            )
+
+        return self.runInteraction("store_server_verify_key", _txn)
 
     def store_server_keys_json(self, server_name, key_id, from_server,
                                ts_now_ms, ts_expires_ms, key_json_bytes):