diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index ddcf3ee348..734ed84d78 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -13,8 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Iterable
from unittest import mock
+from parameterized import parameterized
from signedjson import key as key, sign as sign
from twisted.internet import defer
@@ -23,6 +25,7 @@ from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
from tests import unittest
+from tests.test_utils import make_awaitable
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
@@ -765,6 +768,8 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_user_id = "@test:other"
local_user_id = "@test:test"
+ # Pretend we're sharing a room with the user we're querying. If not,
+ # `_query_devices_for_destination` will return early.
self.store.get_rooms_for_user = mock.Mock(
return_value=defer.succeed({"some_room_id"})
)
@@ -831,3 +836,94 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
}
},
)
+
+ @parameterized.expand(
+ [
+ # The remote homeserver's response indicates that this user has 0/1/2 devices.
+ ([],),
+ (["device_1"],),
+ (["device_1", "device_2"],),
+ ]
+ )
+ def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
+ """Test that requests for all of a remote user's devices are cached.
+
+ We do this by asserting that only one call over federation was made, and that
+ the two queries to the local homeserver produce the same response.
+ """
+ local_user_id = "@test:test"
+ remote_user_id = "@test:other"
+ request_body = {"device_keys": {remote_user_id: []}}
+
+ response_devices = [
+ {
+ "device_id": device_id,
+ "keys": {
+ "algorithms": ["dummy"],
+ "device_id": device_id,
+ "keys": {f"dummy:{device_id}": "dummy"},
+ "signatures": {device_id: {f"dummy:{device_id}": "dummy"}},
+ "unsigned": {},
+ "user_id": "@test:other",
+ },
+ }
+ for device_id in device_ids
+ ]
+
+ response_body = {
+ "devices": response_devices,
+ "user_id": remote_user_id,
+ "stream_id": 12345, # an integer, according to the spec
+ }
+
+ e2e_handler = self.hs.get_e2e_keys_handler()
+
+ # Pretend we're sharing a room with the user we're querying. If not,
+ # `_query_devices_for_destination` will return early.
+ mock_get_rooms = mock.patch.object(
+ self.store,
+ "get_rooms_for_user",
+ new_callable=mock.MagicMock,
+ return_value=make_awaitable(["some_room_id"]),
+ )
+ mock_request = mock.patch.object(
+ self.hs.get_federation_client(),
+ "query_user_devices",
+ new_callable=mock.MagicMock,
+ return_value=make_awaitable(response_body),
+ )
+
+ with mock_get_rooms, mock_request as mocked_federation_request:
+ # Make the first query and sanity check it succeeds.
+ response_1 = self.get_success(
+ e2e_handler.query_devices(
+ request_body,
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+ self.assertEqual(response_1["failures"], {})
+
+ # We should have made a federation request to do so.
+ mocked_federation_request.assert_called_once()
+
+ # Reset the mock so we can prove we don't make a second federation request.
+ mocked_federation_request.reset_mock()
+
+ # Repeat the query.
+ response_2 = self.get_success(
+ e2e_handler.query_devices(
+ request_body,
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+ self.assertEqual(response_2["failures"], {})
+
+ # We should not have made a second federation request.
+ mocked_federation_request.assert_not_called()
+
+ # The two requests to the local homeserver should be identical.
+ self.assertEqual(response_1, response_2)
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 15ac2bfeba..f05a373aa0 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -19,7 +19,7 @@ import sys
import warnings
from asyncio import Future
from binascii import unhexlify
-from typing import Any, Awaitable, Callable, TypeVar
+from typing import Awaitable, Callable, TypeVar
from unittest.mock import Mock
import attr
@@ -46,7 +46,7 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
raise Exception("awaitable has not yet completed")
-def make_awaitable(result: Any) -> Awaitable[Any]:
+def make_awaitable(result: TV) -> Awaitable[TV]:
"""
Makes an awaitable, suitable for mocking an `async` function.
This uses Futures as they can be awaited multiple times so can be returned
|