summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2023-10-27 15:00:30 +0100
committerRichard van der Hoff <richard@matrix.org>2023-10-27 15:00:30 +0100
commitdd45ba4d6781cd7ce57260797d6c08285ec03a97 (patch)
tree8257b49bef80786a5f9ddb959eeab021bf05ba82
parentFix cross-worker ratelimiting (#16558) (diff)
downloadsynapse-dd45ba4d6781cd7ce57260797d6c08285ec03a97.tar.xz
Fix types for OTK claims
We don't know for certain that keys will be `JsonDict`s -- indeed if the key is
not signed it will just be a string. Fix up the types to reflect this.
-rw-r--r--synapse/federation/federation_server.py4
-rw-r--r--synapse/handlers/appservice.py2
-rw-r--r--synapse/handlers/e2e_keys.py7
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py25
4 files changed, 24 insertions, 14 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 3b27925517..d4dbfd3dca 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -84,7 +84,7 @@ from synapse.replication.http.federation import (
 from synapse.storage.databases.main.lock import Lock
 from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
 from synapse.storage.roommember import MemberSummary
-from synapse.types import JsonDict, StateMap, get_domain_from_id
+from synapse.types import JsonDict, JsonSerializable, StateMap, get_domain_from_id
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
 from synapse.util.caches.response_cache import ResponseCache
@@ -1004,7 +1004,7 @@ class FederationServer(FederationBase):
             query, always_include_fallback_keys=always_include_fallback_keys
         )
 
-        json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+        json_result: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {}
         for result in results:
             for user_id, device_keys in result.items():
                 for device_id, keys in device_keys.items():
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 873dadc3bd..2893b55b92 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -861,7 +861,7 @@ class ApplicationServicesHandler:
 
         Returns:
             A tuple of:
-                A map of user ID -> a map device ID -> a map of key ID -> JSON.
+                A map of user ID -> a map device ID -> a map of key ID -> key.
 
                 A copy of the input which has not been fulfilled (either because
                 they are not appservice users or the appservice does not support
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 8c6432035d..9e51a34d70 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -32,6 +32,7 @@ from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
 from synapse.types import (
     JsonDict,
     JsonMapping,
+    JsonSerializable,
     UserID,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
@@ -560,7 +561,7 @@ class E2eKeysHandler:
         self,
         local_query: List[Tuple[str, str, str, int]],
         always_include_fallback_keys: bool,
-    ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
+    ) -> Iterable[Mapping[str, Mapping[str, Mapping[str, JsonSerializable]]]]:
         """Claim one time keys for local users.
 
         1. Attempt to claim OTKs from the database.
@@ -572,7 +573,7 @@ class E2eKeysHandler:
             always_include_fallback_keys: True to always include fallback keys.
 
         Returns:
-            An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+            An iterable of maps of user ID -> a map device ID -> a map of key ID -> key.
         """
 
         # Cap the number of OTKs that can be claimed at once to avoid abuse.
@@ -680,7 +681,7 @@ class E2eKeysHandler:
         )
 
         # A map of user ID -> device ID -> key ID -> key.
-        json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+        json_result: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {}
         for result in results:
             for user_id, device_keys in result.items():
                 for device_id, keys in device_keys.items():
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f70f95eeba..a8aa9ab198 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -52,7 +52,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import StreamIdGenerator
-from synapse.types import JsonDict, JsonMapping
+from synapse.types import JsonDict, JsonMapping, JsonSerializable
 from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.cancellation import cancellable
@@ -1112,7 +1112,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     async def claim_e2e_one_time_keys(
         self, query_list: Iterable[Tuple[str, str, str, int]]
     ) -> Tuple[
-        Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
+        Dict[str, Dict[str, Dict[str, JsonSerializable]]],
+        List[Tuple[str, str, str, int]],
     ]:
         """Take a list of one time keys out of the database.
 
@@ -1121,7 +1122,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         Returns:
             A tuple pf:
-                A map of user ID -> a map device ID -> a map of key ID -> JSON.
+                A map of user ID -> a map device ID -> a map of key ID -> key
 
                 A copy of the input which has not been fulfilled.
         """
@@ -1214,7 +1215,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                 (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
             ]
 
-        results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+        results: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {}
         missing: List[Tuple[str, str, str, int]] = []
         for user_id, device_id, algorithm, count in query_list:
             if self.database_engine.supports_returning:
@@ -1240,7 +1241,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                     device_id, {}
                 )
                 for claim_row in claim_rows:
-                    device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+                    # The shape of the key depends on the algorithm: it is a dict for
+                    # signed_curve25519, or a string for curve25519. In general, it
+                    # is whatever the client chose to upload, since we dont validate it.
+                    decoded_key: JsonSerializable = json_decoder.decode(claim_row[1])
+                    device_results[claim_row[0]] = decoded_key
             # Did we get enough OTKs?
             count -= len(claim_rows)
             if count:
@@ -1250,7 +1255,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
     async def claim_e2e_fallback_keys(
         self, query_list: Iterable[Tuple[str, str, str, bool]]
-    ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+    ) -> Dict[str, Dict[str, Dict[str, JsonSerializable]]]:
         """Take a list of fallback keys out of the database.
 
         Args:
@@ -1260,7 +1265,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         Returns:
             A map of user ID -> a map device ID -> a map of key ID -> JSON.
         """
-        results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+        results: Dict[str, Dict[str, Dict[str, JsonSerializable]]] = {}
         for user_id, device_id, algorithm, mark_as_used in query_list:
             row = await self.db_pool.simple_select_one(
                 table="e2e_fallback_keys_json",
@@ -1298,7 +1303,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                 )
 
             device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
-            device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
+            # The shape of the key depends on the algorithm: it is a dict for
+            # signed_curve25519, or a string for curve25519. In general, it
+            # is whatever the client chose to upload, since we dont validate it.
+            decoded_key: JsonSerializable = json_decoder.decode(key_json)
+            device_results[f"{algorithm}:{key_id}"] = decoded_key
 
         return results