diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d1ab95126c..24741b667b 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -564,7 +564,7 @@ class E2eKeysHandler:
async def claim_local_one_time_keys(
self,
- local_query: List[Tuple[str, str, str]],
+ local_query: List[Tuple[str, str, str, int]],
always_include_fallback_keys: bool,
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
"""Claim one time keys for local users.
@@ -581,6 +581,12 @@ class E2eKeysHandler:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
"""
+ # Cap the number of OTKs that can be claimed at once to avoid abuse.
+ local_query = [
+ (user_id, device_id, algorithm, min(count, 5))
+ for user_id, device_id, algorithm, count in local_query
+ ]
+
otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
# If the application services have not provided any keys via the C-S
@@ -607,7 +613,7 @@ class E2eKeysHandler:
# from the appservice for that user ID / device ID. If it is found,
# check if any of the keys match the requested algorithm & are a
# fallback key.
- for user_id, device_id, algorithm in local_query:
+ for user_id, device_id, algorithm, _count in local_query:
# Check if the appservice responded for this query.
as_result = appservice_results.get(user_id, {}).get(device_id, {})
found_otk = False
@@ -630,13 +636,17 @@ class E2eKeysHandler:
.get(device_id, {})
.keys()
)
+ # Note that it doesn't make sense to request more than 1 fallback key
+ # per (user_id, device_id, algorithm).
fallback_query.append((user_id, device_id, algorithm, mark_as_used))
else:
# All fallback keys get marked as used.
fallback_query = [
+ # Note that it doesn't make sense to request more than 1 fallback key
+ # per (user_id, device_id, algorithm).
(user_id, device_id, algorithm, True)
- for user_id, device_id, algorithm in not_found
+ for user_id, device_id, algorithm, count in not_found
]
# For each user that does not have a one-time keys available, see if
@@ -650,18 +660,19 @@ class E2eKeysHandler:
@trace
async def claim_one_time_keys(
self,
- query: Dict[str, Dict[str, Dict[str, str]]],
+ query: Dict[str, Dict[str, Dict[str, int]]],
timeout: Optional[int],
always_include_fallback_keys: bool,
) -> JsonDict:
- local_query: List[Tuple[str, str, str]] = []
- remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
+ local_query: List[Tuple[str, str, str, int]] = []
+ remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
- for user_id, one_time_keys in query.get("one_time_keys", {}).items():
+ for user_id, one_time_keys in query.items():
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
- for device_id, algorithm in one_time_keys.items():
- local_query.append((user_id, device_id, algorithm))
+ for device_id, algorithms in one_time_keys.items():
+ for algorithm, count in algorithms.items():
+ local_query.append((user_id, device_id, algorithm, count))
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = one_time_keys
@@ -692,7 +703,7 @@ class E2eKeysHandler:
device_keys = remote_queries[destination]
try:
remote_result = await self.federation.claim_client_keys(
- destination, {"one_time_keys": device_keys}, timeout=timeout
+ destination, device_keys, timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
|