summary refs log tree commit diff
path: root/synapse/federation/federation_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_client.py')
-rw-r--r--synapse/federation/federation_client.py49
1 files changed, 48 insertions, 1 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index ba34573d46..0b2d1a78f7 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -235,7 +235,10 @@ class FederationClient(FederationBase):
         )
 
     async def claim_client_keys(
-        self, destination: str, content: JsonDict, timeout: Optional[int]
+        self,
+        destination: str,
+        query: Dict[str, Dict[str, Dict[str, int]]],
+        timeout: Optional[int],
     ) -> JsonDict:
         """Claims one-time keys for a device hosted on a remote server.
 
@@ -247,6 +250,50 @@ class FederationClient(FederationBase):
             The JSON object from the response
         """
         sent_queries_counter.labels("client_one_time_keys").inc()
+
+        # Convert the query with counts into a stable and unstable query and check
+        # if attempting to claim more than 1 OTK.
+        content: Dict[str, Dict[str, str]] = {}
+        unstable_content: Dict[str, Dict[str, List[str]]] = {}
+        use_unstable = False
+        for user_id, one_time_keys in query.items():
+            for device_id, algorithms in one_time_keys.items():
+                if any(count > 1 for count in algorithms.values()):
+                    use_unstable = True
+                if algorithms:
+                    # For the stable query, choose only the first algorithm.
+                    content.setdefault(user_id, {})[device_id] = next(iter(algorithms))
+                    # For the unstable query, repeat each algorithm by count, then
+                    # splat those into chain to get a flattened list of all algorithms.
+                    #
+                    # Converts from {"algo1": 2, "algo2": 2} to ["algo1", "algo1", "algo2"].
+                    unstable_content.setdefault(user_id, {})[device_id] = list(
+                        itertools.chain(
+                            *(
+                                itertools.repeat(algorithm, count)
+                                for algorithm, count in algorithms.items()
+                            )
+                        )
+                    )
+
+        if use_unstable:
+            try:
+                return await self.transport_layer.claim_client_keys_unstable(
+                    destination, unstable_content, timeout
+                )
+            except HttpResponseException as e:
+                # If an error is received that is due to an unrecognised endpoint,
+                # fallback to the v1 endpoint. Otherwise, consider it a legitimate error
+                # and raise.
+                if not is_unknown_endpoint(e):
+                    raise
+
+            logger.debug(
+                "Couldn't claim client keys with the unstable API, falling back to the v1 API"
+            )
+        else:
+            logger.debug("Skipping unstable claim client keys API")
+
         return await self.transport_layer.claim_client_keys(
             destination, content, timeout
         )