diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 076b9287c6..a2cf3a96c6 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -236,6 +236,7 @@ class FederationClient(FederationBase):
async def claim_client_keys(
self,
+ user: UserID,
destination: str,
query: Dict[str, Dict[str, Dict[str, int]]],
timeout: Optional[int],
@@ -243,6 +244,7 @@ class FederationClient(FederationBase):
"""Claims one-time keys for a device hosted on a remote server.
Args:
+ user: The user id of the requesting user
destination: Domain name of the remote homeserver
content: The query content.
@@ -279,7 +281,7 @@ class FederationClient(FederationBase):
if use_unstable:
try:
return await self.transport_layer.claim_client_keys_unstable(
- destination, unstable_content, timeout
+ user, destination, unstable_content, timeout
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
@@ -295,7 +297,7 @@ class FederationClient(FederationBase):
logger.debug("Skipping unstable claim client keys API")
return await self.transport_layer.claim_client_keys(
- destination, content, timeout
+ user, destination, content, timeout
)
@trace
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 1cfc4446c4..0b17f713ea 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -45,7 +45,7 @@ from synapse.events import EventBase, make_event_from_dict
from synapse.federation.units import Transaction
from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser
from synapse.http.types import QueryParams
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
from synapse.util import ExceptionBundle
if TYPE_CHECKING:
@@ -630,7 +630,11 @@ class TransportLayerClient:
)
async def claim_client_keys(
- self, destination: str, query_content: JsonDict, timeout: Optional[int]
+ self,
+ user: UserID,
+ destination: str,
+ query_content: JsonDict,
+ timeout: Optional[int],
) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.
@@ -655,6 +659,7 @@ class TransportLayerClient:
}
Args:
+ user: the user_id of the requesting user
destination: The server to query.
query_content: The user ids to query.
Returns:
@@ -671,7 +676,11 @@ class TransportLayerClient:
)
async def claim_client_keys_unstable(
- self, destination: str, query_content: JsonDict, timeout: Optional[int]
+ self,
+ user: UserID,
+ destination: str,
+ query_content: JsonDict,
+ timeout: Optional[int],
) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.
@@ -696,6 +705,7 @@ class TransportLayerClient:
}
Args:
+ user: the user_id of the requesting user
destination: The server to query.
query_content: The user ids to query.
Returns:
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 24741b667b..ad075497c8 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -661,6 +661,7 @@ class E2eKeysHandler:
async def claim_one_time_keys(
self,
query: Dict[str, Dict[str, Dict[str, int]]],
+ user: UserID,
timeout: Optional[int],
always_include_fallback_keys: bool,
) -> JsonDict:
@@ -703,7 +704,7 @@ class E2eKeysHandler:
device_keys = remote_queries[destination]
try:
remote_result = await self.federation.claim_client_keys(
- destination, device_keys, timeout=timeout
+ user, destination, device_keys, timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 9bbab5e624..413edd8a4d 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -287,7 +287,7 @@ class OneTimeKeyServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- await self.auth.get_user_by_req(request, allow_guest=True)
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
@@ -298,7 +298,7 @@ class OneTimeKeyServlet(RestServlet):
query.setdefault(user_id, {})[device_id] = {algorithm: 1}
result = await self.e2e_keys_handler.claim_one_time_keys(
- query, timeout, always_include_fallback_keys=False
+ query, requester.user, timeout, always_include_fallback_keys=False
)
return 200, result
@@ -335,7 +335,7 @@ class UnstableOneTimeKeyServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- await self.auth.get_user_by_req(request, allow_guest=True)
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
@@ -346,7 +346,7 @@ class UnstableOneTimeKeyServlet(RestServlet):
query.setdefault(user_id, {})[device_id] = Counter(algorithms)
result = await self.e2e_keys_handler.claim_one_time_keys(
- query, timeout, always_include_fallback_keys=True
+ query, requester.user, timeout, always_include_fallback_keys=True
)
return 200, result
|