diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 384e9c5eb0..ad43bb05ab 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,6 +16,7 @@
import itertools
import logging
+from typing import Dict, Iterable, List, Optional, Tuple
from signedjson.key import decode_verify_key_bytes
@@ -41,16 +42,17 @@ class KeyStore(SQLBaseStore):
@cachedList(
cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
)
- def get_server_verify_keys(self, server_name_and_key_ids):
+ async def get_server_verify_keys(
+ self, server_name_and_key_ids: Iterable[Tuple[str, str]]
+ ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]:
"""
Args:
- server_name_and_key_ids (iterable[Tuple[str, str]]):
+ server_name_and_key_ids:
iterable of (server_name, key-id) tuples to fetch keys for
Returns:
- Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
- map from (server_name, key_id) -> FetchKeyResult, or None if the key is
- unknown
+ A map from (server_name, key_id) -> FetchKeyResult, or None if the
+ key is unknown
"""
keys = {}
@@ -86,14 +88,19 @@ class KeyStore(SQLBaseStore):
_get_keys(txn, batch)
return keys
- return self.db_pool.runInteraction("get_server_verify_keys", _txn)
+ return await self.db_pool.runInteraction("get_server_verify_keys", _txn)
- def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
+ async def store_server_verify_keys(
+ self,
+ from_server: str,
+ ts_added_ms: int,
+ verify_keys: Iterable[Tuple[str, str, FetchKeyResult]],
+ ) -> None:
"""Stores NACL verification keys for remote servers.
Args:
- from_server (str): Where the verification keys were looked up
- ts_added_ms (int): The time to record that the key was added
- verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
+ from_server: Where the verification keys were looked up
+ ts_added_ms: The time to record that the key was added
+ verify_keys:
keys to be stored. Each entry is a triplet of
(server_name, key_id, key).
"""
@@ -115,13 +122,7 @@ class KeyStore(SQLBaseStore):
# param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id))
- def _invalidate(res):
- f = self._get_server_verify_key.invalidate
- for i in invalidations:
- f((i,))
- return res
-
- return self.db_pool.runInteraction(
+ await self.db_pool.runInteraction(
"store_server_verify_keys",
self.db_pool.simple_upsert_many_txn,
table="server_signature_keys",
@@ -134,24 +135,34 @@ class KeyStore(SQLBaseStore):
"verify_key",
),
value_values=value_values,
- ).addCallback(_invalidate)
+ )
- def store_server_keys_json(
- self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
- ):
+ invalidate = self._get_server_verify_key.invalidate
+ for i in invalidations:
+ invalidate((i,))
+
+ async def store_server_keys_json(
+ self,
+ server_name: str,
+ key_id: str,
+ from_server: str,
+ ts_now_ms: int,
+ ts_expires_ms: int,
+ key_json_bytes: bytes,
+ ) -> None:
"""Stores the JSON bytes for a set of keys from a server
The JSON should be signed by the originating server, the intermediate
server, and by this server. Updates the value for the
(server_name, key_id, from_server) triplet if one already existed.
Args:
- server_name (str): The name of the server.
- key_id (str): The identifer of the key this JSON is for.
- from_server (str): The server this JSON was fetched from.
- ts_now_ms (int): The time now in milliseconds.
- ts_valid_until_ms (int): The time when this json stops being valid.
- key_json (bytes): The encoded JSON.
+ server_name: The name of the server.
+ key_id: The identifer of the key this JSON is for.
+ from_server: The server this JSON was fetched from.
+ ts_now_ms: The time now in milliseconds.
+ ts_valid_until_ms: The time when this json stops being valid.
+ key_json_bytes: The encoded JSON.
"""
- return self.db_pool.simple_upsert(
+ await self.db_pool.simple_upsert(
table="server_keys_json",
keyvalues={
"server_name": server_name,
@@ -169,7 +180,9 @@ class KeyStore(SQLBaseStore):
desc="store_server_keys_json",
)
- def get_server_keys_json(self, server_keys):
+ async def get_server_keys_json(
+ self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
+ ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
"""Retrive the key json for a list of server_keys and key ids.
If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet entry will be an empty list.
@@ -178,8 +191,7 @@ class KeyStore(SQLBaseStore):
Args:
server_keys (list): List of (server_name, key_id, source) triplets.
Returns:
- Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
- Dict mapping (server_name, key_id, source) triplets to lists of dicts
+ A mapping from (server_name, key_id, source) triplets to a list of dicts
"""
def _get_server_keys_json_txn(txn):
@@ -205,6 +217,6 @@ class KeyStore(SQLBaseStore):
results[(server_name, key_id, from_server)] = rows
return results
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_server_keys_json", _get_server_keys_json_txn
)
|