summary refs log tree commit diff
path: root/synapse/federation/transport
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-04-27 12:57:46 -0400
committerGitHub <noreply@github.com>2023-04-27 12:57:46 -0400
commit57aeeb308b39c4fd455682966eabc9c0fa17c65d (patch)
tree3b59e2a367f7894a2adfca66c6579fe317723a39 /synapse/federation/transport
parentAdd type hints to schema deltas (#15497) (diff)
downloadsynapse-57aeeb308b39c4fd455682966eabc9c0fa17c65d.tar.xz
Add support for claiming multiple OTKs at once. (#15468)
MSC3983 provides a way to request multiple OTKs at once from appservices,
this extends this concept to the Client-Server API.

Note that this will likely be spit out into a separate MSC, but is currently part of
MSC3983.
Diffstat (limited to 'synapse/federation/transport')
-rw-r--r--synapse/federation/transport/client.py49
-rw-r--r--synapse/federation/transport/server/federation.py25
2 files changed, 67 insertions, 7 deletions
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index bedbd23ded..bc70b94f68 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -650,10 +650,10 @@ class TransportLayerClient:
 
         Response:
             {
-              "device_keys": {
+              "one_time_keys": {
                 "<user_id>": {
                   "<device_id>": {
-                    "<algorithm>:<key_id>": "<key_base64>"
+                    "<algorithm>:<key_id>": <OTK JSON>
                   }
                 }
               }
@@ -669,7 +669,50 @@ class TransportLayerClient:
         path = _create_v1_path("/user/keys/claim")
 
         return await self.client.post_json(
-            destination=destination, path=path, data=query_content, timeout=timeout
+            destination=destination,
+            path=path,
+            data={"one_time_keys": query_content},
+            timeout=timeout,
+        )
+
+    async def claim_client_keys_unstable(
+        self, destination: str, query_content: JsonDict, timeout: Optional[int]
+    ) -> JsonDict:
+        """Claim one-time keys for a list of devices hosted on a remote server.
+
+        Request:
+            {
+              "one_time_keys": {
+                "<user_id>": {
+                  "<device_id>": {"<algorithm>": <count>}
+                }
+              }
+            }
+
+        Response:
+            {
+              "one_time_keys": {
+                "<user_id>": {
+                  "<device_id>": {
+                    "<algorithm>:<key_id>": <OTK JSON>
+                  }
+                }
+              }
+            }
+
+        Args:
+            destination: The server to query.
+            query_content: The user ids to query.
+        Returns:
+            A dict containing the one-time keys.
+        """
+        path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/user/keys/claim")
+
+        return await self.client.post_json(
+            destination=destination,
+            path=path,
+            data={"one_time_keys": query_content},
+            timeout=timeout,
         )
 
     async def get_missing_events(
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index e2340d70d5..36b0362504 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -12,6 +12,7 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 import logging
+from collections import Counter
 from typing import (
     TYPE_CHECKING,
     Dict,
@@ -577,16 +578,23 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
     async def on_POST(
         self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
     ) -> Tuple[int, JsonDict]:
+        # Generate a count for each algorithm, which is hard-coded to 1.
+        key_query: List[Tuple[str, str, str, int]] = []
+        for user_id, device_keys in content.get("one_time_keys", {}).items():
+            for device_id, algorithm in device_keys.items():
+                key_query.append((user_id, device_id, algorithm, 1))
+
         response = await self.handler.on_claim_client_keys(
-            origin, content, always_include_fallback_keys=False
+            key_query, always_include_fallback_keys=False
         )
         return 200, response
 
 
 class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
     """
-    Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
-    always includes fallback keys in the response.
+    Identical to the stable endpoint (FederationClientKeysClaimServlet) except
+    it allows for querying for multiple OTKs at once and always includes fallback
+    keys in the response.
     """
 
     PREFIX = FEDERATION_UNSTABLE_PREFIX
@@ -596,8 +604,16 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
     async def on_POST(
         self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
     ) -> Tuple[int, JsonDict]:
+        # Generate a count for each algorithm.
+        key_query: List[Tuple[str, str, str, int]] = []
+        for user_id, device_keys in content.get("one_time_keys", {}).items():
+            for device_id, algorithms in device_keys.items():
+                counts = Counter(algorithms)
+                for algorithm, count in counts.items():
+                    key_query.append((user_id, device_id, algorithm, count))
+
         response = await self.handler.on_claim_client_keys(
-            origin, content, always_include_fallback_keys=True
+            key_query, always_include_fallback_keys=True
         )
         return 200, response
 
@@ -805,6 +821,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
     FederationClientKeysQueryServlet,
     FederationUserDevicesQueryServlet,
     FederationClientKeysClaimServlet,
+    FederationUnstableClientKeysClaimServlet,
     FederationThirdPartyInviteExchangeServlet,
     On3pidBindServlet,
     FederationVersionServlet,