diff options
Diffstat (limited to 'synapse/handlers/e2e_keys.py')
-rw-r--r-- | synapse/handlers/e2e_keys.py | 31 |
1 files changed, 21 insertions, 10 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d1ab95126c..24741b667b 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -564,7 +564,7 @@ class E2eKeysHandler: async def claim_local_one_time_keys( self, - local_query: List[Tuple[str, str, str]], + local_query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool, ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]: """Claim one time keys for local users. @@ -581,6 +581,12 @@ class E2eKeysHandler: An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. """ + # Cap the number of OTKs that can be claimed at once to avoid abuse. + local_query = [ + (user_id, device_id, algorithm, min(count, 5)) + for user_id, device_id, algorithm, count in local_query + ] + otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query) # If the application services have not provided any keys via the C-S @@ -607,7 +613,7 @@ class E2eKeysHandler: # from the appservice for that user ID / device ID. If it is found, # check if any of the keys match the requested algorithm & are a # fallback key. - for user_id, device_id, algorithm in local_query: + for user_id, device_id, algorithm, _count in local_query: # Check if the appservice responded for this query. as_result = appservice_results.get(user_id, {}).get(device_id, {}) found_otk = False @@ -630,13 +636,17 @@ class E2eKeysHandler: .get(device_id, {}) .keys() ) + # Note that it doesn't make sense to request more than 1 fallback key + # per (user_id, device_id, algorithm). fallback_query.append((user_id, device_id, algorithm, mark_as_used)) else: # All fallback keys get marked as used. fallback_query = [ + # Note that it doesn't make sense to request more than 1 fallback key + # per (user_id, device_id, algorithm). (user_id, device_id, algorithm, True) - for user_id, device_id, algorithm in not_found + for user_id, device_id, algorithm, count in not_found ] # For each user that does not have a one-time keys available, see if @@ -650,18 +660,19 @@ class E2eKeysHandler: @trace async def claim_one_time_keys( self, - query: Dict[str, Dict[str, Dict[str, str]]], + query: Dict[str, Dict[str, Dict[str, int]]], timeout: Optional[int], always_include_fallback_keys: bool, ) -> JsonDict: - local_query: List[Tuple[str, str, str]] = [] - remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {} + local_query: List[Tuple[str, str, str, int]] = [] + remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {} - for user_id, one_time_keys in query.get("one_time_keys", {}).items(): + for user_id, one_time_keys in query.items(): # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): - for device_id, algorithm in one_time_keys.items(): - local_query.append((user_id, device_id, algorithm)) + for device_id, algorithms in one_time_keys.items(): + for algorithm, count in algorithms.items(): + local_query.append((user_id, device_id, algorithm, count)) else: domain = get_domain_from_id(user_id) remote_queries.setdefault(domain, {})[user_id] = one_time_keys @@ -692,7 +703,7 @@ class E2eKeysHandler: device_keys = remote_queries[destination] try: remote_result = await self.federation.claim_client_keys( - destination, {"one_time_keys": device_keys}, timeout=timeout + destination, device_keys, timeout=timeout ) for user_id, keys in remote_result["one_time_keys"].items(): if user_id in device_keys: |