summary refs log tree commit diff
path: root/synapse/storage/databases/main/keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/keys.py')
-rw-r--r--synapse/storage/databases/main/keys.py76
1 files changed, 44 insertions, 32 deletions
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py

index 384e9c5eb0..ad43bb05ab 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py
@@ -16,6 +16,7 @@ import itertools import logging +from typing import Dict, Iterable, List, Optional, Tuple from signedjson.key import decode_verify_key_bytes @@ -41,16 +42,17 @@ class KeyStore(SQLBaseStore): @cachedList( cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" ) - def get_server_verify_keys(self, server_name_and_key_ids): + async def get_server_verify_keys( + self, server_name_and_key_ids: Iterable[Tuple[str, str]] + ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]: """ Args: - server_name_and_key_ids (iterable[Tuple[str, str]]): + server_name_and_key_ids: iterable of (server_name, key-id) tuples to fetch keys for Returns: - Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: - map from (server_name, key_id) -> FetchKeyResult, or None if the key is - unknown + A map from (server_name, key_id) -> FetchKeyResult, or None if the + key is unknown """ keys = {} @@ -86,14 +88,19 @@ class KeyStore(SQLBaseStore): _get_keys(txn, batch) return keys - return self.db_pool.runInteraction("get_server_verify_keys", _txn) + return await self.db_pool.runInteraction("get_server_verify_keys", _txn) - def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): + async def store_server_verify_keys( + self, + from_server: str, + ts_added_ms: int, + verify_keys: Iterable[Tuple[str, str, FetchKeyResult]], + ) -> None: """Stores NACL verification keys for remote servers. Args: - from_server (str): Where the verification keys were looked up - ts_added_ms (int): The time to record that the key was added - verify_keys (iterable[tuple[str, str, FetchKeyResult]]): + from_server: Where the verification keys were looked up + ts_added_ms: The time to record that the key was added + verify_keys: keys to be stored. Each entry is a triplet of (server_name, key_id, key). """ @@ -115,13 +122,7 @@ class KeyStore(SQLBaseStore): # param, which is itself the 2-tuple (server_name, key_id). invalidations.append((server_name, key_id)) - def _invalidate(res): - f = self._get_server_verify_key.invalidate - for i in invalidations: - f((i,)) - return res - - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "store_server_verify_keys", self.db_pool.simple_upsert_many_txn, table="server_signature_keys", @@ -134,24 +135,34 @@ class KeyStore(SQLBaseStore): "verify_key", ), value_values=value_values, - ).addCallback(_invalidate) + ) - def store_server_keys_json( - self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes - ): + invalidate = self._get_server_verify_key.invalidate + for i in invalidations: + invalidate((i,)) + + async def store_server_keys_json( + self, + server_name: str, + key_id: str, + from_server: str, + ts_now_ms: int, + ts_expires_ms: int, + key_json_bytes: bytes, + ) -> None: """Stores the JSON bytes for a set of keys from a server The JSON should be signed by the originating server, the intermediate server, and by this server. Updates the value for the (server_name, key_id, from_server) triplet if one already existed. Args: - server_name (str): The name of the server. - key_id (str): The identifer of the key this JSON is for. - from_server (str): The server this JSON was fetched from. - ts_now_ms (int): The time now in milliseconds. - ts_valid_until_ms (int): The time when this json stops being valid. - key_json (bytes): The encoded JSON. + server_name: The name of the server. + key_id: The identifer of the key this JSON is for. + from_server: The server this JSON was fetched from. + ts_now_ms: The time now in milliseconds. + ts_valid_until_ms: The time when this json stops being valid. + key_json_bytes: The encoded JSON. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="server_keys_json", keyvalues={ "server_name": server_name, @@ -169,7 +180,9 @@ class KeyStore(SQLBaseStore): desc="store_server_keys_json", ) - def get_server_keys_json(self, server_keys): + async def get_server_keys_json( + self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] + ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]: """Retrive the key json for a list of server_keys and key ids. If no keys are found for a given server, key_id and source then that server, key_id, and source triplet entry will be an empty list. @@ -178,8 +191,7 @@ class KeyStore(SQLBaseStore): Args: server_keys (list): List of (server_name, key_id, source) triplets. Returns: - Deferred[dict[Tuple[str, str, str|None], list[dict]]]: - Dict mapping (server_name, key_id, source) triplets to lists of dicts + A mapping from (server_name, key_id, source) triplets to a list of dicts """ def _get_server_keys_json_txn(txn): @@ -205,6 +217,6 @@ class KeyStore(SQLBaseStore): results[(server_name, key_id, from_server)] = rows return results - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_server_keys_json", _get_server_keys_json_txn )