summary refs log tree commit diff
path: root/synapse/handlers/e2e_keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/e2e_keys.py')
-rw-r--r--synapse/handlers/e2e_keys.py31
1 files changed, 21 insertions, 10 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d1ab95126c..24741b667b 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -564,7 +564,7 @@ class E2eKeysHandler:
 
     async def claim_local_one_time_keys(
         self,
-        local_query: List[Tuple[str, str, str]],
+        local_query: List[Tuple[str, str, str, int]],
         always_include_fallback_keys: bool,
     ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
         """Claim one time keys for local users.
@@ -581,6 +581,12 @@ class E2eKeysHandler:
             An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
         """
 
+        # Cap the number of OTKs that can be claimed at once to avoid abuse.
+        local_query = [
+            (user_id, device_id, algorithm, min(count, 5))
+            for user_id, device_id, algorithm, count in local_query
+        ]
+
         otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
 
         # If the application services have not provided any keys via the C-S
@@ -607,7 +613,7 @@ class E2eKeysHandler:
             # from the appservice for that user ID / device ID. If it is found,
             # check if any of the keys match the requested algorithm & are a
             # fallback key.
-            for user_id, device_id, algorithm in local_query:
+            for user_id, device_id, algorithm, _count in local_query:
                 # Check if the appservice responded for this query.
                 as_result = appservice_results.get(user_id, {}).get(device_id, {})
                 found_otk = False
@@ -630,13 +636,17 @@ class E2eKeysHandler:
                         .get(device_id, {})
                         .keys()
                     )
+                    # Note that it doesn't make sense to request more than 1 fallback key
+                    # per (user_id, device_id, algorithm).
                     fallback_query.append((user_id, device_id, algorithm, mark_as_used))
 
         else:
             # All fallback keys get marked as used.
             fallback_query = [
+                # Note that it doesn't make sense to request more than 1 fallback key
+                # per (user_id, device_id, algorithm).
                 (user_id, device_id, algorithm, True)
-                for user_id, device_id, algorithm in not_found
+                for user_id, device_id, algorithm, count in not_found
             ]
 
         # For each user that does not have a one-time keys available, see if
@@ -650,18 +660,19 @@ class E2eKeysHandler:
     @trace
     async def claim_one_time_keys(
         self,
-        query: Dict[str, Dict[str, Dict[str, str]]],
+        query: Dict[str, Dict[str, Dict[str, int]]],
         timeout: Optional[int],
         always_include_fallback_keys: bool,
     ) -> JsonDict:
-        local_query: List[Tuple[str, str, str]] = []
-        remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
+        local_query: List[Tuple[str, str, str, int]] = []
+        remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
 
-        for user_id, one_time_keys in query.get("one_time_keys", {}).items():
+        for user_id, one_time_keys in query.items():
             # we use UserID.from_string to catch invalid user ids
             if self.is_mine(UserID.from_string(user_id)):
-                for device_id, algorithm in one_time_keys.items():
-                    local_query.append((user_id, device_id, algorithm))
+                for device_id, algorithms in one_time_keys.items():
+                    for algorithm, count in algorithms.items():
+                        local_query.append((user_id, device_id, algorithm, count))
             else:
                 domain = get_domain_from_id(user_id)
                 remote_queries.setdefault(domain, {})[user_id] = one_time_keys
@@ -692,7 +703,7 @@ class E2eKeysHandler:
             device_keys = remote_queries[destination]
             try:
                 remote_result = await self.federation.claim_client_keys(
-                    destination, {"one_time_keys": device_keys}, timeout=timeout
+                    destination, device_keys, timeout=timeout
                 )
                 for user_id, keys in remote_result["one_time_keys"].items():
                     if user_id in device_keys: