diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/handlers/test_e2e_keys.py | 96 | ||||
-rw-r--r-- | tests/test_utils/__init__.py | 4 |
2 files changed, 98 insertions, 2 deletions
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index ddcf3ee348..734ed84d78 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -13,8 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Iterable from unittest import mock +from parameterized import parameterized from signedjson import key as key, sign as sign from twisted.internet import defer @@ -23,6 +25,7 @@ from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import Codes, SynapseError from tests import unittest +from tests.test_utils import make_awaitable class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): @@ -765,6 +768,8 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): remote_user_id = "@test:other" local_user_id = "@test:test" + # Pretend we're sharing a room with the user we're querying. If not, + # `_query_devices_for_destination` will return early. self.store.get_rooms_for_user = mock.Mock( return_value=defer.succeed({"some_room_id"}) ) @@ -831,3 +836,94 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): } }, ) + + @parameterized.expand( + [ + # The remote homeserver's response indicates that this user has 0/1/2 devices. + ([],), + (["device_1"],), + (["device_1", "device_2"],), + ] + ) + def test_query_all_devices_caches_result(self, device_ids: Iterable[str]): + """Test that requests for all of a remote user's devices are cached. + + We do this by asserting that only one call over federation was made, and that + the two queries to the local homeserver produce the same response. + """ + local_user_id = "@test:test" + remote_user_id = "@test:other" + request_body = {"device_keys": {remote_user_id: []}} + + response_devices = [ + { + "device_id": device_id, + "keys": { + "algorithms": ["dummy"], + "device_id": device_id, + "keys": {f"dummy:{device_id}": "dummy"}, + "signatures": {device_id: {f"dummy:{device_id}": "dummy"}}, + "unsigned": {}, + "user_id": "@test:other", + }, + } + for device_id in device_ids + ] + + response_body = { + "devices": response_devices, + "user_id": remote_user_id, + "stream_id": 12345, # an integer, according to the spec + } + + e2e_handler = self.hs.get_e2e_keys_handler() + + # Pretend we're sharing a room with the user we're querying. If not, + # `_query_devices_for_destination` will return early. + mock_get_rooms = mock.patch.object( + self.store, + "get_rooms_for_user", + new_callable=mock.MagicMock, + return_value=make_awaitable(["some_room_id"]), + ) + mock_request = mock.patch.object( + self.hs.get_federation_client(), + "query_user_devices", + new_callable=mock.MagicMock, + return_value=make_awaitable(response_body), + ) + + with mock_get_rooms, mock_request as mocked_federation_request: + # Make the first query and sanity check it succeeds. + response_1 = self.get_success( + e2e_handler.query_devices( + request_body, + timeout=10, + from_user_id=local_user_id, + from_device_id="some_device_id", + ) + ) + self.assertEqual(response_1["failures"], {}) + + # We should have made a federation request to do so. + mocked_federation_request.assert_called_once() + + # Reset the mock so we can prove we don't make a second federation request. + mocked_federation_request.reset_mock() + + # Repeat the query. + response_2 = self.get_success( + e2e_handler.query_devices( + request_body, + timeout=10, + from_user_id=local_user_id, + from_device_id="some_device_id", + ) + ) + self.assertEqual(response_2["failures"], {}) + + # We should not have made a second federation request. + mocked_federation_request.assert_not_called() + + # The two requests to the local homeserver should be identical. + self.assertEqual(response_1, response_2) diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 15ac2bfeba..f05a373aa0 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -19,7 +19,7 @@ import sys import warnings from asyncio import Future from binascii import unhexlify -from typing import Any, Awaitable, Callable, TypeVar +from typing import Awaitable, Callable, TypeVar from unittest.mock import Mock import attr @@ -46,7 +46,7 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: raise Exception("awaitable has not yet completed") -def make_awaitable(result: Any) -> Awaitable[Any]: +def make_awaitable(result: TV) -> Awaitable[TV]: """ Makes an awaitable, suitable for mocking an `async` function. This uses Futures as they can be awaited multiple times so can be returned |