diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index e2340d70d5..36b0362504 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from collections import Counter
from typing import (
TYPE_CHECKING,
Dict,
@@ -577,16 +578,23 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
+ # Generate a count for each algorithm, which is hard-coded to 1.
+ key_query: List[Tuple[str, str, str, int]] = []
+ for user_id, device_keys in content.get("one_time_keys", {}).items():
+ for device_id, algorithm in device_keys.items():
+ key_query.append((user_id, device_id, algorithm, 1))
+
response = await self.handler.on_claim_client_keys(
- origin, content, always_include_fallback_keys=False
+ key_query, always_include_fallback_keys=False
)
return 200, response
class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
"""
- Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
- always includes fallback keys in the response.
+ Identical to the stable endpoint (FederationClientKeysClaimServlet) except
+ it allows for querying for multiple OTKs at once and always includes fallback
+ keys in the response.
"""
PREFIX = FEDERATION_UNSTABLE_PREFIX
@@ -596,8 +604,16 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
+ # Generate a count for each algorithm.
+ key_query: List[Tuple[str, str, str, int]] = []
+ for user_id, device_keys in content.get("one_time_keys", {}).items():
+ for device_id, algorithms in device_keys.items():
+ counts = Counter(algorithms)
+ for algorithm, count in counts.items():
+ key_query.append((user_id, device_id, algorithm, count))
+
response = await self.handler.on_claim_client_keys(
- origin, content, always_include_fallback_keys=True
+ key_query, always_include_fallback_keys=True
)
return 200, response
@@ -805,6 +821,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationClientKeysQueryServlet,
FederationUserDevicesQueryServlet,
FederationClientKeysClaimServlet,
+ FederationUnstableClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
FederationVersionServlet,
|