summary refs log tree commit diff
path: root/synapse/storage/databases/main/keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/keys.py')
-rw-r--r--synapse/storage/databases/main/keys.py24
1 files changed, 13 insertions, 11 deletions
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 1c0a049c55..ad43bb05ab 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,7 +16,7 @@
 
 import itertools
 import logging
-from typing import Iterable, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple
 
 from signedjson.key import decode_verify_key_bytes
 
@@ -42,16 +42,17 @@ class KeyStore(SQLBaseStore):
     @cachedList(
         cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
     )
-    def get_server_verify_keys(self, server_name_and_key_ids):
+    async def get_server_verify_keys(
+        self, server_name_and_key_ids: Iterable[Tuple[str, str]]
+    ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]:
         """
         Args:
-            server_name_and_key_ids (iterable[Tuple[str, str]]):
+            server_name_and_key_ids:
                 iterable of (server_name, key-id) tuples to fetch keys for
 
         Returns:
-            Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
-                map from (server_name, key_id) -> FetchKeyResult, or None if the key is
-                unknown
+            A map from (server_name, key_id) -> FetchKeyResult, or None if the
+            key is unknown
         """
         keys = {}
 
@@ -87,7 +88,7 @@ class KeyStore(SQLBaseStore):
                 _get_keys(txn, batch)
             return keys
 
-        return self.db_pool.runInteraction("get_server_verify_keys", _txn)
+        return await self.db_pool.runInteraction("get_server_verify_keys", _txn)
 
     async def store_server_verify_keys(
         self,
@@ -179,7 +180,9 @@ class KeyStore(SQLBaseStore):
             desc="store_server_keys_json",
         )
 
-    def get_server_keys_json(self, server_keys):
+    async def get_server_keys_json(
+        self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
+    ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
         """Retrive the key json for a list of server_keys and key ids.
         If no keys are found for a given server, key_id and source then
         that server, key_id, and source triplet entry will be an empty list.
@@ -188,8 +191,7 @@ class KeyStore(SQLBaseStore):
         Args:
             server_keys (list): List of (server_name, key_id, source) triplets.
         Returns:
-            Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
-                Dict mapping (server_name, key_id, source) triplets to lists of dicts
+            A mapping from (server_name, key_id, source) triplets to a list of dicts
         """
 
         def _get_server_keys_json_txn(txn):
@@ -215,6 +217,6 @@ class KeyStore(SQLBaseStore):
                 results[(server_name, key_id, from_server)] = rows
             return results
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_server_keys_json", _get_server_keys_json_txn
         )