summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_client.py49
-rw-r--r--synapse/federation/federation_server.py7
-rw-r--r--synapse/federation/transport/client.py49
-rw-r--r--synapse/federation/transport/server/federation.py25
4 files changed, 116 insertions, 14 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index ba34573d46..0b2d1a78f7 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -235,7 +235,10 @@ class FederationClient(FederationBase):
         )
 
     async def claim_client_keys(
-        self, destination: str, content: JsonDict, timeout: Optional[int]
+        self,
+        destination: str,
+        query: Dict[str, Dict[str, Dict[str, int]]],
+        timeout: Optional[int],
     ) -> JsonDict:
         """Claims one-time keys for a device hosted on a remote server.
 
@@ -247,6 +250,50 @@ class FederationClient(FederationBase):
             The JSON object from the response
         """
         sent_queries_counter.labels("client_one_time_keys").inc()
+
+        # Convert the query with counts into a stable and unstable query and check
+        # if attempting to claim more than 1 OTK.
+        content: Dict[str, Dict[str, str]] = {}
+        unstable_content: Dict[str, Dict[str, List[str]]] = {}
+        use_unstable = False
+        for user_id, one_time_keys in query.items():
+            for device_id, algorithms in one_time_keys.items():
+                if any(count > 1 for count in algorithms.values()):
+                    use_unstable = True
+                if algorithms:
+                    # For the stable query, choose only the first algorithm.
+                    content.setdefault(user_id, {})[device_id] = next(iter(algorithms))
+                    # For the unstable query, repeat each algorithm by count, then
+                    # splat those into chain to get a flattened list of all algorithms.
+                    #
+                    # Converts from {"algo1": 2, "algo2": 2} to ["algo1", "algo1", "algo2"].
+                    unstable_content.setdefault(user_id, {})[device_id] = list(
+                        itertools.chain(
+                            *(
+                                itertools.repeat(algorithm, count)
+                                for algorithm, count in algorithms.items()
+                            )
+                        )
+                    )
+
+        if use_unstable:
+            try:
+                return await self.transport_layer.claim_client_keys_unstable(
+                    destination, unstable_content, timeout
+                )
+            except HttpResponseException as e:
+                # If an error is received that is due to an unrecognised endpoint,
+                # fallback to the v1 endpoint. Otherwise, consider it a legitimate error
+                # and raise.
+                if not is_unknown_endpoint(e):
+                    raise
+
+            logger.debug(
+                "Couldn't claim client keys with the unstable API, falling back to the v1 API"
+            )
+        else:
+            logger.debug("Skipping unstable claim client keys API")
+
         return await self.transport_layer.claim_client_keys(
             destination, content, timeout
         )
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index c618f3d7a6..ca43c7bfc0 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1005,13 +1005,8 @@ class FederationServer(FederationBase):
 
     @trace
     async def on_claim_client_keys(
-        self, origin: str, content: JsonDict, always_include_fallback_keys: bool
+        self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool
     ) -> Dict[str, Any]:
-        query = []
-        for user_id, device_keys in content.get("one_time_keys", {}).items():
-            for device_id, algorithm in device_keys.items():
-                query.append((user_id, device_id, algorithm))
-
         log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
         results = await self._e2e_keys_handler.claim_local_one_time_keys(
             query, always_include_fallback_keys=always_include_fallback_keys
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index bedbd23ded..bc70b94f68 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -650,10 +650,10 @@ class TransportLayerClient:
 
         Response:
             {
-              "device_keys": {
+              "one_time_keys": {
                 "<user_id>": {
                   "<device_id>": {
-                    "<algorithm>:<key_id>": "<key_base64>"
+                    "<algorithm>:<key_id>": <OTK JSON>
                   }
                 }
               }
@@ -669,7 +669,50 @@ class TransportLayerClient:
         path = _create_v1_path("/user/keys/claim")
 
         return await self.client.post_json(
-            destination=destination, path=path, data=query_content, timeout=timeout
+            destination=destination,
+            path=path,
+            data={"one_time_keys": query_content},
+            timeout=timeout,
+        )
+
+    async def claim_client_keys_unstable(
+        self, destination: str, query_content: JsonDict, timeout: Optional[int]
+    ) -> JsonDict:
+        """Claim one-time keys for a list of devices hosted on a remote server.
+
+        Request:
+            {
+              "one_time_keys": {
+                "<user_id>": {
+                  "<device_id>": {"<algorithm>": <count>}
+                }
+              }
+            }
+
+        Response:
+            {
+              "one_time_keys": {
+                "<user_id>": {
+                  "<device_id>": {
+                    "<algorithm>:<key_id>": <OTK JSON>
+                  }
+                }
+              }
+            }
+
+        Args:
+            destination: The server to query.
+            query_content: The user ids to query.
+        Returns:
+            A dict containing the one-time keys.
+        """
+        path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/user/keys/claim")
+
+        return await self.client.post_json(
+            destination=destination,
+            path=path,
+            data={"one_time_keys": query_content},
+            timeout=timeout,
         )
 
     async def get_missing_events(
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index e2340d70d5..36b0362504 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -12,6 +12,7 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 import logging
+from collections import Counter
 from typing import (
     TYPE_CHECKING,
     Dict,
@@ -577,16 +578,23 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
     async def on_POST(
         self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
     ) -> Tuple[int, JsonDict]:
+        # Generate a count for each algorithm, which is hard-coded to 1.
+        key_query: List[Tuple[str, str, str, int]] = []
+        for user_id, device_keys in content.get("one_time_keys", {}).items():
+            for device_id, algorithm in device_keys.items():
+                key_query.append((user_id, device_id, algorithm, 1))
+
         response = await self.handler.on_claim_client_keys(
-            origin, content, always_include_fallback_keys=False
+            key_query, always_include_fallback_keys=False
         )
         return 200, response
 
 
 class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
     """
-    Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
-    always includes fallback keys in the response.
+    Identical to the stable endpoint (FederationClientKeysClaimServlet) except
+    it allows for querying for multiple OTKs at once and always includes fallback
+    keys in the response.
     """
 
     PREFIX = FEDERATION_UNSTABLE_PREFIX
@@ -596,8 +604,16 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
     async def on_POST(
         self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
     ) -> Tuple[int, JsonDict]:
+        # Generate a count for each algorithm.
+        key_query: List[Tuple[str, str, str, int]] = []
+        for user_id, device_keys in content.get("one_time_keys", {}).items():
+            for device_id, algorithms in device_keys.items():
+                counts = Counter(algorithms)
+                for algorithm, count in counts.items():
+                    key_query.append((user_id, device_id, algorithm, count))
+
         response = await self.handler.on_claim_client_keys(
-            origin, content, always_include_fallback_keys=True
+            key_query, always_include_fallback_keys=True
         )
         return 200, response
 
@@ -805,6 +821,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
     FederationClientKeysQueryServlet,
     FederationUserDevicesQueryServlet,
     FederationClientKeysClaimServlet,
+    FederationUnstableClientKeysClaimServlet,
     FederationThirdPartyInviteExchangeServlet,
     On3pidBindServlet,
     FederationVersionServlet,