From 2b35626b6b7aed52a626734a5a85fe77c847251d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 12 Sep 2023 11:08:04 +0100 Subject: Refactor storing of server keys (#16261) --- synapse/storage/databases/main/keys.py | 219 +++++++++++---------------------- 1 file changed, 72 insertions(+), 147 deletions(-) (limited to 'synapse/storage/databases/main') diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 57aa4921e1..41563371dc 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -16,14 +16,17 @@ import itertools import json import logging -from typing import Dict, Iterable, Mapping, Optional, Tuple +from typing import Dict, Iterable, Optional, Tuple +from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote from synapse.storage.types import Cursor +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -36,162 +39,84 @@ db_binary_type = memoryview class KeyStore(CacheInvalidationWorkerStore): """Persistence for signature verification keys""" - @cached() - def _get_server_signature_key( - self, server_name_and_key_id: Tuple[str, str] - ) -> FetchKeyResult: - raise NotImplementedError() - - @cachedList( - cached_method_name="_get_server_signature_key", - list_name="server_name_and_key_ids", - ) - async def get_server_signature_keys( - self, server_name_and_key_ids: Iterable[Tuple[str, str]] - ) -> Dict[Tuple[str, str], FetchKeyResult]: - """ - Args: - server_name_and_key_ids: - iterable of (server_name, key-id) tuples to fetch keys for - - Returns: - A map from (server_name, key_id) -> FetchKeyResult, or None if the - key is unknown - """ - keys = {} - - def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None: - """Processes a batch of keys to fetch, and adds the result to `keys`.""" - - # batch_iter always returns tuples so it's safe to do len(batch) - sql = """ - SELECT server_name, key_id, verify_key, ts_valid_until_ms - FROM server_signature_keys WHERE 1=0 - """ + " OR (server_name=? AND key_id=?)" * len( - batch - ) - - txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) - - for row in txn: - server_name, key_id, key_bytes, ts_valid_until_ms = row - - if ts_valid_until_ms is None: - # Old keys may be stored with a ts_valid_until_ms of null, - # in which case we treat this as if it was set to `0`, i.e. - # it won't match key requests that define a minimum - # `ts_valid_until_ms`. - ts_valid_until_ms = 0 - - keys[(server_name, key_id)] = FetchKeyResult( - verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), - valid_until_ts=ts_valid_until_ms, - ) - - def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: - for batch in batch_iter(server_name_and_key_ids, 50): - _get_keys(txn, batch) - return keys - - return await self.db_pool.runInteraction("get_server_signature_keys", _txn) - - async def store_server_signature_keys( + async def store_server_keys_response( self, + server_name: str, from_server: str, ts_added_ms: int, - verify_keys: Mapping[Tuple[str, str], FetchKeyResult], + verify_keys: Dict[str, FetchKeyResult], + response_json: JsonDict, ) -> None: - """Stores NACL verification keys for remote servers. + """Stores the keys for the given server that we got from `from_server`. + Args: - from_server: Where the verification keys were looked up - ts_added_ms: The time to record that the key was added - verify_keys: - keys to be stored. Each entry is a triplet of - (server_name, key_id, key). + server_name: The owner of the keys + from_server: Which server we got the keys from + ts_added_ms: When we're adding the keys + verify_keys: The decoded keys + response_json: The full *signed* response JSON that contains the keys. """ - key_values = [] - value_values = [] - invalidations = [] - for (server_name, key_id), fetch_result in verify_keys.items(): - key_values.append((server_name, key_id)) - value_values.append( - ( - from_server, - ts_added_ms, - fetch_result.valid_until_ts, - db_binary_type(fetch_result.verify_key.encode()), - ) - ) - # invalidate takes a tuple corresponding to the params of - # _get_server_signature_key. _get_server_signature_key only takes one - # param, which is itself the 2-tuple (server_name, key_id). - invalidations.append((server_name, key_id)) - await self.db_pool.simple_upsert_many( - table="server_signature_keys", - key_names=("server_name", "key_id"), - key_values=key_values, - value_names=( - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "verify_key", - ), - value_values=value_values, - desc="store_server_signature_keys", - ) + key_json_bytes = encode_canonical_json(response_json) + + def store_server_keys_response_txn(txn: LoggingTransaction) -> None: + self.db_pool.simple_upsert_many_txn( + txn, + table="server_signature_keys", + key_names=("server_name", "key_id"), + key_values=[(server_name, key_id) for key_id in verify_keys], + value_names=( + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "verify_key", + ), + value_values=[ + ( + from_server, + ts_added_ms, + fetch_result.valid_until_ts, + db_binary_type(fetch_result.verify_key.encode()), + ) + for fetch_result in verify_keys.values() + ], + ) - invalidate = self._get_server_signature_key.invalidate - for i in invalidations: - invalidate((i,)) + self.db_pool.simple_upsert_many_txn( + txn, + table="server_keys_json", + key_names=("server_name", "key_id", "from_server"), + key_values=[ + (server_name, key_id, from_server) for key_id in verify_keys + ], + value_names=( + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + value_values=[ + ( + ts_added_ms, + fetch_result.valid_until_ts, + db_binary_type(key_json_bytes), + ) + for fetch_result in verify_keys.values() + ], + ) - async def store_server_keys_json( - self, - server_name: str, - key_id: str, - from_server: str, - ts_now_ms: int, - ts_expires_ms: int, - key_json_bytes: bytes, - ) -> None: - """Stores the JSON bytes for a set of keys from a server - The JSON should be signed by the originating server, the intermediate - server, and by this server. Updates the value for the - (server_name, key_id, from_server) triplet if one already existed. - Args: - server_name: The name of the server. - key_id: The identifier of the key this JSON is for. - from_server: The server this JSON was fetched from. - ts_now_ms: The time now in milliseconds. - ts_valid_until_ms: The time when this json stops being valid. - key_json_bytes: The encoded JSON. - """ - await self.db_pool.simple_upsert( - table="server_keys_json", - keyvalues={ - "server_name": server_name, - "key_id": key_id, - "from_server": from_server, - }, - values={ - "server_name": server_name, - "key_id": key_id, - "from_server": from_server, - "ts_added_ms": ts_now_ms, - "ts_valid_until_ms": ts_expires_ms, - "key_json": db_binary_type(key_json_bytes), - }, - desc="store_server_keys_json", - ) + # invalidate takes a tuple corresponding to the params of + # _get_server_keys_json. _get_server_keys_json only takes one + # param, which is itself the 2-tuple (server_name, key_id). + for key_id in verify_keys: + self._invalidate_cache_and_stream( + txn, self._get_server_keys_json, ((server_name, key_id),) + ) + self._invalidate_cache_and_stream( + txn, self.get_server_key_json_for_remote, (server_name, key_id) + ) - # invalidate takes a tuple corresponding to the params of - # _get_server_keys_json. _get_server_keys_json only takes one - # param, which is itself the 2-tuple (server_name, key_id). - await self.invalidate_cache_and_stream( - "_get_server_keys_json", ((server_name, key_id),) - ) - await self.invalidate_cache_and_stream( - "get_server_key_json_for_remote", (server_name, key_id) + await self.db_pool.runInteraction( + "store_server_keys_response", store_server_keys_response_txn ) @cached() -- cgit 1.4.1