summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/15663.misc1
-rw-r--r--synapse/federation/federation_client.py6
-rw-r--r--synapse/federation/transport/client.py16
-rw-r--r--synapse/handlers/e2e_keys.py3
-rw-r--r--synapse/rest/client/keys.py8
-rw-r--r--tests/handlers/test_e2e_keys.py16
6 files changed, 39 insertions, 11 deletions
diff --git a/changelog.d/15663.misc b/changelog.d/15663.misc
new file mode 100644
index 0000000000..cc5f801543
--- /dev/null
+++ b/changelog.d/15663.misc
@@ -0,0 +1 @@
+Add requesting user id parameter to key claim methods in `TransportLayerClient`.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 076b9287c6..a2cf3a96c6 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -236,6 +236,7 @@ class FederationClient(FederationBase):
 
     async def claim_client_keys(
         self,
+        user: UserID,
         destination: str,
         query: Dict[str, Dict[str, Dict[str, int]]],
         timeout: Optional[int],
@@ -243,6 +244,7 @@ class FederationClient(FederationBase):
         """Claims one-time keys for a device hosted on a remote server.
 
         Args:
+            user: The user id of the requesting user
             destination: Domain name of the remote homeserver
             content: The query content.
 
@@ -279,7 +281,7 @@ class FederationClient(FederationBase):
         if use_unstable:
             try:
                 return await self.transport_layer.claim_client_keys_unstable(
-                    destination, unstable_content, timeout
+                    user, destination, unstable_content, timeout
                 )
             except HttpResponseException as e:
                 # If an error is received that is due to an unrecognised endpoint,
@@ -295,7 +297,7 @@ class FederationClient(FederationBase):
             logger.debug("Skipping unstable claim client keys API")
 
         return await self.transport_layer.claim_client_keys(
-            destination, content, timeout
+            user, destination, content, timeout
         )
 
     @trace
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 1cfc4446c4..0b17f713ea 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -45,7 +45,7 @@ from synapse.events import EventBase, make_event_from_dict
 from synapse.federation.units import Transaction
 from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser
 from synapse.http.types import QueryParams
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
 from synapse.util import ExceptionBundle
 
 if TYPE_CHECKING:
@@ -630,7 +630,11 @@ class TransportLayerClient:
         )
 
     async def claim_client_keys(
-        self, destination: str, query_content: JsonDict, timeout: Optional[int]
+        self,
+        user: UserID,
+        destination: str,
+        query_content: JsonDict,
+        timeout: Optional[int],
     ) -> JsonDict:
         """Claim one-time keys for a list of devices hosted on a remote server.
 
@@ -655,6 +659,7 @@ class TransportLayerClient:
             }
 
         Args:
+            user: the user_id of the requesting user
             destination: The server to query.
             query_content: The user ids to query.
         Returns:
@@ -671,7 +676,11 @@ class TransportLayerClient:
         )
 
     async def claim_client_keys_unstable(
-        self, destination: str, query_content: JsonDict, timeout: Optional[int]
+        self,
+        user: UserID,
+        destination: str,
+        query_content: JsonDict,
+        timeout: Optional[int],
     ) -> JsonDict:
         """Claim one-time keys for a list of devices hosted on a remote server.
 
@@ -696,6 +705,7 @@ class TransportLayerClient:
             }
 
         Args:
+            user: the user_id of the requesting user
             destination: The server to query.
             query_content: The user ids to query.
         Returns:
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 24741b667b..ad075497c8 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -661,6 +661,7 @@ class E2eKeysHandler:
     async def claim_one_time_keys(
         self,
         query: Dict[str, Dict[str, Dict[str, int]]],
+        user: UserID,
         timeout: Optional[int],
         always_include_fallback_keys: bool,
     ) -> JsonDict:
@@ -703,7 +704,7 @@ class E2eKeysHandler:
             device_keys = remote_queries[destination]
             try:
                 remote_result = await self.federation.claim_client_keys(
-                    destination, device_keys, timeout=timeout
+                    user, destination, device_keys, timeout=timeout
                 )
                 for user_id, keys in remote_result["one_time_keys"].items():
                     if user_id in device_keys:
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 9bbab5e624..413edd8a4d 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -287,7 +287,7 @@ class OneTimeKeyServlet(RestServlet):
         self.e2e_keys_handler = hs.get_e2e_keys_handler()
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        await self.auth.get_user_by_req(request, allow_guest=True)
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
         timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
 
@@ -298,7 +298,7 @@ class OneTimeKeyServlet(RestServlet):
                 query.setdefault(user_id, {})[device_id] = {algorithm: 1}
 
         result = await self.e2e_keys_handler.claim_one_time_keys(
-            query, timeout, always_include_fallback_keys=False
+            query, requester.user, timeout, always_include_fallback_keys=False
         )
         return 200, result
 
@@ -335,7 +335,7 @@ class UnstableOneTimeKeyServlet(RestServlet):
         self.e2e_keys_handler = hs.get_e2e_keys_handler()
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        await self.auth.get_user_by_req(request, allow_guest=True)
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
         timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
 
@@ -346,7 +346,7 @@ class UnstableOneTimeKeyServlet(RestServlet):
                 query.setdefault(user_id, {})[device_id] = Counter(algorithms)
 
         result = await self.e2e_keys_handler.claim_one_time_keys(
-            query, timeout, always_include_fallback_keys=True
+            query, requester.user, timeout, always_include_fallback_keys=True
         )
         return 200, result
 
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 72d0584061..2eaffe511e 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -27,7 +27,7 @@ from synapse.appservice import ApplicationService
 from synapse.handlers.device import DeviceHandler
 from synapse.server import HomeServer
 from synapse.storage.databases.main.appservice import _make_exclusive_regex
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
 from synapse.util import Clock
 
 from tests import unittest
@@ -45,6 +45,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.handler = hs.get_e2e_keys_handler()
         self.store = self.hs.get_datastores().main
+        self.requester = UserID.from_string(f"@test_requester:{self.hs.hostname}")
 
     def test_query_local_devices_no_devices(self) -> None:
         """If the user has no devices, we expect an empty list."""
@@ -161,6 +162,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         res2 = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=False,
             )
@@ -206,6 +208,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=False,
             )
@@ -225,6 +228,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=False,
             )
@@ -274,6 +278,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=False,
             )
@@ -286,6 +291,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=False,
             )
@@ -307,6 +313,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=False,
             )
@@ -348,6 +355,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=True,
             )
@@ -370,6 +378,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=True,
             )
@@ -1080,6 +1089,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id_1: {"alg1": 1}, device_id_2: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=False,
             )
@@ -1125,6 +1135,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id_1: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=True,
             )
@@ -1169,6 +1180,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id_1: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=True,
             )
@@ -1202,6 +1214,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id_1: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=True,
             )
@@ -1229,6 +1242,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {local_user: {device_id_1: {"alg1": 1}}},
+                self.requester,
                 timeout=None,
                 always_include_fallback_keys=True,
             )