summary refs log tree commit diff
path: root/synapse/storage/databases/main/keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/keys.py')
-rw-r--r--synapse/storage/databases/main/keys.py28
1 files changed, 16 insertions, 12 deletions
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 384e9c5eb0..fadcad51e7 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,6 +16,7 @@
 
 import itertools
 import logging
+from typing import Iterable, Tuple
 
 from signedjson.key import decode_verify_key_bytes
 
@@ -88,12 +89,17 @@ class KeyStore(SQLBaseStore):
 
         return self.db_pool.runInteraction("get_server_verify_keys", _txn)
 
-    def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
+    async def store_server_verify_keys(
+        self,
+        from_server: str,
+        ts_added_ms: int,
+        verify_keys: Iterable[Tuple[str, str, FetchKeyResult]],
+    ) -> None:
         """Stores NACL verification keys for remote servers.
         Args:
-            from_server (str): Where the verification keys were looked up
-            ts_added_ms (int): The time to record that the key was added
-            verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
+            from_server: Where the verification keys were looked up
+            ts_added_ms: The time to record that the key was added
+            verify_keys:
                 keys to be stored. Each entry is a triplet of
                 (server_name, key_id, key).
         """
@@ -115,13 +121,7 @@ class KeyStore(SQLBaseStore):
             # param, which is itself the 2-tuple (server_name, key_id).
             invalidations.append((server_name, key_id))
 
-        def _invalidate(res):
-            f = self._get_server_verify_key.invalidate
-            for i in invalidations:
-                f((i,))
-            return res
-
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "store_server_verify_keys",
             self.db_pool.simple_upsert_many_txn,
             table="server_signature_keys",
@@ -134,7 +134,11 @@ class KeyStore(SQLBaseStore):
                 "verify_key",
             ),
             value_values=value_values,
-        ).addCallback(_invalidate)
+        )
+
+        invalidate = self._get_server_verify_key.invalidate
+        for i in invalidations:
+            invalidate((i,))
 
     def store_server_keys_json(
         self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes