summary refs log tree commit diff
path: root/synapse/storage/databases/main/keys.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-10-11 13:24:56 -0400
committerGitHub <noreply@github.com>2023-10-11 13:24:56 -0400
commita4904dcb04b31ce8ed0deaa2c5c80657780f6618 (patch)
tree179aedc3390ce9cafcd5f3d78a20644ab8d3dd87 /synapse/storage/databases/main/keys.py
parentHandle content types with parameters. (#16440) (diff)
downloadsynapse-a4904dcb04b31ce8ed0deaa2c5c80657780f6618.tar.xz
Convert simple_select_many_batch, simple_select_many_txn to tuples. (#16444)
Diffstat (limited to 'synapse/storage/databases/main/keys.py')
-rw-r--r--synapse/storage/databases/main/keys.py46
1 files changed, 26 insertions, 20 deletions
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py

index 889c578b9c..ea797864b9 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py
@@ -16,7 +16,7 @@ import itertools import json import logging -from typing import Dict, Iterable, Mapping, Optional, Tuple +from typing import Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes @@ -205,35 +205,39 @@ class KeyStore(CacheInvalidationWorkerStore): If we have multiple entries for a given key ID, returns the most recent. """ - rows = await self.db_pool.simple_select_many_batch( - table="server_keys_json", - column="key_id", - iterable=key_ids, - keyvalues={"server_name": server_name}, - retcols=( - "key_id", - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "key_json", + rows = cast( + List[Tuple[str, str, int, int, Union[bytes, memoryview]]], + await self.db_pool.simple_select_many_batch( + table="server_keys_json", + column="key_id", + iterable=key_ids, + keyvalues={"server_name": server_name}, + retcols=( + "key_id", + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + desc="get_server_keys_json_for_remote", ), - desc="get_server_keys_json_for_remote", ) if not rows: return {} - # We sort the rows so that the most recently added entry is picked up. - rows.sort(key=lambda r: r["ts_added_ms"]) + # We sort the rows by ts_added_ms so that the most recently added entry + # will stomp over older entries in the dictionary. + rows.sort(key=lambda r: r[2]) return { - row["key_id"]: FetchKeyResultForRemote( + key_id: FetchKeyResultForRemote( # Cast to bytes since postgresql returns a memoryview. - key_json=bytes(row["key_json"]), - valid_until_ts=row["ts_valid_until_ms"], - added_ts=row["ts_added_ms"], + key_json=bytes(key_json), + valid_until_ts=ts_valid_until_ms, + added_ts=ts_added_ms, ) - for row in rows + for key_id, from_server, ts_added_ms, ts_valid_until_ms, key_json in rows } async def get_all_server_keys_json_for_remote( @@ -260,6 +264,8 @@ class KeyStore(CacheInvalidationWorkerStore): if not rows: return {} + # We sort the rows by ts_added_ms so that the most recently added entry + # will stomp over older entries in the dictionary. rows.sort(key=lambda r: r["ts_added_ms"]) return {