diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index ba34573d46..0b2d1a78f7 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -235,7 +235,10 @@ class FederationClient(FederationBase):
)
async def claim_client_keys(
- self, destination: str, content: JsonDict, timeout: Optional[int]
+ self,
+ destination: str,
+ query: Dict[str, Dict[str, Dict[str, int]]],
+ timeout: Optional[int],
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.
@@ -247,6 +250,50 @@ class FederationClient(FederationBase):
The JSON object from the response
"""
sent_queries_counter.labels("client_one_time_keys").inc()
+
+ # Convert the query with counts into a stable and unstable query and check
+ # if attempting to claim more than 1 OTK.
+ content: Dict[str, Dict[str, str]] = {}
+ unstable_content: Dict[str, Dict[str, List[str]]] = {}
+ use_unstable = False
+ for user_id, one_time_keys in query.items():
+ for device_id, algorithms in one_time_keys.items():
+ if any(count > 1 for count in algorithms.values()):
+ use_unstable = True
+ if algorithms:
+ # For the stable query, choose only the first algorithm.
+ content.setdefault(user_id, {})[device_id] = next(iter(algorithms))
+ # For the unstable query, repeat each algorithm by count, then
+ # splat those into chain to get a flattened list of all algorithms.
+ #
+ # Converts from {"algo1": 2, "algo2": 2} to ["algo1", "algo1", "algo2"].
+ unstable_content.setdefault(user_id, {})[device_id] = list(
+ itertools.chain(
+ *(
+ itertools.repeat(algorithm, count)
+ for algorithm, count in algorithms.items()
+ )
+ )
+ )
+
+ if use_unstable:
+ try:
+ return await self.transport_layer.claim_client_keys_unstable(
+ destination, unstable_content, timeout
+ )
+ except HttpResponseException as e:
+ # If an error is received that is due to an unrecognised endpoint,
+ # fallback to the v1 endpoint. Otherwise, consider it a legitimate error
+ # and raise.
+ if not is_unknown_endpoint(e):
+ raise
+
+ logger.debug(
+ "Couldn't claim client keys with the unstable API, falling back to the v1 API"
+ )
+ else:
+ logger.debug("Skipping unstable claim client keys API")
+
return await self.transport_layer.claim_client_keys(
destination, content, timeout
)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index c618f3d7a6..ca43c7bfc0 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1005,13 +1005,8 @@ class FederationServer(FederationBase):
@trace
async def on_claim_client_keys(
- self, origin: str, content: JsonDict, always_include_fallback_keys: bool
+ self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool
) -> Dict[str, Any]:
- query = []
- for user_id, device_keys in content.get("one_time_keys", {}).items():
- for device_id, algorithm in device_keys.items():
- query.append((user_id, device_id, algorithm))
-
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self._e2e_keys_handler.claim_local_one_time_keys(
query, always_include_fallback_keys=always_include_fallback_keys
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index bedbd23ded..bc70b94f68 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -650,10 +650,10 @@ class TransportLayerClient:
Response:
{
- "device_keys": {
+ "one_time_keys": {
"<user_id>": {
"<device_id>": {
- "<algorithm>:<key_id>": "<key_base64>"
+ "<algorithm>:<key_id>": <OTK JSON>
}
}
}
@@ -669,7 +669,50 @@ class TransportLayerClient:
path = _create_v1_path("/user/keys/claim")
return await self.client.post_json(
- destination=destination, path=path, data=query_content, timeout=timeout
+ destination=destination,
+ path=path,
+ data={"one_time_keys": query_content},
+ timeout=timeout,
+ )
+
+ async def claim_client_keys_unstable(
+ self, destination: str, query_content: JsonDict, timeout: Optional[int]
+ ) -> JsonDict:
+ """Claim one-time keys for a list of devices hosted on a remote server.
+
+ Request:
+ {
+ "one_time_keys": {
+ "<user_id>": {
+ "<device_id>": {"<algorithm>": <count>}
+ }
+ }
+ }
+
+ Response:
+ {
+ "one_time_keys": {
+ "<user_id>": {
+ "<device_id>": {
+ "<algorithm>:<key_id>": <OTK JSON>
+ }
+ }
+ }
+ }
+
+ Args:
+ destination: The server to query.
+ query_content: The user ids to query.
+ Returns:
+ A dict containing the one-time keys.
+ """
+ path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/user/keys/claim")
+
+ return await self.client.post_json(
+ destination=destination,
+ path=path,
+ data={"one_time_keys": query_content},
+ timeout=timeout,
)
async def get_missing_events(
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index e2340d70d5..36b0362504 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from collections import Counter
from typing import (
TYPE_CHECKING,
Dict,
@@ -577,16 +578,23 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
+ # Generate a count for each algorithm, which is hard-coded to 1.
+ key_query: List[Tuple[str, str, str, int]] = []
+ for user_id, device_keys in content.get("one_time_keys", {}).items():
+ for device_id, algorithm in device_keys.items():
+ key_query.append((user_id, device_id, algorithm, 1))
+
response = await self.handler.on_claim_client_keys(
- origin, content, always_include_fallback_keys=False
+ key_query, always_include_fallback_keys=False
)
return 200, response
class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
"""
- Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
- always includes fallback keys in the response.
+ Identical to the stable endpoint (FederationClientKeysClaimServlet) except
+ it allows for querying for multiple OTKs at once and always includes fallback
+ keys in the response.
"""
PREFIX = FEDERATION_UNSTABLE_PREFIX
@@ -596,8 +604,16 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
+ # Generate a count for each algorithm.
+ key_query: List[Tuple[str, str, str, int]] = []
+ for user_id, device_keys in content.get("one_time_keys", {}).items():
+ for device_id, algorithms in device_keys.items():
+ counts = Counter(algorithms)
+ for algorithm, count in counts.items():
+ key_query.append((user_id, device_id, algorithm, count))
+
response = await self.handler.on_claim_client_keys(
- origin, content, always_include_fallback_keys=True
+ key_query, always_include_fallback_keys=True
)
return 200, response
@@ -805,6 +821,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationClientKeysQueryServlet,
FederationUserDevicesQueryServlet,
FederationClientKeysClaimServlet,
+ FederationUnstableClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
FederationVersionServlet,
|