diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index ba34573d46..0b2d1a78f7 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -235,7 +235,10 @@ class FederationClient(FederationBase):
)
async def claim_client_keys(
- self, destination: str, content: JsonDict, timeout: Optional[int]
+ self,
+ destination: str,
+ query: Dict[str, Dict[str, Dict[str, int]]],
+ timeout: Optional[int],
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.
@@ -247,6 +250,50 @@ class FederationClient(FederationBase):
The JSON object from the response
"""
sent_queries_counter.labels("client_one_time_keys").inc()
+
+ # Convert the query with counts into a stable and unstable query and check
+ # if attempting to claim more than 1 OTK.
+ content: Dict[str, Dict[str, str]] = {}
+ unstable_content: Dict[str, Dict[str, List[str]]] = {}
+ use_unstable = False
+ for user_id, one_time_keys in query.items():
+ for device_id, algorithms in one_time_keys.items():
+ if any(count > 1 for count in algorithms.values()):
+ use_unstable = True
+ if algorithms:
+ # For the stable query, choose only the first algorithm.
+ content.setdefault(user_id, {})[device_id] = next(iter(algorithms))
+ # For the unstable query, repeat each algorithm by count, then
+ # splat those into chain to get a flattened list of all algorithms.
+ #
+ # Converts from {"algo1": 2, "algo2": 2} to ["algo1", "algo1", "algo2"].
+ unstable_content.setdefault(user_id, {})[device_id] = list(
+ itertools.chain(
+ *(
+ itertools.repeat(algorithm, count)
+ for algorithm, count in algorithms.items()
+ )
+ )
+ )
+
+ if use_unstable:
+ try:
+ return await self.transport_layer.claim_client_keys_unstable(
+ destination, unstable_content, timeout
+ )
+ except HttpResponseException as e:
+ # If an error is received that is due to an unrecognised endpoint,
+ # fallback to the v1 endpoint. Otherwise, consider it a legitimate error
+ # and raise.
+ if not is_unknown_endpoint(e):
+ raise
+
+ logger.debug(
+ "Couldn't claim client keys with the unstable API, falling back to the v1 API"
+ )
+ else:
+ logger.debug("Skipping unstable claim client keys API")
+
return await self.transport_layer.claim_client_keys(
destination, content, timeout
)
|