diff options
author | Richard van der Hoff <richard@matrix.org> | 2023-10-27 15:00:30 +0100 |
---|---|---|
committer | Richard van der Hoff <richard@matrix.org> | 2023-10-27 15:00:30 +0100 |
commit | dd45ba4d6781cd7ce57260797d6c08285ec03a97 (patch) | |
tree | 8257b49bef80786a5f9ddb959eeab021bf05ba82 | |
parent | Fix cross-worker ratelimiting (#16558) (diff) | |
download | synapse-dd45ba4d6781cd7ce57260797d6c08285ec03a97.tar.xz |
Fix types for OTK claims
We don't know for certain that keys will be `JsonDict`s -- indeed if the key is not signed it will just be a string. Fix up the types to reflect this.
-rw-r--r-- | synapse/federation/federation_server.py | 4 | ||||
-rw-r--r-- | synapse/handlers/appservice.py | 2 | ||||
-rw-r--r-- | synapse/handlers/e2e_keys.py | 7 | ||||
-rw-r--r-- | synapse/storage/databases/main/end_to_end_keys.py | 25 |
4 files changed, 24 insertions, 14 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 3b27925517..d4dbfd3dca 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -84,7 +84,7 @@ from synapse.replication.http.federation import ( from synapse.storage.databases.main.lock import Lock from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.roommember import MemberSummary -from synapse.types import JsonDict, StateMap, get_domain_from_id +from synapse.types import JsonDict, JsonSerializable, StateMap, get_domain_from_id from synapse.util import unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results from synapse.util.caches.response_cache import ResponseCache @@ -1004,7 +1004,7 @@ class FederationServer(FederationBase): query, always_include_fallback_keys=always_include_fallback_keys ) - json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + json_result: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {} for result in results: for user_id, device_keys in result.items(): for device_id, keys in device_keys.items(): diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 873dadc3bd..2893b55b92 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -861,7 +861,7 @@ class ApplicationServicesHandler: Returns: A tuple of: - A map of user ID -> a map device ID -> a map of key ID -> JSON. + A map of user ID -> a map device ID -> a map of key ID -> key. A copy of the input which has not been fulfilled (either because they are not appservice users or the appservice does not support diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 8c6432035d..9e51a34d70 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -32,6 +32,7 @@ from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace from synapse.types import ( JsonDict, JsonMapping, + JsonSerializable, UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -560,7 +561,7 @@ class E2eKeysHandler: self, local_query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool, - ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]: + ) -> Iterable[Mapping[str, Mapping[str, Mapping[str, JsonSerializable]]]]: """Claim one time keys for local users. 1. Attempt to claim OTKs from the database. @@ -572,7 +573,7 @@ class E2eKeysHandler: always_include_fallback_keys: True to always include fallback keys. Returns: - An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. + An iterable of maps of user ID -> a map device ID -> a map of key ID -> key. """ # Cap the number of OTKs that can be claimed at once to avoid abuse. @@ -680,7 +681,7 @@ class E2eKeysHandler: ) # A map of user ID -> device ID -> key ID -> key. - json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + json_result: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {} for result in results: for user_id, device_keys in result.items(): for device_id, keys in device_keys.items(): diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index f70f95eeba..a8aa9ab198 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -52,7 +52,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.types import JsonDict, JsonMapping +from synapse.types import JsonDict, JsonMapping, JsonSerializable from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.cancellation import cancellable @@ -1112,7 +1112,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker async def claim_e2e_one_time_keys( self, query_list: Iterable[Tuple[str, str, str, int]] ) -> Tuple[ - Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]] + Dict[str, Dict[str, Dict[str, JsonSerializable]]], + List[Tuple[str, str, str, int]], ]: """Take a list of one time keys out of the database. @@ -1121,7 +1122,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker Returns: A tuple pf: - A map of user ID -> a map device ID -> a map of key ID -> JSON. + A map of user ID -> a map device ID -> a map of key ID -> key A copy of the input which has not been fulfilled. """ @@ -1214,7 +1215,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows ] - results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + results: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {} missing: List[Tuple[str, str, str, int]] = [] for user_id, device_id, algorithm, count in query_list: if self.database_engine.supports_returning: @@ -1240,7 +1241,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker device_id, {} ) for claim_row in claim_rows: - device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) + # The shape of the key depends on the algorithm: it is a dict for + # signed_curve25519, or a string for curve25519. In general, it + # is whatever the client chose to upload, since we dont validate it. + decoded_key: JsonSerializable = json_decoder.decode(claim_row[1]) + device_results[claim_row[0]] = decoded_key # Did we get enough OTKs? count -= len(claim_rows) if count: @@ -1250,7 +1255,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker async def claim_e2e_fallback_keys( self, query_list: Iterable[Tuple[str, str, str, bool]] - ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]: + ) -> Dict[str, Dict[str, Dict[str, JsonSerializable]]]: """Take a list of fallback keys out of the database. Args: @@ -1260,7 +1265,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker Returns: A map of user ID -> a map device ID -> a map of key ID -> JSON. """ - results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + results: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {} for user_id, device_id, algorithm, mark_as_used in query_list: row = await self.db_pool.simple_select_one( table="e2e_fallback_keys_json", @@ -1298,7 +1303,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) - device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json) + # The shape of the key depends on the algorithm: it is a dict for + # signed_curve25519, or a string for curve25519. In general, it + # is whatever the client chose to upload, since we dont validate it. + decoded_key: JsonSerializable = json_decoder.decode(key_json) + device_results[f"{algorithm}:{key_id}"] = decoded_key return results |